# Imports

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import torch.nn as nn
import torch.optim as optim
import re
import shutil

!cp -r /content/drive/MyDrive/ccpd_green/ /content/ 
!cp -r /content/drive/MyDrive/ccpd_green/test /content/

# Globals

In [None]:
# GLOBALS FOR THE TEXT OF THE CAR PLATE

provinces = ["皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑", "苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤", "桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁", "新", "警", "学", "O"]
alphabets = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'O']
ads = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'O']

# PREPARING THE ENCONDING AND DECODING

seen = set()    
MY_DICTIONARY = []
for char_list in [provinces, alphabets, ads]:
    for char in char_list:
        if char not in seen:
            MY_DICTIONARY.append(char)
            seen.add(char)

# ENCODING AND DECODING            

char2idx = {c: i for i, c in enumerate(MY_DICTIONARY)}
idx2char = {i: c for i, c in enumerate(MY_DICTIONARY)}
BLANK_IDX = len(MY_DICTIONARY) 

# TRANSFORMATIONS

transform = transforms.Compose([
    transforms.Resize((64, 256)), # Aspect ratio più realistico per targhe
    transforms.ToTensor(),
    transforms.Grayscale(num_output_channels=1),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Utils

In [None]:
# CHECKING DATASET STRUCTURE

def validate_dictionary():
    print("=== validating dict ===")
    print(f"Provinces: {len(provinces)} chars")
    print(f"Alphabets: {len(alphabets)} chars")
    print(f"Ads: {len(ads)} chars")
    print(f"Final dict: {len(MY_DICTIONARY)} unique chars")

    all_chars = set(provinces + alphabets + ads)
    dict_chars = set(MY_DICTIONARY)

    if all_chars == dict_chars:
        print("all chars are included in the dict")
    else:
        missing = all_chars - dict_chars
        extra = dict_chars - all_chars
        if missing:
            print(f"missing chars: {missing}")
        if extra:
            print(f"there's extra chars!: {extra}")

    test_filename = "025-95_113-154&383_386&473-386&473_177&454_154&383_363&402-0_0_22_27_27_33_16-37-15.jpg"
    print(f"\n=== Test Parsing ===")
    print(f"Test filename: {test_filename}")

    fields = test_filename.split('-')
    indices = fields[4].split("_")
    test_plate = provinces[int(indices[0])] + alphabets[int(indices[1])] + "".join([ads[int(i)] for i in indices[2:]])
    print(f"parsed plate: '{test_plate}'")

    missing_chars = [c for c in test_plate if c not in char2idx]
    if missing_chars:
        print(f"missing chars in dict: {missing_chars}")
    else:
        print("the dictionary is ok!")

validate_dictionary()

# Data

In [None]:
# DATASET CLASS

class CarPlateDataset(Dataset):

    def __init__(self, img_dir, transform=None, cropped = False):
        self.img_dir = img_dir
        self.transform = transform
        self.image_names = os.listdir(img_dir)
        self.cropped = cropped


    def __len__(self):
        return len(self.image_names)

    def parse_filename(self, filename):
        fields = filename.split('-')
        area = float(fields[0]) / 100  #filename encodes the area in percentage (ratio plate-no plate area), so divising by 100 gives me a 0-1 range
        tilt_degree = fields[1].split('_')
        h_tilt = int(tilt_degree[0])    #horizontal tilt degree
        v_tilt = int(tilt_degree[1])    #vertical tilt degree
        tilt_list = np.array([h_tilt, v_tilt], dtype=np.float32)


        bbox_coords = fields[2].split('_')  #bounding box coordinates
        leftUp_bbox = bbox_coords[0].split('&')
        leftUp_bbox_x = int(leftUp_bbox[0])
        leftUp_bbox_y = int(leftUp_bbox[1])
        rightBottom_bbox = bbox_coords[1].split('&')
        rightDown_bbox_x = int(rightBottom_bbox[0])
        rightDown_bbox_y = int(rightBottom_bbox[1])
        bbox_coords_list = np.array([(leftUp_bbox_x, leftUp_bbox_y),
                                    (rightDown_bbox_x, rightDown_bbox_y)], dtype=np.float32)

        vertices = fields[3].split('_')  #vertices of the plate
        right_bottom_vertex = vertices[0].split('&')
        right_bottom_vertex_x = int(right_bottom_vertex[0])
        right_bottom_vertex_y = int(right_bottom_vertex[1])
        left_bottom_vertex = vertices[1].split('&')
        left_bottom_vertex_x = int(left_bottom_vertex[0])
        left_bottom_vertex_y = int(left_bottom_vertex[1])
        left_up_vertex = vertices[2].split('&')
        left_up_vertex_x = int(left_up_vertex[0])
        left_up_vertex_y = int(left_up_vertex[1])
        right_up_vertex = vertices[3].split('&')
        right_up_vertex_x = int(right_up_vertex[0])
        right_up_vertex_y = int(right_up_vertex[1])

        vertices_list = np.array([(left_bottom_vertex_x, left_bottom_vertex_y),
                                (right_bottom_vertex_x, right_bottom_vertex_y),
                                (right_up_vertex_x, right_up_vertex_y),
                                (left_up_vertex_x, left_up_vertex_y)], dtype=np.float32)

        text=str(fields[4])
        indices=text.split("_")
        province_character=provinces[int(indices[0])]
        alphabet_character=alphabets[int(indices[1])]
        ads_charachters=[ads[int(i)] for i in indices[2:]]
        plate_text=province_character+alphabet_character+"".join(ads_charachters)

        brightness = int(fields[5])
        blurriness_str = fields[6].replace('.jpg', '')
        match = re.match(r'\d+', blurriness_str)
        if match:
            blurriness = int(match.group())
        else:
            print(f"[WARNING] File '{filename}': blurriness non standard '{fields[6]}', imposto a 0.")
            blurriness = 0

        # Convert license plate text to indices for CTC training
        lp_indexes = []
        for c in plate_text:
            if c in char2idx:
                lp_indexes.append(char2idx[c])
            else:
                print(f"[WARNING] Carattere non riconosciuto '{c}' in '{plate_text}'")

        return {
            'area': area,
            'tilt': tilt_list,
            'bbox_coords': bbox_coords_list,
            'vertices': vertices_list,
            'lp': plate_text,
            'lp_indexes': lp_indexes,
            'brightness': brightness,
            'blurriness': blurriness,
        }

    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path)
        metadata = self.parse_filename(img_name)
        if self.cropped:    #I use this dataset for both baselines, so I check if I need to skip detection part and use dataset bbox.
            #I can use the crop method of PIL, that crops the image using coords in this way: (left, upper, right, lower)
            '''
            left is the x-coordinate of the left edge.

            upper is the y-coordinate of the top edge.

            right is the x-coordinate of the right edge.

            lower is the y-coordinate of the bottom edge.
            seen on the online odcs of pillow
            '''
            bbox_coords = metadata['bbox_coords']

            left = int(bbox_coords[0][0])   # x-coordinate of the left edge
            upper = int(bbox_coords[0][1])  # y-coordinate of the top edge
            right = int(bbox_coords[1][0])  # x-coordinate of the right edge
            lower = int(bbox_coords[1][1])  # y-coordinate of the bottom edge

            image = image.crop((left, upper, right, lower))

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(metadata['lp_indexes'], dtype=torch.long)  # Return the image and the license plate indexes as a tensor, for the CNN to elaborate

# COLLATE FUNCTION

def ctc_collate_fn(batch):
    '''
    basically what I do here is stacking all the images in a batch into a single tensor and
    then computing the len of each label (assuming different lenght plate can happen). (I could actually avoid this but it's more general)
    Finally just concatenating all the labels into a vector (pytorch CTC wantres them in a line, not list)
    then returning image-label-its lenght.
    I need this to tell CTC where labels finish and i do not care padding as CTC deals with that internally (NICE)
    '''
    images, labels = zip(*batch)
    images = torch.stack(images)
    label_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)
    labels = torch.cat(labels)
    return images, labels, label_lengths

In [None]:
# CHOICE OF THE DATASET

dataset_train = CarPlateDataset(img_dir='/content/ccpd_green/train', transform=transform, cropped=True)
dataset_eval = CarPlateDataset(img_dir='/content/ccpd_green/val', transform=transform, cropped=True)
dataset_test = CarPlateDataset(img_dir='/content/ccpd_green/test', transform=transform, cropped=True)

dataloader_train = DataLoader(dataset_train, batch_size=32, shuffle=True, collate_fn=ctc_collate_fn)
dataloader_eval = DataLoader(dataset_eval, batch_size=32, shuffle=False, collate_fn=ctc_collate_fn)
dataloader_test = DataLoader(dataset_test, batch_size=32, shuffle=False, collate_fn=ctc_collate_fn)

print(f"Train set size: {len(dataset_train)}")
print(f"Eval set size: {len(dataset_eval)}")
print(f"Test set size: {len(dataset_test)}")

# Network

In [None]:
# MODEL

class CRNN(nn.Module):
    def __init__(self, num_classes, input_channels=1):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(input_channels, 64, 3, 1, 1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, None))  # (height=1, width stays)
        )
        self.rnn = nn.LSTM(256, 128, num_layers=2, bidirectional=True)
        self.fc = nn.Linear(128*2, num_classes)  # bidirectional

    def forward(self, x):
        x = self.cnn(x)  # [B, C, 1, W]
        x = x.squeeze(2) # [B, C, W]
        x = x.permute(2, 0, 1)  # [W, B, C]
        x, _ = self.rnn(x)
        x = self.fc(x)  # [W, B, num_classes]
        return x  # output for CTC: [seq_len, batch, num_classes]

# DECODER FUNCTION

def ctc_greedy_decoder(output, idx2char, blank=BLANK_IDX):
    '''
    Now, I know the network returns probabilities, as it does a softmax with logits of characters.
    I need to transform that probability into an actual char to compose the plate.
    I take the argmax of the softmax (most prob char), remove blanks used by CTC and possible
    duplicates CTC can actually produce.
    At the end I simply use the  mappings char-index index-char deefined at the beginning to compose the plate.
    This is greedy as it just takes the argmax of every step, I think it's more than enough here.
    '''
    # output: [seq_len, batch, num_classes]
    out = output.permute(1, 0, 2)  # [batch, seq_len, num_classes]
    pred_strings = []
    for probs in out:
        pred = probs.argmax(1).cpu().numpy()
        prev = -1
        pred_str = []
        for p in pred:
            if p != blank and p != prev:
                pred_str.append(idx2char[p])
            prev = p
        pred_strings.append(''.join(pred_str))
    return pred_strings

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for images, labels, label_lengths in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        label_lengths = label_lengths.to(device)
        optimizer.zero_grad()
        outputs = model(images)  # [W, B, num_classes]
        log_probs = outputs.log_softmax(2)
        input_lengths = torch.full(size=(images.size(0),), fill_value=log_probs.size(0), dtype=torch.long).to(device)
        loss = criterion(log_probs, labels, input_lengths, label_lengths)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch average loss: {avg_loss:.4f}")
    return avg_loss

# EVLUATE FUNCTION

def evaluate(model, dataloader, device, verbose=False):
    model.eval()
    total = 0
    correct = 0
    total_chars = 0
    correct_chars = 0

    # Metriche aggiuntive per analisi dettagliata
    length_errors = 0
    province_correct = 0
    alphabet_correct = 0

    with torch.no_grad():
        for batch_idx, (images, labels, label_lengths) in enumerate(dataloader):
            images = images.to(device)
            outputs = model(images)
            pred_strings = ctc_greedy_decoder(outputs, idx2char)
            labels_cpu = labels.cpu().numpy()   #need to move tensors tu cpu memory for numpy
            lengths_cpu = label_lengths.cpu().numpy()
            idx = 0
            gt_strings = []
            for l in lengths_cpu:
                gt = ''.join([idx2char[i] for i in labels_cpu[idx:idx+l]])
                gt_strings.append(gt)
                idx += l

            for pred, gt in zip(pred_strings, gt_strings):
                # Accuracy completa
                if pred == gt:
                    correct += 1
                total += 1

                # Accuracy per carattere
                min_len = min(len(pred), len(gt))
                correct_chars += sum([p == g for p, g in zip(pred[:min_len], gt[:min_len])])
                total_chars += len(gt)

                # Metriche aggiuntive
                if len(pred) != len(gt):
                    length_errors += 1

                # Accuracy per provincia (primo carattere)
                if len(pred) > 0 and len(gt) > 0 and pred[0] == gt[0]:
                    province_correct += 1

                # Accuracy per alfabeto (secondo carattere)
                if len(pred) > 1 and len(gt) > 1 and pred[1] == gt[1]:
                    alphabet_correct += 1

                # Stampa esempi di errore se richiesto
                if verbose and pred != gt and batch_idx == 0:
                    print(f"Pred: '{pred}' | GT: '{gt}'")

    acc = correct / total if total > 0 else 0
    acc_char = correct_chars / total_chars if total_chars > 0 else 0
    length_error_rate = length_errors / total if total > 0 else 0
    province_acc = province_correct / total if total > 0 else 0
    alphabet_acc = alphabet_correct / total if total > 0 else 0

    print(f"Eval accuracy (full plate): {acc:.4f} | Char accuracy: {acc_char:.4f}")
    print(f"Length error rate: {length_error_rate:.4f} | Province acc: {province_acc:.4f} | Alphabet acc: {alphabet_acc:.4f}")

    return acc, acc_char

# Train

In [None]:
# TRAINING SET UP

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = CRNN(num_classes=len(MY_DICTIONARY)+1).to(device)  # +1 per blank
print(f"Model created with {len(MY_DICTIONARY)+1} output classes")

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)  
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)
ctc_loss = nn.CTCLoss(blank=BLANK_IDX, zero_infinity=True)

In [None]:
# TRAINING PARAMETERS

NUM_EPOCHS = 80 # I did a run on 400 epochs tracked with wandb (see metrics pictures). I found out best epoch is around 75, after that the loss stayies steady, along with acc.
best_val_acc = 0.0
best_epoch = 0
patience_counter = 0
early_stopping_patience = 15  # Stop training if no improvement for 15 epochs

print("Starting training...")
print(f"Training for {NUM_EPOCHS} epochs with early stopping (patience={early_stopping_patience})")

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")

    # Training
    train_loss = train_one_epoch(model, dataloader_train, optimizer, ctc_loss, device)

    # Validation
    val_acc, val_acc_char = evaluate(model, dataloader_eval, device, verbose=(epoch % 10 == 0)) # basically every 10 epochs i print some examples of bad predictions (see val function)

    # Learning rate scheduling
    scheduler.step(val_acc)
    current_lr = optimizer.param_groups[0]['lr']

    print(f"Train Loss: {train_loss:.4f} | Val Accuracy (full plate): {val_acc:.4f} | Char Accuracy: {val_acc_char:.4f} | LR: {current_lr:.6f}")
    print("-" * 50)

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch + 1
        patience_counter = 0
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_acc': val_acc,
            'val_acc_char': val_acc_char,
            'train_loss': train_loss
        }, "best_crnn_ctc_model.pth")
        print(f"==> New best model saved at epoch {best_epoch} with acc {best_val_acc:.4f}")
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{early_stopping_patience}")

    # Early stopping
    if patience_counter >= early_stopping_patience:
        print(f"Early stopping triggered after {epoch + 1} epochs")
        break

print(f"Training completed!")
print(f"Best model at epoch {best_epoch} with val acc {best_val_acc:.4f}")

In [None]:
# SAVING THE MODEL

source_path = 'best_crnn_ctc_model.pth'

# Destination directory in Google Drive (replace with your desired path)
destination_dir = '/content/drive/MyDrive/SavedModels/'

# Create the destination directory if it doesn't exist
os.makedirs(destination_dir, exist_ok=True)

# Construct the full destination path
destination_path = os.path.join(destination_dir, 'best_crnn_ctc_model.pth')

# Copy the file
shutil.copyfile(source_path, destination_path)

print(f"Model saved to: {destination_path}")

# Evaluate

In [None]:
# TESTING THE MODEL

model_path = '/content/drive/MyDrive/SavedModels/best_crnn_ctc_model.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load checkpoint
checkpoint = torch.load(model_path, map_location=device)

# model instance
model = CRNN(num_classes=len(MY_DICTIONARY)+1).to(device)

# Load wheights
model.load_state_dict(checkpoint['model_state_dict'])

print(f"Model loaded successfully from epoch {checkpoint['epoch']}")
print(f"Best validation accuracy was: {checkpoint['val_acc']:.4f}")
print(f"Best validation char accuracy was: {checkpoint['val_acc_char']:.4f}")

# Test
print("\n" + "="*50)
print("Testing on test set with loaded model:")
print("="*50)
test_acc, test_acc_char = evaluate(model, dataloader_test, device, verbose=True)
print(f"\nFinal Test Results:")
print(f"Test Accuracy (full plate): {test_acc:.4f}")
print(f"Test Character Accuracy: {test_acc_char:.4f}")