In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

import warnings
warnings.simplefilter('ignore')

import gc

from os import path
import sys
sys.path.append(path.abspath('..'))

In [None]:
from torchvision.datasets import MNIST
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import itertools
import operator
import pandas as pd
from tqdm.notebook import tqdm

from src.dataset import BarCodeDataset
from src.lightning_module import OCRModule
from src.augmentations import get_transforms
from src.predict_utils import matrix_to_string

from onnxruntime import InferenceSession

In [None]:
DEVICE = 'cpu:0'
VOCAB = '0123456789'

In [None]:
transforms = get_transforms(
    width=416, height=96, text_size=13, vocab=VOCAB, postprocessing=True, augmentations=False
)

In [None]:
df = pd.read_csv('../data/df_valid.csv')
dataset = BarCodeDataset(df=df, data_folder='../data')
len(dataset)

In [None]:
model = InferenceSession('../experiments/exp1/ocr.onnx')

In [None]:
from scipy.special import softmax

In [None]:
softmax(model.run(None, {'input': [transformed_image]})[0].transpose(1, 0, 2)).argmax(axis=2)

In [None]:
gt_texts = []
pr_texts = []

for i in tqdm(range(len(dataset))):
    image, text, _ = dataset[i]

    transformed_image = transforms(image=image, text='')['image']
    predict = model.run(None, {'input': [transformed_image]})[0]
    predict = torch.as_tensor(predict)
    string_pred, _ = matrix_to_string(predict, VOCAB)

    gt_texts.append(text)
    pr_texts.append(string_pred[0])

gt_texts = np.array(gt_texts)
pr_texts = np.array(pr_texts)

In [None]:
def postprocess(image: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    probas = softmax(x=image.transpose(1, 0, 2))
    return probas.argmax(axis=2), probas.max(axis=2)

In [None]:
predict = model.run(None, {'input': [transformed_image]})[0]
predict.shape

In [None]:
labels, confidences = postprocess(predict)
labels.shape, confidences.shape

In [None]:
labels[0].shape, confidences[0].shape

In [None]:
import operator
import itertools

In [None]:
decode(labels, confidences)

In [None]:
list(itertools.groupby(zip(label, confidence), operator.itemgetter(0)))

In [None]:
print(f'accuracy = {np.mean(gt_texts == pr_texts)}')

In [None]:
# ошибочные индексы
np.where(gt_texts != pr_texts)[0]

In [None]:
idx = 45
image, text, _ = dataset[idx]
print(f'pred = {pr_texts[idx]}')
print(f'true = {gt_texts[idx]}')
Image.fromarray(image)