In [None]:
!pip install torch>=1.9.0 torchvision>=0.10.0 numpy>=1.20.0 matplotlib>=3.3.0 seaborn>=0.11.0 opencv-python>=4.5.0 pillow>=8.0.0 scikit-learn>=1.0.0 tqdm>=4.60.0

In [None]:
import math, random, string, time
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms


In [None]:
import math, random, string, time
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

class MnistSequenceDataset(Dataset):
    def __init__(self, root, train=True, seq_min=2, seq_max=6, samples=10000, img_h=32):
        self.seq_min, self.seq_max = seq_min, seq_max
        self.img_h = img_h
        base = datasets.MNIST(root, train=train, download=True, transform=transforms.ToTensor())
        # bucket MNIST digits for quick sampling
        self.by_digit = {d: [] for d in range(10)}
        for x,y in base:
            self.by_digit[int(y)].append(x)  # x: [1,28,28]
        self.samples = samples

    def _rand_digit_img(self, d):
        x = random.choice(self.by_digit[d])  # [1,28,28]
        # augment: random affine + resize to fixed height, keep aspect
        aug = transforms.Compose([
            transforms.RandomAffine(degrees=8, translate=(0.1,0.1), scale=(0.9,1.1), shear=5),
        ])
        x = aug(x)
        # resize to fixed height, keep width (scale factor)
        _, h, w = x.shape
        scale = self.img_h / h
        new_w = int(round(w * scale))
        x = F.interpolate(x.unsqueeze(0), size=(self.img_h, new_w), mode='bilinear', align_corners=False).squeeze(0)
        return x

    def __len__(self): return self.samples

    def __getitem__(self, idx):
        L = random.randint(self.seq_min, self.seq_max)
        digits = [random.randint(0,9) for _ in range(L)]
        imgs = [self._rand_digit_img(d) for d in digits]
        # small random spacing between digits
        spaces = []
        for _ in range(L-1):
            pad_w = random.randint(1, 6)
            spaces.append(torch.zeros(1, self.img_h, pad_w))
        # concat along width
        pieces = []
        for i, im in enumerate(imgs):
            pieces.append(im)
            if i < L-1: pieces.append(spaces[i])
        img = torch.cat(pieces, dim=2)
        return img, torch.tensor(digits, dtype=torch.long)

In [None]:
def collate_batch(batch):
    imgs, labels = zip(*batch)
    # pad images by width to max_W
    H = imgs[0].shape[1]
    W_max = max(im.shape[2] for im in imgs)
    padded = []
    widths = []
    for im in imgs:
        pad_w = W_max - im.shape[2]
        if pad_w > 0:
            im = F.pad(im, (0, pad_w, 0, 0))  # pad width (left,right,top,bottom): (wL,wR,hT,hB)
        padded.append(im)
        widths.append(W_max)
    images = torch.stack(padded, dim=0)  # [B,1,H,W_max]
    # labels -> flat targets + lengths (for CTC)
    targets = torch.cat([torch.tensor(l, dtype=torch.long) for l in labels])
    target_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)
    return images, targets, target_lengths

In [None]:
class CRNN(nn.Module):
    def __init__(self, num_classes=10, img_h=32, cnn_out=256, rnn_hidden=256, rnn_layers=2):
        super().__init__()
        # CNN backbone (kept tiny for speed)
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d((2,2)),     # H/2
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d((2,2)),   # H/4
            nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(), nn.MaxPool2d((2,1)),  # H/8, keep width
            nn.Conv2d(256, 512, 3, padding=1), nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, cnn_out, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d((img_h//8,1))  # collapse height to 1 exactly
        )
        self.bi = nn.LSTM(input_size=cnn_out, hidden_size=rnn_hidden, num_layers=rnn_layers, bidirectional=True)
        self.fc = nn.Linear(rnn_hidden*2, num_classes+1)  # +1 for CTC blank

    def forward(self, x):
        # x: [B,1,H,W]
        feats = self.cnn(x)              # [B,C,1,W']
        feats = feats.squeeze(2)         # [B,C,W']
        feats = feats.permute(2,0,1)     # [T=W', B, C]
        seq, _ = self.bi(feats)          # [T, B, 2H]
        logits = self.fc(seq)            # [T, B, classes+blank]
        log_probs = F.log_softmax(logits, dim=-1)
        return log_probs


In [None]:
def ctc_train(num_epochs=5, batch_size=64, samples_train=20000, samples_val=2000, lr=1e-3, device='cuda' if torch.cuda.is_available() else 'cpu'):
    train_ds = MnistSequenceDataset('./data', train=True, samples=samples_train)
    val_ds   = MnistSequenceDataset('./data', train=False, samples=samples_val)
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=collate_batch, pin_memory=True)
    val_dl   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, collate_fn=collate_batch, pin_memory=True)

    model = CRNN().to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    ctc_loss = nn.CTCLoss(blank=10, zero_infinity=True)  # classes 0..9, blank=10

    def step(dl, train=True):
        model.train(train)
        total, n = 0.0, 0
        for images, targets, target_lengths in dl:
            images = images.to(device)
            targets = targets.to(device)
            # forward
            logp = model(images)                # [T,B,C]
            T, B, C = logp.shape
            input_lengths = torch.full((B,), T, dtype=torch.long, device=device)
            loss = ctc_loss(logp, targets, input_lengths, target_lengths)
            if train:
                opt.zero_grad(set_to_none=True)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
                opt.step()
            total += loss.item() * B
            n += B
        return total / max(n,1)

    for epoch in range(1, num_epochs+1):
        tr = step(train_dl, True)
        va = step(val_dl, False)
        print(f"Epoch {epoch:02d} | train CTC loss: {tr:.3f} | val CTC loss: {va:.3f}")

    return model, val_dl


In [None]:
def greedy_decode(log_probs):  # log_probs: [T,B,C]
    # choose argmax at each time, collapse repeats, remove blanks
    blank = 10
    T,B,C = log_probs.shape
    pred = log_probs.argmax(dim=-1).transpose(0,1)  # [B,T]
    sequences = []
    for b in range(B):
        prev = blank
        out = []
        for t in range(T):
            p = int(pred[b,t])
            if p != blank and p != prev:
                out.append(p)
            prev = p
        sequences.append(out)
    return sequences  # list of lists of ints


In [None]:
model, val_dl = ctc_train(num_epochs=5, samples_train=15000, samples_val=1000)

# quick sanity check on a few batches
model.eval()
with torch.no_grad():
    for images, targets, target_lengths in val_dl:
        logp = model(images.to(next(model.parameters()).device))
        hyps = greedy_decode(logp.cpu())
        # detokenize first 5 predictions & targets
        i = 0
        off = 0
        for _ in range(5):
            L = int(target_lengths[i])
            gt = targets[off:off+L].tolist()
            pr = hyps[i]
            print(f"GT: {''.join(map(str,gt))} | PR: {''.join(map(str,pr))}")
            off += L
            i += 1
        break


Epoch 01 | train CTC loss: 2.102 | val CTC loss: 0.265
Epoch 02 | train CTC loss: 0.169 | val CTC loss: 0.139
Epoch 03 | train CTC loss: 0.087 | val CTC loss: 0.068
Epoch 04 | train CTC loss: 0.064 | val CTC loss: 0.069
Epoch 05 | train CTC loss: 0.058 | val CTC loss: 0.061
GT: 6434 | PR: 6434
GT: 39225 | PR: 39225
GT: 180 | PR: 180
GT: 999 | PR: 999
GT: 69 | PR: 69


In [None]:
import matplotlib.pyplot as plt

# make a dataset
ds = MnistSequenceDataset('./data', train=True, samples=10)

# pick one sample
img, label = ds[0]   # img: [1,H,W], label: tensor of digits

print("Digits:", label.tolist())

# convert [1,H,W] -> [H,W]
plt.imshow(img.squeeze(0).numpy(), cmap='gray')
plt.title("Digits: " + ''.join(map(str, label.tolist())))
plt.axis('off')
plt.show()
