In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import glob
import cv2
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, Subset
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from sklearn.metrics import classification_report, confusion_matrix
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split

In [None]:
# Set directories
dataset_path = "/content/drive/MyDrive/LY_Project/IAM_DATASET/iam_words/words/"
label_file = "/content/drive/MyDrive/LY_Project/IAM_DATASET/iam_words/final_words.txt"

# Read labels from word.txt
def load_labels(label_file):
    label_dict = {}
    with open(label_file, "r") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 2:
                image_name = parts[0] + ".png"  # Assuming images have .png extension
                label_dict[image_name] = parts[-1]  # Last word is the label
    return label_dict

labels = load_labels(label_file)

In [None]:
print(labels)



In [None]:
unique_chars = sorted(set("".join(labels)))
print(len(unique_chars))
print(unique_chars)
char_to_index = {char: i for i, char in enumerate(unique_chars)}
index_to_char = {i: char for char, i in char_to_index.items()}
print(char_to_index)
print(index_to_char)
print(len(char_to_index))
print(len(index_to_char))

23
['-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'n', 'p', 'u', 'x']
{'-': 0, '.': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, 'a': 12, 'b': 13, 'c': 14, 'd': 15, 'e': 16, 'f': 17, 'g': 18, 'n': 19, 'p': 20, 'u': 21, 'x': 22}
{0: '-', 1: '.', 2: '0', 3: '1', 4: '2', 5: '3', 6: '4', 7: '5', 8: '6', 9: '7', 10: '8', 11: '9', 12: 'a', 13: 'b', 14: 'c', 15: 'd', 16: 'e', 17: 'f', 18: 'g', 19: 'n', 20: 'p', 21: 'u', 22: 'x'}
23
23


In [None]:
all_labels = list(labels.values())  # Get all labels from the dictionary
unique_chars = sorted(set("".join(all_labels)))
char_to_index = {char: i for i, char in enumerate(unique_chars)}
print(len(char_to_index))
print(char_to_index)
print(unique_chars)
print(len(unique_chars))

76
{'!': 0, '"': 1, '#': 2, "'": 3, '(': 4, ')': 5, '*': 6, ',': 7, '-': 8, '.': 9, '/': 10, '0': 11, '1': 12, '2': 13, '3': 14, '4': 15, '5': 16, '6': 17, '7': 18, '8': 19, '9': 20, ':': 21, ';': 22, '?': 23, 'A': 24, 'B': 25, 'C': 26, 'D': 27, 'E': 28, 'F': 29, 'G': 30, 'H': 31, 'I': 32, 'J': 33, 'K': 34, 'L': 35, 'M': 36, 'N': 37, 'O': 38, 'P': 39, 'Q': 40, 'R': 41, 'S': 42, 'T': 43, 'U': 44, 'V': 45, 'W': 46, 'X': 47, 'Y': 48, 'Z': 49, 'a': 50, 'b': 51, 'c': 52, 'd': 53, 'e': 54, 'f': 55, 'g': 56, 'h': 57, 'i': 58, 'j': 59, 'k': 60, 'l': 61, 'm': 62, 'n': 63, 'o': 64, 'p': 65, 'q': 66, 'r': 67, 's': 68, 't': 69, 'u': 70, 'v': 71, 'w': 72, 'x': 73, 'y': 74, 'z': 75}
['!', '"', '#', "'", '(', ')', '*', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', '

In [None]:
print(len(labels))
print(len(val_dataset))
print(len(test_dataset))
print(len(train_dataset))


44564
4456
4457
35650


In [None]:
print(len(unique_chars))
print(len(char_to_index))

76
76


In [None]:
# Decode first few encoded labels to check correctness
decoded_labels = ["".join(dataset.index_to_char[i] if isinstance(i, int) else i for i in label) for label in dataset.labels[:5]]

# Print original labels vs. decoded labels
for i, (orig, decoded) in enumerate(zip(dataset.labels[:5], decoded_labels)):
    print(f"Sample {i + 1}:")

    # Ensure `orig` is a list of indices before decoding
    original_label = "".join([dataset.index_to_char[idx] if isinstance(idx, int) else idx for idx in orig])

    print(f"  Original Label: {original_label}")  # Convert indices to string
    print(f"  Decoded Label: {decoded}")
    print("-" * 30)


Sample 1:
  Original Label: the
  Decoded Label: the
------------------------------
Sample 2:
  Original Label: first
  Decoded Label: first
------------------------------
Sample 3:
  Original Label: In
  Decoded Label: In
------------------------------
Sample 4:
  Original Label: it
  Decoded Label: it
------------------------------
Sample 5:
  Original Label: is
  Decoded Label: is
------------------------------


In [None]:
import torch.nn.functional as F
class ResNet_FPN_BiLSTM(nn.Module):
    def __init__(self, num_classes):
        super(ResNet_FPN_BiLSTM, self).__init__()

        # Load ResNet-50 backbone
        resnet = models.resnet50(pretrained=True)

        # Initial ResNet layers (conv1, BN, ReLU, maxpool)
        self.initial = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1
        )

        # Extract intermediate ResNet layers for FPN
        self.resnet_layers = nn.ModuleList([resnet.layer2, resnet.layer3, resnet.layer4])

        # Lateral Convolutions for FPN
        self.lateral_convs = nn.ModuleList([
            nn.Conv2d(512, 256, kernel_size=1),  # Changed from 512 to 256
            nn.Conv2d(1024, 256, kernel_size=1),  # Changed from 1024 to 512
            nn.Conv2d(2048, 256, kernel_size=1)
        ])
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

        # BiLSTM
        self.lstm = nn.LSTM(256, 256, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(p=0.3)

        # Fully Connected Layer
        self.fc = nn.Linear(256 * 2, num_classes)  # BiLSTM outputs (batch, seq_len, 512)

    def forward(self, x):
        batch_size = x.size(0)

        # Initial ResNet layers
        x = self.initial(x)          # -> C=256
        c3 = self.resnet_layers[0](x)  # -> C=512
        c4 = self.resnet_layers[1](c3) # -> C=1024
        c5 = self.resnet_layers[2](c4) # -> C=2048

        # FPN feature fusion
        p5 = F.relu(self.lateral_convs[2](c5))          # 2048 -> 256
        p4 = F.relu(F.interpolate(p5, size=c4.shape[2:], mode='nearest') + self.lateral_convs[1](c4))  # align sizes
        p3 = F.relu(F.interpolate(p4, size=c3.shape[2:], mode='nearest') + self.lateral_convs[0](c3))  # align sizes

        # Normalize spatial size: reduce height to 1 and fix width to 50 (for example)
        p3 = F.adaptive_avg_pool2d(p3, (1, 75))  # Shape: (B, 256, 1, 50)

        # Reshape for BiLSTM: (B, W=50, C=256)
        x = p3.squeeze(2).permute(0, 2, 1)

        # BiLSTM
        x, _ = self.lstm(x)

        # Dropout
        x = self.dropout(x)

        # Output layer
        x = self.fc(x)  # Shape: (B, W=50, num_classes)

        return x



In [None]:
#Debug code by adding attention layer to the architecture of CNN-RNN-FPN.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.attention = nn.Linear(hidden_size * 2, 1)  # BiLSTM is bidirectional

    def forward(self, lstm_output):
        # lstm_output: (batch_size, seq_len, hidden_size*2)
        attn_weights = self.attention(lstm_output)  # (batch_size, seq_len, 1)
        attn_weights = torch.softmax(attn_weights, dim=1)  # Softmax over the sequence length

        # Apply attention weights
        weighted_output = lstm_output * attn_weights  # Element-wise multiplication
        return weighted_output, attn_weights

class ResNet_FPN_BiLSTM_Attention(nn.Module):
    def __init__(self, num_classes):
        super(ResNet_FPN_BiLSTM_Attention, self).__init__()

        # Load ResNet-50 backbone
        resnet = models.resnet50(pretrained=True)

        # Initial ResNet layers
        self.initial = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1
        )

        # Extract intermediate ResNet layers for FPN
        self.resnet_layers = nn.ModuleList([resnet.layer2, resnet.layer3, resnet.layer4])

        # Lateral Convolutions for FPN
        self.lateral_convs = nn.ModuleList([
            nn.Conv2d(512, 256, kernel_size=1),
            nn.Conv2d(1024, 256, kernel_size=1),
            nn.Conv2d(2048, 256, kernel_size=1)
        ])

        # BiLSTM
        self.lstm = nn.LSTM(256, 256, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(p=0.3)

        # Attention layer
        self.attention = Attention(hidden_size=256)

        # Fully Connected Layer
        self.fc = nn.Linear(256 * 2, num_classes)  # BiLSTM outputs (batch, seq_len, 512)

    def forward(self, x):
        batch_size = x.size(0)

        # Initial ResNet layers
        x = self.initial(x)           # -> C=256
        c3 = self.resnet_layers[0](x) # -> C=512
        c4 = self.resnet_layers[1](c3) # -> C=1024
        c5 = self.resnet_layers[2](c4) # -> C=2048

        # FPN feature fusion
        p5 = F.relu(self.lateral_convs[2](c5))  # 2048 -> 256
        p4 = F.relu(F.interpolate(p5, size=c4.shape[2:], mode='nearest') + self.lateral_convs[1](c4))
        p3 = F.relu(F.interpolate(p4, size=c3.shape[2:], mode='nearest') + self.lateral_convs[0](c3))

        # Normalize spatial size
        p3 = F.adaptive_avg_pool2d(p3, (1, 75))  # Shape: (B, 256, 1, 75)

        # Reshape for BiLSTM: (B, W=75, C=256)
        x = p3.squeeze(2).permute(0, 2, 1)

        # BiLSTM
        x, _ = self.lstm(x)

        # Apply Dropout
        x = self.dropout(x)

        # Apply Attention
        x, attn_weights = self.attention(x)

        # Fully connected layer
        x = self.fc(x)  # Shape: (B, W=75, num_classes)

        return x


In [None]:
import os
import glob
import torch
from torch.utils.data import Dataset
from PIL import Image
import cv2
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence

class HandwritingDataset(Dataset):
    def __init__(self, root_dir, labels, char_to_index, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        cv2.setNumThreads(0)

        self.image_paths = glob.glob(os.path.join(root_dir, "**", "*.png"), recursive=True)
        self.labels = []
        valid_image_paths = []
        for img_path in self.image_paths:
            img_name = os.path.basename(img_path)
            if img_name in labels:
                valid_image_paths.append(img_path)
                self.labels.append(labels[img_name])

        self.image_paths = valid_image_paths
        self.char_to_index = char_to_index
        self.index_to_char = {i: c for c, i in char_to_index.items()}
        self.num_classes = len(char_to_index) + 1  # +1 for CTC blank

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        img = cv2.imread(img_path)
        if img is None:
            raise ValueError(f"Failed to load image: {img_path}")

        image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        if self.transform:
            image = self.transform(image)

        encoded_label = torch.tensor([self.char_to_index[char] for char in label], dtype=torch.long)
        return image, encoded_label

    def decode_label(self, label_tensor):
        return "".join([self.index_to_char[idx.item()] for idx in label_tensor if idx.item() in self.index_to_char])

def collate_fn(batch, blank_index):
    images, labels = zip(*batch)
    images = torch.stack(images, dim=0)
    #labels_padded = pad_sequence(labels, batch_first=True, padding_value=blank_index)
    return images, labels


# Character index mappings
unique_chars = sorted(set("".join(labels.values())))
char_to_index = {char: i for i, char in enumerate(unique_chars)}
blank_index = len(char_to_index)

# Transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Dataset and loaders
dataset = HandwritingDataset(dataset_path, labels, char_to_index, transform)

# Dataset splits
torch.manual_seed(42)
total_len = len(dataset)
train_size = int(0.8 * total_len)
val_size = int(0.1 * total_len)
test_size = total_len - train_size - val_size

indices = torch.randperm(total_len)
train_dataset = Subset(dataset, indices[:train_size])
val_dataset = Subset(dataset, indices[train_size:train_size + val_size])
test_dataset = Subset(dataset, indices[train_size + val_size:])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,
                          collate_fn=lambda x: collate_fn(x, blank_index), num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False,
                        collate_fn=lambda x: collate_fn(x, blank_index), num_workers=2, pin_memory=True)


In [None]:
#DEBUG VERSION OF HANDWRITTEN DATASET CODE

import os
import glob
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image
import cv2
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence
import time  # For timing

class HandwritingDataset(Dataset):
    def __init__(self, root_dir, labels, char_to_index, transform=None):
        start_time = time.time()  # Debug: Dataset loading start
        self.root_dir = root_dir
        self.transform = transform
        cv2.setNumThreads(0)

        self.image_paths = glob.glob(os.path.join(root_dir, "**", "*.png"), recursive=True)
        self.labels = []
        valid_image_paths = []
        for img_path in self.image_paths:
            img_name = os.path.basename(img_path)
            if img_name in labels:
                valid_image_paths.append(img_path)
                self.labels.append(labels[img_name])

        self.image_paths = valid_image_paths
        self.char_to_index = char_to_index
        self.index_to_char = {i: c for c, i in char_to_index.items()}
        self.num_classes = len(char_to_index) + 1  # +1 for CTC blank

        end_time = time.time()  # Debug: Dataset loading end
        print(f"Dataset initialized with {len(self.image_paths)} images in {end_time - start_time:.2f} seconds")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        start_time = time.time()

        img_path = self.image_paths[idx]
        label = self.labels[idx]

        img = cv2.imread(img_path)
        if img is None:
            raise ValueError(f"Failed to load image: {img_path}")

        load_time = time.time() - start_time
        if load_time > 0.5:  # Debug: Slow image load
            print(f"Slow load: {img_path} took {load_time:.2f} seconds")

        start_transform_time = time.time()

        image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        if self.transform:
            image = self.transform(image)

        transform_time = time.time() - start_transform_time
        if transform_time > 0.5:  # Debug: Slow transformation
            print(f"Slow transform: {img_path} took {transform_time:.2f} seconds")

        encoded_label = torch.tensor([self.char_to_index[char] for char in label], dtype=torch.long)
        return image, encoded_label

    def decode_label(self, label_tensor):
        return "".join([self.index_to_char[idx.item()] for idx in label_tensor if idx.item() in self.index_to_char])

def collate_fn(batch, blank_index):
    images, labels = zip(*batch)
    images = torch.stack(images, dim=0)
    return images, labels

# Character index mappings
unique_chars = sorted(set("".join(labels.values())))
char_to_index = {char: i for i, char in enumerate(unique_chars)}
blank_index = len(char_to_index)

# Transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Dataset loading with debug
start = time.time()
dataset = HandwritingDataset(dataset_path, labels, char_to_index, transform)
print(f"Total dataset loading time: {time.time() - start:.2f} seconds")

# Dataset splits
torch.manual_seed(42)
total_len = len(dataset)
train_size = int(0.8 * total_len)
val_size = int(0.1 * total_len)
test_size = total_len - train_size - val_size

indices = torch.randperm(total_len)
train_dataset = Subset(dataset, indices[:train_size])
val_dataset = Subset(dataset, indices[train_size:train_size + val_size])
test_dataset = Subset(dataset, indices[train_size + val_size:])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,
                          collate_fn=lambda x: collate_fn(x, blank_index), num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False,
                        collate_fn=lambda x: collate_fn(x, blank_index), num_workers=2, pin_memory=True)


Dataset initialized with 44563 images in 108.88 seconds
Total dataset loading time: 108.88 seconds


In [None]:
# Device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Import your model here
# from model import ResNet_FPN_BiLSTM
model = ResNet_FPN_BiLSTM(len(char_to_index) + 1).to(device)

criterion = nn.CTCLoss(blank=blank_index, reduction='mean', zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)

checkpoint_path = "/content/drive/MyDrive/MediPal_Final_Version/checkpoint_v2.pth"
best_model_path = "/content/drive/MyDrive/MediPal_Final_Version/best_model_v2.pth"

start_epoch = 0
best_val_loss = float("inf")
patience = 5
patience_counter = 0

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    best_val_loss = checkpoint["best_val_loss"]
    patience_counter = checkpoint["patience_counter"]
    print(f"Resuming from epoch {start_epoch}")

# Training loop
for epoch in range(start_epoch, 50):
    print(f"\nEpoch {epoch + 1}/50")
    print("-" * 30)
    model.train()
    total_loss = 0
    correct_preds = 0
    total_samples = 0

    for images, labels in train_loader:
        images = images.to(device)
        outputs = model(images)

        input_lengths = torch.full((images.size(0),), outputs.size(1), dtype=torch.long).to(device)
        target_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long).to(device)
        flattened_labels = torch.cat([l for l in labels]).to(device)

        optimizer.zero_grad()
        log_probs = torch.nn.functional.log_softmax(outputs, dim=2).permute(1, 0, 2)
        loss = criterion(log_probs, flattened_labels, input_lengths, target_lengths)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        # Accuracy
        preds = log_probs.argmax(dim=2).permute(1, 0)
        for i, pred in enumerate(preds):
            pred_str = []
            prev = -1
            for idx in pred.cpu().numpy():
                if idx != blank_index and idx != prev:
                    pred_str.append(dataset.index_to_char[idx])
                prev = idx
            predicted_text = ''.join(pred_str)
            true_text = dataset.decode_label(labels[i])
            if predicted_text == true_text:
                correct_preds += 1
            total_samples += 1

    train_acc = (correct_preds / total_samples) * 100
    print(f"Train Loss: {total_loss / len(train_loader):.4f}, Accuracy: {train_acc:.2f}%")

    # Validation
    model.eval()
    val_loss = 0
    correct_preds = 0
    total_samples = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            outputs = model(images)

            input_lengths = torch.full((images.size(0),), outputs.size(1), dtype=torch.long).to(device)
            target_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long).to(device)
            flattened_labels = torch.cat([l for l in labels]).to(device)

            log_probs = torch.nn.functional.log_softmax(outputs, dim=2).permute(1, 0, 2)
            loss = criterion(log_probs, flattened_labels, input_lengths, target_lengths)
            val_loss += loss.item()

            preds = log_probs.argmax(dim=2).permute(1, 0)
            for i, pred in enumerate(preds):
                pred_str = []
                prev = -1
                for idx in pred.cpu().numpy():
                    if idx != blank_index and idx != prev:
                        pred_str.append(dataset.index_to_char[idx])
                    prev = idx
                predicted_text = ''.join(pred_str)
                true_text = dataset.decode_label(labels[i])
                if predicted_text == true_text:
                    correct_preds += 1
                total_samples += 1

    val_acc = (correct_preds / total_samples) * 100
    print(f"Val Loss: {val_loss / len(val_loader):.4f}, Accuracy: {val_acc:.2f}%")

    # Save checkpoint
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "best_val_loss": best_val_loss,
        "patience_counter": patience_counter
    }, checkpoint_path)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), best_model_path)
        print("Best model saved.")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

In [None]:
#DEBUG VERSION OF TRAINING LOOP
import os
import torch
import torch.nn as nn
import torch.optim as optim
import time
from datetime import datetime

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Model init (assumes you imported everything)
model = ResNet_FPN_BiLSTM_Attention(len(char_to_index) + 1).to(device)

criterion = nn.CTCLoss(blank=blank_index, reduction='mean', zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)

checkpoint_path = "/content/drive/MyDrive/MediPal_Final_Version/checkpoint_v2_CNN_FPN_RNN_Attention.pth"
best_model_path = "/content/drive/MyDrive/MediPal_Final_Version/best_model_v2_CNN_FPN_RNN_Attention.pth"

start_epoch = 0
best_val_loss = float("inf")
patience = 5
patience_counter = 0
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

# Resume training if checkpoint exists
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    best_val_loss = checkpoint["best_val_loss"]
    patience_counter = checkpoint["patience_counter"]
    train_losses = checkpoint.get("train_losses", [])
    val_losses = checkpoint.get("val_losses", [])
    train_accuracies = checkpoint.get("train_accuracies", [])
    val_accuracies = checkpoint.get("val_accuracies", [])
    print(f"Resuming from epoch {start_epoch}")

# Training loop
for epoch in range(start_epoch, 50):
    print(f"\n Epoch {epoch + 1}/50 — {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print("-" * 50)
    model.train()
    total_loss = 0
    correct_preds = 0
    total_samples = 0

    epoch_start = time.time()

    for batch_idx, (images, labels) in enumerate(train_loader):
        batch_start = time.time()
        if (batch_idx + 1) % 100 == 0 or (batch_idx + 1) == len(train_loader):
            print(f" Batch {batch_idx + 1}/{len(train_loader)}")

        images = images.to(device)
        outputs = model(images)

        input_lengths = torch.full((images.size(0),), outputs.size(1), dtype=torch.long).to(device)
        target_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long).to(device)
        flattened_labels = torch.cat([l for l in labels]).to(device)

        optimizer.zero_grad()
        log_probs = torch.nn.functional.log_softmax(outputs, dim=2).permute(1, 0, 2)
        loss = criterion(log_probs, flattened_labels, input_lengths, target_lengths)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        # Accuracy
        preds = log_probs.argmax(dim=2).permute(1, 0)
        for i, pred in enumerate(preds):
            pred_str = []
            prev = -1
            for idx in pred.cpu().numpy():
                if idx != blank_index and idx != prev:
                    pred_str.append(dataset.index_to_char[idx])
                prev = idx
            predicted_text = ''.join(pred_str)
            true_text = dataset.decode_label(labels[i])
            if predicted_text == true_text:
                correct_preds += 1
            total_samples += 1

        #print(f"✅ Batch {batch_idx+1} done in {time.time() - batch_start:.2f}s — Loss: {loss.item():.4f}")

    epoch_duration = time.time() - epoch_start
    train_acc = (correct_preds / total_samples) * 100
    avg_train_loss = total_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    train_accuracies.append(train_acc)
    print(f" Epoch {epoch+1} completed in {epoch_duration:.2f}s")
    print(f" Train Loss: {avg_train_loss:.4f} | Accuracy: {train_acc:.2f}%")

    # Validation
    model.eval()
    val_loss = 0
    correct_preds = 0
    total_samples = 0
    val_start = time.time()

    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(val_loader):
            images = images.to(device)
            outputs = model(images)

            input_lengths = torch.full((images.size(0),), outputs.size(1), dtype=torch.long).to(device)
            target_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long).to(device)
            flattened_labels = torch.cat([l for l in labels]).to(device)

            log_probs = torch.nn.functional.log_softmax(outputs, dim=2).permute(1, 0, 2)
            loss = criterion(log_probs, flattened_labels, input_lengths, target_lengths)
            val_loss += loss.item()

            preds = log_probs.argmax(dim=2).permute(1, 0)
            for i, pred in enumerate(preds):
                pred_str = []
                prev = -1
                for idx in pred.cpu().numpy():
                    if idx != blank_index and idx != prev:
                        pred_str.append(dataset.index_to_char[idx])
                    prev = idx
                predicted_text = ''.join(pred_str)
                true_text = dataset.decode_label(labels[i])
                if predicted_text == true_text:
                    correct_preds += 1
                total_samples += 1

    val_acc = (correct_preds / total_samples) * 100
    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    val_accuracies.append(val_acc)
    print(f" Validation done in {time.time() - val_start:.2f}s")
    print(f" Val Loss: {avg_val_loss:.4f} |  Val Accuracy: {val_acc:.2f}%")

    # Save checkpoint
    print(" Saving checkpoint...")
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "best_val_loss": best_val_loss,
        "patience_counter": patience_counter,
        "train_losses": train_losses,
        "val_losses": val_losses,
        "train_accuracies": train_accuracies,
        "val_accuracies": val_accuracies
    }, checkpoint_path)

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), best_model_path)
        print("Best model saved!")
    else:
        patience_counter += 1
        print(f" No improvement. Patience counter: {patience_counter}/{patience}")
        if patience_counter >= patience:
            print(" Early stopping triggered.")
            break


Using device: cuda


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 195MB/s]



 Epoch 1/50 — 2025-06-20 05:13:03
--------------------------------------------------
Slow load: /content/drive/MyDrive/LY_Project/IAM_DATASET/iam_words/words/d03/d03-117/d03-117-03-05.png took 0.58 seconds
Slow load: /content/drive/MyDrive/LY_Project/IAM_DATASET/iam_words/words/a03/a03-080/a03-080-05-06.png took 0.75 seconds
Slow load: /content/drive/MyDrive/LY_Project/IAM_DATASET/iam_words/words/c03/c03-084b/c03-084b-01-03.png took 0.65 seconds
Slow load: /content/drive/MyDrive/LY_Project/IAM_DATASET/iam_words/words/b01/b01-004/b01-004-01-05.png took 0.55 seconds
Slow load: /content/drive/MyDrive/LY_Project/IAM_DATASET/iam_words/words/e07/e07-012/e07-012-10-05.png took 0.58 seconds
Slow load: /content/drive/MyDrive/LY_Project/IAM_DATASET/iam_words/words/a01/a01-053x/a01-053x-08-04.png took 0.52 seconds
Slow load: /content/drive/MyDrive/LY_Project/IAM_DATASET/iam_words/words/a01/a01-102/a01-102-05-05.png took 0.81 seconds
Slow load: /content/drive/MyDrive/LY_Project/IAM_DATASET/iam_wo

In [None]:
# Performance Metrics
accuracy = np.mean(np.array(all_preds) == np.array(all_labels))
print(f"Test Accuracy: {accuracy * 100:.2f}%")
print(classification_report(all_labels, all_preds, target_names=unique_labels))

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', xticklabels=unique_labels, yticklabels=unique_labels)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()
