In [24]:
import torch
from datasets import load_dataset, Image
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision.transforms.functional as f
import numpy as np
import torch.nn.functional as F
import time
import requests
import io

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [25]:
# Data Preprocessing

ds = load_dataset("priyank-m/MJSynth_text_recognition")

labels = ds["train"].unique("label")
label2id = {lab: i for i, lab in enumerate(labels)}

def add_label_id(ex):
    ex["label_id"] = label2id[ex["label"]]
    return ex

ds = ds.map(add_label_id)

def preprocess(img):
    w, h = img.size
    new_w = int(w * (32 / h))
    img = f.resize(img, (32, new_w))

    if new_w < 256:
        pad_right = 256 - new_w
        img = f.pad(img, [0, 0, pad_right, 0], fill=0)
    else:
        img = f.crop(img, 0, 0, 32, 256) 

    return f.to_tensor(img)

def process(batch):
    batch["pixel_values"] = [preprocess(im) for im in batch["image"]]
    return batch

train_set = ds['train'].with_transform(process)
test_set = ds['test'].with_transform(process)
val_set = ds['val'].with_transform(process)

def collate_fn(data):
    x = torch.stack([d["pixel_values"] for d in data])
    y = torch.tensor([d["label_ids"] for d in data], dtype=torch.long)
    texts = [b["label"] for b in data]
    return {"pixel_values": x, "label_ids": y, "labels": texts}

train_dataloader = DataLoader(train_set, batch_size=64, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_set, batch_size=64, shuffle=False, collate_fn=collate_fn)
val_dataloader = DataLoader(val_set, batch_size=64, shuffle=False, collate_fn=collate_fn)

In [26]:
print(f"Train Size: {len(train_set)}\nTest Size: {len(test_set)}\nVal Size: {len(val_set)}\nTotal: {len(test_set) + len(train_set) + len(val_set)}\n")

# batche size = 64
print(f"Training Batches: {len(train_dataloader)}\nTest Batches: {len(test_dataloader)}\nVal Batches: {len(val_dataloader)}\n")

Train Size: 7224600
Test Size: 891924
Val Size: 802733
Total: 8919257

Training Batches: 112885
Test Batches: 13937
Val Batches: 12543



In [54]:
import torch
from torch import nn

# shape = [3, 32, 256]
class CNN(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.fc1 = nn.Linear(8192, 4096)
        self.fc2 = nn.Linear(2048, 1024)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, num_classes)

    def forward(self, x):
        out = self.pool1(F.relu(self.conv1(x)))
        out = self.pool2(F.relu(self.conv2(out)))
        out = self.pool3(F.relu(self.conv3(out)))
        out = self.pool4(F.relu(self.conv4(out)))

        print(f"Pre-flatten: {out.shape}")
        out = torch.flatten(out, start_dim=1)
        print(f"Flatten: {out.shape}")

        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))

        out = self.fc4(out)
        return out

num_classes = 23
model = CNN(3, num_classes).to(device=device)
    

In [55]:
# training loop

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 15
steps_per_epoch = len(train_dataloader)

for epoch in range(num_epochs):
    epoch_start = time.perf_counter()
    print(f"Epoch {epoch + 1}/{num_epochs}")

    model.train()

    train_loss = 0.0
    train_correct = 0
    train_total = 0

    for i in range(len(train_dataloader)):
        x = train_dataloader.dataset[i]["pixel_values"].to(device)
        y = train_dataloader.dataset[i]["label_id"]

        logits = model(x)
        loss = loss_fn(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean().item()
        print(f"Loss: {loss.item():.3f}, Acc: {acc:.3f}")

#     model.eval()

#     val_loss = 0.0
#     val_correct = 0
#     val_total = 0

#     with torch.no_grad():
#         for batch in test_dataloader:
#             x = batch["pixel_values"].to(device)
#             y = batch["label"].to(device).long()

#             logits = model(x)

#             val_correct += (logits.argmax(dim=1) == y).sum().item()
#             val_total += y.size(0)

    elapsed_s = time.perf_counter() - epoch_start
#     ms_per_step = (elapsed_s / steps_per_epoch) * 1000

# model.eval()
# test_correct = 0
# test_total = 0

# with torch.no_grad():
#     for batch in test_dataloader:
#         x = batch["pixel_values"].to(device)
#         y = batch["label"].to(device).long()

#         outputs = model(x)
#         test_correct += (outputs.argmax(dim=1) == y).sum().item()
#         test_total += y.size(0)


Epoch 1/15
Pre-flatten: torch.Size([256, 2, 16])
Flatten: torch.Size([256, 32])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x32 and 8192x4096)