import libs

In [None]:
import torch
from torch.utils.data import DataLoader
from src.dataset import IAMDataset, transform
from src.model import DTrOCR_RNNT
from src.train import train, evaluate
from src.utils import decode_prediction
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

set device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

load dataset

In [None]:
train_dataset = IAMDataset(root_dir='data/iam', split_file='data/iam/trainset.txt', transform=transform)
test_dataset = IAMDataset(root_dir='data/iam', split_file='data/iam/testset.txt', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

initialize model

In [None]:
model = DTrOCR_RNNT(vocab_size=len(train_dataset.char_to_idx)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train loop

In [None]:
epochs = 5  # Reduced for demo
for epoch in range(epochs):
    train_loss = train(model, train_loader, optimizer, device)
    test_loss = evaluate(model, test_loader, device)
    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")

save model

In [None]:
torch.save(model.state_dict(), 'model.pth')

test

In [None]:
model.load_state_dict(torch.load('model.pth'))
model.eval()

# Test inference
with torch.no_grad():
    for imgs, targets, target_lengths in test_loader:
        imgs = imgs.to(device)
        logits = model(imgs)  # Inference mode
        predictions = decode_prediction(logits, train_dataset.idx_to_char)
        
        # Visualize one sample
        img = imgs[0].cpu().squeeze().numpy()
        plt.imshow(img, cmap='gray')
        plt.title(f"Prediction: {predictions[0]}")
        plt.axis('off')
        plt.show()
        break  # Show only one example