In [7]:
from dataset import IAMProcessedDataset
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
from torch.optim.lr_scheduler import StepLR

from cnn import *

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Resize((128, 1024))

dataset = IAMProcessedDataset(lines_path='lines_processed', vocab_path='vocab.txt', transform=transform)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_data_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
test_data_loader = DataLoader(test_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

vocab_size = dataset.num_classes
hidden_size = 64

model = HandwritingRecognitionModel(vocab_size, hidden_size)
model.to(device)

criterion = nn.CTCLoss(blank=dataset.blank_idx, zero_infinity=True)

optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

for epoch in range(20):
    # Training Loop
    # Implement the training loop here

    model.train()
    epoch_loss = 0

    for batch_idx, (images, transcriptions_padded, transcription_lengths) in enumerate(train_data_loader):
        # Move data to device
        images = images.to(device)
        transcriptions_padded = transcriptions_padded.to(device)

        # Compute logits
        logits = model(images)  # Shape: (batch_size, width', vocab_size)

        # Get input_lengths (sequence lengths after CNN processing)
        input_lengths = torch.full(
            size=(logits.size(0),), fill_value=logits.size(1), dtype=torch.long
        ).to(device)

        # Calculate CTC loss
        loss = criterion(
            logits.log_softmax(2).permute(1, 0, 2),  # CTC expects (time, batch, classes)
            transcriptions_padded,
            input_lengths,
            torch.tensor(transcription_lengths).to(device),
        )

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate loss
        epoch_loss += loss.item()

        if batch_idx % 10 == 0:
            print(
                f"Epoch [{epoch + 1}/{10}], Batch [{batch_idx}/{len(train_data_loader)}], Loss: {loss.item():.4f}"
            )

    # Step the scheduler
    scheduler.step()

    print(f"Epoch [{epoch + 1}/{10}] completed with average loss: {epoch_loss / len(train_data_loader):.4f}")


Epoch [1/10], Batch [0/308], Loss: 5.0446
Epoch [1/10], Batch [10/308], Loss: 3.1697
Epoch [1/10], Batch [20/308], Loss: 3.1211
Epoch [1/10], Batch [30/308], Loss: 1.4630
Epoch [1/10], Batch [40/308], Loss: 2.2411
Epoch [1/10], Batch [50/308], Loss: 3.1378
Epoch [1/10], Batch [60/308], Loss: 3.2375
Epoch [1/10], Batch [70/308], Loss: 3.2096
Epoch [1/10], Batch [80/308], Loss: 2.3677
Epoch [1/10], Batch [90/308], Loss: 3.1187
Epoch [1/10], Batch [100/308], Loss: 3.1868
Epoch [1/10], Batch [110/308], Loss: 3.0677
Epoch [1/10], Batch [120/308], Loss: 2.2578
Epoch [1/10], Batch [130/308], Loss: 3.0567
Epoch [1/10], Batch [140/308], Loss: 3.1020
Epoch [1/10], Batch [150/308], Loss: 2.5086
Epoch [1/10], Batch [160/308], Loss: 3.1561
Epoch [1/10], Batch [170/308], Loss: 2.3157
Epoch [1/10], Batch [180/308], Loss: 3.2349
Epoch [1/10], Batch [190/308], Loss: 2.2931
Epoch [1/10], Batch [200/308], Loss: 3.2639
Epoch [1/10], Batch [210/308], Loss: 3.1865
Epoch [1/10], Batch [220/308], Loss: 3.1774

KeyboardInterrupt: 

In [5]:
print(type(dataset))

<class 'cnn.IAMProcessedDataset'>


In [9]:
def decode_predictions(logits, idx_to_char):
    """
    Decode predictions using greedy decoding.
    """
    probs = logits.softmax(2)  # Convert logits to probabilities
    preds = torch.argmax(probs, dim=2)  # Get the most likely class for each timestep
    pred_transcriptions = []
    for pred in preds:
        transcription = []
        prev_char = None
        for idx in pred:
            if idx != prev_char and idx != dataset.blank_idx:  # Remove duplicates and blanks
                transcription.append(idx_to_char[idx.item()])
            prev_char = idx
        pred_transcriptions.append("".join(transcription))
    return pred_transcriptions

In [10]:
# Test Loop
model.eval()
test_loss = 0
all_predictions = []
all_ground_truth = []

with torch.no_grad():
    for images, transcriptions_padded, transcription_lengths in test_data_loader:
        # Move data to device
        images = images.to(device)
        transcriptions_padded = transcriptions_padded.to(device)

        # Compute logits
        logits = model(images)  # Shape: (batch_size, width', vocab_size)

        # Decode predictions
        predictions = decode_predictions(logits.cpu(), dataset.idx_to_char)

        # Convert ground truth to readable strings
        ground_truth = []
        for idx, length in enumerate(transcription_lengths):
            transcription = transcriptions_padded[idx][:length].tolist()
            ground_truth.append(decode(transcription, dataset.idx_to_char))

        all_predictions.extend(predictions)
        all_ground_truth.extend(ground_truth)

        # Calculate loss for reporting (optional)
        input_lengths = torch.full(
            size=(logits.size(0),), fill_value=logits.size(1), dtype=torch.long
        ).to(device)
        loss = criterion(
            logits.log_softmax(2).permute(1, 0, 2),  # (time, batch, classes)
            transcriptions_padded,
            input_lengths,
            torch.tensor(transcription_lengths).to(device),
        )
        test_loss += loss.item()

        # Print predictions and ground truth
        for pred, gt in zip(predictions, ground_truth):
            print(f"Prediction: {pred}")
            print(f"Ground Truth: {gt}")
            print("-" * 20)

# Calculate and print accuracy
char_accuracy = calculate_character_accuracy(all_predictions, all_ground_truth)
print(f"Test Loss: {test_loss / len(test_data_loader):.4f}")
print(f"Character-Level Accuracy: {char_accuracy:.4f}")


Prediction: Th . 
Ground Truth: The bismuth was recovered from the eluate as the phosphate. Results did not differ from those obtained by the more convenient method of heating the dissolved chro- mates in 2 N hydrochloric acid for fifteen minutes. The more rigorous method of securing chemical exchange was unnecessary. Lead-210 when present in effluent is likely to be found only at very low concentrations. 
--------------------
Prediction: Th . 
Ground Truth: They intend to sit outside the Ministry of Defence. It is their protest against the H-bomb. They ought to have a pleasant time. The weather forecast is good; except for them, Whitehall should be deserted. And they will have a fine view of St. James's Park, with its placid lake, pelicans, rare ducks, and other wild life. 
--------------------
Prediction: Th . 
Ground Truth: His voice was like his black and pin-stripe, a grey superimposition of respectability over the original colour of his own natural vowels, the result being someho

KeyboardInterrupt: 