In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from svtr.data_pipeline.mnist import ConcatenatedMNISTDataset
from svtr.model import utils
from svtr.model.ctc_decoder import CTCDecoder
from svtr.model.training import evaluate_metrics
from svtr.model.ctc_loss import CTCLoss
from svtr.model.metrics import NormalizedEditDistance

device = 'cpu'

## Load a batch from the test split

In [None]:
test_dataset_5 = ConcatenatedMNISTDataset(num_digits=5, train=False, device=device)
test_loader_5 = DataLoader(
    dataset=test_dataset_5, 
    batch_size=8, 
    shuffle=False
)

for images, labels in test_loader_5:
    break 

In [None]:
cols = 2
rows = images.shape[0] // cols
fig, axes = plt.subplots(rows, cols, figsize=(10, 6))

axes = np.ravel(axes)
for i in range(len(axes)):
    ax = axes[i]
    ax.imshow(images[i,0])
    ax.set_title(labels[i].numpy());

## Plot the metrics for the trained model

In [None]:
df_metrics = pd.read_csv('../experiments/svtr_small/metrics.csv')
metrics = ['loss', 'ned', 'acc', 'lr']

fig, axes = plt.subplots(1, len(metrics), figsize=(15, 3))
axes = np.ravel(axes)

for i, metric in enumerate(metrics):
    ax = axes[i]
    ax.set_title(metric, fontsize=14)
    if metric == 'lr':
        ax.plot(df_metrics[metric], c='black', alpha=0.7, label='lr')
    else:
        ax.plot(df_metrics[f'train_{metric}'], c='orange', alpha=0.7, label='train')
        ax.plot(df_metrics[f'val_{metric}'], c='blue', alpha=0.7, label='val')
    ax.grid(ls='--', lw=0.5, c='black', alpha=0.4)
    ax.legend()

In [None]:
# load the model from a checkpoint
svtr = utils.load_model('../experiments/model_small/checkpoints/ckpt_ep07.pth')
svtr = svtr.eval()
svtr = svtr.to(device)
# create corresponding decoder
decoder = CTCDecoder(vocab=test_dataset_5.vocab)

## Run inference on in domain image width (5 characters)

In [None]:
ctc_loss = CTCLoss(blank=0)
normalized_edit_distance = NormalizedEditDistance(decoder)
evaluate_metrics(svtr, test_loader_5, ctc_loss, normalized_edit_distance)
print(f"Loss: {ctc_loss.compute():.4f}")
print(f"ned/acc: {normalized_edit_distance.ned_result():.4f}/{normalized_edit_distance.acc_result()*100:.2f}")

In [None]:
out = svtr(images)
out.shape

In [None]:
transcript_indices, scores = decoder(out, to_text=False)

In [None]:
# check correctness of predictions
[list(t) == list(l) for t, l in zip(transcript_indices, labels)]

In [None]:
transcripts, scores = decoder(out, to_text=True)
transcripts

## Inference on different input width (10 characters)

The normalized edit distance should be the same as with 5 character input widths if the model generalizes well.

In [None]:
test_dataset_10 = ConcatenatedMNISTDataset(num_digits=10, train=False, device=device)
test_loader_10 = DataLoader(
    dataset=test_dataset_10, 
    batch_size=8, 
    shuffle=False
)

for images, labels in test_loader_10:
    break 

In [None]:
cols = 2
rows = images.shape[0] // cols
fig, axes = plt.subplots(rows, cols, figsize=(15, 6))

axes = np.ravel(axes)
for i in range(len(axes)):
    ax = axes[i]
    ax.imshow(images[i,0])
    ax.set_title(labels[i].numpy());

In [None]:
ctc_loss = CTCLoss(blank=0)
normalized_edit_distance = NormalizedEditDistance(decoder)
evaluate_metrics(svtr, test_loader_10, ctc_loss, normalized_edit_distance)
print(f"Loss: {ctc_loss.compute():.4f}")
print(f"ned/acc: {normalized_edit_distance.ned_result():.4f}/{normalized_edit_distance.acc_result()*100:.2f}")

In [None]:
out = svtr(images)
out.shape

In [None]:
transcript_indices, scores = decoder(out, to_text=False)

In [None]:
# check correctness of predictions
[list(t) == list(l) for t, l in zip(transcript_indices, labels)]

In [None]:
transcripts, scores = decoder(out, to_text=True)
transcripts

## Optionally evaluate CRNN

In [None]:
crnn = utils.load_model('../experiments/model_crnn/checkpoints/ckpt_ep07.pth')
crnn = crnn.eval()
crnn = crnn.to(device)

In [None]:
ctc_loss = CTCLoss(blank=0)
normalized_edit_distance = NormalizedEditDistance(decoder)
evaluate_metrics(crnn, test_loader_10, ctc_loss, normalized_edit_distance)
print(f"Loss: {ctc_loss.compute():.4f}")
print(f"ned/acc: {normalized_edit_distance.ned_result():.4f}/{normalized_edit_distance.acc_result()*100:.2f}")