# OCR Baseline using LPRNet (PyTorch)

This notebook implements the OCR pipeline using **LPRNet**, a lightweight Convolutional Neural Network designed for License Plate Recognition.

**Steps:**
1. **Setup**: Import libraries and Modules.
2. **Data**: Load dataset using custom `LPRDataset`.
3. **Model**: Initialize `LPRNet`.
4. **Training**: Train the model using CTC Loss.
5. **Inference**: Evaluate on validation set.

## 1. Setup

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import cv2

# Add src to path if needed
sys.path.append('..')

from src.ocr.lprnet import build_lprnet
from src.ocr.reader import LPRDataset, collate_fn
from src.ocr.decoder import LPRLabelEncoder, CHARS

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Dataset Preparation

In [None]:
DATASET_DIR = "../datasets/IndonesianLiscenePlateDataset/plate_text_dataset"
IMAGES_DIR = os.path.join(DATASET_DIR, "dataset")
LABEL_FILE = os.path.join(DATASET_DIR, "label.csv")

BATCH_SIZE = 32
IMG_SIZE = (94, 24) # LPRNet standard input size

# Initialize Dataset
dataset = LPRDataset(img_dir=IMAGES_DIR, label_file=LABEL_FILE, img_size=IMG_SIZE)

# Split Dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")

## 3. Model Initialization

In [None]:
lpr_max_len = 18 # Maximum length of license plate
class_num = len(CHARS) + 1 # +1 for blank

model = build_lprnet(lpr_max_len=lpr_max_len, class_num=class_num, dropout_rate=0.5)
model.to(device)
print(model)

## 4. Training Loop

In [None]:
criterion = nn.CTCLoss(blank=len(CHARS), reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

num_epochs = 50

def train(model, loader, optimizer, criterion, device):
    model.train()
    epoch_loss = 0
    for images, labels, lengths in loader:
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        logits = model(images) # (B, C, W)
        
        # Transform for CTC Loss: (T, N, C)
        log_probs = logits.permute(2, 0, 1) # (W, B, C)
        log_probs = log_probs.log_softmax(2).requires_grad_()
        
        # Input lengths
        input_lengths = torch.full(size=(images.size(0),), fill_value=logits.size(2), dtype=torch.long)
        
        loss = criterion(log_probs, labels, input_lengths, lengths)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(loader)

def evaluate(model, loader, device):
    model.eval()
    encoder = LPRLabelEncoder(CHARS)
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels, lengths in loader:
            images = images.to(device)
            
            logits = model(images)
            preds = encoder.decode_greedy(logits) # returns list of strings
            
            # Decode labels (labels are flattened in collate_fn, need to split)
            start_idx = 0
            label_list = labels.cpu().numpy()
            for i, length in enumerate(lengths):
                true_indices = label_list[start_idx : start_idx + length]
                true_text = "".join([CHARS[idx] for idx in true_indices])
                start_idx += length
                
                if preds[i] == true_text:
                    correct += 1
                total += 1
                
    return correct / total

# Start Training
for epoch in range(num_epochs):
    loss = train(model, train_loader, optimizer, criterion, device)
    scheduler.step()
    
    if (epoch + 1) % 5 == 0:
        acc = evaluate(model, val_loader, device)
        print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {loss:.4f} Val Acc: {acc:.4f}")
    else:
        print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {loss:.4f}")

# Save Model
os.makedirs("../models", exist_ok=True)
torch.save(model.state_dict(), "../models/lprnet_best.pth")
print("Model saved to ../models/lprnet_best.pth")

## 5. Inference Visualization

In [None]:
def show_results(model, dataset, num=5):
    model.eval()
    encoder = LPRLabelEncoder(CHARS)
    
    indices = np.random.randint(0, len(dataset), num)
    
    for idx in indices:
        img_tensor, _, _ = dataset[idx]
        input_img = img_tensor.unsqueeze(0).to(device)
        
        with torch.no_grad():
            logits = model(input_img)
            pred_text = encoder.decode_greedy(logits)[0]
            
        # De-normalize for plotting
        display_img = img_tensor.numpy().transpose(1, 2, 0)
        display_img = (display_img / 0.0078125) + 127.5
        display_img = display_img.astype(np.uint8)
        
        plt.figure()
        plt.imshow(cv2.cvtColor(display_img, cv2.COLOR_BGR2RGB))
        plt.title(f"Pred: {pred_text}")
        plt.axis('off')
        plt.show()

show_results(model, val_dataset)