## Dataset and Main Libraries

In [1]:
import os
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
path='../Handwritten_equations/Handwritten/Dataset'
df=pd.read_csv('caption_data.csv')

# Modyfing the name of images
df["Column1"] = df["Column1"].astype(str) + ".bmp"
df.head()

Unnamed: 0,Column1,Column2
0,18_em_0.bmp,x _ { k } x x _ { k } + y _ { k } y x _ { k }
1,18_em_10.bmp,2 6
2,18_em_11.bmp,q _ { t } = 2 q
3,18_em_12.bmp,\frac { p e ^ { t } } { 1 - ( 1 - p ) e ^ { t } }
4,18_em_13.bmp,4 ^ { 2 } + 4 ^ { 2 } + \frac { 4 } { 4 }


## Trainning Function

In [4]:
def train(model, dataloader, optimizer, criterion, tokenizer, device, clip_grad_norm=1.0):
    model.train()
    total_loss = 0
    for batch_idx, (images, labels) in enumerate(dataloader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        # Forward pass with teacher forcing
        outputs = model(images, labels, teacher_forcing_ratio=0.5)

        # Assumptions about SOS token id from tokenizer
        sos_token_id = tokenizer.token_to_id.get("<SOS>", 1)
        pad_token_id = tokenizer.token_to_id.get("<PAD>", 0)

        # Shift outputs and labels to ignore SOS token prediction in loss
        outputs_shifted = outputs[:, 1:, :].contiguous()
        labels_shifted = labels[:, 1:].contiguous()

        # Create mask to ignore loss on padding tokens
        mask = labels_shifted != pad_token_id  # [B, seq_len]

        # Flatten outputs and labels for loss computation
        outputs_flat = outputs_shifted.view(-1, tokenizer.vocab_size())
        labels_flat = labels_shifted.view(-1)

        # Apply mask to ignore padding tokens in loss
        mask_flat = mask.view(-1)
        outputs_flat = outputs_flat[mask_flat]
        labels_flat = labels_flat[mask_flat]

        # Compute loss only on non-padded tokens
        loss = criterion(outputs_flat, labels_flat)

        loss.backward()

        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)

        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    return avg_loss


## Validation Function

In [5]:
def validate(model, dataloader, criterion, tokenizer, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(dataloader):
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass without teacher forcing (autoregressive generation)
            outputs = model(images, targets=None, teacher_forcing_ratio=0.0)

            # Token IDs
            sos_token_id = tokenizer.token_to_id.get("<SOS>", 1)
            pad_token_id = tokenizer.token_to_id.get("<PAD>", 0)

            # Shift outputs and labels to ignore SOS token
            outputs_shifted = outputs[:, 1:, :].contiguous()
            labels_shifted = labels[:, 1:].contiguous()

            # Mask padding tokens in labels for loss calculation
            mask = labels_shifted != pad_token_id  # [B, seq_len]

            # Flatten outputs and labels for loss
            outputs_flat = outputs_shifted.view(-1, tokenizer.vocab_size())
            labels_flat = labels_shifted.view(-1)
            mask_flat = mask.view(-1)

            # Apply mask to outputs and labels
            outputs_flat = outputs_flat[mask_flat]
            labels_flat = labels_flat[mask_flat]

            # Compute loss only on non-padded tokens
            if labels_flat.numel() == 0:
                # Skip batch if no valid labels
                continue

            loss = criterion(outputs_flat, labels_flat)
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else 0.0
    return avg_loss


## Prediction Function

In [6]:
def predict(model, image, tokenizer, device, max_len=50):
    model.eval()
    with torch.no_grad():
        image = image.to(device).unsqueeze(0)
        encoder_out = model.encoder(image)

        sos_token_id = tokenizer.token_to_id.get("<SOS>", 1)
        eos_token_id = tokenizer.token_to_id.get("<EOS>", 2)

        inputs = torch.tensor([sos_token_id]).to(device)

        # Initialize hidden state from encoder output
        encoder_mean = encoder_out.mean(dim=1)  # [1, H]
        h_0 = encoder_mean.unsqueeze(0)         # [1, 1, H]
        c_0 = torch.zeros_like(h_0)             # [1, 1, H]
        hidden = (h_0, c_0)

        decoded_tokens = []

        for _ in range(max_len):
            output, hidden, _ = model.decoder(inputs, hidden, encoder_out)
            top1 = output.argmax(1)

            if top1.item() == eos_token_id:
                break

            decoded_tokens.append(top1.item())
            inputs = top1

        return tokenizer.decode(decoded_tokens) if decoded_tokens else ""


## Use of Tokenizer and Creating Dataset

In [9]:
from torchvision import transforms

from Modules.Tokenizer import Tokenizer
from Modules.EquationSeqDataset import EquationSeqDataset


# 1. Tokenizer improvements
tokenizer = Tokenizer()
tokenizer.fit(df['Column2'])

# Consider adding special tokens explicitly if your tokenizer supports it:
special_tokens = ["<PAD>", "<SOS>", "<EOS>", "<UNK>"]
for token in special_tokens:
    if token not in tokenizer.token_to_id:
        tokenizer.add_token(token)  # or the equivalent tokenizer method

transform = transforms.Compose([
    transforms.Resize((224, 224)),               # Slightly larger for random crop
    transforms.RandomHorizontalFlip(p=0.5),      # Flip with 50% chance
    transforms.RandomRotation(15),                # Smaller rotation range (more realistic)
    transforms.ToTensor(),
    transforms.Normalize(                         # Normalize with ImageNet stats (if using ResNet pretrained)
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    ),
])

# 3. Instantiate dataset with improved transforms
dataset = EquationSeqDataset(df, path, tokenizer, transform=transform)

## Custom Dataloader

In [11]:
import random
import numpy as np
from torch.utils.data import random_split, DataLoader

from Modules.custom_collate_fn import custom_collate_fn

# 1. Fix random seed for reproducibility and stable split
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

# 2. Stratified split (optional, if your labels have classes)
# If your dataset has class imbalance, consider stratified splitting
# via sklearn's StratifiedShuffleSplit or custom logic.

# 3. Use a validation set instead of test split for hyperparameter tuning
train_ratio = 0.8
train_size = int(train_ratio * len(dataset))
test_size = len(dataset) - train_size

train_data, test_data = random_split(dataset, [train_size, test_size])

batch_size = 16

train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
    collate_fn=custom_collate_fn
)

test_loader = DataLoader(
    test_data,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=True,
    collate_fn=custom_collate_fn
)

## Model Building

In [12]:
import torch.optim as optim

from Modules.Encoder import CNNEncoder
from Modules.Decoder import RNNDecoder
from Modules.Sequence import Seq2Seq
from Modules.EarlyStopping import EarlyStopping

encoder = CNNEncoder(output_dim=256).to(device)
decoder = RNNDecoder(hidden_dim=256, vocab_size=tokenizer.vocab_size()).to(device)
model = Seq2Seq(encoder, decoder, device).to(device)

# Use CrossEntropyLoss ignoring PAD token (assumed 0)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id.get("<PAD>", 0)).to(device)

# AdamW optimizer with weight decay helps generalization
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

early_stopping = EarlyStopping(patience=4, min_delta=0.001, save_path='best_model.pth')



## Trainning Loop

In [None]:
import time
from torch import amp
from torch.cuda.amp import GradScaler

epochs = 50
train_loss_values = []
val_loss_values = []

# Use automatic mixed precision for faster training & better generalization
scaler = GradScaler()

for epoch in range(epochs):
    start = time.time()
    print(f"Epoch {epoch + 1}/{epochs}")

    # --- Training ---
    model.train()
    total_train_loss = 0
    
    for batch in train_loader:
        images, labels = batch[0], batch[1]
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()


        with amp.autocast(device_type='cuda'):
            outputs = model(images, labels, teacher_forcing_ratio=0.5)
            # Prepare targets and outputs for loss (same as in your train function)
            start_idx = tokenizer.token_to_id.get("<SOS>", 1)
            outputs_shifted = outputs[:, start_idx:, :].contiguous()
            labels_shifted = labels[:, start_idx:]
            outputs_flat = outputs_shifted.reshape(-1, tokenizer.vocab_size())
            labels_flat = labels_shifted.reshape(-1)
            loss = criterion(outputs_flat, labels_flat)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
        scaler.step(optimizer)
        scaler.update()

        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    train_loss_values.append(avg_train_loss)
    print(f"Train Loss: {avg_train_loss:.4f}")

    # Early stopping
    early_stopping(avg_train_loss, model)
    if early_stopping.early_stop:
        print("Early stopping triggered. Loading best model weights...")
        model.load_state_dict(torch.load(early_stopping.save_path))
        break

    end = time.time()
    print(f"Epoch Time: {end - start:.2f} seconds\n")

print("Training completed!")

  scaler = GradScaler()


Epoch 1/50


## Prediction on Test Dataset

In [None]:
from difflib import SequenceMatcher

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in train_loader:
        images, labels = batch[0], batch[1]
        images, labels = images.to(device), labels.to(device)

        batch_size = images.size(0)

        # Predict batch-wise (avoid loop over single images when possible)
        for i in range(batch_size):
            image = images[i]
            label = labels[i]

            # Extract true sequence between <SOS> and <EOS>
            try:
                eos_indices = (label == tokenizer.token_to_id["<EOS>"]).nonzero(as_tuple=True)[0]
                eos_index = eos_indices[0].item() if eos_indices.numel() > 0 else (label != 0).sum().item()
            except IndexError:  # No EOS token found
                eos_index = (label != 0).sum().item()  # Use length ignoring padding
            
            start_idx = tokenizer.token_to_id.get("<SOS>", 1)
            true_tokens = label[start_idx:eos_index]
            true_latex = tokenizer.decode(true_tokens.cpu().tolist())  # Move to CPU for decode

            predicted_latex = predict(model, image.cpu(), tokenizer, device)  # Move image to CPU if predict expects CPU input

            similarity = SequenceMatcher(None, predicted_latex.strip(), true_latex.strip()).ratio()

            # Adjust threshold or use more sophisticated metrics (BLEU, ROUGE, etc.)
            if similarity >= 0.60:  # Slightly stricter threshold to reduce false positives
                correct += 1
            total += 1

accuracy = (correct / total) * 100 if total > 0 else 0.0
print(f"Approximate Match Accuracy: {accuracy:.2f}%")


## Prediction on Train Dataset

In [None]:
from difflib import SequenceMatcher

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in train_loader:
        images, labels = batch[0], batch[1]
        images, labels = images.to(device), labels.to(device)

        batch_size = images.size(0)

        # Predict batch-wise (avoid loop over single images when possible)
        for i in range(batch_size):
            image = images[i]
            label = labels[i]

            # Extract true sequence between <SOS> and <EOS>
            try:
                eos_indices = (label == tokenizer.token_to_id["<EOS>"]).nonzero(as_tuple=True)[0]
                eos_index = eos_indices[0].item() if eos_indices.numel() > 0 else (label != 0).sum().item()
            except IndexError:  # No EOS token found
                eos_index = (label != 0).sum().item()  # Use length ignoring padding
            
            start_idx = tokenizer.token_to_id.get("<SOS>", 1)
            true_tokens = label[start_idx:eos_index]
            true_latex = tokenizer.decode(true_tokens.cpu().tolist())  # Move to CPU for decode

            predicted_latex = predict(model, image.cpu(), tokenizer, device)  # Move image to CPU if predict expects CPU input

            similarity = SequenceMatcher(None, predicted_latex.strip(), true_latex.strip()).ratio()

            # Adjust threshold or use more sophisticated metrics (BLEU, ROUGE, etc.)
            if similarity >= 0.60:  # Slightly stricter threshold to reduce false positives
                correct += 1
            total += 1

accuracy = (correct / total) * 100 if total > 0 else 0.0
print(f"Approximate Match Accuracy: {accuracy:.2f}%")


## Visulisation of Loss

In [None]:
import matplotlib.pyplot as plt
epoch=50
plt.figure(figsize=(10, 5))
plt.plot(range(1, epoch + 1), train_loss_values, linestyle='-', color='b', label="Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training & Validation Loss vs. Epoch")
plt.legend()
plt.grid()
plt.show()

## Visulisation of Some images and their Prediction

In [None]:
check = 0.005
check1_size = int(check * len(dataset))
check2_size = len(dataset) - check1_size
check1, check2 = random_split(dataset, [check1_size, check2_size])

check_loader = DataLoader(
    check1,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=True,
    collate_fn=custom_collate_fn
)


In [None]:
model.eval()

with torch.no_grad():
    for batch in check_loader:
        images, labels = batch[0], batch[1]
        images, labels = images.to(device), labels.to(device)

        for i in range(len(images)):
            image = images[i].unsqueeze(0)  # Add batch dimension
            label = labels[i]

            # Get <SOS> and <EOS> indices
            sos_idx = tokenizer.token_to_id.get("<SOS>", 1)
            eos_idx = tokenizer.token_to_id.get("<EOS>", 2)

            # Find <EOS> index if present
            eos_pos = (label == eos_idx).nonzero(as_tuple=True)
            eos_indices = (label == tokenizer.token_to_id["<EOS>"]).nonzero(as_tuple=True)[0]
            eos_index = eos_indices[0].item() if eos_indices.numel() > 0 else (label != 0).sum().item()

            # Extract ground-truth token sequence
            true_tokens = label[sos_idx:eos_index]
            true_latex = tokenizer.decode(true_tokens.cpu().tolist())

            # Predict LaTeX from model
            predicted_latex = predict(model, image.squeeze(0), tokenizer, device)

            # Log output
            print(f"True Tokens        : {true_tokens.cpu().tolist()}")
            print(f"Ground Truth LaTeX : {true_latex}")
            print(f"Predicted LaTeX    : {predicted_latex}")
            print('\n')


## Model Saving

In [None]:
import torch
import pickle

# Define special token IDs
sos_token_id = tokenizer.token_to_id.get("<SOS>", 1)
eos_token_id = tokenizer.token_to_id.get("<EOS>", 2)

# Save model checkpoint
torch.save({
    'model_state_dict': model.state_dict(),
    'sos_token_id': sos_token_id,
    'eos_token_id': eos_token_id
}, 'model_checkpoint.pth')

# Save tokenizer
with open('tokenizer.pkl', 'wb') as f:
    pickle.dump(tokenizer, f)

print("Model and tokenizer saved.")
