In [81]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from PIL import Image


# Simplified Dataset Class
from torch.nn.utils.rnn import pad_sequence

class Dataset(Dataset):
    def __init__(self, image_dir, annotations_file, transform=None, max_seq_len=10, num_classes=36):
        self.image_dir = image_dir
        self.transform = transform
        self.max_seq_len = max_seq_len
        self.num_classes = num_classes  # Add num_classes

        # Load annotations
        with open(annotations_file, 'r') as f:
            lines = f.readlines()[1:]  # Skip the header
        self.annotations = [line.strip().split(",") for line in lines]

    def __len__(self):
        return len(self.annotations)
        
    def __getitem__(self, idx):
        # Retrieve the annotation for the current index
        image_name, label = self.annotations[idx]
        
        # Construct the full image path (adjust based on folder structure)
        image_path = os.path.join(self.image_dir, image_name.split('[')[0], image_name)

        # Load the image in grayscale mode
        image = Image.open(image_path).convert('L')
        if self.transform:
            image = self.transform(image)

        # Map label characters to integers
        label_indices = []
        for c in label.upper():  # Convert to uppercase for uniformity
            if 'A' <= c <= 'Z':  # Map 'A-Z' to 0-25
                label_indices.append(ord(c) - ord('A'))
            elif '0' <= c <= '9':  # Map '0-9' to 26-35
                label_indices.append(ord(c) - ord('0') + 26)
            else:
                raise ValueError(f"Unsupported character in label: {c}")

        # Ensure the label length matches max_seq_len
        label_indices = label_indices[:self.max_seq_len]  # Truncate if too long
        label_indices += [0] * (self.max_seq_len - len(label_indices))  # Pad if too short

        # Return the transformed image and the label tensor
        return image, torch.tensor(label_indices, dtype=torch.long)




# Load Pre-trained Model
class LicensePlateCNN(nn.Module):
    def __init__(self, num_classes=36, max_seq_len=10):
        super(LicensePlateCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 4 * 4, 256)  # Match the pre-trained CNN
        self.fc2 = nn.Linear(256, num_classes * max_seq_len)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# Paths and Transformations
image_dir = '/Users/arya/Desktop/CVIP_Proj/output/new'
annotations_file = '/Users/arya/Desktop/CVIP_Proj/new_annotations.csv'
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Dataset and DataLoader
dataset = Dataset(image_dir, annotations_file, transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# Initialize Model with Pre-Trained Weights
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LicensePlateCNN(num_classes=36, max_seq_len=10)
pretrained_weights_path = '/Users/arya/Desktop/CVIP_Proj/OCR/character_cnn.pth'

# Load pre-trained weights (only layers with matching shapes)
pretrained_dict = torch.load(pretrained_weights_path, map_location=device)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].shape == v.shape}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

model.to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training and Validation Loop
best_val_loss = float('inf')
for epoch in range(50):  # Increase epochs if necessary
    model.train()
    train_loss = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        outputs = outputs.view(-1, dataset.num_classes)  # Use dataset.num_classes
        labels = labels.view(-1)

        if outputs.size(0) != labels.size(0):
            print(f"Output size: {outputs.size(0)}, Label size: {labels.size(0)}")
            raise ValueError(f"Output size {outputs.size(0)} and label size {labels.size(0)} do not match.")

        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    print(f"Epoch {epoch + 1}, Train Loss: {train_loss / len(train_loader):.4f}")

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)

            # Flatten labels
            labels = labels.view(-1)

            outputs = model(images)
            outputs = outputs.view(-1, dataset.num_classes)

            # Ensure shapes match for CrossEntropyLoss
            if outputs.size(0) != labels.size(0):
                raise ValueError(f"Output size {outputs.size(0)} and label size {labels.size(0)} do not match.")

            loss = criterion(outputs, labels)
            val_loss += loss.item()


    val_loss /= len(val_loader)
    print(f"Validation Loss: {val_loss:.4f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), '/Users/arya/Desktop/CVIP_Proj/OCR/BEST_cnn.pth')
        print("Validation loss improved. Model saved.")



Epoch 1, Train Loss: 3.1067
Validation Loss: 2.2923
Validation loss improved. Model saved.
Epoch 2, Train Loss: 1.7798
Validation Loss: 1.5481
Validation loss improved. Model saved.
Epoch 3, Train Loss: 1.3985
Validation Loss: 1.3069
Validation loss improved. Model saved.
Epoch 4, Train Loss: 1.1480
Validation Loss: 1.0685
Validation loss improved. Model saved.
Epoch 5, Train Loss: 0.9085
Validation Loss: 0.8485
Validation loss improved. Model saved.
Epoch 6, Train Loss: 0.6972
Validation Loss: 0.6722
Validation loss improved. Model saved.
Epoch 7, Train Loss: 0.5325
Validation Loss: 0.5233
Validation loss improved. Model saved.
Epoch 8, Train Loss: 0.4084
Validation Loss: 0.4206
Validation loss improved. Model saved.
Epoch 9, Train Loss: 0.3126
Validation Loss: 0.3313
Validation loss improved. Model saved.
Epoch 10, Train Loss: 0.2402
Validation Loss: 0.2621
Validation loss improved. Model saved.
Epoch 11, Train Loss: 0.1853
Validation Loss: 0.2121
Validation loss improved. Model save