In [1]:

import os
import glob
import random
import numpy as np
from PIL import Image, ImageOps, ImageDraw
from tqdm import tqdm
import editdistance
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

# ==========================================
# 2. CONFIGURATION
# ==========================================
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üî• Using Device: {DEVICE}")

# AUTO-DETECT DATASET PATH
search = glob.glob("/kaggle/input/malayalam-cleaned-dataset/CleanedDataset/rec_gt_train.txt", recursive=True)
if not search: raise FileNotFoundError("‚ùå Dataset not found!")
DATA_ROOT = os.path.dirname(search[0])
TRAIN_TXT = search[0]
val_search = glob.glob("/kaggle/input/malayalam-cleaned-dataset/CleanedDataset/rec_gt_test.txt", recursive=True)
VAL_TXT = val_search[0] if val_search else TRAIN_TXT 

CHECKPOINT_DIR = '/kaggle/working/checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# OPTIMIZED HYPERPARAMETERS
BATCH_SIZE = 256        # Saturation for T4 GPUs
NUM_EPOCHS = 40         # OneCycleLR needs defined epochs
MAX_LR = 0.003          # Aggressive start for OneCycle
IMG_H = 32

# ==========================================
# 3. ROBUST AUGMENTATION
# ==========================================
class RealWorldAugment:
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, img):
        if random.random() > self.prob: return img
        
        # 1. Elastic / Geometric Distortions
        if random.random() > 0.5:
            angle = random.uniform(-2, 2)
            shear = random.uniform(-10, 10) 
            img = img.transform(img.size, Image.AFFINE, (1, shear/100, 0, 0, 1, 0), resample=Image.BILINEAR).rotate(angle)
        
        # 2. Line Noise
        img = img.convert("RGB")
        draw = ImageDraw.Draw(img)
        w, h = img.size
        if random.random() > 0.6:
            for _ in range(random.randint(1, 3)):
                x = random.randint(0, w); y = random.randint(0, 3)
                draw.line([(x, y), (x + random.randint(5, 20), y)], fill=(255,255,255), width=1)
        
        return img.convert("L")

# ==========================================
# 4. DATASET (RAM CACHED)
# ==========================================
def build_charset():
    print("üî§ Building Charset...")
    unique_chars = set()
    with open(TRAIN_TXT, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split('\t') if '\t' in line else line.strip().split(' ', 1)
            if len(parts) >= 2: unique_chars.update(list(parts[1]))
    chars = sorted(list(unique_chars))
    itos = ['<BLANK>'] + chars
    stoi = {c: i for i, c in enumerate(itos)}
    print(f"‚úÖ Vocab Size: {len(itos)}")
    return itos, stoi

ITOS, STOI = build_charset()
NUM_CLASSES = len(ITOS)

class MalayalamDataset(Dataset):
    def __init__(self, listfile, root_dir, img_h=32, augment=False):
        self.samples = []
        self.cached_images = []
        self.img_h = img_h
        self.augment = augment
        self.augmentor = RealWorldAugment(prob=0.8)
        self.tensor_aug = T.Compose([T.ColorJitter(brightness=0.4, contrast=0.4)])

        temp_samples = []
        with open(listfile, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t') if '\t' in line else line.strip().split(' ', 1)
                if len(parts) < 2: continue
                rel_path = parts[0].strip().lstrip('./').lstrip('/')
                temp_samples.append((os.path.join(root_dir, rel_path), parts[1]))

        print(f"üöÄ Loading {len(temp_samples)} images into RAM...")
        for path, text in tqdm(temp_samples):
            if not os.path.exists(path): continue
            try:
                img = Image.open(path).convert('L')
                w, h = img.size
                new_w = max(1, int(w * (self.img_h / h)))
                img = img.resize((new_w, self.img_h), Image.BILINEAR)
                self.cached_images.append(np.array(img, dtype=np.uint8))
                self.samples.append(text)
            except: continue

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        img = Image.fromarray(self.cached_images[idx])
        text = self.samples[idx]
        
        if self.augment: img = self.augmentor(img)
        
        img_arr = np.array(img).astype(np.float32) / 255.0
        img_arr = 1.0 - img_arr
        img_t = torch.from_numpy(img_arr).unsqueeze(0)
        
        if self.augment: img_t = self.tensor_aug(img_t)
        
        label = [STOI[c] for c in text if c in STOI]
        return img_t, torch.tensor(label, dtype=torch.long), text

def pad_batch(batch):
    imgs, labels, texts = zip(*batch)
    max_w = max([img.shape[2] for img in imgs])
    padded_imgs = torch.zeros(len(imgs), 1, IMG_H, max_w)
    for i, img in enumerate(imgs): padded_imgs[i, :, :, :img.shape[2]] = img
    targets = torch.cat(labels)
    target_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)
    return padded_imgs, targets, target_lengths, texts

# ==========================================
# 5. MODEL: HIGH-RES CUSTOM CRNN
# ==========================================
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None: identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out

class CustomCRNN(nn.Module):
    def __init__(self, num_classes):
        super(CustomCRNN, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        # High-Res Configuration:
        # We only downsample width in Layer 2. Layers 3 & 4 keep width.
        self.layer1 = self._make_layer(64, 2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)     # H/2, W/2
        self.layer3 = self._make_layer(256, 2, stride=(2,1)) # H/4, W/2
        self.layer4 = self._make_layer(512, 2, stride=(2,1)) # H/8, W/2

        self.last_conv = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=2, stride=(2,1), padding=0), # H/16, W/2
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        
        # Regularized RNN Head
        self.rnn = nn.Sequential(
            # Dropout=0.5 here is the key to fixing the 91% plateau
            nn.LSTM(512, 256, bidirectional=True, batch_first=True, num_layers=2, dropout=0.5),
            nn.Linear(512, 256),
            nn.ELU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )

    def _make_layer(self, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )
        layers = []
        layers.append(ResNetBlock(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(ResNetBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.last_conv(x)
        x = x.squeeze(2).permute(0, 2, 1) 
        x, _ = self.rnn[0](x) # LSTM
        x = self.rnn[1](x)    # Linear
        x = self.rnn[2](x)    # ELU
        x = self.rnn[3](x)    # Dropout
        x = self.rnn[4](x)    # Linear
        return x 

# ==========================================
# 6. TRAINING LOOP
# ==========================================
def decode(logits):
    probs = logits.softmax(2).argmax(2).transpose(0, 1)
    results = []
    for seq in probs:
        res = []
        prev = 0
        for idx in seq:
            idx = idx.item()
            if idx != 0 and idx != prev: res.append(ITOS[idx])
            prev = idx
        results.append("".join(res))
    return results

def run_training():
    model = CustomCRNN(NUM_CLASSES).to(DEVICE)
    if torch.cuda.device_count() > 1:
        print(f"üî• Multi-GPU: {torch.cuda.device_count()} GPUs")
        model = nn.DataParallel(model)
        MULTI_GPU = True
    else: MULTI_GPU = False

    print("\nüì¶ Loading Data...")
    train_ds = MalayalamDataset(TRAIN_TXT, DATA_ROOT, img_h=IMG_H, augment=True)
    val_ds = MalayalamDataset(VAL_TXT, DATA_ROOT, img_h=IMG_H, augment=False)
    
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, collate_fn=pad_batch, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, collate_fn=pad_batch, num_workers=4)

    # AdamW + OneCycleLR = SOTA Convergence
    optimizer = optim.AdamW(model.parameters(), lr=MAX_LR, weight_decay=1e-2)
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=MAX_LR, epochs=NUM_EPOCHS, 
        steps_per_epoch=len(train_loader), pct_start=0.3, div_factor=25
    )
    
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)
    best_acc = 0.0

    print("üöÄ Starting Optimized Custom Training...")

    for epoch in range(1, NUM_EPOCHS + 1):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
        
        for imgs, targets, target_lens, texts in pbar:
            imgs, targets = imgs.to(DEVICE), targets.to(DEVICE)
            optimizer.zero_grad()
            
            logits = model(imgs) # [Batch, Time, Class]
            logits_ctc = logits.transpose(0, 1) # [Time, Batch, Class]
            
            input_lens = torch.full(size=(imgs.size(0),), fill_value=logits.size(1), dtype=torch.long).to(DEVICE)
            loss = criterion(logits_ctc.log_softmax(2), targets, input_lens, target_lens)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            optimizer.step()
            scheduler.step()
            
            pbar.set_postfix(loss=f"{loss.item():.4f}")

        # VALIDATION
        model.eval()
        val_correct = 0; count = 0; val_cer_sum = 0
        with torch.no_grad():
            for imgs, _, _, texts in val_loader:
                imgs = imgs.to(DEVICE)
                logits = model(imgs)
                preds = decode(logits.transpose(0, 1))
                
                for pred, true_text in zip(preds, texts):
                    if pred == true_text: val_correct += 1
                    dist = editdistance.eval(pred, true_text)
                    val_cer_sum += dist / max(1, len(true_text))
                    count += 1

        acc = val_correct / count
        cer = val_cer_sum / count
        lr = optimizer.param_groups[0]['lr']
        
        print(f"\nüìä Epoch {epoch}: Acc: {acc*100:.2f}% | CER: {cer:.4f} | LR: {lr:.6f}")

        if acc > best_acc:
            best_acc = acc
            state = model.module.state_dict() if MULTI_GPU else model.state_dict()
            torch.save(state, os.path.join(CHECKPOINT_DIR, 'best_custom_model.pth'))
            print("üî• New Best Model Saved!")


üî• Using Device: cuda
üî§ Building Charset...
‚úÖ Vocab Size: 96


In [2]:
# ==========================================
# ‚ö° FINE-TUNING "SQUEEZE" SCRIPT
# ==========================================
# HYPERPARAMETERS FOR FINE-TUNING
FT_EPOCHS = 15
FT_LR = 1e-5  # Very low, constant learning rate
FT_CHECKPOINT = '/kaggle/input/custom-ocr-model/pytorch/default/1/best_custom_model.pth'
FT_OUTPUT = '/kaggle/working/checkpoints/best_custom_model_finetuned.pth'

def run_finetuning():
    print(f"‚ôªÔ∏è Loading Best Model from: {FT_CHECKPOINT}")
    
    # 1. Initialize Model
    model = CustomCRNN(NUM_CLASSES).to(DEVICE)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        MULTI_GPU = True
    else: MULTI_GPU = False
        
    # 2. Load Weights
    # We need to handle DataParallel wrapping if it was saved that way
    state_dict = torch.load(FT_CHECKPOINT)
    try:
        model.load_state_dict(state_dict)
    except RuntimeError:
        # If model was saved without DataParallel but we are using it now (or vice versa)
        # We might need to add/remove "module." prefix
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k.replace("module.", "") if "module." in k else k
            new_state_dict[name] = v
        # Try loading into the unwrapped model first if using DataParallel
        if MULTI_GPU:
            model.module.load_state_dict(new_state_dict)
        else:
            model.load_state_dict(new_state_dict)
            
    print("‚úÖ Weights Loaded! Starting Fine-Tuning...")

    # 3. Data Loaders (Reuse existing ones)
    train_ds = MalayalamDataset(TRAIN_TXT, DATA_ROOT, img_h=IMG_H, augment=True)
    val_ds = MalayalamDataset(VAL_TXT, DATA_ROOT, img_h=IMG_H, augment=False)
    
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, collate_fn=pad_batch, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, collate_fn=pad_batch, num_workers=4)

    # 4. Optimizer (Constant Low LR)
    optimizer = optim.AdamW(model.parameters(), lr=FT_LR, weight_decay=1e-2)
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)
    
    best_acc = 0.9302 # Start tracking from your current best
    
    for epoch in range(1, FT_EPOCHS + 1):
        model.train()
        pbar = tqdm(train_loader, desc=f"Squeeze Epoch {epoch}")
        
        for imgs, targets, target_lens, texts in pbar:
            imgs, targets = imgs.to(DEVICE), targets.to(DEVICE)
            optimizer.zero_grad()
            
            logits = model(imgs)
            logits_ctc = logits.transpose(0, 1)
            input_lens = torch.full(size=(imgs.size(0),), fill_value=logits.size(1), dtype=torch.long).to(DEVICE)
            
            loss = criterion(logits_ctc.log_softmax(2), targets, input_lens, target_lens)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            optimizer.step()
            
            pbar.set_postfix(loss=f"{loss.item():.4f}")

        # VALIDATION
        model.eval()
        val_correct = 0; count = 0; val_cer_sum = 0
        with torch.no_grad():
            for imgs, _, _, texts in val_loader:
                imgs = imgs.to(DEVICE)
                logits = model(imgs)
                preds = decode(logits.transpose(0, 1))
                
                for pred, true_text in zip(preds, texts):
                    if pred == true_text: val_correct += 1
                    dist = editdistance.eval(pred, true_text)
                    val_cer_sum += dist / max(1, len(true_text))
                    count += 1

        acc = val_correct / count
        cer = val_cer_sum / count
        
        print(f"\nüìä Squeeze {epoch}: Acc: {acc*100:.2f}% | CER: {cer:.4f}")

        if acc > best_acc:
            best_acc = acc
            state = model.module.state_dict() if MULTI_GPU else model.state_dict()
            torch.save(state, FT_OUTPUT)
            print(f"üî• Squeezed New Best: {acc*100:.2f}% (Saved!)")

if __name__ == "__main__":
    run_finetuning()

‚ôªÔ∏è Loading Best Model from: /kaggle/input/custom-ocr-model/pytorch/default/1/best_custom_model.pth
‚úÖ Weights Loaded! Starting Fine-Tuning...
üöÄ Loading 85270 images into RAM...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85270/85270 [12:09<00:00, 116.88it/s]


üöÄ Loading 19635 images into RAM...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19635/19635 [02:48<00:00, 116.61it/s]
Squeeze Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 334/334 [06:13<00:00,  1.12s/it, loss=0.0005]



üìä Squeeze 1: Acc: 92.94% | CER: 0.0114


Squeeze Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 334/334 [06:47<00:00,  1.22s/it, loss=0.0055]



üìä Squeeze 2: Acc: 92.99% | CER: 0.0114


Squeeze Epoch 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 334/334 [06:35<00:00,  1.19s/it, loss=0.0030]



üìä Squeeze 3: Acc: 92.92% | CER: 0.0115


Squeeze Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 334/334 [06:30<00:00,  1.17s/it, loss=0.0013]



üìä Squeeze 4: Acc: 92.96% | CER: 0.0114


Squeeze Epoch 5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 334/334 [06:30<00:00,  1.17s/it, loss=0.0015]



üìä Squeeze 5: Acc: 92.96% | CER: 0.0114


Squeeze Epoch 6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 334/334 [06:30<00:00,  1.17s/it, loss=0.0001]



üìä Squeeze 6: Acc: 92.93% | CER: 0.0114


Squeeze Epoch 7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 334/334 [06:27<00:00,  1.16s/it, loss=0.0047]



üìä Squeeze 7: Acc: 93.00% | CER: 0.0113


Squeeze Epoch 8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 334/334 [06:29<00:00,  1.17s/it, loss=0.0187]



üìä Squeeze 8: Acc: 92.96% | CER: 0.0114


Squeeze Epoch 9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 334/334 [06:27<00:00,  1.16s/it, loss=0.0008]



üìä Squeeze 9: Acc: 92.97% | CER: 0.0114


Squeeze Epoch 10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 334/334 [06:41<00:00,  1.20s/it, loss=0.0003]



üìä Squeeze 10: Acc: 93.01% | CER: 0.0113


Squeeze Epoch 11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 334/334 [06:42<00:00,  1.20s/it, loss=0.0045]



üìä Squeeze 11: Acc: 92.97% | CER: 0.0114


Squeeze Epoch 12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 334/334 [06:28<00:00,  1.16s/it, loss=0.0012]



üìä Squeeze 12: Acc: 92.90% | CER: 0.0115


Squeeze Epoch 13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 334/334 [06:30<00:00,  1.17s/it, loss=0.0008]



üìä Squeeze 13: Acc: 92.98% | CER: 0.0114


Squeeze Epoch 14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 334/334 [06:30<00:00,  1.17s/it, loss=0.0017]



üìä Squeeze 14: Acc: 92.96% | CER: 0.0113


Squeeze Epoch 15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 334/334 [06:26<00:00,  1.16s/it, loss=0.0504]



üìä Squeeze 15: Acc: 92.96% | CER: 0.0113
