# 🚀 CRNN OCR - Algerian License Plates

**Setup:** Enable GPU (T4) + Add your datasets + Update paths in Cell 2

In [None]:
!pip install -q albumentations

In [None]:
import os, json, random, numpy as np
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import cv2, albumentations as A

random.seed(42); np.random.seed(42); torch.manual_seed(42)

# ⚠️ UPDATE THESE PATHS!
SYNTHETIC_DIR = '/kaggle/input/YOUR-SYNTHETIC-DATASET'
REAL_DIR = '/kaggle/input/YOUR-REAL-DATASET'
OUTPUT_DIR = '/kaggle/working/checkpoints'
os.makedirs(OUTPUT_DIR, exist_ok=True)

IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES = 64, 200, 11
BATCH_SIZE, NUM_EPOCHS, LR = 64, 100, 0.001
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Data loading
def scan_plates(directory):
    images, labels = [], []
    for p in Path(directory).rglob('*.jpg'):
        label = p.stem.split('_')[0].zfill(11)
        if 10 <= len(label) <= 11 and label.isdigit():
            images.append(str(p)); labels.append(label)
    return images, labels

def load_data():
    synth_i, synth_l = scan_plates(SYNTHETIC_DIR)
    real_i, real_l = scan_plates(REAL_DIR)
    print(f"Synthetic: {len(synth_i):,}, Real: {len(real_i):,}")
    
    needed = min(int(len(real_i) * 0.8 / 0.2), len(synth_i))
    if needed < len(synth_i):
        idx = random.sample(range(len(synth_i)), needed)
        synth_i = [synth_i[i] for i in idx]
        synth_l = [synth_l[i] for i in idx]
    
    all_i = synth_i + real_i
    all_l = synth_l + real_l
    all_s = ['synthetic']*len(synth_i) + ['real']*len(real_i)
    
    combined = list(zip(all_i, all_l, all_s))
    random.shuffle(combined)
    all_i, all_l, all_s = zip(*combined)
    
    t_end = int(len(all_i) * 0.7)
    v_end = t_end + int(len(all_i) * 0.15)
    
    return {
        'train': (list(all_i[:t_end]), list(all_l[:t_end]), list(all_s[:t_end])),
        'val': (list(all_i[t_end:v_end]), list(all_l[t_end:v_end]), list(all_s[t_end:v_end]))
    }

# Augmentation
heavy_aug = A.Compose([
    A.Rotate(limit=15, p=0.7), A.Perspective(scale=(0.02, 0.05), p=0.5),
    A.RandomBrightnessContrast(0.3, 0.3, p=0.8),
    A.OneOf([A.GaussianBlur((3,7), p=1), A.MotionBlur(5, p=1)], p=0.6),
    A.GaussNoise(p=0.5), A.Resize(IMG_HEIGHT, IMG_WIDTH)
])

medium_aug = A.Compose([
    A.Rotate(limit=10, p=0.5), A.RandomBrightnessContrast(0.2, 0.2, p=0.6),
    A.Resize(IMG_HEIGHT, IMG_WIDTH)
])

no_aug = A.Compose([A.Resize(IMG_HEIGHT, IMG_WIDTH)])

print("✅ Data functions ready")

In [None]:
# Dataset
class LPDataset(Dataset):
    def __init__(self, paths, labels, sources, train=True):
        self.paths, self.labels, self.sources, self.train = paths, labels, sources, train
        self.c2i = {str(i): i for i in range(10)}
    
    def __len__(self): return len(self.paths)
    
    def __getitem__(self, i):
        img = cv2.imread(self.paths[i])
        if img is None: img = np.zeros((IMG_HEIGHT, IMG_WIDTH, 3), dtype=np.uint8)
        else: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.train:
            img = (heavy_aug if self.sources[i]=='real' else medium_aug)(image=img)['image']
        else:
            img = no_aug(image=img)['image']
        
        img = torch.from_numpy(img.astype(np.float32)/255.0).permute(2,0,1)
        label = [self.c2i[c] for c in self.labels[i]]
        return img, torch.LongTensor(label), len(label)

def collate_fn(batch):
    imgs, labs, lens = zip(*batch)
    return torch.stack(imgs), torch.cat(labs), torch.LongTensor(lens)

print("✅ Dataset ready")

In [None]:
# CRNN Model
class CRNN(nn.Module):
    def __init__(self, h=64, w=200, nc=11, hs=256):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3,64,3,1,1), nn.BatchNorm2d(64), nn.ReLU(True), nn.MaxPool2d(2,2),
            nn.Conv2d(64,128,3,1,1), nn.BatchNorm2d(128), nn.ReLU(True), nn.MaxPool2d(2,2),
            nn.Conv2d(128,256,3,1,1), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.Conv2d(256,256,3,1,1), nn.BatchNorm2d(256), nn.ReLU(True), nn.MaxPool2d((2,1),(2,1)),
            nn.Conv2d(256,512,3,1,1), nn.BatchNorm2d(512), nn.ReLU(True),
            nn.Conv2d(512,512,3,1,1), nn.BatchNorm2d(512), nn.ReLU(True), nn.MaxPool2d((2,1),(2,1))
        )
        self.rnn = nn.LSTM(512*(h//16), hs, 2, True, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hs*2, nc)
    
    def forward(self, x):
        x = self.cnn(x)
        b,c,h,w = x.size()
        x = x.permute(0,3,1,2).reshape(b,w,c*h)
        x, _ = self.rnn(x)
        x = self.fc(x).permute(1,0,2)
        return F.log_softmax(x, 2)

print("✅ Model ready")

In [None]:
# Training functions
def train_epoch(model, loader, crit, opt, dev):
    model.train()
    total = 0
    for imgs, labs, lens in tqdm(loader, desc='Train'):
        imgs, labs = imgs.to(dev), labs.to(dev)
        out = model(imgs)
        T, B = out.size(0), out.size(1)
        loss = crit(out, labs, torch.full((B,), T, dtype=torch.long), lens)
        opt.zero_grad(); loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        opt.step()
        total += loss.item()
    return total / len(loader)

def validate(model, loader, crit, dev):
    model.eval()
    total, correct, count = 0, 0, 0
    with torch.no_grad():
        for imgs, labs, lens in tqdm(loader, desc='Val'):
            imgs, labs = imgs.to(dev), labs.to(dev)
            out = model(imgs)
            T, B = out.size(0), out.size(1)
            loss = crit(out, labs, torch.full((B,), T, dtype=torch.long), lens)
            total += loss.item()
            
            _, preds = out.max(2)
            preds = preds.T
            offset = 0
            for i, l in enumerate(lens):
                pred = preds[i].cpu().numpy()
                tgt = labs[offset:offset+l].cpu().numpy()
                offset += l
                
                dec = []
                prev = -1
                for p in pred:
                    if p!=10 and p!=prev: dec.append(p)
                    prev = p
                dec = dec[:len(tgt)]
                
                if len(dec)==len(tgt) and all(p==t for p,t in zip(dec,tgt)):
                    correct += 1
                count += 1
    
    return total/len(loader), correct/count

print("✅ Training functions ready")

In [None]:
# Main training
print("="*60)
print("🚀 TRAINING START")
print("="*60)

splits = load_data()
train_ds = LPDataset(*splits['train'], True)
val_ds = LPDataset(*splits['val'], False)

train_ld = DataLoader(train_ds, BATCH_SIZE, True, num_workers=2, collate_fn=collate_fn, pin_memory=True)
val_ld = DataLoader(val_ds, BATCH_SIZE, False, num_workers=2, collate_fn=collate_fn, pin_memory=True)

print(f"Train: {len(train_ds)}, Val: {len(val_ds)}")

model = CRNN().to(DEVICE)
print(f"Params: {sum(p.numel() for p in model.parameters()):,}")

crit = nn.CTCLoss(blank=10, zero_infinity=True)
opt = optim.Adam(model.parameters(), lr=LR)
sch = optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', 0.5, 5, verbose=True)

best_acc = 0
history = {'train_loss': [], 'val_loss': [], 'val_acc': []}

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    
    train_loss = train_epoch(model, train_ld, crit, opt, DEVICE)
    val_loss, val_acc = validate(model, val_ld, crit, DEVICE)
    sch.step(val_loss)
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/best_model.pth")
        print(f"✅ Best: {best_acc*100:.2f}%")

print(f"\n🎉 Done! Best: {best_acc*100:.2f}%")

In [None]:
# Plot results
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(history['train_loss'], label='Train')
plt.plot(history['val_loss'], label='Val')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.title('Loss')

plt.subplot(1,2,2)
plt.plot([x*100 for x in history['val_acc']])
plt.xlabel('Epoch'); plt.ylabel('Acc (%)'); plt.title('Validation Accuracy')

plt.tight_layout()
plt.savefig('/kaggle/working/curves.png')
plt.show()

print(f"Best: {max(history['val_acc'])*100:.2f}%")