In [1]:
import os

# 1. Install missing library
!pip install editdistance

# 2. Mount Drive
from google.colab import drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# 3. Unzip Data (Adjust the path if your zip file is named differently)
ZIP_PATH = '/content/drive/MyDrive/CleanedDataset.zip'

if os.path.exists(ZIP_PATH):
    if not os.path.exists('/content/CleanedDataset'):
        print("üìÇ Unzipping dataset...")
        !unzip -q {ZIP_PATH} -d /content/
        print("‚úÖ Data Ready.")
else:
    print(f"‚ö†Ô∏è Zip file not found at {ZIP_PATH}. Please check the path!")

# 4. Generate Charset (The Dictionary)
print("üî§ Generating Charset...")
train_file = '/content/CleanedDataset/rec_gt_train.txt'
charset_file = '/content/CleanedDataset/charset.txt'

unique_chars = set()
if os.path.exists(train_file):
    with open(train_file, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split('\t') if '\t' in line else line.strip().split(' ', 1)
            if len(parts) >= 2:
                for char in parts[1]:
                    unique_chars.add(char)

    with open(charset_file, 'w', encoding='utf-8') as f:
        for char in sorted(list(unique_chars)):
            f.write(char + '\n')
    print(f"‚úÖ Charset created with {len(unique_chars)} characters.")
else:
    print("‚ùå Error: Training file not found.")

Mounted at /content/drive
üìÇ Unzipping dataset...
‚úÖ Data Ready.
üî§ Generating Charset...
‚úÖ Charset created with 95 characters.


In [2]:
import os
import random
import numpy as np
from PIL import Image
from tqdm import tqdm
import editdistance

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

# ==========================================
# 1. CONFIGURATION & SETUP
# ==========================================
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

# Paths
DATA_ROOT = '/content/CleanedDataset'
TRAIN_TXT = os.path.join(DATA_ROOT, 'rec_gt_train.txt')
VAL_TXT   = os.path.join(DATA_ROOT, 'rec_gt_test.txt')
CHARSET_FILE = os.path.join(DATA_ROOT, 'charset.txt')
CHECKPOINT_DIR = '/content/drive/MyDrive/pytorch_ocr_checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Hyperparameters
BATCH_SIZE = 64
NUM_EPOCHS = 50
LEARNING_RATE = 0.001
IMG_H = 32

# Load Charset
if not os.path.exists(CHARSET_FILE):
    raise FileNotFoundError("Run the Charset Generation step first!")

with open(CHARSET_FILE, 'r', encoding='utf-8') as f:
    chars = [line.strip() for line in f if line.strip()]

itos = ['<BLANK>'] + chars
stoi = {c: i for i, c in enumerate(itos)}
NUM_CLASSES = len(itos)
print(f"Vocab size: {NUM_CLASSES}")

# ==========================================
# 2. DATASET CLASS
# ==========================================
class MalayalamDataset(Dataset):
    def __init__(self, listfile, root_dir, img_h=32, augment=False):
        self.samples = []
        self.root_dir = root_dir
        self.img_h = img_h
        self.augment = augment
        self.transform = T.Compose([
            T.RandomAffine(degrees=2, translate=(0.02, 0.02), shear=2),
            T.ColorJitter(brightness=0.2, contrast=0.2)
        ])

        with open(listfile, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line: continue
                if '\t' in line: parts = line.split('\t')
                else: parts = line.split(' ', 1)
                if len(parts) < 2: continue

                img_path = parts[0].replace("./", "")
                self.samples.append((os.path.join(root_dir, img_path), parts[1]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, text = self.samples[idx]
        try:
            img = Image.open(path).convert('L')
        except:
            return self.__getitem__((idx + 1) % len(self))

        w, h = img.size
        new_w = max(1, int(w * (self.img_h / h)))
        img = img.resize((new_w, self.img_h), Image.BILINEAR)
        img_arr = np.array(img).astype(np.float32) / 255.0
        img_arr = 1.0 - img_arr
        img_t = torch.from_numpy(img_arr).unsqueeze(0)

        if self.augment:
            img_t = self.transform(img_t)

        label = [stoi[c] for c in text if c in stoi]
        return img_t, torch.tensor(label, dtype=torch.long), text

def pad_batch(batch):
    imgs, labels, texts = zip(*batch)
    widths = [img.shape[2] for img in imgs]
    max_w = max(widths)
    padded_imgs = torch.zeros(len(imgs), 1, IMG_H, max_w)
    for i, img in enumerate(imgs):
        w = img.shape[2]
        padded_imgs[i, :, :, :w] = img

    target_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)
    targets = torch.cat(labels)
    return padded_imgs, targets, target_lengths, texts

# ==========================================
# 3. RESNET MODEL ARCHITECTURE
# ==========================================
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

class ResNetCRNN(nn.Module):
    def __init__(self, num_classes):
        super(ResNetCRNN, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.layer1 = self._make_layer(64, 2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=(2,1))
        self.layer4 = self._make_layer(512, 2, stride=(2,1))

        self.last_conv = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=2, stride=(2,1), padding=0),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.rnn = nn.Sequential(
            nn.LSTM(512, 256, bidirectional=True, batch_first=True),
            nn.Linear(512, 256),
            nn.ELU(),
            nn.Linear(256, num_classes)
        )

    def _make_layer(self, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )
        layers = []
        layers.append(ResNetBlock(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(ResNetBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.last_conv(x)
        x = x.squeeze(2).permute(0, 2, 1)
        x, _ = self.rnn[0](x)
        x = self.rnn[1](x)
        x = self.rnn[2](x)
        x = self.rnn[3](x)
        return x.transpose(0, 1)

def decode(logits):
    probs = logits.softmax(2).argmax(2).transpose(0, 1)
    results = []
    for seq in probs:
        res = []
        prev = 0
        for idx in seq:
            idx = idx.item()
            if idx != 0 and idx != prev:
                res.append(itos[idx])
            prev = idx
        results.append("".join(res))
    return results

# ==========================================
# 4. TRAINING LOOP (CORRECTED)
# ==========================================


Using device: cuda
Vocab size: 96


In [None]:
def run_resnet_training():
    train_ds = MalayalamDataset(TRAIN_TXT, DATA_ROOT, img_h=IMG_H, augment=True)
    val_ds = MalayalamDataset(VAL_TXT, DATA_ROOT, img_h=IMG_H)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, collate_fn=pad_batch, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, collate_fn=pad_batch)

    print(f"Loaded {len(train_ds)} Training, {len(val_ds)} Validation")

    model = ResNetCRNN(NUM_CLASSES).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # FIX: Removed 'verbose=True' to fix the TypeError
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

    criterion = nn.CTCLoss(blank=0, zero_infinity=True)
    best_accuracy = 0.0

    print("üöÄ Starting ResNet-CRNN Training...")

    for epoch in range(1, NUM_EPOCHS + 1):
        model.train()
        loss_accum = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}")

        for imgs, targets, target_lens, texts in pbar:
            imgs, targets = imgs.to(DEVICE), targets.to(DEVICE)
            logits = model(imgs)
            input_lens = torch.full(size=(imgs.size(0),), fill_value=logits.size(0), dtype=torch.long).to(DEVICE)
            loss = criterion(logits.log_softmax(2), targets, input_lens, target_lens)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            optimizer.step()
            loss_accum += loss.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}")

        model.eval()
        val_correct = 0
        count = 0
        val_cer_sum = 0

        with torch.no_grad():
            for imgs, _, _, texts in val_loader:
                imgs = imgs.to(DEVICE)
                logits = model(imgs)
                preds = decode(logits)
                for pred, true_text in zip(preds, texts):
                    if pred == true_text: val_correct += 1
                    dist = editdistance.eval(pred, true_text)
                    val_cer_sum += dist / max(1, len(true_text))
                    count += 1

        accuracy = val_correct / count
        avg_cer = val_cer_sum / count

        # Step the scheduler
        scheduler.step(accuracy)

        # We manually print the learning rate since verbose=True is removed
        current_lr = optimizer.param_groups[0]['lr']
        print(f"\nüìä Epoch {epoch}: Val Acc: {accuracy*100:.2f}% | Val CER: {avg_cer:.4f} | LR: {current_lr:.6f}")

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, 'best_resnet_model.pth'))
            print(f"üî• Best Model Saved!")

# Execute
run_resnet_training()

In [None]:
# ==========================================
# RESUME TRAINING SCRIPT
# ==========================================
def resume_training():
    # 1. Setup Data & Model
    train_ds = MalayalamDataset(TRAIN_TXT, DATA_ROOT, img_h=IMG_H, augment=True)
    val_ds = MalayalamDataset(VAL_TXT, DATA_ROOT, img_h=IMG_H)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, collate_fn=pad_batch, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, collate_fn=pad_batch)

    print(f"Loaded {len(train_ds)} Training, {len(val_ds)} Validation")

    model = ResNetCRNN(NUM_CLASSES).to(DEVICE)

    # 2. LOAD WEIGHTS FROM DRIVE
    checkpoint_path = os.path.join(CHECKPOINT_DIR, 'best_resnet_model.pth')
    start_epoch = 1
    best_accuracy = 0.0

    if os.path.exists(checkpoint_path):
        print(f"üîÑ Found checkpoint at: {checkpoint_path}")
        print("üì• Loading weights to resume training...")
        model.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))

        # We assume we are restarting around epoch 9
        start_epoch = 9
        # Set baseline high so we don't overwrite the good model with a bad first epoch
        best_accuracy = 0.8493
        print(f"‚úÖ Model loaded! Resuming from approx Epoch {start_epoch}")
    else:
        print("‚ö†Ô∏è No checkpoint found! Starting from scratch.")

    # 3. Setup Optimizer
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    # Re-initialize scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)

    print("üöÄ Resuming ResNet-CRNN Training...")

    # Continue for remaining epochs (e.g., 50 - 9 = 41 more epochs)
    for epoch in range(start_epoch, NUM_EPOCHS + 1):
        model.train()
        loss_accum = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}")

        for imgs, targets, target_lens, texts in pbar:
            imgs, targets = imgs.to(DEVICE), targets.to(DEVICE)
            logits = model(imgs)
            input_lens = torch.full(size=(imgs.size(0),), fill_value=logits.size(0), dtype=torch.long).to(DEVICE)
            loss = criterion(logits.log_softmax(2), targets, input_lens, target_lens)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            optimizer.step()
            loss_accum += loss.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}")

        # Validation
        model.eval()
        val_correct = 0
        count = 0
        val_cer_sum = 0

        with torch.no_grad():
            for imgs, _, _, texts in val_loader:
                imgs = imgs.to(DEVICE)
                logits = model(imgs)
                preds = decode(logits)
                for pred, true_text in zip(preds, texts):
                    if pred == true_text: val_correct += 1
                    dist = editdistance.eval(pred, true_text)
                    val_cer_sum += dist / max(1, len(true_text))
                    count += 1

        accuracy = val_correct / count
        avg_cer = val_cer_sum / count

        scheduler.step(accuracy)
        current_lr = optimizer.param_groups[0]['lr']

        print(f"\nüìä Epoch {epoch}: Val Acc: {accuracy*100:.2f}% | Val CER: {avg_cer:.4f} | LR: {current_lr:.6f}")

        # Only save if we beat the previous best (0.8493)
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, 'best_resnet_model.pth'))
            print(f"üî• New Best Model Saved! ({accuracy*100:.2f}%)")

# Execute
resume_training()

Loaded 85270 Training, 19635 Validation
üîÑ Found checkpoint at: /content/drive/MyDrive/pytorch_ocr_checkpoints/best_resnet_model.pth
üì• Loading weights to resume training...
‚úÖ Model loaded! Resuming from approx Epoch 9
üöÄ Resuming ResNet-CRNN Training...


Epoch 9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:57<00:00,  2.23it/s, loss=0.0237]



üìä Epoch 9: Val Acc: 85.83% | Val CER: 0.0225 | LR: 0.001000
üî• New Best Model Saved! (85.83%)


Epoch 10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:56<00:00,  2.23it/s, loss=0.0372]



üìä Epoch 10: Val Acc: 85.83% | Val CER: 0.0229 | LR: 0.001000


Epoch 11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:56<00:00,  2.23it/s, loss=0.0611]



üìä Epoch 11: Val Acc: 84.36% | Val CER: 0.0249 | LR: 0.001000


Epoch 12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:58<00:00,  2.23it/s, loss=0.0498]



üìä Epoch 12: Val Acc: 86.03% | Val CER: 0.0226 | LR: 0.001000
üî• New Best Model Saved! (86.03%)


Epoch 13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:57<00:00,  2.23it/s, loss=0.0192]



üìä Epoch 13: Val Acc: 86.00% | Val CER: 0.0225 | LR: 0.001000


Epoch 14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:59<00:00,  2.22it/s, loss=0.0298]



üìä Epoch 14: Val Acc: 87.24% | Val CER: 0.0200 | LR: 0.001000
üî• New Best Model Saved! (87.24%)


Epoch 15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:57<00:00,  2.23it/s, loss=0.0236]



üìä Epoch 15: Val Acc: 86.12% | Val CER: 0.0219 | LR: 0.001000


Epoch 16: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:58<00:00,  2.23it/s, loss=0.0394]



üìä Epoch 16: Val Acc: 87.47% | Val CER: 0.0197 | LR: 0.001000
üî• New Best Model Saved! (87.47%)


Epoch 17: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:56<00:00,  2.24it/s, loss=0.0881]



üìä Epoch 17: Val Acc: 88.08% | Val CER: 0.0194 | LR: 0.001000
üî• New Best Model Saved! (88.08%)


Epoch 18: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:56<00:00,  2.23it/s, loss=0.0053]



üìä Epoch 18: Val Acc: 87.79% | Val CER: 0.0196 | LR: 0.001000


Epoch 19: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:56<00:00,  2.24it/s, loss=0.0798]



üìä Epoch 19: Val Acc: 89.09% | Val CER: 0.0170 | LR: 0.001000
üî• New Best Model Saved! (89.09%)


Epoch 20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:54<00:00,  2.24it/s, loss=0.0465]



üìä Epoch 20: Val Acc: 87.40% | Val CER: 0.0197 | LR: 0.001000


Epoch 21: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:56<00:00,  2.24it/s, loss=0.0485]



üìä Epoch 21: Val Acc: 89.08% | Val CER: 0.0174 | LR: 0.001000


Epoch 22: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:55<00:00,  2.24it/s, loss=0.0117]



üìä Epoch 22: Val Acc: 88.20% | Val CER: 0.0187 | LR: 0.001000


Epoch 23: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:55<00:00,  2.24it/s, loss=0.0226]



üìä Epoch 23: Val Acc: 88.36% | Val CER: 0.0183 | LR: 0.000500


Epoch 24: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:56<00:00,  2.23it/s, loss=0.0019]



üìä Epoch 24: Val Acc: 90.20% | Val CER: 0.0155 | LR: 0.000500
üî• New Best Model Saved! (90.20%)


Epoch 25: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:57<00:00,  2.23it/s, loss=0.0008]



üìä Epoch 25: Val Acc: 89.48% | Val CER: 0.0168 | LR: 0.000500


Epoch 26: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [09:56<00:00,  2.23it/s, loss=0.0008]



üìä Epoch 26: Val Acc: 90.04% | Val CER: 0.0154 | LR: 0.000500


Epoch 27:  82%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè | 1094/1333 [08:10<01:44,  2.28it/s, loss=0.0065]

In [None]:
import os
import random
import numpy as np
from PIL import Image
from tqdm import tqdm
import editdistance

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models

# ==========================================
# 1. CONFIGURATION
# ==========================================
DATA_ROOT = '/content/CleanedDataset'
TRAIN_TXT = os.path.join(DATA_ROOT, 'rec_gt_train.txt')
VAL_TXT   = os.path.join(DATA_ROOT, 'rec_gt_test.txt')
CHARSET_FILE = os.path.join(DATA_ROOT, 'charset.txt')
CHECKPOINT_DIR = '/content/drive/MyDrive/pytorch_ocr_checkpoints'
BEST_MODEL_PATH = os.path.join(CHECKPOINT_DIR, 'resnet_best.pth') # Loading the ResNet one

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 64
NUM_EPOCHS = 40
IMG_H = 32
LEARNING_RATE = 0.0003

print(f"üöÄ Resuming on Device: {DEVICE}")

# Load Charset
if os.path.exists(CHARSET_FILE):
    with open(CHARSET_FILE, 'r', encoding='utf-8') as f:
        chars = [line.strip() for line in f if line.strip()]
    itos = ['<BLANK>'] + chars
    stoi = {c: i for i, c in enumerate(itos)}
    NUM_CLASSES = len(itos)
else:
    print("‚ùå Error: Charset missing. Run Step 1!")
    NUM_CLASSES = 100

# ==========================================
# 2. RE-DEFINE CLASSES
# ==========================================
class MalayalamDataset(Dataset):
    def __init__(self, listfile, root_dir, img_h=32, augment=False):
        self.samples = []
        self.root_dir = root_dir
        self.img_h = img_h
        self.augment = augment

        # Harder Augmentation (Same as before)
        self.transform = T.Compose([
            T.RandomAffine(degrees=3, translate=(0.05, 0.05), scale=(0.9, 1.1), shear=10),
            T.GaussianBlur(kernel_size=3, sigma=(0.1, 1.5)),
            T.ColorJitter(brightness=0.4, contrast=0.4)
        ])

        if os.path.exists(listfile):
            with open(listfile, 'r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if not line: continue
                    if '\t' in line: parts = line.split('\t')
                    else: parts = line.split(' ', 1)
                    if len(parts) < 2: continue
                    path = parts[0].replace("./", "")
                    self.samples.append((os.path.join(root_dir, path), parts[1]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, text = self.samples[idx]
        try:
            img = Image.open(path).convert('L')
        except:
            return self.__getitem__((idx + 1) % len(self))

        w, h = img.size
        new_w = max(1, int(w * (self.img_h / h)))
        img = img.resize((new_w, self.img_h), Image.BILINEAR)
        img_arr = np.array(img).astype(np.float32) / 255.0
        img_arr = 1.0 - img_arr
        img_t = torch.from_numpy(img_arr).unsqueeze(0)

        if self.augment: img_t = self.transform(img_t)
        label = [stoi.get(c, 0) for c in text]
        return img_t, torch.tensor(label, dtype=torch.long), text

def pad_batch(batch):
    imgs, labels, texts = zip(*batch)
    widths = [img.shape[2] for img in imgs]
    max_w = max(widths)
    padded_imgs = torch.zeros(len(imgs), 1, IMG_H, max_w)
    for i, img in enumerate(imgs):
        w = img.shape[2]
        padded_imgs[i, :, :, :w] = img
    target_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)
    targets = torch.cat(labels)
    return padded_imgs, targets, target_lengths, texts

class CRNN_ResNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.cnn = nn.Sequential(*list(resnet.children())[:-2])
        self.rnn = nn.Sequential(
            nn.LSTM(512, 256, bidirectional=True, batch_first=True),
            nn.Linear(512, 256),
            nn.ELU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        features = self.cnn(x).squeeze(2).permute(0, 2, 1)
        output = self.rnn[0](features)[0]
        return self.rnn[3](self.rnn[2](self.rnn[1](output))).transpose(0, 1)

def decode(logits):
    probs = logits.softmax(2).argmax(2).transpose(0, 1)
    results = []
    for seq in probs:
        res = []
        prev = 0
        for idx in seq:
            idx = idx.item()
            if idx != 0 and idx != prev: res.append(itos[idx])
            prev = idx
        results.append("".join(res))
    return results

# ==========================================
# 3. RESUME LOGIC
# ==========================================
def resume_training():
    print("‚è≥ Loading Data...")
    train_ds = MalayalamDataset(TRAIN_TXT, DATA_ROOT, img_h=IMG_H, augment=True)
    val_ds = MalayalamDataset(VAL_TXT, DATA_ROOT, img_h=IMG_H)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, collate_fn=pad_batch, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, collate_fn=pad_batch)

    model = CRNN_ResNet(NUM_CLASSES).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)

    # CHECKPOINT LOADING
    best_accuracy = 0.0
    start_epoch = 1

    if os.path.exists(BEST_MODEL_PATH):
        print(f"üîÑ Found checkpoint: {BEST_MODEL_PATH}")
        try:
            model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE))
            print("‚úÖ Weights Loaded! Continuing training...")
            # Set baseline to beat (e.g. 50%) so we don't save a bad model immediately
            best_accuracy = 0.50
        except Exception as e:
            print(f"‚ö†Ô∏è Error loading weights: {e}")
            print("Starting fresh...")
    else:
        print("‚ö†Ô∏è No checkpoint found. Starting fresh.")

    print(f"üöÄ Training Loop Restarted...")

    for epoch in range(start_epoch, NUM_EPOCHS + 1):
        model.train()
        loss_accum = 0
        pbar = tqdm(train_loader, desc=f"Resume Epoch {epoch}")

        for imgs, targets, target_lens, texts in pbar:
            imgs, targets = imgs.to(DEVICE), targets.to(DEVICE)
            logits = model(imgs)
            input_lens = torch.full(size=(imgs.size(0),), fill_value=logits.size(0), dtype=torch.long).to(DEVICE)

            loss = criterion(logits.log_softmax(2), targets, input_lens, target_lens)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_accum += loss.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}")

        # VALIDATION
        model.eval()
        val_correct = 0
        count = 0
        spot_true = ""
        spot_pred = ""

        with torch.no_grad():
            for i, (imgs, _, _, texts) in enumerate(val_loader):
                imgs = imgs.to(DEVICE)
                logits = model(imgs)
                preds = decode(logits)
                for j, (pred, true_text) in enumerate(zip(preds, texts)):
                    if pred == true_text: val_correct += 1
                    count += 1
                    if random.random() < 0.05:
                        spot_true = true_text
                        spot_pred = pred

        accuracy = val_correct / count
        print(f"\nüìä Epoch {epoch} - Loss: {loss_accum/len(train_loader):.4f} | Acc: {accuracy*100:.2f}%")

        print(f"   üîç Random Sample: True: {spot_true} | Pred: {spot_pred}")

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), BEST_MODEL_PATH)
            print(f"üî• Saved New Best Model! ({accuracy*100:.2f}%)")

if __name__ == "__main__":
    resume_training()

üöÄ Resuming on Device: cuda
‚è≥ Loading Data...
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 44.7M/44.7M [00:00<00:00, 182MB/s]


üîÑ Found checkpoint: /content/drive/MyDrive/pytorch_ocr_checkpoints/resnet_best.pth
‚úÖ Weights Loaded! Continuing training...
üöÄ Training Loop Restarted...


Resume Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [02:50<00:00,  7.83it/s, loss=0.2368]



üìä Epoch 1 - Loss: 0.2730 | Acc: 34.66%
   üîç Random Sample: True: ‡¥Æ‡¥≤‡¥ô‡µç‡¥ï‡¥∞ | Pred: ‡¥Æ‡¥∂‡µç‡¥¶‡¥∞


Resume Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [02:45<00:00,  8.06it/s, loss=0.0849]



üìä Epoch 2 - Loss: 0.2436 | Acc: 35.23%
   üîç Random Sample: True: ‡¥µ‡¥æ‡¥≤‡µç‡¥Ø‡µÅ‡¥µ‡¥ø‡¥®‡µç‡¥±‡µÜ | Pred: ‡¥µ‡¥≤‡µç‡¥™‡µÅ‡¥ø‡¥®‡µç‡¥±‡µÜ


Resume Epoch 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [02:44<00:00,  8.09it/s, loss=0.1225]



üìä Epoch 3 - Loss: 0.2211 | Acc: 37.45%
   üîç Random Sample: True: ‡¥∏‡¥æ‡¥ô‡µç‡¥ï‡µá‡¥§‡¥ø‡¥ï‡¥µ‡¥ø‡¥¶‡µç‡¥Ø | Pred: ‡¥∏‡¥æ‡¥ô‡µç‡¥ï‡µá‡¥§‡¥ø‡¥µ‡¥ø‡¥¶‡µç‡¥Ø


Resume Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [02:44<00:00,  8.10it/s, loss=0.3013]



üìä Epoch 4 - Loss: 0.1995 | Acc: 40.01%
   üîç Random Sample: True: ‡¥™‡µÅ‡¥∞‡µã‡¥ó‡¥Æ‡¥ø‡¥ï‡µç‡¥ï‡¥µ‡µÜ | Pred: ‡¥™‡µÅ‡¥∞‡µã‡¥ß‡¥ø‡¥ï‡µç‡¥ï‡¥µ‡µÜ


Resume Epoch 5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [02:42<00:00,  8.23it/s, loss=0.2663]



üìä Epoch 5 - Loss: 0.1850 | Acc: 42.61%
   üîç Random Sample: True: ‡¥â‡¥¶‡¥æ‡¥π‡¥∞‡¥£‡¥Æ‡¥æ‡¥ï‡µç‡¥ï‡¥ø | Pred: ‡¥â‡¥π‡¥¶‡¥æ‡¥π‡¥∞‡¥£‡¥Æ‡¥æ‡¥ï‡¥ø


Resume Epoch 6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [02:44<00:00,  8.10it/s, loss=0.1337]



üìä Epoch 6 - Loss: 0.1726 | Acc: 42.44%
   üîç Random Sample: True: ‡¥∏‡¥Ç‡¥∞‡¥ï‡µç‡¥∑‡¥£ | Pred: ‡¥∏‡¥Ç‡¥∞‡¥ï‡µç‡¥∑‡¥£


Resume Epoch 7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1333/1333 [02:47<00:00,  7.98it/s, loss=0.4351]



üìä Epoch 7 - Loss: 0.1611 | Acc: 42.61%
   üîç Random Sample: True: ‡¥™‡¥ø‡¥±‡¥µ‡¥Ç | Pred: ‡¥™‡¥ø‡¥±‡¥µ‡¥Ç


Resume Epoch 8:  40%|‚ñà‚ñà‚ñà‚ñâ      | 531/1333 [01:06<01:29,  8.98it/s, loss=0.1810]