In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from svtr.data_pipeline.mnist import ConcatenatedMNISTDataset
from svtr.model import model
from svtr.model.ctc_decoder import CTCDecoder

device = 'cpu'

## Load a batch from the test split

In [None]:
test_dataset = ConcatenatedMNISTDataset(num_digits=5, train=False, device=device)
test_loader = DataLoader(
    dataset=test_dataset, 
    batch_size=8, 
    shuffle=False
)

for images, labels in test_loader:
    break 

In [None]:
cols = 2
rows = images.shape[0] // cols
fig, axes = plt.subplots(rows, cols, figsize=(10, 6))

axes = np.ravel(axes)
for i in range(len(axes)):
    ax = axes[i]
    ax.imshow(images[i,0])
    ax.set_title(labels[i].numpy());

## plot the metrics for the trained model

In [None]:
df_metrics = pd.read_csv('../experiments/svtr_tiny/metrics.csv')
metrics = ['loss', 'ned', 'acc', 'lr']

fig, axes = plt.subplots(1, len(metrics), figsize=(15, 3))
axes = np.ravel(axes)

for i, metric in enumerate(metrics):
    ax = axes[i]
    ax.set_title(metric, fontsize=14)
    if metric == 'lr':
        ax.plot(df_metrics[metric], c='black', alpha=0.7, label='lr')
    else:
        ax.plot(df_metrics[f'train_{metric}'], c='orange', alpha=0.7, label='train')
        ax.plot(df_metrics[f'val_{metric}'], c='blue', alpha=0.7, label='val')
    ax.grid(ls='--', lw=0.5, c='black', alpha=0.4)
    ax.legend()

## Run inference

In [None]:
# load the model from a checkpoint
svtr = model.load_model('../experiments/svtr_tiny/checkpoints/svtr_ep04.pth')
svtr = svtr.eval()
svtr = svtr.to(device)
# create corresponding decoder
decoder = CTCDecoder(vocab=test_dataset.vocab)

In [None]:
out = svtr(images)
out.shape

In [None]:
transcript_indices, scores = decoder(out, to_text=False)

In [None]:
# check correctness of predictions
np.array(transcript_indices) == labels.numpy()

In [None]:
transcripts, scores = decoder(out, to_text=True)
transcripts