# OCR Baseline using LPRNet (PyTorch)

This notebook implements the OCR pipeline using **LPRNet**, a lightweight Convolutional Neural Network designed for License Plate Recognition.

**Optimized for A100 GPU:**
- **Batch Size**: 1024
- **Mixed Precision**: Enabled (AMP)
- **Workers**: 8
- **Progress Tracking**: TQDM enabled
- **Metrics**: Sequence Acc, Char Acc, Edit Distance

**Steps:**
1. **Setup**: Import libraries and Modules.
2. **Data**: Load dataset using custom `LPRDataset`.
3. **Model**: Initialize `LPRNet`.
4. **Training**: Train the model using CTC Loss with AMP.
5. **Inference**: Evaluate on validation set.

## 1. Setup

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler # Mixed Precision
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm

# Add src to path if needed
sys.path.append('..')

from src.ocr.lprnet import build_lprnet
from src.ocr.reader import LPRDataset, collate_fn
from src.ocr.decoder import LPRLabelEncoder, CHARS
from src.evaluation.metrics import compute_ocr_metrics

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Check CUDA Capability
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Dataset Preparation

In [None]:
DATASET_DIR = "../datasets/plate_text_cropped"
IMAGES_DIR = os.path.join(DATASET_DIR, "dataset")
LABEL_FILE = os.path.join(DATASET_DIR, "label.csv")

# HPC Optimization
BATCH_SIZE = 1024 # Increased for A100
NUM_WORKERS = 8   # Parallel data loading
PIN_MEMORY = True # Speed up host-to-device transfer

IMG_SIZE = (94, 24) # LPRNet standard input size

# Initialize Dataset
dataset = LPRDataset(img_dir=IMAGES_DIR, label_file=LABEL_FILE, img_size=IMG_SIZE)

# Split Dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                          collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                        collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")

## 3. Model Initialization

In [None]:
lpr_max_len = 18 # Maximum length of license plate
class_num = len(CHARS) + 1 # +1 for blank

model = build_lprnet(lpr_max_len=lpr_max_len, class_num=class_num, dropout_rate=0.5)
model.to(device)
# print(model) # Commented out to reduce log clutter

## 4. Training Loop (with AMP & TQDM)

In [None]:
criterion = nn.CTCLoss(blank=len(CHARS), reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
scaler = GradScaler() # For Mixed Precision

num_epochs = 100

def train(model, loader, optimizer, criterion, device, scaler, epoch):
    model.train()
    epoch_loss = 0
    pbar = tqdm(loader, desc=f"Train Epoch {epoch+1}", leave=False)
    
    for images, labels, lengths in pbar:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        optimizer.zero_grad()
        
        with autocast():
            logits = model(images)
            log_probs = logits.permute(2, 0, 1)
            log_probs = log_probs.log_softmax(2)
            input_lengths = torch.full(size=(images.size(0),), fill_value=logits.size(2), dtype=torch.long)
            loss = criterion(log_probs, labels, input_lengths, lengths)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        epoch_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
    return epoch_loss / len(loader)

def evaluate(model, loader, device):
    model.eval()
    encoder = LPRLabelEncoder(CHARS)
    
    all_preds = []
    all_targets = []
    
    pbar = tqdm(loader, desc="Evaluating", leave=False)
    
    with torch.no_grad():
        for images, labels, lengths in pbar:
            images = images.to(device, non_blocking=True)
            
            with autocast():
                logits = model(images)
                
            preds = encoder.decode_greedy(logits)
            
            start_idx = 0
            label_list = labels.cpu().numpy()
            for i, length in enumerate(lengths):
                true_indices = label_list[start_idx : start_idx + length]
                true_text = "".join([CHARS[idx] for idx in true_indices])
                start_idx += length
                
                all_preds.append(preds[i])
                all_targets.append(true_text)
                
    metrics = compute_ocr_metrics(all_preds, all_targets)
    return metrics

# Start Training
print("Starting training on A100...")
for epoch in range(num_epochs):
    loss = train(model, train_loader, optimizer, criterion, device, scaler, epoch)
    scheduler.step()
    
    if (epoch + 1) % 5 == 0:
        metrics = evaluate(model, val_loader, device)
        print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {loss:.4f} | "
              f"Seq Acc: {metrics['seq_acc']:.4f} | Char Acc: {metrics['char_acc']:.4f} | "
              f"Edit Dist: {metrics['avg_edit_dist']:.4f}")
    else:
        print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {loss:.4f}")

# Save Model
os.makedirs("../models", exist_ok=True)
torch.save(model.state_dict(), "../models/lprnet_a100_best.pth")
print("Model saved to ../models/lprnet_a100_best.pth")

## 5. Inference Visualization

In [None]:
def show_results(model, dataset, num=5):
    model.eval()
    encoder = LPRLabelEncoder(CHARS)
    
    indices = np.random.randint(0, len(dataset), num)
    
    for idx in indices:
        img_tensor, _, _ = dataset[idx]
        input_img = img_tensor.unsqueeze(0).to(device)
        
        with torch.no_grad():
            logits = model(input_img)
            pred_text = encoder.decode_greedy(logits)[0]
            
        # De-normalize for plotting
        display_img = img_tensor.numpy().transpose(1, 2, 0)
        display_img = (display_img / 0.0078125) + 127.5
        display_img = display_img.astype(np.uint8)
        
        plt.figure()
        plt.imshow(cv2.cvtColor(display_img, cv2.COLOR_BGR2RGB))
        plt.title(f"Pred: {pred_text}")
        plt.axis('off')
        plt.show()

show_results(model, val_dataset)