In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
from tqdm.notebook import tqdm
import time
from datetime import datetime
import pandas as pd
import json

KeyboardInterrupt: 

In [2]:
class OCRDataset(Dataset):
    def __init__(self, image_paths, image_labels, char_to_id, transform=None):
        self.image_paths = image_paths
        self.image_labels = image_labels
        self.char_to_id = char_to_id
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        
        # Normalize the image
        img = img / 255.0
        
        # Apply transforms if any
        if self.transform:
            img = self.transform(img)
        
        # Convert to tensor and add channel dimension
        img = torch.FloatTensor(img).unsqueeze(0)
        
        # Get label
        label = torch.LongTensor(self.image_labels[idx])
        
        return img, label

# Collate function to handle variable length sequences
def collate_fn(batch):
    images, labels = zip(*batch)
    
    # Pad images to same height and width if necessary
    max_h = max([img.shape[1] for img in images])
    max_w = max([img.shape[2] for img in images])
    
    padded_images = []
    for img in images:
        pad_h = max_h - img.shape[1]
        pad_w = max_w - img.shape[2]
        padded_img = nn.functional.pad(img, (0, pad_w, 0, pad_h), "constant", 0)
        padded_images.append(padded_img)
    
    # Stack all images into a batch
    images = torch.stack(padded_images)
    
    # Store original label lengths for CTC loss
    label_lengths = torch.LongTensor([len(label) for label in labels])
    
    # Pad labels for batch processing
    labels = pad_sequence(labels, batch_first=True, padding_value=0)
    
    return images, labels, label_lengths

# CTC Loss wrapper
class CTCLoss(nn.Module):
    def __init__(self, blank=0):
        super(CTCLoss, self).__init__()
        self.ctc = nn.CTCLoss(blank=blank, reduction='mean', zero_infinity=True)
    
    def forward(self, log_probs, targets, input_lengths, target_lengths):
        return self.ctc(log_probs, targets, input_lengths, target_lengths)


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [4]:
class OCRModel(nn.Module):
    def __init__(self, vocab_size, img_channels=1, hidden_size=128, num_lstm_layers=2):  # Reduced from 256 to 128
        super(OCRModel, self).__init__()

        # CNN Feature Extractor (Reduce Width Pooling)
        self.cnn = nn.Sequential(
            nn.Conv2d(img_channels, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),  # (B, 64, 50, 250)

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),  # (B, 128, 25, 125)

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),  # (B, 256, 12, 125)

            nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),  # Reduced from 512 to 384
            nn.BatchNorm2d(384),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))   # (B, 384, 6, 125)
        )

        # Add a 1x1 convolution to reduce channels before LSTM
        self.channel_reducer = nn.Sequential(
            nn.Conv2d(384, 256, kernel_size=1),  # Reduce channels with 1x1 conv
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        # Compute LSTM input size dynamically
        sample_input = torch.randn(1, img_channels, 100, 500)
        with torch.no_grad():
            cnn_output = self.cnn(sample_input)
            reduced_output = self.channel_reducer(cnn_output)
            _, C, H, W = reduced_output.shape
            self.lstm_input_size = C * H  # Should be smaller than before
            self.width_after_cnn = W

        # LSTM Sequence Model with smaller hidden size
        self.lstm = nn.LSTM(
            input_size=self.lstm_input_size, 
            hidden_size=hidden_size,
            num_layers=num_lstm_layers, 
            batch_first=True, 
            bidirectional=True,
            dropout=0.2  # Add dropout directly in LSTM
        )
        
        # Extra dropout between LSTM and FC
        self.dropout = nn.Dropout(0.2)

        # Fully connected layer (Bidirectional LSTM)
        self.fc = nn.Linear(hidden_size * 2, vocab_size)

    def forward(self, x):
        # CNN forward
        features = self.cnn(x)
        
        # Reduce channels
        features = self.channel_reducer(features)

        # Reshape for LSTM (treat width as time-steps)
        b, c, h, w = features.size()
        features = features.permute(0, 3, 1, 2).contiguous().view(b, w, -1)

        # LSTM forward
        lstm_out, _ = self.lstm(features)
        
        # Apply dropout
        lstm_out = self.dropout(lstm_out)

        # Fully connected
        output = self.fc(lstm_out)
        
        # Apply log_softmax for CTC loss stability
        output = nn.functional.log_softmax(output, dim=2)

        return output

# Initialize weights properly
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)
    elif isinstance(m, nn.LSTM):
        for name, param in m.named_parameters():
            if 'weight' in name:
                nn.init.orthogonal_(param, gain=0.8)  # Reduced gain factor for stability
            elif 'bias' in name:
                nn.init.constant_(param, 0)  # Initialize biases to zero

In [5]:
df = pd.read_csv("../dataset/mini_qa_images/mini_qa.csv")
with open("../dataset/mini_qa_images/char_mappings.json", "r", encoding="utf-8") as f:
    loaded_mapping = json.load(f)

# Convert keys back to int for id_to_char (JSON keys are saved as strings)
char_to_id = loaded_mapping["char_to_id"]
id_to_char = {int(k): v for k, v in loaded_mapping["id_to_char"].items()}

print("Character mappings loaded successfully!")
print("Loaded char_to_id:", char_to_id)
print("Loaded id_to_char:", id_to_char)

def text_to_ids(text):
    return [char_to_id.get(char, 0) for char in text]  # Defaulting to 0 for unknown characters
def ids_to_text(ids):
    return ''.join([id_to_char.get(id, '?') for id in ids])  # Use '?' for unknown IDs
image_paths = []
image_labels = []
for index, row in df.iterrows():
    # if(index%20==0):
    #     print(f"{index} out of 100")
    text = row['question']
    path = f"../dataset/mini_qa_images/question/{index}.png"
    image_paths.append(path)
    image_labels.append(text_to_ids(text))

Character mappings loaded successfully!
Loaded char_to_id: {' ': 1, '?': 2, '᠂': 3, '᠃': 4, '᠋': 5, '᠌': 6, '᠍': 7, '\u180e': 8, 'ᠠ': 9, 'ᠡ': 10, 'ᠢ': 11, 'ᠣ': 12, 'ᠤ': 13, 'ᠥ': 14, 'ᠦ': 15, 'ᠧ': 16, 'ᠨ': 17, 'ᠩ': 18, 'ᠪ': 19, 'ᠬ': 20, 'ᠭ': 21, 'ᠮ': 22, 'ᠯ': 23, 'ᠰ': 24, 'ᠱ': 25, 'ᠲ': 26, 'ᠳ': 27, 'ᠴ': 28, 'ᠵ': 29, 'ᠶ': 30, 'ᠷ': 31, 'ᠹ': 32, '\u202f': 33, '︖': 34, '？': 35}
Loaded id_to_char: {1: ' ', 2: '?', 3: '᠂', 4: '᠃', 5: '᠋', 6: '᠌', 7: '᠍', 8: '\u180e', 9: 'ᠠ', 10: 'ᠡ', 11: 'ᠢ', 12: 'ᠣ', 13: 'ᠤ', 14: 'ᠥ', 15: 'ᠦ', 16: 'ᠧ', 17: 'ᠨ', 18: 'ᠩ', 19: 'ᠪ', 20: 'ᠬ', 21: 'ᠭ', 22: 'ᠮ', 23: 'ᠯ', 24: 'ᠰ', 25: 'ᠱ', 26: 'ᠲ', 27: 'ᠳ', 28: 'ᠴ', 29: 'ᠵ', 30: 'ᠶ', 31: 'ᠷ', 32: 'ᠹ', 33: '\u202f', 34: '︖', 35: '？'}


In [6]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    batch_count = 0
    
    for images, labels, label_lengths in tqdm(dataloader, desc="Training"):
        # Move data to device
        images = images.to(device)
        labels = labels.to(device)
        label_lengths = label_lengths.to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        
        # Calculate input lengths for CTC loss
        batch_size, time_steps, _ = outputs.size()
        input_lengths = torch.full((batch_size,), time_steps, dtype=torch.long).to(device)
        
        # Compute loss
        loss = criterion(outputs.permute(1, 0, 2), labels, input_lengths, label_lengths)
        
        # Backward pass and optimize
        loss.backward()
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()
        
        running_loss += loss.item()
        batch_count += 1
    
    return running_loss / batch_count


In [7]:
def validate(model, dataloader, criterion, device, id_to_char):
    model.eval()
    running_loss = 0.0
    batch_count = 0
    correct_chars = 0
    total_chars = 0
    
    all_predictions = []
    all_ground_truths = []
    
    with torch.no_grad():
        for images, labels, label_lengths in tqdm(dataloader, desc="Validating"):
            # Move data to device
            images = images.to(device)
            labels = labels.to(device)
            label_lengths = label_lengths.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Calculate input lengths for CTC loss
            batch_size, time_steps, _ = outputs.size()
            input_lengths = torch.full((batch_size,), time_steps, dtype=torch.long).to(device)
            
            # Compute loss
            loss = criterion(outputs.permute(1, 0, 2), labels, input_lengths, label_lengths)
            
            running_loss += loss.item()
            batch_count += 1
            
            # Decode predictions
            predictions = greedy_decode(outputs, id_to_char)
            
            # Convert labels to strings
            ground_truths = []
            for i in range(len(labels)):
                gt = ''.join([id_to_char[id.item()] for id in labels[i][:label_lengths[i]]])
                ground_truths.append(gt)
            
            # Store for computing metrics
            all_predictions.extend(predictions)
            all_ground_truths.extend(ground_truths)
            
            # Character accuracy
            for pred, gt in zip(predictions, ground_truths):
                correct_chars += sum([p == g for p, g in zip(pred[:min(len(pred), len(gt))], gt[:min(len(pred), len(gt))])])
                total_chars += len(gt)
    
    # Compute metrics
    val_loss = running_loss / batch_count
    char_accuracy = correct_chars / total_chars if total_chars > 0 else 0
    
    # Compute sequence accuracy (exact matches)
    sequence_matches = sum([p == g for p, g in zip(all_predictions, all_ground_truths)])
    sequence_accuracy = sequence_matches / len(all_predictions) if all_predictions else 0
    
    return val_loss, char_accuracy, sequence_accuracy, all_predictions, all_ground_truths


In [8]:
def greedy_decode(output, id_to_char):
    # Get most likely class
    predictions = torch.argmax(output, dim=2).detach().cpu().numpy()
    
    # Convert to characters
    decoded = []
    for pred in predictions:
        # Remove repeating characters
        collapsed = []
        prev = -1
        for p in pred:
            if p != prev:  # CTC decoding - collapse repeats
                collapsed.append(p)
            prev = p
        
        # Remove blanks (assuming blank is 0)
        text = ''.join([id_to_char[p] for p in collapsed if p != 0])
        decoded.append(text)
    
    return decoded

# Visualize predictions

In [9]:
def visualize_predictions(dataloader, model, id_to_char, device, num_samples=5):
    model.eval()
    images, labels, label_lengths = next(iter(dataloader))
    
    with torch.no_grad():
        images = images.to(device)
        outputs = model(images)
        predictions = greedy_decode(outputs, id_to_char)
    
    # Only show up to num_samples
    num_samples = min(num_samples, len(images))
    
    plt.figure(figsize=(15, 5 * num_samples))
    for i in range(num_samples):
        plt.subplot(num_samples, 1, i + 1)
        plt.imshow(images[i, 0].cpu().numpy(), cmap='gray')
        
        # Get ground truth
        gt = ''.join([id_to_char[id.item()] for id in labels[i][:label_lengths[i]]])
        
        plt.title(f'Prediction: {predictions[i]}\nGround Truth: {gt}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()


In [10]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, id_to_char, checkpoint_dir='checkpoints'):
    # Create checkpoint directory if it doesn't exist
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Initialize tracking variables
    best_val_loss = float('inf')
    history = {
        'train_loss': [],
        'val_loss': [],
        'char_accuracy': [],
        'sequence_accuracy': []
    }
    
    # Start training
    start_time = time.time()
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Train
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        history['train_loss'].append(train_loss)
        
        # Validate
        val_loss, char_accuracy, sequence_accuracy, predictions, ground_truths = validate(model, val_loader, criterion, device, id_to_char)
        history['val_loss'].append(val_loss)
        history['char_accuracy'].append(char_accuracy)
        history['sequence_accuracy'].append(sequence_accuracy)
        
        # Update learning rate
        scheduler.step(val_loss)
        
        # Print metrics
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(f"Character Accuracy: {char_accuracy:.4f} | Sequence Accuracy: {sequence_accuracy:.4f}")
        
        # Sample predictions
        print("\nSample Predictions:")
        for i in range(min(5, len(predictions))):
            print(f"Prediction: {predictions[i]}")
            print(f"Ground Truth: {ground_truths[i]}")
            print()
        
        # Save checkpoint if best
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            checkpoint_path = f"{checkpoint_dir}/best_model_{timestamp}_ep{epoch+1}_loss{val_loss:.4f}_acc{char_accuracy:.4f}.pt"
            
            print(f"Saving best model checkpoint to {checkpoint_path}")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'char_accuracy': char_accuracy,
                'sequence_accuracy': sequence_accuracy,
            }, checkpoint_path)
        
        # Save regular checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            checkpoint_path = f"{checkpoint_dir}/model_epoch{epoch+1}.pt"
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'char_accuracy': char_accuracy,
                'sequence_accuracy': sequence_accuracy,
            }, checkpoint_path)
    
    # Training complete
    total_time = time.time() - start_time
    print(f"\nTraining completed in {total_time/60:.2f} minutes")
    
    return model, history

In [11]:
def plot_history(history):
    plt.figure(figsize=(15, 10))
    
    # Plot loss
    plt.subplot(2, 1, 1)
    plt.plot(history['train_loss'], label='Training Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Loss Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Plot accuracy
    plt.subplot(2, 1, 2)
    plt.plot(history['char_accuracy'], label='Character Accuracy')
    plt.plot(history['sequence_accuracy'], label='Sequence Accuracy')
    plt.title('Accuracy Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()


In [12]:
batch_size = 32
learning_rate = 0.001
num_epochs = 1

# Assuming you have these defined:
# image_paths, image_labels, char_to_id, id_to_char

# Split data into train and validation
from sklearn.model_selection import train_test_split

train_paths, val_paths, train_labels, val_labels = train_test_split(
    image_paths, image_labels, test_size=0.2, random_state=42
)

# Create datasets
train_dataset = OCRDataset(train_paths, train_labels, char_to_id)
val_dataset = OCRDataset(val_paths, val_labels, char_to_id)

# Create dataloaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    collate_fn=collate_fn,
    num_workers=2
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    collate_fn=collate_fn,
    num_workers=2
)

# Initialize model
vocab_size = len(char_to_id)
model = OCRModel(vocab_size=vocab_size)
model.apply(init_weights)  # Initialize weights
model = model.to(device)

# Define loss and optimizer
criterion = CTCLoss(blank=0)  # Assuming 0 is blank
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

# Train model
model, history = train_model(
    model, train_loader, val_loader, criterion, optimizer, scheduler, 
    num_epochs, device, id_to_char
)

# Plot history
plot_history(history)

# Visualize some predictions
visualize_predictions(val_loader, model, id_to_char, device)


Epoch 1/1




Training:   0%|          | 0/3 [00:00<?, ?it/s]

RuntimeError: DataLoader worker (pid(s) 22940) exited unexpectedly