# Install Necessary Libraries for the Project

In [None]:
!pip install torch datasets accelerate trl jiwer

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting trl
  Downloading trl-0.16.1-py3-none-any.whl.metadata (12 kB)
Collecting jiwer
  Downloading jiwer-3.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Coll

# Load and Prepare the Dataset

In [1]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms  # Add this import
import random
import torch
# Define the Dataset with transform
class HandwrittenMathDataset(Dataset):
    def __init__(self, image_directory, labels_file, transform=None):
        self.image_paths = []
        self.latex_sequences = []
        self.transform = transform  # Store the transform

        with open(labels_file, 'r') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                parts = line.split('\t')
                if len(parts) == 2:
                    image_filename, latex_seq = parts
                    image_path = os.path.join(image_directory, image_filename)
                    if os.path.exists(image_path):
                        self.image_paths.append(image_path)
                        self.latex_sequences.append(latex_seq)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('L')  # PIL Image

        # Apply transform if defined
        if self.transform:
            image = self.transform(image)

        latex_seq = self.latex_sequences[idx]
        return image, latex_seq

from google.colab import drive
drive.mount('/content/drive')
folder_path = '/content/drive/MyDrive/3312_images/'

# Define the transform
transform = transforms.Compose([
    transforms.ToTensor()  # Convert PIL Image to PyTorch Tensor
])

# Create dataset with transform
dataset = HandwrittenMathDataset(
    image_directory=folder_path + "synthetic_images",
    labels_file=folder_path + "synthetic_labels.txt",
    transform=transform  # Add the transform here
)

# train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# print(f"Dataset size: {len(train_dataset)}")
# print(train_dataset.__getitem__(0))

def create_train_test_split(dataset, test_size=0.2, random_state=42):
    """
    Split the dataset into training and testing sets

    Args:
        dataset (ImageTextDataset): The dataset to split
        test_size (float): Proportion of the dataset to include in the test split
        random_state (int): Random seed for reproducibility

    Returns:
        tuple: (train_dataset, test_dataset)
    """
    # Method 1: Using PyTorch's random_split
    train_size = int((1 - test_size) * len(dataset))
    test_size = len(dataset) - train_size

    train_dataset, test_dataset = random_split(
        dataset,
        [train_size, test_size],
        generator=torch.Generator().manual_seed(random_state)
    )

    return train_dataset, test_dataset
train_dataset, test_dataset = create_train_test_split(dataset, test_size=0.2)

# Create dataloaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

Mounted at /content/drive


In [2]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import defaultdict
def build_vocab(labels_file, vocab_size=500):
    # Collect all characters
    all_chars = defaultdict(int)
    special_tokens = ['<start>', '<end>', '<pad>']
    with open(labels_file, 'r') as f:
        for line in f:
            line = line.strip()
            if line:
                split_line = line.split('\t')
                if(len(split_line) > 2):
                    latex = "\t".join(split_line[1:len(split_line)])
                else:
                    latex = split_line[1]
                for char in latex:
                    all_chars[char] += 1
    # Sort characters by frequency
    sorted_chars = sorted(all_chars.items(), key=lambda x: x[1], reverse=True)
    # Assign indices: 0 for <pad>, 1 for <start>, 2 for <end>, then others
    vocab = {'<pad>': 0, '<start>': 1, '<end>': 2}
    idx = 3
    for char, _ in sorted_chars:
        if char not in vocab and idx < vocab_size:
            vocab[char] = idx
            idx += 1
    return vocab
folder_path = '/content/drive/MyDrive/3312_images/'
labels_file = folder_path + "synthetic_labels.txt"
vocab = build_vocab(labels_file, vocab_size=500)
vocab_size = len(vocab)
print(f"Vocabulary size: {vocab_size}")

Vocabulary size: 95


In [3]:
def string_to_tensor(string_list, vocab):
    max_length = 0
    all_indices = []
    for string in string_list:
        # Convert characters to indices, using vocab.get(char, 0) for unknowns (0 is <pad>)
        indices = [vocab.get('<start>')] + [vocab.get(char, 0) for char in string] + [vocab.get('<end>')]
        all_indices.append(indices)
        max_length = max(max_length, len(indices))
    # Pad sequences to max_length
    padded_indices = [
        seq + [vocab['<pad>']] * (max_length - len(seq)) for seq in all_indices
    ]
    tensor = torch.tensor(padded_indices, dtype=torch.long).t()  # [seq_len, batch_size]
    return tensor

def tensor_to_string(tensor, vocab):
    """Convert a tensor of token indices to strings"""
    # Get index-to-token mapping (reverse of the vocabulary)
    idx_to_token = {idx: token for token, idx in vocab.items()}

    # If tensor is [T, B], convert to [B, T] for batch processing
    if tensor.dim() == 2:
        tensor = tensor.transpose(0, 1)

    batch_texts = []
    for sequence in tensor:
        tokens = [idx_to_token.get(idx.item(), "") for idx in sequence]
        # Stop at end-of-sequence token if present
        if "<eos>" in tokens:
            tokens = tokens[:tokens.index("<eos>")]
        text = "".join(tokens)
        batch_texts.append(text)

    return batch_texts

# Project Code

Step 1.) CNN or Transformer based Image -> Latex conversion

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Define the Encoder (CNN)
class CNNEncoder(nn.Module):
    def __init__(self, encoded_image_size=14):
        super(CNNEncoder, self).__init__()
        # TODO: Should we replace this with a pretrained model instead?
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.encoded_image_size = encoded_image_size

    def forward(self, images):
        x = F.relu(self.conv1(images))   # [B, 32, H, W]
        x = self.pool(x)                 # Downsample
        x = F.relu(self.conv2(x))        # [B, 64, H/2, W/2]
        x = self.pool(x)                 # Downsample further
        print(x.shape)
        batch_size, channels, height, width = x.size()
        x = x.view(batch_size, channels, -1)  # Flatten spatial dimensions: [B, C, N]
        x = x.permute(0, 2, 1)           # [B, N, C] for the transformer encoder input
        return x

# Define the Decoder (Transformer)
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=256, num_layers=2, nhead=8, dropout=0.1, max_seq_length=100):
        super(TransformerDecoder, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_length, d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt, memory):
        # tgt: [T, B] and memory: [S, B, E]
        T, B = tgt.size()
        positions = torch.arange(0, T).unsqueeze(1).expand(T, B).to(tgt.device)
        positions = positions.clamp(0, self.pos_embedding.num_embeddings - 1)  # Clamp positions
        tgt_emb = self.embedding(tgt) + self.pos_embedding(positions)
        tgt_emb = self.dropout(tgt_emb)
        # Generate a mask to prevent attention to future tokens
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(T).to(tgt.device)
        output = self.transformer_decoder(tgt_emb, memory, tgt_mask=tgt_mask)
        output = self.fc_out(output)
        return output

# Combine Encoder and Decoder into one Model
class HandwrittenMathToLatexModel(nn.Module):
    def __init__(self, vocab_size, d_model=256):
        super(HandwrittenMathToLatexModel, self).__init__()
        self.encoder = CNNEncoder()
        self.decoder = TransformerDecoder(vocab_size=vocab_size, d_model=d_model)
        # Project encoder output to match decoder d_model if needed
        self.enc_to_dec = nn.Linear(64, d_model)

    def forward(self, images, tgt_seq):
        # images: [B, 1, H, W]
        # tgt_seq: [T, B]
        enc_out = self.encoder(images)  # [B, N, 64]
        enc_out = self.enc_to_dec(enc_out)  # [B, N, d_model]
        # Transformer expects: [S, B, E]
        enc_out = enc_out.permute(1, 0, 2)
        output = self.decoder(tgt_seq, enc_out)  # [T, B, vocab_size]
        return output

# Training Loop
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, target_seq in dataloader:
        images = images.to(device)
        # Get the tensor from the tuple returned by string_to_tensor
        target_seq = string_to_tensor(target_seq, vocab).to(device)
        target_seq = target_seq.to(device)
        # Assume target_seq is size [T, B]
        optimizer.zero_grad()
        # Shift target sequence for teacher forcing
        input_seq = target_seq[:-1, :]
        output = model(images, input_seq)
        # Compute loss between output and target_seq[1:,:]
        loss = criterion(output.reshape(-1, output.shape[-1]), target_seq[1:, :].reshape(-1))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(dataloader)
def test(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for images, target_seq in dataloader:
            images = images.to(device)
            target_seq = string_to_tensor(target_seq, vocab).to(device)
            # Shift target sequence for decoder input (same as in training)
            input_seq = target_seq[:-1, :]
            output = model(images, input_seq)

            # _, predicted_indices = torch.max(output, dim=2)
            # predicted_text = tensor_to_string(predicted_indices, vocab)
            # print(f"Predicted: {predicted_text}")

            # Calculate loss against the full target sequence (offset by 1)
            loss = criterion(output.reshape(-1, output.shape[-1]),
                            target_seq[1:, :].reshape(-1))
            running_loss += loss.item()
    return running_loss / len(dataloader)


# Initialize Hyperparameters
batch_size = 16
learning_rate = 1e-3
num_epochs = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HandwrittenMathToLatexModel(vocab_size=vocab_size).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # Adjust ignore_index if needed (for padding)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    test_loss = test(model, test_loader, criterion, device)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")


torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 64, 64])
torch.Size([16

In [None]:
import math
class ImprovedCNNEncoder(nn.Module):
    def __init__(self, encoded_image_size=14):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
        )

    def forward(self, images):
        features = self.cnn(images)  # [B, 256, 14, 14]
        features = features.flatten(2).permute(0, 2, 1)  # [B, 196, 256]
        return features

class ImprovedTransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_layers=4, nhead=8, dropout=0.2, max_seq_length=5000):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len=max_seq_length)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward=d_model*4, dropout=dropout)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, tgt, memory, tgt_mask=None):
        tgt_emb = self.embedding(tgt) * math.sqrt(self.embedding.embedding_dim)
        tgt_emb = self.pos_encoder(tgt_emb)
        output = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
        return self.fc_out(output)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


class ImprovedHandwrittenMathToLatexModel(nn.Module):
    def __init__(self, vocab_size, d_model=512):
        super().__init__()
        self.encoder = ImprovedCNNEncoder()
        self.decoder = ImprovedTransformerDecoder(vocab_size, d_model)
        self.enc_proj = nn.Linear(256, d_model)

    def forward(self, images, tgt_seq):
        enc_out = self.encoder(images)  # [B, N, 256]
        enc_out = self.enc_proj(enc_out)  # [B, N, d_model]
        enc_out = enc_out.permute(1, 0, 2)  # [N, B, d_model]
        output = self.decoder(tgt_seq, enc_out)
        return output

# Modify hyperparameters
d_model = 512
num_epochs = 10
batch_size = 16
learning_rate = 3e-4
weight_decay = 1e-5

# Use learning rate scheduler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImprovedHandwrittenMathToLatexModel(vocab_size=vocab_size).to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate,
                                        steps_per_epoch=len(train_loader), epochs=num_epochs)

# Add label smoothing
criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, target_seq in dataloader:
        images = images.to(device)
        # Get the tensor from the tuple returned by string_to_tensor
        target_seq = string_to_tensor(target_seq, vocab).to(device)
        target_seq = target_seq.to(device)
        # Assume target_seq is size [T, B]
        optimizer.zero_grad()
        # Shift target sequence for teacher forcing
        input_seq = target_seq[:-1, :]
        output = model(images, input_seq)
        # Compute loss between output and target_seq[1:,:]
        loss = criterion(output.reshape(-1, output.shape[-1]), target_seq[1:, :].reshape(-1))
        loss.backward()
        optimizer.step()
        scheduler.step()
        running_loss += loss.item()
    return running_loss / len(dataloader)
def test(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for images, target_seq in dataloader:
            images = images.to(device)
            target_seq = string_to_tensor(target_seq, vocab).to(device)
            # Shift target sequence for decoder input (same as in training)
            input_seq = target_seq[:-1, :]
            output = model(images, input_seq)
            # Calculate loss against the full target sequence (offset by 1)
            loss = criterion(output.reshape(-1, output.shape[-1]),
                            target_seq[1:, :].reshape(-1))
            running_loss += loss.item()
    return running_loss / len(dataloader)

for epoch in range(num_epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    test_loss = test(model, test_loader, criterion, device)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")


Epoch [1/10], Train Loss: 3.6705, Test Loss: 3.0987
Epoch [2/10], Train Loss: 2.9336, Test Loss: 2.7483
Epoch [3/10], Train Loss: 2.7144, Test Loss: 2.6744
Epoch [4/10], Train Loss: 2.5620, Test Loss: 2.5311
Epoch [5/10], Train Loss: 2.4619, Test Loss: 2.4439
Epoch [6/10], Train Loss: 2.3683, Test Loss: 2.3456
Epoch [7/10], Train Loss: 2.2839, Test Loss: 2.3125
Epoch [8/10], Train Loss: 2.2357, Test Loss: 2.2829
Epoch [9/10], Train Loss: 2.1945, Test Loss: 2.2683
Epoch [10/10], Train Loss: 2.1663, Test Loss: 2.2675


In [None]:
# Define the Encoder (Transformer)
class TransformerEncoder(nn.Module):
    def __init__(self, d_model=256, nhead=8, num_encoder_layers=3, dim_feedforward=1024, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        # Initial embedding layer to convert image patches to embeddings
        self.patch_embedding = nn.Conv2d(1, d_model, kernel_size=4, stride=4)

        # Positional encoding for patches
        self.pos_embedding = nn.Parameter(torch.zeros(1, 4096, d_model))

        # Transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)

    def forward(self, images):
        # images: [B, 1, H, W]
        batch_size = images.size(0)

        # Create patch embeddings [B, d_model, h, w]
        x = self.patch_embedding(images)

        # Reshape to [B, d_model, N] where N is number of patches
        h, w = x.shape[-2:]
        x = x.reshape(batch_size, self.d_model, h*w)

        # Permute to [B, N, d_model] for transformer input
        x = x.permute(0, 2, 1)

        # Add positional embeddings
        x = x + self.pos_embedding[:, :x.size(1), :]

        # Apply dropout
        x = self.dropout(x)

        # Transformer expects: [N, B, E]
        x = x.permute(1, 0, 2)

        # Pass through transformer encoder
        memory = self.transformer_encoder(x)

        return memory  # [N, B, E]

# Define the Decoder (Transformer)
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=256, num_layers=2, nhead=8, dropout=0.1, max_seq_length=100):
        super(TransformerDecoder, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_length, d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dropout=dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt, memory):
        # tgt: [T, B] and memory: [S, B, E]
        T, B = tgt.size()
        positions = torch.arange(0, T).unsqueeze(1).expand(T, B).to(tgt.device)
        positions = positions.clamp(0, self.pos_embedding.num_embeddings - 1)  # Clamp positions
        tgt_emb = self.embedding(tgt) + self.pos_embedding(positions)
        tgt_emb = self.dropout(tgt_emb)
        # Generate a mask to prevent attention to future tokens
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(T).to(tgt.device)
        output = self.transformer_decoder(tgt_emb, memory, tgt_mask=tgt_mask)
        output = self.fc_out(output)
        return output

# Combine Encoder and Decoder into one Model
class HandwrittenMathToLatexModel(nn.Module):
    def __init__(self, vocab_size, d_model=256):
        super(HandwrittenMathToLatexModel, self).__init__()
        self.encoder = TransformerEncoder(d_model=d_model)
        self.decoder = TransformerDecoder(vocab_size=vocab_size, d_model=d_model)

    def forward(self, images, tgt_seq):
        # images: [B, 1, H, W]
        # tgt_seq: [T, B]
        memory = self.encoder(images)  # Already in shape [S, B, E]
        output = self.decoder(tgt_seq, memory)  # [T, B, vocab_size]
        return output

# Training Loop
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, target_seq in dataloader:
        images = images.to(device)
        # Get the tensor from the tuple returned by string_to_tensor
        target_seq = string_to_tensor(target_seq, vocab).to(device)
        target_seq = target_seq.to(device)
        # Assume target_seq is size [T, B]
        optimizer.zero_grad()
        # Shift target sequence for teacher forcing
        input_seq = target_seq[:-1, :]
        output = model(images, input_seq)
        # Compute loss between output and target_seq[1:,:]
        loss = criterion(output.reshape(-1, output.shape[-1]), target_seq[1:, :].reshape(-1))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(dataloader)

def test(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for images, target_seq in dataloader:
            images = images.to(device)
            target_seq = string_to_tensor(target_seq, vocab).to(device)
            # Shift target sequence for decoder input (same as in training)
            input_seq = target_seq[:-1, :]
            output = model(images, input_seq)

            # _, predicted_indices = torch.max(output, dim=2)
            # predicted_text = tensor_to_string(predicted_indices, vocab)
            # print(f"Predicted: {predicted_text}")

            # Calculate loss against the full target sequence (offset by 1)
            loss = criterion(output.reshape(-1, output.shape[-1]),
                            target_seq[1:, :].reshape(-1))
            running_loss += loss.item()
    return running_loss / len(dataloader)


# Initialize Hyperparameters
batch_size = 16
learning_rate = 1e-3
num_epochs = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HandwrittenMathToLatexModel(vocab_size=vocab_size).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # Adjust ignore_index if needed (for padding)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    test_loss = test(model, test_loader, criterion, device)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")

Epoch [1/10], Train Loss: 2.9391, Test Loss: 2.5367
Epoch [2/10], Train Loss: 2.4199, Test Loss: 2.2440
Epoch [3/10], Train Loss: 2.1816, Test Loss: 2.0670
Epoch [4/10], Train Loss: 2.0129, Test Loss: 1.9482
Epoch [5/10], Train Loss: 1.9048, Test Loss: 1.8924
Epoch [6/10], Train Loss: 1.8280, Test Loss: 1.8456
Epoch [7/10], Train Loss: 1.7559, Test Loss: 1.8052
Epoch [8/10], Train Loss: 1.6816, Test Loss: 1.7902
Epoch [9/10], Train Loss: 1.6341, Test Loss: 1.7596
Epoch [10/10], Train Loss: 1.5769, Test Loss: 1.7520


**Potential Model Evaluation Method** (from https://github.com/google-research/google-research/blob/master/mathwriting/mathwriting_code_examples.ipynb)

In [None]:
import re

_COMMAND_RE = re.compile(r'\\mathbb\{[a-zA-Z]\}|\\begin\{[a-z]+\}|\\end\{[a-z]+\}|\\operatorname\*?|[a-zA-Z]+|.')

def tokenize_expression(s: str) -> list[str]:
  """Transform a Latex math string into a list of tokens."""
  tokens = []
  while s:
    if s[0] == '\\':
      tokens.append(_COMMAND_RE.match(s).group(0))
    else:
      tokens.append(s[0])
    s = s[len(tokens[-1]) :]
  return tokens

# Example Usage
print(tokenize_expression(r'\frac{\alpha}{2} \not\in\mathbb{R}'))

['\\', 'f', 'r', 'a', 'c', '{', '\\', 'a', 'l', 'p', 'h', 'a', '}', '{', '2', '}', ' ', '\\', 'n', 'o', 't', '\\', 'i', 'n', '\\mathbb{R}']


In [None]:
import jiwer

class TokenizeTransform(jiwer.transforms.AbstractTransform):
    def process_string(self, s: str):
      return tokenize_expression(r'{}'.format(s))
    def process_list(self, tokens: list[str]):
      return [self.process_string(token) for token in tokens]

def compute_cer(truth_and_output: list[tuple[str, str]]):
  """Computes CER given pairs of ground truth and model output."""
  ground_truth, model_output = zip(*truth_and_output)
  return jiwer.cer(truth=list(ground_truth),
            hypothesis=list(model_output),
            reference_transform=TokenizeTransform(),
            hypothesis_transform=TokenizeTransform(),
      )

# Test data to run compute_cer().
# The first element is the model prediction, the second the ground truth.
examples = [
    (r'\sqrt{2}', r'\sqrt{2}'),  # 0 mistakes, 4 tokens
    (r'\frac{1}{2}', r'\frac{i}{2}'),  # 1 mistake, 7 tokens
    (r'\alpha^{2}', 'a^{2}'),  # 1 mistake, 5 tokens
    ('abc', 'def'),  # 3 mistakes, 3 tokens
]

# 5 mistakes for 19 tokens: 26.3% error rate.
print(f"{compute_cer(examples)*100:.1f} %")

28.1 %


Step 2.) Finetune an LLM using GRPO training to correct errors in the Latex syntax

In [None]:
from trl import GRPOConfig, GRPOTrainer
from datasets import Dataset

# Turn the Pandas df from step 1 into a Dataset object
dataset = Dataset.from_pandas(df)

# Create the prompts for GRPO Training
def create_prompt(example):
    example["prompt"] = f"""Please ensure that the following text is valid LaTeX by fixing syntax issues as needed. Here is the potentially invalid LaTeX: {example["predicted_latex"]}. What is the fixed valid LaTeX: """
    return example

dataset = dataset.map(create_prompt)
print(dataset)

# Determine device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training will run on: {device}")

# Create a unique checkpoint directory for each run using a timestamp
run = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
checkpoint_dir = f'/users/0/brogn002/{run}'
os.makedirs(checkpoint_dir, exist_ok=True)

def reward(completions, **kwargs):
    """Reward function that rewards a similarity score between two strings in the range [0,1]."""
    correct_latex = kwargs["label"]
    rewards = []
    for completion, reference in zip(completions, correct_latex):
      if not completion or not reference:
        rewards.append(0.0)
        continue
      # Do not reward empty strings
      if len(completion) == 0:
            rewards.append(0.0)
            continue
      # Perfect match gets a full reward
      if completion == reference:
          rewards.append(1.0)
          continue
      # Apply RapidFuzz ratio for all cases (handles different lengths well)
      similarity = fuzz.ratio(completion, reference) / 100.0
      # Add additional penalty for length mismatch
      length_penalty = max(0, 1 - (abs(len(completion) - len(reference)) / max(len(reference), 1)))
      # Combined score is a linear combination of similarity and length_penalty
      final_score = (similarity * 0.5) + (length_penalty * 0.5)
      rewards.append(final_score)
    return rewards

training_args = GRPOConfig(
    output_dir=checkpoint_dir,
    logging_steps=50,
    per_device_train_batch_size=4,  # Decrease this to lower vram usage
    num_generations=4,  # Decrease this to lower vram usage
    save_strategy="no",  # Do not save checkpoints (saves storage space)
    bf16=True,  # Enable bf16 mixed precision on A100 GPUs
)

trainer = GRPOTrainer(
    model="microsoft/Phi-4-mini-instruct",
    reward_funcs=reward,
    args=training_args,
    train_dataset=dataset,
)

trainer.train()