## 4.2 Optical Character Recognition

### 4.2.1 Task 1 : Dataset 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont

import nltk  # type: ignore

# nltk.download('words')
from nltk.corpus import words  # type: ignore

# Create the dataset

def create_dataset():

    # Set up image and font parameters
    img_width = 256
    img_height = 64

    # Path to a valid font file 
    font_path = "/Library/Fonts/Arial.ttf"  
    font_size = 16.5  # Base font size
    font = ImageFont.truetype(font_path, int(font_size)) 

    
    output_dir = "../../data/interim/5/OCR_dataset/"
    os.makedirs(output_dir, exist_ok=True)

    all_words = words.words() 
    selected_words = all_words[:100000]  
    for i, word in enumerate(selected_words):
        img = Image.new('RGB', (img_width, img_height), color=(255, 255, 255))  # Create a white image
        draw = ImageDraw.Draw(img)
        
        # Calculate the position to center the word
        bbox = draw.textbbox((0, 0), word, font=font)  # Get bounding box of the text
        text_width = bbox[2] - bbox[0]  
        text_height = bbox[3] - bbox[1] 
        text_x = (img_width - text_width) // 2
        text_y = (img_height - text_height) // 2
        
        # Draw the word onto the image
        draw.text((text_x, text_y), word, font=font, fill=(0, 0, 0)) 
        
        img.save(f"{output_dir}word_{i}.png") 

    print("Dataset creation completed.")

from nltk.corpus import words

all_words = words.words()
selected_words = all_words[:100000]

# Write these words to a text file
output_file = "words.txt"

# with open(output_file, "w") as file:
#     for word in selected_words:
#         file.write(word + "\n")

print(f"File '{output_file}' has been created with {len(selected_words)} words.")

# create_dataset()

File 'words.txt' has been created with 100000 words.


### 4.2.2 Task 2 : Architecture 

### 4.2.3 Task 3 : Training


In [None]:
# Constants
IMAGE_WIDTH = 256
IMAGE_HEIGHT = 64
BATCH_SIZE = 32
MAX_WORD_LENGTH = 30
HIDDEN_SIZE = 256
DROPOUT_RATE = 0.2
LEARNING_RATE = 0.001

# Character mapping
ALL_CHARS = "abcdefghijklmnopqrstuvwxyz"
CHAR_TO_IDX = {char: idx + 1 for idx, char in enumerate(ALL_CHARS)}
IDX_TO_CHAR = {idx + 1: char for idx, char in enumerate(ALL_CHARS)}
VOCAB_SIZE = len(CHAR_TO_IDX) + 1

class OCRDataset(Dataset):
    def __init__(self, image_dir, word_list_path, transform=None):
        self.image_dir = image_dir
        with open(word_list_path, 'r') as f:
            self.labels = []
            for line in f:
                word = line.strip()
                if len(word) <= MAX_WORD_LENGTH:
                    self.labels.append(word)
        self.labels = self.labels[:100000]
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, f"word_{idx}.png")
        image = Image.open(img_path).convert('L')
        
        if self.transform:
            image = self.transform(image)
        
        word = self.labels[idx].lower()
        label = torch.zeros(MAX_WORD_LENGTH, dtype=torch.long)
        for i, char in enumerate(word):
            if char in CHAR_TO_IDX:
                label[i] = CHAR_TO_IDX[char]
        
        return image, label, len(word)

class CNNEncoder(nn.Module):
    def __init__(self):
        super(CNNEncoder, self).__init__()
        # Simplified CNN architecture with careful dimension handling
        self.features = nn.Sequential(
            # Input: batch x 1 x 64 x 256
            nn.Conv2d(1, 64, kernel_size=3, padding=1),  # 64 x 64 x 256
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # 64 x 32 x 128
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),  # 128 x 32 x 128
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # 128 x 16 x 64
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),  # 256 x 16 x 64
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((2, 2), (2, 2)),  # 256 x 8 x 32
            
            nn.Dropout2d(DROPOUT_RATE)
        )
        
        # Calculate the exact output size
        self.feature_size = 256 * 8  # Height after CNN
        self.sequence_length = 32     # Width after CNN

    def forward(self, x):
        # Input: batch x 1 x 64 x 256
        x = self.features(x)  # batch x 256 x 8 x 32
        
        # Reshape to (batch, sequence_length, features)
        batch_size = x.size(0)
        x = x.permute(0, 3, 1, 2)  # batch x 32 x 256 x 8
        x = x.contiguous().view(batch_size, self.sequence_length, -1)  # batch x 32 x (256*8)
        
        
        x = x[:, :MAX_WORD_LENGTH, :]
        return x

class RNNDecoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNNDecoder, self).__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=1,
            bidirectional=True,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_size * 2, output_size)
        self.dropout = nn.Dropout(DROPOUT_RATE)

    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.dropout(x)
        x = self.fc(x)
        return x

class OCRModel(nn.Module):
    def __init__(self):
        super(OCRModel, self).__init__()
        self.encoder = CNNEncoder()
        self.decoder = RNNDecoder(
            input_size=256 * 8,  # Matches encoder output size
            hidden_size=HIDDEN_SIZE,
            output_size=VOCAB_SIZE
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

def train_model(model, train_loader, criterion, optimizer, device, epoch):
    model.train()
    total_loss = 0
    batch_losses = []
    
    for batch_idx, (images, labels, lengths) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        
        loss = 0
        for t in range(MAX_WORD_LENGTH):
            loss += criterion(outputs[:, t, :], labels[:, t])
        
        loss = loss / MAX_WORD_LENGTH
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        batch_losses.append(loss.item())
        
        if batch_idx % 50 == 0:
            print(f'Epoch: {epoch+1}, Batch: {batch_idx}/{len(train_loader)}, '
                  f'Loss: {loss.item():.4f}')
    
    return total_loss / len(train_loader), batch_losses

def plot_training_progress(train_losses, val_losses, save_path='training_progress.png'):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Progress')
    plt.legend()
    plt.savefig(save_path)
    plt.close()



In [56]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS device")
else:
    device = torch.device("cpu")
    print("MPS device not found, using CPU")

transform = transforms.Compose([
    transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = OCRDataset(
    image_dir="../../data/interim/5/OCR_dataset/",
    word_list_path="words.txt",
    transform=transform
)

# Print dataset size
print(f"Dataset size: {len(dataset)}")

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    num_workers=0
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    num_workers=0
)

model = OCRModel().to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)

num_epochs = 15
best_val_loss = float('inf')
train_losses = []
val_losses = []



for epoch in range(num_epochs):
    train_loss, batch_losses = train_model(
        model, train_loader, criterion, optimizer, device, epoch
    )
    train_losses.append(train_loss)
    
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, labels, lengths in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            
            for t in range(MAX_WORD_LENGTH):
                val_loss += criterion(outputs[:, t, :], labels[:, t]).item()
    
    val_loss /= len(val_loader) * MAX_WORD_LENGTH
    val_losses.append(val_loss)
    scheduler.step(val_loss)
    
    print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch+1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
        }, 'best_ocr_model_1.pth')
        print(f'Saved best model with validation loss: {val_loss:.4f}')
    
    plot_training_progress(train_losses, val_losses)


Using MPS device
Dataset size: 100000
Epoch: 1, Batch: 0/2500, Loss: 1.6456
Epoch: 1, Batch: 50/2500, Loss: 1.4444
Epoch: 1, Batch: 100/2500, Loss: 1.2795
Epoch: 1, Batch: 150/2500, Loss: 1.5007
Epoch: 1, Batch: 200/2500, Loss: 1.3646
Epoch: 1, Batch: 250/2500, Loss: 1.2748
Epoch: 1, Batch: 300/2500, Loss: 1.5434
Epoch: 1, Batch: 350/2500, Loss: 1.4867
Epoch: 1, Batch: 400/2500, Loss: 1.4037
Epoch: 1, Batch: 450/2500, Loss: 1.6456
Epoch: 1, Batch: 500/2500, Loss: 1.4856
Epoch: 1, Batch: 550/2500, Loss: 1.4267
Epoch: 1, Batch: 600/2500, Loss: 1.4955
Epoch: 1, Batch: 650/2500, Loss: 1.5592
Epoch: 1, Batch: 700/2500, Loss: 1.1766
Epoch: 1, Batch: 750/2500, Loss: 1.3743
Epoch: 1, Batch: 800/2500, Loss: 1.2482
Epoch: 1, Batch: 850/2500, Loss: 1.2720
Epoch: 1, Batch: 900/2500, Loss: 1.6634
Epoch: 1, Batch: 950/2500, Loss: 1.8340
Epoch: 1, Batch: 1000/2500, Loss: 1.3984
Epoch: 1, Batch: 1050/2500, Loss: 1.8862
Epoch: 1, Batch: 1100/2500, Loss: 1.4689
Epoch: 1, Batch: 1150/2500, Loss: 1.4729
E

### Epoch-wise train and validation losses:   
   


Epoch 1: Train Loss: 1.3782, Val Loss: 1.3801  
Saved best model with validation loss: 1.3801  

Epoch 2: Train Loss: 1.1173, Val Loss: 1.0039  
Saved best model with validation loss: 1.0039  

Epoch 3: Train Loss: 0.9307, Val Loss: 0.7896  
Saved best model with validation loss: 0.7896  

Epoch 4: Train Loss: 0.6743, Val Loss: 0.6508  
Saved best model with validation loss: 0.6508  

Epoch 5: Train Loss: 0.4581, Val Loss: 0.4033  
Saved best model with validation loss: 0.4033  

Epoch 6: Train Loss: 0.3439, Val Loss: 0.2867  
Saved best model with validation loss: 0.2867  

Epoch 7: Train Loss: 0.2710, Val Loss: 0.3575  

Epoch 8: Train Loss: 0.2206, Val Loss: 0.3429  

Epoch 9: Train Loss: 0.1850, Val Loss: 0.1917  
Saved best model with validation loss: 0.1917  

Epoch 10: Train Loss: 0.1557, Val Loss: 2.2481  

Epoch 11: Train Loss: 0.1340, Val Loss: 0.2132  

Epoch 12: Train Loss: 0.1183, Val Loss: 0.1964  

Epoch 13: Train Loss: 0.1033, Val Loss: 0.1750  
Saved best model with validation loss: 0.1750  

Epoch 14: Train Loss: 0.0944, Val Loss: 2.5423  

Epoch 15: Train Loss: 0.0845, Val Loss: 0.3006  

### Evaluation 

In [None]:
def calculate_accuracy(predictions, targets, lengths):
    batch_size = predictions.size(0)
    pred_chars = predictions.argmax(dim=2)  # Convert from one-hot to indices
    correct_chars = 0
    total_chars = 0
    
    for i in range(batch_size):
        length = lengths[i]
        correct_chars += (pred_chars[i, :length] == targets[i, :length]).sum().item()
        total_chars += length
    
    return correct_chars / total_chars if total_chars > 0 else 0

def generate_random_baseline(batch_size, max_length, vocab_size):
    return torch.randint(1, vocab_size, (batch_size, max_length))

def clean_prediction(word):
    # Remove non-alphabetic characters
    word = ''.join(c for c in word if c.isalpha())
    
    if not word:
        return word
        
    # Initialize result with first character
    result = [word[0]]
    
    # Track repeated characters
    repeat_count = 1
    max_repeats = 2  
    
    # Process rest of the word
    for i in range(1, len(word)):
        if word[i] == result[-1]:
            repeat_count += 1
            if repeat_count <= max_repeats:
                result.append(word[i])
        else:
            repeat_count = 1
            result.append(word[i])
    
    return ''.join(result)

def decode_prediction(pred_tensor):
    # First decode the tensor to a string
    word = ""
    for idx in pred_tensor:
        if idx.item() == 0:  # Skip padding
            break
        if idx.item() in IDX_TO_CHAR:
            word += IDX_TO_CHAR[idx.item()]
    
    # Clean up the prediction
    return clean_prediction(word)

def evaluate_model(model, val_loader, device, num_examples=20):
    model.eval()
    total_accuracy = 0
    total_baseline_accuracy = 0
    num_batches = 0
    examples = []
    
    with torch.no_grad():
        for images, labels, lengths in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            # Model predictions
            outputs = model(images)
            pred_chars = outputs.argmax(dim=2)
            accuracy = calculate_accuracy(outputs, labels, lengths)
            
            # Random baseline
            random_preds = generate_random_baseline(
                labels.size(0), 
                MAX_WORD_LENGTH, 
                VOCAB_SIZE
            ).to(device)
            baseline_accuracy = calculate_accuracy(
                torch.nn.functional.one_hot(random_preds, VOCAB_SIZE).float(), 
                labels, 
                lengths
            )
            
            total_accuracy += accuracy
            total_baseline_accuracy += baseline_accuracy
            num_batches += 1
            
            # Collect example predictions
            if len(examples) < num_examples:
                for i in range(min(num_examples - len(examples), len(lengths))):
                    true_word = decode_prediction(labels[i])
                    pred_word = decode_prediction(pred_chars[i])
                    examples.append({
                        'true': true_word,
                        'predicted': pred_word,
                        'correct_chars': sum(1 for t, p in zip(true_word, pred_word) if t == p)
                    })
    
    avg_accuracy = total_accuracy / num_batches
    avg_baseline_accuracy = total_baseline_accuracy / num_batches
    
    return avg_accuracy, avg_baseline_accuracy, examples

def print_evaluation_results(model, val_loader, device):
    print("\nModel Evaluation Results:")
    print("-" * 50)
    
    accuracy, baseline_accuracy, examples = evaluate_model(model, val_loader, device)
    
    print(f"Model Accuracy (Average Correct Characters): {accuracy:.4f}")
    print(f"Random Baseline Accuracy: {baseline_accuracy:.4f}")
    print(f"Improvement over baseline: {accuracy - baseline_accuracy:.4f}")
    
    print("\nExample Predictions:")
    print("-" * 50)
    for i, example in enumerate(examples, 1):
        print(f"\nExample {i}:")
        print(f"True word:      {example['true']}")
        print(f"Predicted word: {example['predicted']}")
        print(f"Correct characters: {example['correct_chars']}/{len(example['true'])}")
    
    return accuracy, baseline_accuracy, examples



def load_and_evaluate_model():
    # Set device
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print("Using MPS device")
    else:
        device = torch.device("cpu")
        print("MPS device not found, using CPU")
    
    
    model = OCRModel().to(device)
    
    # Load the trained model
    print("\nLoading best model from checkpoint...")
    checkpoint = torch.load('best_ocr_model_1.pth', map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    print(f"Model loaded from epoch {checkpoint['epoch']}")
    print(f"Previous training metrics:")
    print(f"- Training loss: {checkpoint['train_loss']:.4f}")
    print(f"- Validation loss: {checkpoint['val_loss']:.4f}")
    
    # Prepare dataset and loader
    transform = transforms.Compose([
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    dataset = OCRDataset(
        image_dir="../../data/interim/5/OCR_dataset/",
        word_list_path="words.txt",
        transform=transform
    )
    
    # Split dataset
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    _, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False,
        num_workers=0
    )
    
    # Evaluate model
    print("\nEvaluating model performance...")
    model.eval()
    accuracy, baseline_accuracy, examples = print_evaluation_results(model, val_loader, device)
    
    # Plot some example images with predictions
    plot_example_predictions(model, val_loader, device)
    
    return model, accuracy, baseline_accuracy, examples

def plot_example_predictions(model, val_loader, device, num_examples=5):
    model.eval()
    plt.figure(figsize=(15, 3*num_examples))
    
    with torch.no_grad():
        images, labels, lengths = next(iter(val_loader))
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        pred_chars = outputs.argmax(dim=2)
        
        for i in range(min(num_examples, len(images))):
            plt.subplot(num_examples, 1, i+1)
            
            # Convert tensor to image
            img = images[i].cpu().squeeze().numpy()
            plt.imshow(img, cmap='gray')
            
            # Get true and predicted words
            true_word = decode_prediction(labels[i])
            pred_word = decode_prediction(pred_chars[i])
            
            plt.title(f'True: "{true_word}" | Predicted: "{pred_word}"')
            plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('example_predictions.png')
    plt.close()

if __name__ == "__main__":
    # Load and evaluate the model
    model, accuracy, baseline_accuracy, examples = load_and_evaluate_model()
    
    print("\nFinal Results Summary:")
    print("-" * 50)
    print(f"Model Accuracy: {accuracy:.4f}")
    print(f"Baseline Accuracy: {baseline_accuracy:.4f}")
    print(f"Improvement over baseline: {accuracy - baseline_accuracy:.4f}")
    
    print("\nExample predictions have been plotted and saved as 'example_predictions.png'")

Using MPS device

Loading best model from checkpoint...
Model loaded from epoch 13
Previous training metrics:
- Training loss: 0.1033
- Validation loss: 0.1750

Evaluating model performance...

Model Evaluation Results:
--------------------------------------------------


  checkpoint = torch.load('best_ocr_model_1.pth', map_location=device)


Model Accuracy (Average Correct Characters): 0.9619
Random Baseline Accuracy: 0.0381
Improvement over baseline: 0.9238

Example Predictions:
--------------------------------------------------

Example 1:
True word:      considerability
Predicted word: considerabilityyeeyy
Correct characters: 15/15

Example 2:
True word:      extratympanic
Predicted word: extratympaniccyyee
Correct characters: 13/13

Example 3:
True word:      bewinged
Predicted word: bewingeddyyeeyy
Correct characters: 8/8

Example 4:
True word:      hipless
Predicted word: hiplessyyeeyy
Correct characters: 7/7

Example 5:
True word:      cosmographer
Predicted word: cosmographeryyeeyy
Correct characters: 12/12

Example 6:
True word:      flagroot
Predicted word: flagroottyyeeyy
Correct characters: 8/8

Example 7:
True word:      filaria
Predicted word: filariaaeyyeeyy
Correct characters: 7/7

Example 8:
True word:      hemispherically
Predicted word: hemisphericallyyeeyy
Correct characters: 15/15

Example 9:
True word

<img src="figures/example_predictions.png" alt="example_predictions" width="800">