In [None]:
!pip install -q torch torchvision torchaudio pillow tqdm matplotlib numpy pandas

import os, random, string, math, sys, time, urllib.request, shutil
from pathlib import Path
from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps, ImageEnhance
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ---------------- Config / hyperparameters ----------------
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# dataset sizes (adjust for runtime)
NUM_TRAIN = 8000
NUM_VAL   = 1500
NUM_TEST  = 1500

# image size (W x H)
IMG_W, IMG_H = 160, 32

# text length constraints
MIN_LEN, MAX_LEN = 3, 12

BATCH_SIZE = 64
EPOCHS = 12
LR = 1e-3

# fonts directory
FONTS_DIR = Path("fonts")
FONTS_DIR.mkdir(exist_ok=True)

# characters allowed: lowercase + digits
CHARS = string.ascii_lowercase + string.digits

# mapping to/from idx for CTC (blank will be 0)
char_list = list(CHARS)
idx_to_char = {i+1: ch for i, ch in enumerate(char_list)}
char_to_idx = {ch: i+1 for i, ch in enumerate(char_list)}
BLANK_IDX = 0

# ---------------- Download fonts ----------------
font_urls = [
    "https://github.com/google/fonts/raw/main/apache/roboto/Roboto-Regular.ttf",
    "https://github.com/google/fonts/raw/main/apache/opensans/OpenSans-Regular.ttf",
    "https://github.com/google/fonts/raw/main/apache/inconsolata/Inconsolata-Regular.ttf",
]
downloaded_fonts = []
for i, url in enumerate(font_urls):
    fname = FONTS_DIR / f"font_{i}.ttf"
    try:
        urllib.request.urlretrieve(url, fname)
        downloaded_fonts.append(str(fname))
    except Exception as e:
        print(f"Font download failed for {url}: {e}")

if len(downloaded_fonts) == 0:
    print("No fonts downloaded â€” will use default PIL font.")
    downloaded_fonts = [None]
else:
    print("Downloaded fonts:", downloaded_fonts)

# ---------------- Word pool ----------------
words = []
try:
    with open("/usr/share/dict/words", "r") as f:
        for w in f:
            w = w.strip().lower()
            if MIN_LEN <= len(w) <= MAX_LEN and all(ch in CHARS for ch in w):
                words.append(w)
except Exception:
    pass

if len(words) < 500:
    letters = string.ascii_lowercase + string.digits
    for _ in range(8000):
        L = random.randint(MIN_LEN, MAX_LEN)
        words.append("".join(random.choices(letters, k=L)))
words = list(set(words))
print("Word pool size:", len(words))

# ---------------- Renderer (uses textbbox) ----------------
def render_text_image(text, img_w=IMG_W, img_h=IMG_H):
    bg = Image.new("L", (img_w, img_h), color=255)
    draw = ImageDraw.Draw(bg)

    font_path = random.choice(downloaded_fonts)
    if font_path is None:
        font = ImageFont.load_default()
    else:
        size = random.randint(int(img_h*0.6), int(img_h*0.9))
        try:
            font = ImageFont.truetype(font_path, size=size)
        except Exception:
            font = ImageFont.load_default()

    bbox = draw.textbbox((0, 0), text, font=font)
    tw = bbox[2] - bbox[0]
    th = bbox[3] - bbox[1]

    if tw > img_w - 10 and hasattr(font, "size") and font.size is not None:
        scale = (img_w - 10) / tw
        new_size = max(8, int(font.size * scale))
        try:
            font = ImageFont.truetype(font_path, size=new_size)
            bbox = draw.textbbox((0, 0), text, font=font)
            tw = bbox[2] - bbox[0]
            th = bbox[3] - bbox[1]
        except Exception:
            pass

    x = max(2, (img_w - tw)//2 + random.randint(-4, 4))
    y = max(0, (img_h - th)//2 + random.randint(-2, 2))
    ink = random.randint(0, 30)
    draw.text((x, y), text, font=font, fill=ink)

    if random.random() < 0.2:
        bg = bg.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.2, 1.0)))
    if random.random() < 0.05:
        bg = ImageOps.invert(bg)
        if random.random() < 0.5:
            bg = ImageOps.invert(bg)
    enh = ImageEnhance.Contrast(bg)
    bg = enh.enhance(random.uniform(0.9, 1.15))

    return bg

# ---------------- Dataset & collate ----------------
class SyntheticWordsDataset(Dataset):
    def __init__(self, words_pool, size):
        self.words_pool = words_pool
        self.size = size

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        w = random.choice(self.words_pool)
        img = render_text_image(w)
        arr = np.array(img).astype(np.float32) / 255.0  # H x W
        tensor = torch.from_numpy(arr).unsqueeze(0)    # 1 x H x W
        target_idxs = [char_to_idx[ch] for ch in w]
        target_tensor = torch.tensor(target_idxs, dtype=torch.long)
        return tensor, target_tensor, w, img  # also return PIL image for saving examples

def collate_fn(batch):
    images, targets, target_strs, pil_images = zip(*batch)
    images = torch.stack(images, dim=0)
    target_lengths = torch.tensor([t.numel() for t in targets], dtype=torch.long)
    if len(targets) > 0:
        targets_concat = torch.cat(targets)
    else:
        targets_concat = torch.tensor([], dtype=torch.long)
    return images, targets_concat, target_lengths, target_strs, pil_images

train_ds = SyntheticWordsDataset(words, NUM_TRAIN)
val_ds   = SyntheticWordsDataset(words, NUM_VAL)
test_ds  = SyntheticWordsDataset(words, NUM_TEST)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)

# ---------------- CRNN model ----------------
class CRNN(nn.Module):
    def __init__(self, img_h, num_classes, cnn_out=256, rnn_hidden=256, n_rnn_layers=2):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.MaxPool2d((2,2)),            # H/2, W/2
            nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.MaxPool2d((2,2)),            # H/4, W/4
            nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.MaxPool2d((1,2)),            # H/4, W/8 (width halved)
            nn.Conv2d(256, cnn_out, 3, 1, 1), nn.BatchNorm2d(cnn_out), nn.ReLU(True),
        )
        reduced_h = img_h // 4
        rnn_input_size = cnn_out * reduced_h
        self.rnn = nn.LSTM(input_size=rnn_input_size, hidden_size=rnn_hidden, num_layers=n_rnn_layers,
                           bidirectional=True, batch_first=False)
        self.fc = nn.Linear(rnn_hidden*2, num_classes)

    def forward(self, x):
        conv = self.conv(x)           # (B, C, H', W')
        b, c, h, w = conv.size()
        conv = conv.permute(3, 0, 2, 1)      # (W', B, H', C)
        conv = conv.contiguous().view(w, b, h*c)  # (W', B, h*c)
        rnn_out, _ = self.rnn(conv)     # (W', B, 2*hidden)
        logits = self.fc(rnn_out)       # (W', B, num_classes)
        return logits.log_softmax(2)    # CTC expects log-probs

num_classes = len(char_list) + 1  # +1 for blank (index 0)
model = CRNN(IMG_H, num_classes=num_classes).to(device)
print(model)

# ---------------- Loss & optimizer ----------------
ctc_loss = nn.CTCLoss(blank=BLANK_IDX, zero_infinity=True)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

# ---------------- Decoder & metrics ----------------
def ctc_greedy_decode(preds_log_softmax):
    preds = preds_log_softmax.detach().cpu().numpy()  # (T,B,C)
    T, B, C = preds.shape
    texts = []
    for b in range(B):
        seq = []
        last = -1
        for t in range(T):
            idx = int(preds[t, b].argmax())
            if idx != last and idx != BLANK_IDX:
                seq.append(idx_to_char.get(idx, ""))
            last = idx
        texts.append("".join(seq))
    return texts

def levenshtein(a, b):
    la, lb = len(a), len(b)
    if la == 0: return lb
    if lb == 0: return la
    dp = list(range(lb+1))
    for i in range(1, la+1):
        prev = dp[0]
        dp[0] = i
        for j in range(1, lb+1):
            cur = dp[j]
            if a[i-1] == b[j-1]:
                dp[j] = prev
            else:
                dp[j] = 1 + min(prev, dp[j-1], dp[j])
            prev = cur
    return dp[lb]

# ---------------- Evaluation helper (returns detailed lists) ----------------
def evaluate_collect(loader, model, collect_examples=False, max_examples=None):
    model.eval()
    tot_loss = 0.0
    tot_samples = 0
    exact_matches = 0
    total_chars = 0
    total_char_errors = 0

    records = []  # will store tuples (pred, target, ed, target_len, pil_image) if requested

    with torch.no_grad():
        for images, targets_concat, target_lengths, target_strs, pil_images in loader:
            images = images.to(device)
            B = images.size(0)
            logits = model(images)               # (T,B,C)
            T, _, _ = logits.shape
            input_lengths = torch.full((B,), T, dtype=torch.long).to(device)
            targets = targets_concat.to(device)
            loss = ctc_loss(logits, targets, input_lengths, target_lengths.to(device))
            tot_loss += loss.item() * B
            tot_samples += B

            preds = ctc_greedy_decode(logits)
            for i, (pred, target) in enumerate(zip(preds, target_strs)):
                ed = levenshtein(pred, target)
                total_char_errors += ed
                total_chars += max(1, len(target))
                if pred == target:
                    exact_matches += 1
                rec = (pred, target, ed, len(target))
                if collect_examples:
                    rec = rec + (pil_images[i],)
                records.append(rec)

            if max_examples and len(records) >= max_examples:
                break

    avg_loss = tot_loss / max(1, tot_samples)
    exact_acc = exact_matches / max(1, tot_samples)
    cer = total_char_errors / max(1, total_chars)
    return avg_loss, exact_acc, cer, total_chars, total_char_errors, records

# ---------------- Training loop ----------------
train_losses = []
val_losses = []
val_accs = []
val_cers = []

print("Starting training...")
for epoch in range(1, EPOCHS+1):
    model.train()
    running_loss = 0.0
    total_seen = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=False)
    for images, targets_concat, target_lengths, target_strs, pil_images in pbar:
        images = images.to(device)
        B = images.size(0)
        optimizer.zero_grad()
        logits = model(images)   # (T,B,C)
        T, _, _ = logits.shape
        input_lengths = torch.full((B,), T, dtype=torch.long).to(device)
        loss = ctc_loss(logits, targets_concat.to(device), input_lengths, target_lengths.to(device))
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * B
        total_seen += B
        pbar.set_postfix({"batch_loss": f"{loss.item():.4f}"})
    epoch_train_loss = running_loss / max(1, len(train_ds))
    train_losses.append(epoch_train_loss)

    # validation
    val_loss, val_acc, val_cer, _, _, _ = evaluate_collect(val_loader, model, collect_examples=False)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    val_cers.append(val_cer)

    print(f"Epoch {epoch}  TrainLoss: {epoch_train_loss:.4f}  ValLoss: {val_loss:.4f}  ValAcc: {val_acc*100:.2f}%  ValCER: {val_cer:.4f}")

# ---------------- Plot training curves ----------------
plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(range(1, len(train_losses)+1), train_losses, label="train_loss")
plt.plot(range(1, len(val_losses)+1), val_losses, label="val_loss")
plt.xlabel("Epoch"); plt.title("Loss"); plt.legend()
plt.subplot(1,2,2)
plt.plot(range(1, len(val_accs)+1), [a*100 for a in val_accs], label="val_exact_acc_%")
plt.xlabel("Epoch"); plt.title("Validation Exact-match Accuracy (%)"); plt.legend()
plt.tight_layout()
plt.show()

# ---------------- Detailed TEST evaluation ----------------
test_loss, test_acc, test_cer, total_chars, total_char_errors, test_records = evaluate_collect(test_loader, model, collect_examples=True)

char_accuracy = (total_chars - total_char_errors) / max(1, total_chars)

print("\nFINAL EVALUATION ON TEST SET")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Exact-match Accuracy: {test_acc*100:.2f}%")
print(f"Test CER (char error rate): {test_cer:.4f}")
print(f"Test Character Accuracy: {char_accuracy*100:.2f}% (i.e., chars correct / total chars)")

# ---------------- Save CSV of predictions ----------------
df = pd.DataFrame(test_records, columns=["predicted", "target", "edit_distance", "target_length", "pil_image"])
# drop PIL image before saving CSV; but keep it in-memory for example export
df_to_save = df.drop(columns=["pil_image"])
csv_path = "test_predictions.csv"
df_to_save.to_csv(csv_path, index=False)
print(f"\nSaved test predictions to: {csv_path}")
display(df_to_save.head(20))

# ---------------- Histogram of edit distances ----------------
eds = df["edit_distance"].astype(int).values
plt.figure(figsize=(6,4))
plt.hist(eds, bins=range(0, max(eds)+2), align="left")
plt.xlabel("Edit distance (Levenshtein)")
plt.ylabel("Number of samples")
plt.title("Histogram of edit distances on test set")
plt.tight_layout()
plt.show()

# ---------------- Exact-match accuracy by word length ----------------
grouped = df.groupby("target_length").apply(lambda x: (x["predicted"]==x["target"]).mean())
lengths = grouped.index.tolist()
accs_by_len = grouped.values.tolist()
plt.figure(figsize=(7,4))
plt.bar(lengths, [a*100 for a in accs_by_len])
plt.xlabel("Target word length")
plt.ylabel("Exact-match accuracy (%)")
plt.title("Exact-match accuracy by word length (test set)")
plt.xticks(lengths)
plt.tight_layout()
plt.show()

# ---------------- Save worst examples (highest edit distance) ----------------
out_dir = Path("test_examples")
if out_dir.exists():
    shutil.rmtree(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)

# sort by edit distance descending, take top N worst
N_worst = 10
df_sorted = df.sort_values(by="edit_distance", ascending=False).head(N_worst).reset_index(drop=True)

for i, row in df_sorted.iterrows():
    pil_img = row["pil_image"]
    pred = row["predicted"]
    target = row["target"]
    ed = int(row["edit_distance"])
    # create a new RGB image with extra space below to write text
    canvas = Image.new("RGB", (IMG_W, IMG_H+30), color=(255,255,255))
    canvas.paste(pil_img.convert("RGB"), (0,0))
    draw = ImageDraw.Draw(canvas)
    # choose a font for overlay if available
    try:
        font_path = downloaded_fonts[0] if downloaded_fonts else None
        font_small = ImageFont.truetype(font_path, size=12) if font_path else ImageFont.load_default()
    except:
        font_small = ImageFont.load_default()
    text = f"P:{pred}  T:{target}  ED:{ed}"
    draw.text((4, IMG_H+4), text, fill=(0,0,0), font=font_small)
    save_path = out_dir / f"worst_{i+1}_ed{ed}.png"
    canvas.save(save_path)
print(f"\nSaved top {N_worst} worst examples to folder: {out_dir.resolve()}")

# display a few worst examples inline (if in notebook)
from IPython.display import display as ipydisplay
print("\nTop worst examples (display):")
for p in sorted(out_dir.iterdir())[:N_worst]:
    print(p.name)
    ipydisplay(Image.open(p))

# ---------------- Show some random examples from test set ----------------
print("\nSome random test set predictions (predicted -> target):")
sample = df.sample(min(20, len(df)), random_state=seed).reset_index(drop=True)
for i, r in sample.head(20).iterrows():
    print(f"{i+1:2d}. {r['predicted']}  ->  {r['target']}  (ED={r['edit_distance']})")



Mon Nov 24 18:04:10 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   65C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/usr/local/lib/python3.12/dist-packages/colab_kernel_launcher.py", line 37, in <module>
    ColabKernelApp.launch_instance()
  File "/usr/local/lib/python3.12/dist-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelapp.py", line 712, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.12/dist-package

AttributeError: _ARRAY_API not found

ImportError: numpy.core.multiarray failed to import