----------------------
## 4.2 Optical Character Recognition
----------------------

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from PIL import Image, ImageDraw, ImageFont
from nltk.corpus import words
from sklearn.model_selection import train_test_split
import nltk
import matplotlib.pyplot as plt

nltk.download("words")

In [None]:

IMG_W, IMG_H = 256, 64
FONT_SIZE = 32
BLANK = "_"
MAX_LEN = 30
DATASET_SIZE = 10000
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def gen_img_with_text(text):
    img = Image.new("L", (IMG_W, IMG_H), color=255)
    draw = ImageDraw.Draw(img)
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf", FONT_SIZE)
    bbox = draw.textbbox((0, 0), text, font=font)
    text_w, text_h = bbox[2] - bbox[0], bbox[3] - bbox[1]
    x, y = (IMG_W - text_w) // 2, (IMG_H - text_h) // 2
    draw.text((x, y), text, font=font, fill=0)
    return np.array(img)

def create_uniform_dataset():
    alphabet_groups = {chr(i): [] for i in range(97, 123)}  # 'a' to 'z'
    for word in words.words():
        if len(word) <= MAX_LEN and word[0].isalpha():
            alphabet_groups[word[0].lower()].append(word)

    selected_words = []
    words_per_group = DATASET_SIZE // len(alphabet_groups)
    for group, word_list in alphabet_groups.items():
        selected_words.extend(random.sample(word_list, min(len(word_list), words_per_group)))

    imgs, lbls = [], []
    for word in selected_words:
        word_padded = word.ljust(MAX_LEN, BLANK)
        imgs.append(gen_img_with_text(word))
        lbls.append(word_padded)

    imgs = np.array(imgs).reshape(-1, 1, IMG_H, IMG_W) / 255.0
    return imgs, lbls

imgs, lbls = create_uniform_dataset()
train_imgs, val_imgs, train_lbls, val_lbls = train_test_split(imgs, lbls, test_size=0.2, random_state=42)

all_chars = sorted(set("".join(lbls)))
char_to_idx = {c: i for i, c in enumerate(all_chars)}
idx_to_char = {i: c for c, i in char_to_idx.items()}
NUM_CLASSES = len(all_chars)


In [None]:
class OCRDataset(Dataset):
    def __init__(self, imgs, lbls):
        self.imgs = torch.tensor(imgs, dtype=torch.float32)
        self.lbls = torch.tensor([[char_to_idx[c] for c in label] for label in lbls], dtype=torch.long)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        return self.imgs[idx], self.lbls[idx]

train_dataset = OCRDataset(train_imgs, train_lbls)
val_dataset = OCRDataset(val_imgs, val_lbls)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [None]:
class OptimizedCRNN(nn.Module):
    def __init__(self, num_classes):
        super(OptimizedCRNN, self).__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(32), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(64), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(128), nn.MaxPool2d(2)
        )
        self.fc = nn.Linear(128 * 8 * 32, 256)
        self.rnn = nn.GRU(input_size=256, hidden_size=128, num_layers=2, batch_first=True, bidirectional=True)
        self.output_layer = nn.Linear(256, num_classes) 

    def forward(self, x):
        batch_size = x.size(0)
        x = self.cnn(x)
        x = x.view(batch_size, -1)
        x = self.fc(x)
        x = x.unsqueeze(1).repeat(1, MAX_LEN, 1)
        x, _ = self.rnn(x)
        x = self.output_layer(x)
        return x

model = OptimizedCRNN(NUM_CLASSES).to(DEVICE)

In [None]:
def eval_rand_baseline(val_loader, char_to_idx):
    correct, total = 0, 0
    for imgs, lbls in val_loader:
        rand_preds = torch.randint(len(char_to_idx), (imgs.size(0), MAX_LEN))
        correct += (rand_preds == lbls).sum().item()
        total += lbls.numel()
    avg_correct = correct / total
    # print(f"\n{'Random Baseline':<20} | Avg Correct Chars: {avg_correct:.4f}")
    return avg_correct


def train_model(model, train_loader, val_loader, epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for imgs, lbls in train_loader:
            imgs, lbls = imgs.to(DEVICE), lbls.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(imgs)

            loss = 0
            for t in range(MAX_LEN):
                loss += criterion(outputs[:, t, :], lbls[:, t])
            loss /= MAX_LEN
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        scheduler.step()

        model.eval()
        val_loss, correct_chars, total_chars = 0, 0, 0
        with torch.no_grad():
            for imgs, lbls in val_loader:
                imgs, lbls = imgs.to(DEVICE), lbls.to(DEVICE)
                outputs = model(imgs)

                loss = 0
                for t in range(MAX_LEN):
                    loss += criterion(outputs[:, t, :], lbls[:, t])
                loss /= MAX_LEN
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=2)
                correct_chars += (preds == lbls).sum().item()
                total_chars += lbls.numel()

        avg_correct = correct_chars / total_chars
        rand_baseline_avg = eval_rand_baseline(val_loader, char_to_idx)

        print(f"\n{'Epoch':<20}{epoch+1}/{epochs:<5} | {'Model Avg Correct':<20}{avg_correct:.4f} | "
              f"{'Random Baseline':<20}{rand_baseline_avg:.4f}")
        print(f"{'Train Loss':<20}{train_loss / len(train_loader):.4f} | {'Val Loss':<20}{val_loss / len(val_loader):.4f}")

        if epoch % 2 == 0:
              sample_imgs, sample_lbls = next(iter(val_loader))
              sample_imgs, sample_lbls = sample_imgs.to(DEVICE), sample_lbls.to(DEVICE)
              outputs = model(sample_imgs)
              preds = torch.argmax(outputs, dim=2).cpu().numpy()
              print("\nSample Predictions:")
              for i in range(5):

                  pred_text = "".join(idx_to_char[idx] for idx in preds[i] if idx != char_to_idx[BLANK]).rstrip(BLANK)
                  true_text = "".join(idx_to_char[idx] for idx in sample_lbls[i].cpu().numpy()).rstrip(BLANK)
                  print(f"Prediction: {pred_text:<30} | Actual: {true_text}")

train_model(model, train_loader, val_loader, epochs=25)


[nltk_data] Downloading package words to /root/nltk_data...
[nltk_data]   Unzipping corpora/words.zip.



Epoch               1/25    | Model Avg Correct   0.7207 | Random Baseline     0.0182
Train Loss          1.0766 | Val Loss            0.9656

Sample Predictions:
Prediction: oooooooooiis                   | Actual: monumentalism
Prediction: aaai                           | Actual: worth
Prediction: oooooooooiis                   | Actual: ethnobiological
Prediction: aaaaiii                        | Actual: xyphoid
Prediction: aaaaiii                        | Actual: quiritary

Epoch               2/25    | Model Avg Correct   0.7219 | Random Baseline     0.0185
Train Loss          0.9611 | Val Loss            0.9621

Epoch               3/25    | Model Avg Correct   0.7176 | Random Baseline     0.0188
Train Loss          0.9410 | Val Loss            0.9759

Sample Predictions:
Prediction: aaooooiiiie                    | Actual: monumentalism
Prediction: yaae                           | Actual: worth
Prediction: peooooooiiiiie                 | Actual: ethnobiological
Prediction: yaa