<a href="https://colab.research.google.com/github/LineIntegralx/CalligraNet/blob/main/Training_Scripts/CNN_Swin_CTC_Training_v2_Finalized.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [21]:
from google.colab import drive
drive.mount('/content/drive')

!pip install -q timm


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [22]:
import os
import random
from pathlib import Path
from collections import Counter

import numpy as np
import pandas as pd
from tqdm import tqdm

import cv2
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import timm


In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

ROOT = Path("/content/drive/MyDrive/EECE693_Project")

DATA_ROOTS = {
    "D0_preprocessed": ROOT / "Preprocessed_HICMA",
    "D1_augmented":    ROOT / "Augmented_HICMA",
    "D2_synth":        ROOT / "HICMA_Plus_Synthetic",
}

CKPT_DIR = ROOT / "SwinCTC_Checkpoints"
CKPT_DIR.mkdir(parents=True, exist_ok=True)
print("Checkpoint dir:", CKPT_DIR)

BATCH_SIZE    = 8
NUM_EPOCHS    = 40
FREEZE_EPOCHS = 5          # epochs with Swin frozen
LR_MAIN       = 3e-4       # CNN stem + CTC head
LR_SWIN       = 1e-4       # Swin fine-tuning
WEIGHT_DECAY  = 1e-2
PATIENCE      = 7          # early stopping on val CER

TARGET_HEIGHT = 256
PAD_DIVISOR   = 32

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device.type == "cuda":
    torch.cuda.manual_seed_all(SEED)


Device: cuda
Checkpoint dir: /content/drive/MyDrive/EECE693_Project/SwinCTC_Checkpoints


In [24]:
def load_split_dfs(base_dir: Path):
    dfs = {}
    for split in ["train", "val", "test"]:
        df = pd.read_csv(base_dir / f"{split}_labels.csv")
        df["label"] = df["label"].astype(str)
        df["img_name"] = df["img_name"].astype(str).str.strip()
        dfs[split] = df
    return dfs

for name, base in DATA_ROOTS.items():
    print(f"\n=== {name} ===")
    dfs = load_split_dfs(base)
    for split, df in dfs.items():
        print(f"{split}: {len(df)} rows, classes={df['class'].value_counts().to_dict()}")
        print(f"  avg label len = {df['label'].str.len().mean():.1f}")



=== D0_preprocessed ===
train: 4020 rows, classes={'Naskh': 2988, 'Thuluth': 808, 'Diwani': 190, 'Kufic': 21, 'Muhaquaq': 13}
  avg label len = 42.5
val: 502 rows, classes={'Naskh': 373, 'Thuluth': 101, 'Diwani': 23, 'Kufic': 3, 'Muhaquaq': 2}
  avg label len = 41.6
test: 503 rows, classes={'Naskh': 374, 'Thuluth': 101, 'Diwani': 24, 'Muhaquaq': 2, 'Kufic': 2}
  avg label len = 44.5

=== D1_augmented ===
train: 20000 rows, classes={'Naskh': 4000, 'Thuluth': 4000, 'Diwani': 4000, 'Muhaquaq': 4000, 'Kufic': 4000}
  avg label len = 34.3
val: 502 rows, classes={'Naskh': 373, 'Thuluth': 101, 'Diwani': 23, 'Kufic': 3, 'Muhaquaq': 2}
  avg label len = 41.6
test: 503 rows, classes={'Naskh': 374, 'Thuluth': 101, 'Diwani': 24, 'Muhaquaq': 2, 'Kufic': 2}
  avg label len = 44.5

=== D2_synth ===
train: 27316 rows, classes={'Kufic': 6923, 'Diwani': 6296, 'Thuluth': 6097, 'Naskh': 4000, 'Muhaquaq': 4000}
  avg label len = 37.6
val: 1416 rows, classes={'Naskh': 373, 'Kufic': 368, 'Thuluth': 363, 'Di

In [25]:
# Build char vocab from richest dataset (D2_synth)
def build_vocab_from_dataset(base_dir: Path):
    labels = []
    for split in ["train", "val", "test"]:
        df = pd.read_csv(base_dir / f"{split}_labels.csv")
        labels.extend(df["label"].astype(str).tolist())
    all_text = "".join(labels)
    char_counter = Counter(all_text)
    chars = sorted(list(char_counter.keys()))
    print("\nNum unique chars:", len(chars))
    return chars

chars = build_vocab_from_dataset(DATA_ROOTS["D2_synth"])

BLANK_IDX = 0
stoi = {ch: i + 1 for i, ch in enumerate(chars)}   # chars start at 1
itos = {i + 1: ch for i, ch in enumerate(chars)}
vocab_size = len(chars) + 1  # + blank
print("vocab_size (including blank):", vocab_size)

class TextEncoder:
    def __init__(self, stoi, itos, blank_idx=0):
        self.stoi = stoi
        self.itos = itos
        self.blank_idx = blank_idx

    def encode(self, text: str):
        return [self.stoi[c] for c in text if c in self.stoi]

    def decode(self, ids):
        return "".join(self.itos[i] for i in ids if i in self.itos)

text_encoder = TextEncoder(stoi, itos, BLANK_IDX)



Num unique chars: 70
vocab_size (including blank): 71


In [26]:
TARGET_HEIGHT = 256
PAD_DIVISOR   = 32
MAX_WIDTH_RESIZED = 1600  # you can later try 1024 or 2048

class ResizePadTo256:
    def __init__(self,
                 target_height=TARGET_HEIGHT,
                 pad_divisor=PAD_DIVISOR,
                 max_width_resized=MAX_WIDTH_RESIZED):
        self.target_height = target_height
        self.pad_divisor = pad_divisor
        self.max_width_resized = max_width_resized

    def __call__(self, pil_img: Image.Image):
        # 1) to grayscale numpy
        img = pil_img.convert("L")
        img = np.array(img)  # HÃ—W

        h, w = img.shape[:2]

        # 2) resize to fixed height
        scale = self.target_height / float(h)
        new_w = int(round(w * scale))
        img_resized = cv2.resize(
            img, (new_w, self.target_height), interpolation=cv2.INTER_AREA
        )

        # 3) if still too wide, compress width
        if new_w > self.max_width_resized:
            new_w = self.max_width_resized
            img_resized = cv2.resize(
                img_resized, (new_w, self.target_height), interpolation=cv2.INTER_AREA
            )

        # 4) pad width to multiple-of-divisor
        if self.pad_divisor is not None:
            padded_w = int(np.ceil(new_w / self.pad_divisor) * self.pad_divisor)
        else:
            padded_w = new_w

        canvas = np.full((self.target_height, padded_w), 255, dtype=np.uint8)
        x0 = (padded_w - new_w) // 2
        canvas[:, x0:x0+new_w] = img_resized

        # 5) to tensor [1, H, W] in [-1, 1]
        tensor = torch.from_numpy(canvas).float() / 255.0
        tensor = (tensor - 0.5) / 0.5
        return tensor.unsqueeze(0)

resize_pad = ResizePadTo256()
train_transform = resize_pad
eval_transform  = resize_pad


In [27]:
class HICMADataset(Dataset):
    def __init__(self, base_dir: Path, split: str, transform, text_encoder: TextEncoder):
        self.base_dir = base_dir
        self.split = split
        self.transform = transform
        self.text_encoder = text_encoder

        df = pd.read_csv(self.base_dir / f"{split}_labels.csv")
        df["img_name"] = df["img_name"].astype(str).str.strip()
        df["label"] = df["label"].astype(str)
        self.df = df[["img_name", "class", "label"]].copy()

        self.img_dir = self.base_dir / split / "images"

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row["img_name"]
        cls      = row["class"]
        text     = row["label"]

        img_path = self.img_dir / img_name
        image = Image.open(img_path).convert("L")
        image = self.transform(image)               # [1, H, W]

        target = torch.tensor(self.text_encoder.encode(text), dtype=torch.long)

        return {
            "image": image,
            "target": target,
            "text": text,
            "class": cls,
            "img_name": img_name,
        }


In [28]:
def ctc_collate(batch):
    images  = [b["image"] for b in batch]    # each: [1, H_i, W_i]
    targets = [b["target"] for b in batch]

    batch_size = len(images)
    C = images[0].shape[0]
    heights = [img.shape[1] for img in images]
    widths  = [img.shape[2] for img in images]

    max_h = max(heights)
    max_w = max(widths)

    def round_up(x, m=4):
        return ((x + m - 1) // m) * m

    max_h = round_up(max_h, 4)
    max_w = round_up(max_w, 4)

    pad_val = 1.0  # white in normalized space (since we use [-1,1])
    padded = torch.full((batch_size, C, max_h, max_w),
                        pad_val, dtype=images[0].dtype)

    for i, img in enumerate(images):
        h, w = img.shape[1], img.shape[2]
        padded[i, :, :h, :w] = img  # top-left placement

    target_lengths = torch.tensor([t.size(0) for t in targets], dtype=torch.long)
    targets_concat = torch.cat(targets, dim=0)

    meta = {
        "texts": [b["text"] for b in batch],
        "img_names": [b["img_name"] for b in batch],
        "classes": [b["class"] for b in batch],
        "widths": widths,
        "heights": heights,
    }

    return padded, targets_concat, target_lengths, meta

# quick sanity check on one dataset
ds = HICMADataset(DATA_ROOTS["D0_preprocessed"], "train", train_transform, text_encoder)
print("Sample image shape:", ds[0]["image"].shape)


Sample image shape: torch.Size([1, 256, 1600])


In [29]:
class CNNSwinCTC(nn.Module):
    def __init__(self, vocab_size, hidden_dim=256):
        super().__init__()

        # CNN stem: reduce width moderately, keep height
        self.cnn_stem = nn.Sequential(
            # 1 x H x W -> 32 x H x (W/2)
            nn.Conv2d(1, 32, kernel_size=3, stride=(1, 2), padding=1),
            nn.BatchNorm2d(32),
            nn.GELU(),

            # 32 x H x W/2 -> 64 x H x (W/4)
            nn.Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1),
            nn.BatchNorm2d(64),
            nn.GELU(),

            # 64 x H x W/4 -> 3 x H x (W/4)
            nn.Conv2d(64, 3, kernel_size=3, stride=(1, 1), padding=1),
            nn.BatchNorm2d(3),
            nn.GELU(),
        )

        # Swin-T backbone, NCHW features
        self.swin = timm.create_model(
            "swin_tiny_patch4_window7_224",
            pretrained=True,
            features_only=True,
            out_indices=[-1],
            in_chans=3,
            img_size=256,
            strict_img_size=False,
        )

        self.hidden_dim = hidden_dim
        self.proj = None  # will be created lazily once we know C
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        """
        x: [B, 1, H, W]
        returns:
          log_probs: [T, B, C]
          T_len: int (same for all in batch)
        """
        # CNN stem
        x = self.cnn_stem(x)              # [B, 3, Hc, Wc]

        # Swin features: [B, C, Hs, Ws]
        feat = self.swin(x)[0]
        B, C, Hs, Ws = feat.shape

        # Lazily create projection with correct input dim C
        if self.proj is None:
            self.proj = nn.Linear(C, self.hidden_dim).to(feat.device)

        # Flatten spatial dims: T = Hs * Ws
        feat = feat.permute(0, 2, 3, 1).contiguous()  # [B, Hs, Ws, C]
        feat = feat.view(B, Hs * Ws, C)               # [B, T, C]

        feat = self.proj(feat)                        # [B, T, hidden]
        feat = self.dropout(feat)
        logits = self.classifier(feat)                # [B, T, vocab_size]

        log_probs = F.log_softmax(logits, dim=-1)
        T_len = logits.size(1)
        return log_probs.permute(1, 0, 2), T_len      # [T, B, C], T


In [30]:
def levenshtein(a, b):
    dp = [[0] * (len(b) + 1) for _ in range(len(a) + 1)]
    for i in range(len(a) + 1):
        dp[i][0] = i
    for j in range(len(b) + 1):
        dp[0][j] = j
    for i in range(1, len(a) + 1):
        for j in range(1, len(b) + 1):
            cost = 0 if a[i - 1] == b[j - 1] else 1
            dp[i][j] = min(
                dp[i-1][j] + 1,
                dp[i][j-1] + 1,
                dp[i-1][j-1] + cost,
            )
    return dp[-1][-1]

def compute_cer(preds, gts):
    total_dist, total_len = 0, 0
    for p, g in zip(preds, gts):
        total_dist += levenshtein(p, g)
        total_len  += len(g)
    return total_dist / max(total_len, 1)

def compute_wer(preds, gts):
    total_dist, total_len = 0, 0
    for p, g in zip(preds, gts):
        p_words = p.split()
        g_words = g.split()
        total_dist += levenshtein(p_words, g_words)
        total_len  += len(g_words)
    return total_dist / max(total_len, 1)

def greedy_decode(log_probs, text_encoder: TextEncoder):
    """
    log_probs: [T, B, C]
    return: list of predicted strings
    """
    max_ids = log_probs.argmax(dim=-1).transpose(0, 1)  # [B, T]
    pred_strs = []
    for seq in max_ids:
        prev = BLANK_IDX
        ids = []
        for i in seq.tolist():
            if i != prev and i != BLANK_IDX:
                ids.append(i)
            prev = i
        pred_strs.append(text_encoder.decode(ids))
    return pred_strs


In [31]:
def evaluate(model, loader, text_encoder: TextEncoder):
    model.eval()
    total_loss = 0.0
    all_preds, all_gts = [], []

    with torch.no_grad():
        for images, targets, target_lengths, meta in loader:
            images  = images.to(device)
            targets = targets.to(device)

            log_probs, T_len = model(images)  # [T, B, C]
            input_lengths = torch.full(
                (images.size(0),), T_len, dtype=torch.long, device=device
            )

            loss = F.ctc_loss(
                log_probs, targets,
                input_lengths, target_lengths,
                blank=BLANK_IDX, zero_infinity=True
            )
            total_loss += loss.item() * images.size(0)

            pred_strs = greedy_decode(log_probs.cpu(), text_encoder)
            all_preds.extend(pred_strs)
            all_gts.extend(meta["texts"])

    avg_loss = total_loss / len(loader.dataset)
    cer = compute_cer(all_preds, all_gts)
    wer = compute_wer(all_preds, all_gts)
    return avg_loss, cer, wer


In [32]:
def make_dataloaders_for_experiment(base_dir: Path, text_encoder: TextEncoder):
    """Create train/val/test datasets + dataloaders for one experiment."""
    train_ds = HICMADataset(base_dir, "train", train_transform, text_encoder)
    val_ds   = HICMADataset(base_dir, "val",   eval_transform,  text_encoder)
    # Fixed test set from *original* HICMA (D0)
    test_ds  = HICMADataset(DATA_ROOTS["D0_preprocessed"], "test", eval_transform, text_encoder)

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=0, pin_memory=True, collate_fn=ctc_collate
    )
    val_loader = DataLoader(
        val_ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=0, pin_memory=True, collate_fn=ctc_collate
    )
    test_loader = DataLoader(
        test_ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=0, pin_memory=True, collate_fn=ctc_collate
    )
    return train_ds, val_ds, test_ds, train_loader, val_loader, test_loader


In [33]:
# Sanity check on shapes and T_len vs target lengths
base_dir = DATA_ROOTS["D0_preprocessed"]

train_ds, val_ds, test_ds, train_loader, val_loader, test_loader = \
    make_dataloaders_for_experiment(base_dir, text_encoder)

model = CNNSwinCTC(vocab_size=vocab_size).to(device)

batch = next(iter(train_loader))
images, targets, target_lengths, meta = batch
images = images.to(device)

with torch.no_grad():
    log_probs, T_len = model(images)

print("images.shape:", images.shape)               # [B, 1, H, W]
print("log_probs.shape:", log_probs.shape)         # [T, B, C]
print("T_len:", T_len)
print("target_lengths (first 8):", target_lengths[:8].tolist())
print("max target length:", target_lengths.max().item())


images.shape: torch.Size([8, 1, 256, 1600])
log_probs.shape: torch.Size([9984, 8, 71])
T_len: 9984
target_lengths (first 8): [35, 42, 20, 27, 34, 52, 41, 50]
max target length: 52


In [34]:
def train_experiment(exp_name: str, base_dir: Path):
    print(f"\n\n########## {exp_name} on {base_dir.name} ##########")

    print("  -> Building dataloaders...")
    train_ds, val_ds, test_ds, train_loader, val_loader, test_loader = \
        make_dataloaders_for_experiment(base_dir, text_encoder)
    print("  -> Dataloaders ready. Building model...")

    model = CNNSwinCTC(vocab_size=vocab_size).to(device)
    print("  -> Model ready. Starting training...")

    # Stage 1: freeze Swin
    for p in model.swin.parameters():
        p.requires_grad = False

    main_params = [p for n, p in model.named_parameters() if not n.startswith("swin.")]
    swin_params = [p for n, p in model.named_parameters() if n.startswith("swin.")]

    optimizer = torch.optim.AdamW(
        [
            {"params": main_params, "lr": LR_MAIN},
            {"params": swin_params, "lr": LR_SWIN},
        ],
        weight_decay=WEIGHT_DECAY
    )

    best_val_cer = float("inf")
    best_val_wer = None
    best_epoch   = -1
    best_state   = None

    for epoch in range(1, NUM_EPOCHS + 1):
        # Unfreeze Swin after FREEZE_EPOCHS
        if epoch == FREEZE_EPOCHS + 1:
            print(">> Unfreezing Swin backbone for fine-tuning.")
            for p in model.swin.parameters():
                p.requires_grad = True

        model.train()
        running_loss = 0.0

        print(f"\nEpoch {epoch:02d}:")
        for batch_idx, (images, targets, target_lengths, meta) in enumerate(train_loader):
            images  = images.to(device)
            targets = targets.to(device)
            target_lengths = target_lengths.to(device)

            optimizer.zero_grad()
            log_probs, T_len = model(images)
            input_lengths = torch.full(
                (images.size(0),), T_len, dtype=torch.long, device=device
            )

            loss = F.ctc_loss(
                log_probs, targets,
                input_lengths, target_lengths,
                blank=BLANK_IDX, zero_infinity=True
            )
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            running_loss += loss.item() * images.size(0)

            if batch_idx % 50 == 0:
                print(f"    batch {batch_idx}/{len(train_loader)}  loss={loss.item():.4f}")

        train_loss = running_loss / len(train_ds)
        val_loss, val_cer, val_wer = evaluate(model, val_loader, text_encoder)
        print(f"  train_loss={train_loss:.4f}  "
              f"val_loss={val_loss:.4f}  "
              f"val_CER={val_cer:.4f}  "
              f"val_WER={val_wer:.4f}")

        # Track best model by val CER (still keeps checkpoints!)
        if val_cer < best_val_cer:
            best_val_cer = val_cer
            best_val_wer = val_wer
            best_epoch   = epoch
            best_state   = model.state_dict()

            ckpt_path = CKPT_DIR / f"{exp_name}_best.pt"
            torch.save({
                "model_state": best_state,
                "epoch": best_epoch,
                "val_cer": best_val_cer,
                "val_wer": best_val_wer,
                "vocab_size": vocab_size,
                "chars": chars,
            }, ckpt_path)
            print(f"  >> New best model saved to: {ckpt_path}")

    # Load best model and evaluate on test set
    if best_state is not None:
        model.load_state_dict(best_state)
    test_loss, test_cer, test_wer = evaluate(model, test_loader, text_encoder)
    print(f"[{exp_name}] BEST epoch={best_epoch}  "
          f"val_CER={best_val_cer:.4f}  val_WER={best_val_wer:.4f}  "
          f"TEST_CER={test_cer:.4f}  TEST_WER={test_wer:.4f}")

    return {
        "exp": exp_name,
        "dataset": base_dir.name,
        "train_size": len(train_ds),
        "val_size": len(val_ds),
        "test_size": len(test_ds),
        "best_epoch": best_epoch,
        "best_val_cer": best_val_cer,
        "best_val_wer": best_val_wer,
        "test_cer": test_cer,
        "test_wer": test_wer,
    }


In [35]:
results = []

exp_order = [
    ("D0_preprocessed", DATA_ROOTS["D0_preprocessed"]),
    ("D1_augmented",    DATA_ROOTS["D1_augmented"]),
    ("D2_synth",        DATA_ROOTS["D2_synth"]),
]

for exp_name, base in exp_order:
    res = train_experiment(exp_name, base)
    results.append(res)

results_df = pd.DataFrame(results)
results_df




########## D0_preprocessed on Preprocessed_HICMA ##########
  -> Building dataloaders...
  -> Dataloaders ready. Building model...
  -> Model ready. Starting training...

Epoch 01:
    batch 0/503  loss=1857.7651
    batch 50/503  loss=1262.1812
    batch 100/503  loss=493.9257
    batch 150/503  loss=508.2015
    batch 200/503  loss=127.7903
    batch 250/503  loss=95.5062
    batch 300/503  loss=74.3637
    batch 350/503  loss=15.5091
    batch 400/503  loss=15.0952
    batch 450/503  loss=12.7883
    batch 500/503  loss=9.1789
  train_loss=274.5385  val_loss=6.3209  val_CER=0.7543  val_WER=1.0131
  >> New best model saved to: /content/drive/MyDrive/EECE693_Project/SwinCTC_Checkpoints/D0_preprocessed_best.pt

Epoch 02:
    batch 0/503  loss=11.0013
    batch 50/503  loss=7.0710
    batch 100/503  loss=6.2885
    batch 150/503  loss=26.0705
    batch 200/503  loss=6.2459
    batch 250/503  loss=4.7795
    batch 300/503  loss=6.9814
    batch 350/503  loss=20.5592
    batch 400/503  

KeyboardInterrupt: 

In [36]:
results = []

res = train_experiment("D2_synth", DATA_ROOTS["D2_synth"])
results.append(res)

results_df = pd.DataFrame(results)
results_df




########## D2_synth on HICMA_Plus_Synthetic ##########
  -> Building dataloaders...
  -> Dataloaders ready. Building model...
  -> Model ready. Starting training...

Epoch 01:
    batch 0/3415  loss=3498.1519
    batch 50/3415  loss=1176.8870
    batch 100/3415  loss=975.6541
    batch 150/3415  loss=457.6769
    batch 200/3415  loss=109.7695
    batch 250/3415  loss=160.3979
    batch 300/3415  loss=77.8753
    batch 350/3415  loss=25.0032
    batch 400/3415  loss=80.7116
    batch 450/3415  loss=29.6122
    batch 500/3415  loss=13.1184
    batch 550/3415  loss=9.3345
    batch 600/3415  loss=34.6656
    batch 650/3415  loss=25.8660
    batch 700/3415  loss=15.4306
    batch 750/3415  loss=28.2290
    batch 800/3415  loss=17.5991
    batch 850/3415  loss=18.7326
    batch 900/3415  loss=11.4186
    batch 950/3415  loss=20.0316
    batch 1000/3415  loss=11.9456
    batch 1050/3415  loss=11.2969
    batch 1100/3415  loss=5.7521
    batch 1150/3415  loss=5.4778
    batch 1200/3415  los

KeyboardInterrupt: 

In [41]:
# ============================
# 1) Paths, device, constants
# ============================
from pathlib import Path
import torch
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

ROOT      = Path("/content/drive/MyDrive/EECE693_Project")
CKPT_DIR  = ROOT / "SwinCTC_Checkpoints"
TEST_ROOT = ROOT / "HICMA_Plus_Synthetic"   # base dir for test set

print("Checkpoint dir:", CKPT_DIR)
print("Test base dir :", TEST_ROOT)

TARGET_HEIGHT      = 256
PAD_DIVISOR        = 32
MAX_WIDTH_RESIZED  = 1600

# transform for eval
resize_pad = ResizePadTo256(
    target_height=TARGET_HEIGHT,
    pad_divisor=PAD_DIVISOR,
    max_width_resized=MAX_WIDTH_RESIZED,
)
eval_transform = resize_pad


Device: cuda
Checkpoint dir: /content/drive/MyDrive/EECE693_Project/SwinCTC_Checkpoints
Test base dir : /content/drive/MyDrive/EECE693_Project/HICMA_Plus_Synthetic


In [42]:
# ==========================================
# 2) Load model + TextEncoder from checkpoint
#    (handles lazy self.proj creation)
# ==========================================

def load_model_and_encoder(ckpt_path: Path):
    print(f"\nLoading checkpoint from: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location=device)

    chars      = ckpt["chars"]
    vocab_size = ckpt["vocab_size"]
    print("  vocab_size in ckpt:", vocab_size)
    print("  num chars in ckpt :", len(chars))

    BLANK_IDX = 0
    stoi = {ch: i + 1 for i, ch in enumerate(chars)}
    itos = {i + 1: ch for i, ch in enumerate(chars)}
    text_encoder = TextEncoder(stoi, itos, blank_idx=BLANK_IDX)

    # build model
    model = CNNSwinCTC(vocab_size=vocab_size).to(device)

    # dummy forward so that self.proj is created before loading weights
    with torch.no_grad():
        dummy = torch.zeros(1, 1, TARGET_HEIGHT, 512, device=device)
        _ = model(dummy)

    model.load_state_dict(ckpt["model_state"])
    model.eval()
    return model, text_encoder


In [43]:
# ======================================
# 3) Build test loader for a given encoder
# ======================================
from torch.utils.data import DataLoader

def make_test_loader(text_encoder: TextEncoder):
    test_ds = HICMADataset(TEST_ROOT, "test", eval_transform, text_encoder)
    test_loader = DataLoader(
        test_ds,
        batch_size=8,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        collate_fn=ctc_collate,
    )
    print("  Test set size:", len(test_ds))
    return test_ds, test_loader


In [44]:
# ===============================================
# 4) Evaluate one checkpoint on the shared test set
# ===============================================

def evaluate_checkpoint_on_test(exp_name: str):
    ckpt_path = CKPT_DIR / f"{exp_name}_best.pt"
    model, text_encoder = load_model_and_encoder(ckpt_path)

    test_ds, test_loader = make_test_loader(text_encoder)
    test_loss, test_cer, test_wer = evaluate(model, test_loader, text_encoder)
    print(f"[{exp_name}] TEST_loss={test_loss:.4f}  "
          f"TEST_CER={test_cer:.4f}  TEST_WER={test_wer:.4f}")

    return {
        "exp": exp_name,
        "test_loss": test_loss,
        "test_cer": test_cer,
        "test_wer": test_wer,
        "model": model,
        "text_encoder": text_encoder,
        "test_ds": test_ds,
    }


In [45]:
# ===============================================
# 5) Run evaluation for all three checkpoints
# ===============================================

exps = ["D0_preprocessed", "D1_augmented", "D2_synth"]

results_raw = []
models_cache = {}

for exp_name in exps:
    res = evaluate_checkpoint_on_test(exp_name)
    models_cache[exp_name] = {
        "model": res.pop("model"),
        "text_encoder": res.pop("text_encoder"),
        "test_ds": res.pop("test_ds"),
    }
    results_raw.append(res)

results_test_df = pd.DataFrame(results_raw)
results_test_df



Loading checkpoint from: /content/drive/MyDrive/EECE693_Project/SwinCTC_Checkpoints/D0_preprocessed_best.pt
  vocab_size in ckpt: 71
  num chars in ckpt : 70
  Test set size: 1418
[D0_preprocessed] TEST_loss=4.7860  TEST_CER=0.7314  TEST_WER=1.2028

Loading checkpoint from: /content/drive/MyDrive/EECE693_Project/SwinCTC_Checkpoints/D1_augmented_best.pt
  vocab_size in ckpt: 71
  num chars in ckpt : 70
  Test set size: 1418
[D1_augmented] TEST_loss=5.4834  TEST_CER=0.7221  TEST_WER=1.0727

Loading checkpoint from: /content/drive/MyDrive/EECE693_Project/SwinCTC_Checkpoints/D2_synth_best.pt
  vocab_size in ckpt: 71
  num chars in ckpt : 70
  Test set size: 1418
[D2_synth] TEST_loss=2.5905  TEST_CER=0.5323  TEST_WER=0.8554


Unnamed: 0,exp,test_loss,test_cer,test_wer
0,D0_preprocessed,4.785981,0.73138,1.202804
1,D1_augmented,5.483359,0.722073,1.072711
2,D2_synth,2.590474,0.532336,0.855398
