# OCR for IAM lines dataset pipeline 

This notebook presents a complete end-to-end pipeline for Optical Character Recognition (OCR) on the IAM Lines dataset, a widely used benchmark for handwritten text recognition. The pipeline covers all key stages of building an OCR system: data exploration and preprocessing, augmentation, dataset preparation, model design (a CRNN with attention), training with CTC loss, and evaluation using beam search with a language model.

# Installing libraries
These commands install the core dependencies required for the OCR pipeline. The datasets library provides easy access to the IAM Lines dataset, while editdistance is used for computing accuracy metrics, albumentations supports image augmentation, pyctcdecode enables CTC decoding with beam search and kenlm is used to integrate a language model for improved recognition.


In [None]:
!pip install datasets editdistance albumentations
!pip install pyctcdecode
!pip install kenlm

# Parameter and vocabulary configuration

This configuration block defines the core settings for training and preprocessing. It sets hyperparameters such as learning rate, max number of epochs, batch size, weight decay, and dropout, which control how the model learns. It also specifies image dimensions for consistent preprocessing and builds the vocabulary with character-to-index mappings, which are essential for encoding handwritten text into model-readable form and decoding predictions back into text.

In [None]:
import random
import numpy as np
import torch

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

# Training hyperparameters
LEARNING_RATE = 3e-4
EPOCHS = 80
WEIGHT_DECAY = 1e-4
BATCH_SIZE = 8
DROPOUT = 0.2

# Image preprocessing
TARGET_HEIGHT = 128
TARGET_WIDTH = 1028

# Vocabulary
VOCAB = list("!#&'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz ")
VOCAB = ['<BLANK>'] + VOCAB
VOCAB_SIZE = len(VOCAB)

# Create mappings
CHAR_TO_IDX = {char: idx for idx, char in enumerate(VOCAB)}
IDX_TO_CHAR = {idx: char for char, idx in CHAR_TO_IDX.items()}
CHARS = VOCAB[1:]  # Remove <BLANK> for pyctcdecode

import os
import cv2
import tempfile
import logging
import re
from typing import List, Tuple
import numpy as np
from PIL import Image
import editdistance
from collections import Counter
from itertools import chain
import matplotlib.pyplot as plt

def setup_environment():
    """Set up environment variables and CUDA settings."""
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Data Analysis 

The functions show_random_samples and analyze_dataset provide essential exploratory data analysis for the IAM dataset. By visualizing random handwritten text samples, we can qualitatively assess the dataset’s quality and variety. The analysis further explores character frequency, line length distribution, image dimensions, unique character sets, and common words, offering valuable insights into the dataset’s structure. This understanding guides vocabulary design, preprocessing strategies, and model configuration choices.

In [None]:
def show_random_samples(dataset, n=5):
    """Display random image samples from dataset."""
    samples = dataset.shuffle(seed=42).select(range(n))
    for i, sample in enumerate(samples):
        plt.figure(figsize=(8, 2))
        plt.imshow(sample["image"], cmap="gray")
        plt.title(f"Text: {sample['text']}")
        plt.axis("off")
        plt.show()

def analyze_dataset(dataset):
    """Perform exploratory data analysis on the dataset."""
    # Character frequency analysis
    char_counts = Counter()
    for sample in dataset:  
        char_counts.update(sample["text"])

    chars, freqs = zip(*char_counts.most_common())
    plt.figure(figsize=(20, 6))
    plt.bar(chars, freqs)
    plt.title("Most Common Characters in Dataset")
    plt.xlabel("Character")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.show()

    # Line length distribution
    line_lengths = [len(sample["text"]) for sample in dataset]
    plt.figure(figsize=(10, 5))
    plt.hist(line_lengths, bins=30, color="skyblue", edgecolor="black")
    plt.title("Distribution of Line Text Lengths")
    plt.xlabel("Number of Characters")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.show()
    print(f"Average line length: {np.mean(line_lengths):.2f}")

    # Image size analysis (sample)
    widths, heights = [], []
    for sample in dataset.select(range(500)):  
        w, h = sample["image"].size
        widths.append(w)
        heights.append(h)

    plt.figure(figsize=(10, 4))
    plt.hist(widths, bins=30, alpha=0.7, label="Width")
    plt.hist(heights, bins=30, alpha=0.7, label="Height")
    plt.title("Image Width and Height Distribution (First 500 Samples)")
    plt.xlabel("Pixels")
    plt.ylabel("Frequency")
    plt.legend()
    plt.grid(True)
    plt.show()
    print(f"Average Width: {np.mean(widths):.1f}, Average Height: {np.mean(heights):.1f}")

    # Unique characters
    unique_chars = sorted(set(''.join([s["text"] for s in dataset])))  
    print("Unique Characters Found:")
    print("".join(unique_chars))
    print(f"Total Unique Characters: {len(unique_chars)}")

    # Word frequency
    word_counts = Counter(chain.from_iterable(s["text"].split() for s in dataset))  
    print("Top 10 most common words:")
    print(word_counts.most_common(10))
    print(f"Total unique words: {len(word_counts)}")

# Image augmentations

This block defines the data augmentation pipeline using Albumentations, a library for efficient image transformations. The pipeline introduces controlled variations such as slight rotations, brightness/contrast adjustments, elastic deformations, grid distortions, noise, and motion blur. These augmentations simulate the natural variability in handwriting and scanning conditions, making the model more robust and better at generalizing to unseen handwriting styles.

In [None]:
import cv2
import logging
import numpy as np
from PIL import Image
import albumentations as A

# Augmentation pipeline
ALBUMENTATIONS_TRANSFORM = A.Compose([
    A.Rotate(limit=2, border_mode=cv2.BORDER_CONSTANT, p=0.4),
    A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.4),
    A.ElasticTransform(alpha=1, sigma=50, p=0.3),
    A.GridDistortion(num_steps=5, distort_limit=0.1, p=0.3),
    A.GaussNoise(noise_scale_factor=0.1, p=0.2),
    A.MotionBlur(blur_limit=3, p=0.2),
])

# Image preprocessing and showing of preprocessed image samples

This section defines the preprocessing pipeline that prepares raw handwritten line images for training. The preprocess_image function standardizes inputs by converting them to grayscale, handling invalid cases, applying optional contrast enhancement, and supporting several binarization methods (e.g., Otsu, adaptive mean, adaptive Gaussian). It then resizes and pads images to fixed dimensions while preserving aspect ratio, and normalizes pixel values to either [0,1] or [-1,1]. The show_preprocessed_samples helper function visualizes original images alongside their preprocessed versions, making it easier to verify the effects of preprocessing choices.

In [None]:
def preprocess_image(img, target_height=128, target_width=1028,
                     enhance_contrast=False, normalization_range="0_1",
                     log_rescaling=True, binarization=None, binary_threshold=127,
                     adaptive_threshold_block_size=11, adaptive_threshold_c=2):
    """
    Preprocess image for IAM OCR.
    
    Args:
        img: PIL Image or numpy array
        target_height: Target height in pixels
        target_width: Target width in pixels
        enhance_contrast: Whether to apply contrast enhancement
        normalization_range: "0_1" for [0,1] or "-1_1" for [-1,1]
        log_rescaling: Whether to log rescaling operations
        binarization: Binarization method (None, "otsu", "adaptive_mean", etc.)
        binary_threshold: Threshold for simple binarization
        adaptive_threshold_block_size: Block size for adaptive thresholding
        adaptive_threshold_c: Constant for adaptive thresholding
    
    Returns:
        Preprocessed image as numpy array with shape (height, width, 1)
    """
    # Convert to grayscale numpy array
    if isinstance(img, Image.Image):
        img = np.array(img.convert("L"))
    elif len(img.shape) == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    
    # Handle edge cases
    if img.size == 0:
        logging.warning("Empty image provided, returning zero array")
        return np.zeros((target_height, target_width, 1), dtype=np.float32)
    
    h, w = img.shape
    if h == 0 or w == 0:
        logging.warning(f"Invalid image dimensions: {h}x{w}, returning zero array")
        return np.zeros((target_height, target_width, 1), dtype=np.float32)
    
    # Optional contrast enhancement
    if enhance_contrast:
        img = cv2.convertScaleAbs(img, alpha=1.1, beta=5)
    
    # Apply binarization
    if binarization == "otsu":
        _, img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    elif binarization == "adaptive_mean":
        img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_MEAN_C, 
                                   cv2.THRESH_BINARY, adaptive_threshold_block_size, 
                                   adaptive_threshold_c)
    elif binarization == "adaptive_gaussian":
        img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                                   cv2.THRESH_BINARY, adaptive_threshold_block_size, 
                                   adaptive_threshold_c)
    elif binarization == "simple":
        _, img = cv2.threshold(img, binary_threshold, 255, cv2.THRESH_BINARY)
    elif binarization is not None:
        raise ValueError("Invalid binarization method")
    
    # Calculate scaling factors
    scale_h = target_height / h
    scale_w = target_width / w
    scale = min(scale_h, scale_w)
    
    new_h = int(h * scale)
    new_w = int(w * scale)
    
    # Log aspect ratio changes
    if log_rescaling:
        original_aspect = w / h
        target_aspect = target_width / target_height
        if abs(original_aspect - target_aspect) / target_aspect > 0.1:
            logging.info(f"Image rescaled from {w}x{h} to {new_w}x{new_h}")
    
    # Resize and pad
    img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
    padded = np.full((target_height, target_width), 255, dtype=np.uint8)
    
    y_offset = (target_height - new_h) // 2
    x_offset = (target_width - new_w) // 2
    padded[y_offset:y_offset + new_h, x_offset:x_offset + new_w] = img
    
    # Normalize
    img = padded.astype(np.float32) / 255.0
    if normalization_range == "-1_1":
        img = (img - 0.5) * 2
    elif normalization_range != "0_1":
        raise ValueError("normalization_range must be '0_1' or '-1_1'")
    
    return np.expand_dims(img, axis=-1)

def show_preprocessed_samples(dataset, n=5):
    """Show original vs preprocessed images."""
    samples = dataset.shuffle(seed=42).select(range(n))
    
    for sample in samples:
        original_img = sample["image"]
        text = sample["text"]
        
        processed_img = preprocess_image(
            original_img, 128, 1028, True, "0_1", True, "adaptive_gaussian"
        )
        processed_img_vis = processed_img.squeeze()
        
        plt.figure(figsize=(12, 3))
        
        plt.subplot(1, 2, 1)
        plt.imshow(original_img, cmap="gray")
        plt.title("Original")
        plt.axis("off")
        
        plt.subplot(1, 2, 2)
        plt.imshow(processed_img_vis, cmap="gray")
        plt.title(f"Preprocessed\n{text}")
        plt.axis("off")
        
        plt.tight_layout()
        plt.show()


# Dataset wrapper class and functions

This part defines how the IAM dataset is wrapped and prepared for training. The IAMDataset class extends PyTorch’s Dataset and handles preprocessing, augmentation, and encoding text labels into numerical indices. It also filters out samples with unsupported characters or overly long text. Each item returns the preprocessed image tensor along with its text and target sequence. The collate_fn function ensures variable-length text sequences are batched properly by concatenating targets and storing their lengths. Finally, create_dataloaders constructs efficient PyTorch dataloaders for both training and validation, handling batching, shuffling, and parallel data loading. This structure ensures seamless integration between the dataset and the model training loop.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
from PIL import Image
from typing import Dict, List, Optional

class IAMDataset(Dataset):
    """Dataset wrapper for IAM handwritten text recognition."""
    
    def __init__(self, 
                 hf_dataset,
                 char_to_idx: Dict[str, int],
                 idx_to_char: Dict[int, str],
                 target_height: int = 128,
                 target_width: int = 1028,
                 enhance_contrast: bool = False,
                 normalization_range: str = "0_1",
                 max_text_length: Optional[int] = None,
                 augment: bool = False,
                 log_rescaling: bool = True, 
                 binarization=None, 
                 binary_threshold=127,
                 adaptive_threshold_block_size=11, 
                 adaptive_threshold_c=2):
        
        self.dataset = hf_dataset
        self.char_to_idx = char_to_idx
        self.idx_to_char = idx_to_char
        self.target_height = target_height
        self.target_width = target_width
        self.enhance_contrast = enhance_contrast
        self.normalization_range = normalization_range
        self.max_text_length = max_text_length
        self.augment = augment
        self.log_rescaling = log_rescaling
        self.binarization = binarization
        self.binary_threshold = binary_threshold
        self.adaptive_threshold_block_size = adaptive_threshold_block_size
        self.adaptive_threshold_c = adaptive_threshold_c
            
        # Filter by max text length
        if max_text_length:
            self.dataset = self.dataset.filter(
                lambda x: len(x["text"]) <= max_text_length
            )
            print(f"Filtered dataset to {len(self.dataset)} samples")
        
        # Filter samples that can't be encoded
        self.valid_indices = []
        for idx in range(len(self.dataset)):
            text = self.dataset[idx]["text"]
            if self._can_encode_text(text):
                self.valid_indices.append(idx)
        
        print(f"Valid samples: {len(self.valid_indices)}/{len(self.dataset)}")
    
    def _can_encode_text(self, text: str) -> bool:
        """Check if text can be encoded with current vocabulary."""
        return all(char in self.char_to_idx for char in text)
    
    def _encode_text(self, text: str) -> List[int]:
        """Encode text to indices."""
        return [self.char_to_idx[char] for char in text if char in self.char_to_idx]
    
    def _preprocess_image(self, img):
        """Preprocess image with optional augmentation."""
        if isinstance(img, Image.Image):
            img = np.array(img.convert("L"))
        elif len(img.shape) == 3:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        if img.dtype != np.uint8:
            img = img.astype(np.uint8)

        if self.augment:
            img = ALBUMENTATIONS_TRANSFORM(image=img)["image"]

        return preprocess_image(
            img,
            target_height=self.target_height,
            target_width=self.target_width,
            enhance_contrast=self.enhance_contrast,
            normalization_range=self.normalization_range,
            log_rescaling=self.log_rescaling,
            binarization=self.binarization,
            binary_threshold=self.binary_threshold,
            adaptive_threshold_block_size=self.adaptive_threshold_block_size,
            adaptive_threshold_c=self.adaptive_threshold_c
        )
    
    def __len__(self) -> int:
        return len(self.valid_indices)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        real_idx = self.valid_indices[idx]
        sample = self.dataset[real_idx]
        
        image = self._preprocess_image(sample["image"])
        image = torch.from_numpy(image).permute(2, 0, 1)
        
        text = sample["text"]
        target = self._encode_text(text)
        
        return {
            'image': image,
            'text': text,
            'target': torch.tensor(target, dtype=torch.long),
            'target_length': torch.tensor(len(target), dtype=torch.long)
        }

def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
    """Collate function for DataLoader."""
    images = torch.stack([item['image'] for item in batch])
    texts = [item['text'] for item in batch]
    targets = [item['target'] for item in batch]
    target_lengths = torch.stack([item['target_length'] for item in batch])
    targets_concat = torch.cat(targets)
    
    return {
        'images': images,
        'texts': texts,
        'targets': targets_concat,
        'target_lengths': target_lengths
    }

def create_dataloaders(train_dataset, val_dataset, batch_size=8, num_workers=4):
    """Create train and validation dataloaders."""
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, collate_fn=collate_fn, pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, collate_fn=collate_fn, pin_memory=True
    )
    
    return train_loader, val_loader

# CRNN architecture with Spatial Attention

The CRNN model combines convolutional layers, optional spatial attention, and recurrent layers to recognize handwritten text. The convolutional backbone with residual blocks extracts hierarchical visual features and compresses the image into a sequence representation. A multi-head attention module can be applied to emphasize informative regions. The extracted features are then processed by stacked bidirectional LSTMs, which capture sequential dependencies across the text line. Finally, a linear classifier maps the sequence outputs to the vocabulary, producing log-probabilities suitable for CTC loss training.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import math

class ResidualBlock(nn.Module):
    """Residual block with GELU activation."""
    
    def __init__(self, main_path, shortcut):
        super(ResidualBlock, self).__init__()
        self.main_path = main_path
        self.shortcut = shortcut
    
    def forward(self, x):
        return F.gelu(self.main_path(x) + self.shortcut(x))

class MultiHeadSpatialAttention(nn.Module):
    """Multi-head spatial attention mechanism."""
    
    def __init__(self, channels, num_heads=4):
        super(MultiHeadSpatialAttention, self).__init__()
        self.num_heads = num_heads
        self.channels = channels
        self.head_dim = channels // num_heads
        
        assert channels % num_heads == 0, "channels must be divisible by num_heads"
        
        self.query_conv = nn.Conv2d(channels, channels, 1)
        self.key_conv = nn.Conv2d(channels, channels, 1)
        self.value_conv = nn.Conv2d(channels, channels, 1)
        self.output_conv = nn.Conv2d(channels, channels, 1)
        
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        batch_size, channels, height, width = x.size()
        
        # Generate Q, K, V
        q = self.query_conv(x).view(batch_size, self.num_heads, self.head_dim, height * width)
        k = self.key_conv(x).view(batch_size, self.num_heads, self.head_dim, height * width)
        v = self.value_conv(x).view(batch_size, self.num_heads, self.head_dim, height * width)
        
        # Transpose for attention computation
        q = q.permute(0, 1, 3, 2)  
        k = k.permute(0, 1, 2, 3)  
        v = v.permute(0, 1, 3, 2)  
        
        # Compute attention
        attention_weights = torch.matmul(q, k) / math.sqrt(self.head_dim)
        attention_weights = self.softmax(attention_weights)
        attention_weights = self.dropout(attention_weights)
        
        attended_values = torch.matmul(attention_weights, v)
        attended_values = attended_values.permute(0, 1, 3, 2).contiguous()
        attended_values = attended_values.view(batch_size, channels, height, width)
        
        output = self.output_conv(attended_values)
        return output + x  # Residual connection

class CRNN(nn.Module):
    """CRNN model for handwriting recognition."""
    
    def __init__(self, vocab_size=80, hidden_size=256, num_lstm_layers=2, 
                 dropout=0.2, use_attention=True, attention_heads=4):
        super(CRNN, self).__init__()
        
        self.vocab_size = vocab_size
        self.use_attention = use_attention
        self.num_lstm_layers = num_lstm_layers
        self.dropout = dropout
        
        # CNN backbone
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.GELU(),
            
            # Block 1: 128x1024 -> 64x512
            self._make_residual_block(64, 64),
            nn.MaxPool2d(2, 2),
            
            # Block 2: 64x512 -> 32x256  
            self._make_residual_block(64, 128),
            self._make_residual_block(128, 128),
            nn.MaxPool2d(2, 2),
            
            # Block 3: 32x256 -> 16x256
            self._make_residual_block(128, 256),
            self._make_residual_block(256, 256),
            nn.MaxPool2d((2, 1), (2, 1)),
            
            # Block 4: 16x256 -> 8x256
            self._make_residual_block(256, 512),
            self._make_residual_block(512, 512),
            nn.MaxPool2d((2, 1), (2, 1)),
            
            # Final compression: 8x256 -> 1x256
            nn.Conv2d(512, 512, (3, 1), (2, 1), (1, 0)),
            nn.BatchNorm2d(512),
            nn.GELU(),
            nn.Conv2d(512, 512, (3, 1), (2, 1), (1, 0)),
            nn.BatchNorm2d(512),
            nn.GELU(),
            nn.AdaptiveAvgPool2d((1, None)),
        )
        
        # Spatial attention
        if self.use_attention:
            self.multihead_attention = MultiHeadSpatialAttention(512, attention_heads)
        
        # RNN layers
        self.rnn_layers = nn.ModuleList()
        self.dropout_layers = nn.ModuleList()
        
        for i in range(num_lstm_layers):
            input_size = 512 if i == 0 else hidden_size
            
            self.rnn_layers.append(
                nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
            )
            
            self.dropout_layers.append(nn.ModuleDict({
                'linear': nn.Linear(hidden_size * 2, hidden_size),
                'dropout': nn.Dropout(dropout)
            }))
        
        # Classifier
        self.classifier = nn.Linear(hidden_size, vocab_size)
        self._initialize_weights()
    
    def _make_residual_block(self, in_channels, out_channels, stride=1):
        """Create residual block."""
        layers = [
            nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels)
        ]
        main_path = nn.Sequential(*layers)
        
        shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        
        return ResidualBlock(main_path, shortcut)
    
    def _initialize_weights(self):
        """Initialize model weights."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='linear')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='linear')
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.LSTM):
                for name, param in m.named_parameters():
                    if 'weight' in name:
                        init.orthogonal_(param)
                    elif 'bias' in name:
                        init.constant_(param, 0)
    
    def forward(self, x):
        # CNN feature extraction
        conv_features = self.conv_layers(x)
        
        # Apply attention
        if self.use_attention:
            conv_features = self.multihead_attention(conv_features)
        
        batch_size, channels, height, width = conv_features.size()
        assert height == 1, f"Height should be 1 after CNN, got {height}"
        
        rnn_input = conv_features.squeeze(2).permute(0, 2, 1)
        
        # RNN layers
        rnn_output = rnn_input
        for i in range(self.num_lstm_layers):
            lstm_out, _ = self.rnn_layers[i](rnn_output)
            rnn_output = self.dropout_layers[i]['linear'](lstm_out)
            rnn_output = self.dropout_layers[i]['dropout'](rnn_output)
        
        # Classification
        output = self.classifier(rnn_output)
        output = F.log_softmax(output, dim=2)
        
        return output

# Setting up KenLM language model

This section sets up beam search decoding with a language model to improve transcription quality. First, we build a KenLM n-gram model from a cleaned text corpus that combines IAM training texts and optionally WikiText. The setup_kenlm and create_kenlm_from_corpus functions handle installing KenLM, preparing the corpus, and training an n-gram model. The setup_beam_search_decoder function integrates this language model into a CTC beam search decoder using pyctcdecode, enabling the model to generate more linguistically plausible outputs compared to greedy decoding. Finally, the wbs_decode_batch function applies this decoder to batched log-probabilities, producing refined text predictions with temperature scaling and proper handling of blank tokens.

In [None]:
import os
import tempfile
import re
import numpy as np
import torch
from pyctcdecode import build_ctcdecoder
import kenlm

def setup_kenlm():
    """Setup KenLM for language modeling."""
    print("Installing dependencies...")
    os.system("apt-get update -qq")
    os.system("apt-get install -y build-essential libboost-all-dev cmake zlib1g-dev libbz2-dev liblzma-dev")
    
    print("Cloning and building KenLM...")
    if not os.path.exists("/kaggle/working/kenlm"):
        os.system("git clone https://github.com/kpu/kenlm.git")
    os.system("cd kenlm && mkdir -p build && cd build && cmake .. && make -j$(nproc)")
    
    kenlm_path = "/kaggle/working/kenlm/build/bin"
    lmplz_path = os.path.join(kenlm_path, "lmplz")
    
    if os.path.exists(lmplz_path):
        print(f"lmplz found at: {lmplz_path}")
        return lmplz_path
    else:
        print("lmplz not found.")
        return None

def clean_text_for_kenlm(text):
    """Clean text for KenLM compatibility."""
    if not text or not isinstance(text, str):
        return ""
    
    # Remove problematic tokens
    text = text.replace('<unk>', ' ')
    text = text.replace('<UNK>', ' ')
    text = text.replace('<s>', ' ')
    text = text.replace('</s>', ' ')
    text = text.replace('<pad>', ' ')
    text = text.replace('<PAD>', ' ')
    
    # Clean whitespace
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    
    # Keep only vocabulary characters
    allowed_chars = set("!#&'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz ")
    text = ''.join(c for c in text if c in allowed_chars)
    
    return text

def create_kenlm_from_corpus(corpus_text, order=3, lmplz_path=None):
    """Create KenLM model from text corpus."""
    corpus_file = tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False)
    model_file = tempfile.NamedTemporaryFile(mode='w', suffix='.arpa', delete=False)
    
    try:
        # Write corpus
        print("Writing corpus to temporary file...")
        corpus_file.write(corpus_text)
        corpus_file.close()
        
        print(f"Corpus file size: {os.path.getsize(corpus_file.name)} bytes")
        
        # Find lmplz executable
        if lmplz_path and os.path.exists(lmplz_path):
            lmplz_cmd = lmplz_path
        else:
            possible_paths = [
                "/kaggle/working/kenlm/build/bin/lmplz",
                "/kaggle/working/kenlm/bin/lmplz",
                "lmplz"
            ]
            
            lmplz_cmd = None
            for path in possible_paths:
                if os.path.exists(path) or path == "lmplz":
                    lmplz_cmd = path
                    break
            
            if lmplz_cmd is None:
                raise FileNotFoundError("lmplz executable not found!")
        
        # Build language model
        cmd = f"{lmplz_cmd} -o {order} --discount_fallback --skip_symbols < {corpus_file.name} > {model_file.name}"
        print(f"Running: {cmd}")
        
        result = os.system(cmd)
        
        if result != 0:
            raise RuntimeError(f"lmplz failed with return code {result}")
        
        if not os.path.exists(model_file.name) or os.path.getsize(model_file.name) == 0:
            raise RuntimeError("lmplz produced empty model file")
        
        print(f"Model created successfully: {model_file.name}")
        print(f"Model file size: {os.path.getsize(model_file.name)} bytes")
        
        return model_file.name
        
    except Exception as e:
        if os.path.exists(corpus_file.name):
            os.unlink(corpus_file.name)
        if os.path.exists(model_file.name):
            os.unlink(model_file.name)
        raise e
    finally:
        if os.path.exists(corpus_file.name):
            os.unlink(corpus_file.name)

def setup_beam_search_decoder(vocab, train_data, wiki_data=None, order=3):
    """Setup beam search decoder with language model."""
    lmplz_path = setup_kenlm()
    if lmplz_path is None:
        raise RuntimeError("Failed to setup KenLM")
    
    print("Preparing text corpus...")
    
    # Process IAM texts
    iam_texts = []
    for item in train_data:
        if item.get("text"):
            cleaned = clean_text_for_kenlm(item["text"])
            if cleaned:  
                iam_texts.append(cleaned)
    
    print(f"Cleaned IAM texts: {len(iam_texts)} samples")
    
    # Process WikiText if available
    if wiki_data:
        print("Processing WikiText data...")
        wiki_texts = []
        for _, x in enumerate(wiki_data):
            if x.get('text') and x['text'].strip():
                cleaned = clean_text_for_kenlm(x['text'])
                if cleaned: 
                    wiki_texts.append(cleaned)
        
        print(f"Cleaned WikiText: {len(wiki_texts)} samples")
        combined_texts = iam_texts + wiki_texts[:100000]
    else:
        combined_texts = iam_texts
    
    # Create corpus
    corpus_text = "\n".join(combined_texts)
    print(f"Final corpus: {len(combined_texts)} lines, {len(corpus_text)} characters")
    
    # Create KenLM model
    print(f"Training {order}-gram language model...")
    model_path = create_kenlm_from_corpus(corpus_text, order=order, lmplz_path=lmplz_path)
    
    # Build decoder
    print("Building CTC decoder...")
    chars = vocab[1:]  # Remove <BLANK> token
    decoder = build_ctcdecoder(
        labels=chars,
        kenlm_model_path=model_path,
        alpha=0.35,
        beta=0.6,
        unk_score_offset=-10.0,
    )
    
    print("Beam search decoder ready!")
    return decoder, model_path

def wbs_decode_batch(log_probs_btV, decoder, blank_is_first=True, temperature=1.1, chars=None):
    """Decode batch using beam search with language model."""
    # Apply temperature scaling
    if temperature != 1.0:
        log_probs_btV = log_probs_btV / temperature
        log_probs_btV = torch.log_softmax(log_probs_btV, dim=-1)
    
    # Convert to numpy
    logits = log_probs_btV.detach().cpu().numpy()
    
    # Move blank to last position if needed
    if blank_is_first:
        logits = np.concatenate([logits[:, :, 1:], logits[:, :, :1]], axis=2)
    
    texts = []
    for i in range(logits.shape[0]):
        try:
            text = decoder.decode(logits[i])
            texts.append(text)
        except Exception as e:
            raise RuntimeError(f"Beam search decode failed for batch item {i}: {e}")
    
    return texts

# Evaluation metrics functions

This section defines evaluation metrics for handwriting recognition. The calculate_accuracy function computes both character-level and word-level accuracy by comparing predicted text sequences with ground truth, accounting for substitutions, insertions, and deletions. The word_edit_distance helper calculates the edit distance between sequences of words, enabling computation of word accuracy (WRA). Additionally, calculate_edit_distance provides a normalized edit distance (CER/WER style), giving a single metric that reflects the overall transcription error relative to the length of the ground truth.

In [None]:
import editdistance

def word_edit_distance(pred_words, gt_words):
    """Calculate edit distance between word sequences."""
    m, n = len(pred_words), len(gt_words)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j
    
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if pred_words[i-1] == gt_words[j-1]:
                dp[i][j] = dp[i-1][j-1]
            else:
                dp[i][j] = 1 + min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1])
    
    return dp[m][n]

def calculate_accuracy(predictions, ground_truths):
    """Calculate character and word accuracy."""
    preds = [str(p) for p in predictions]
    gts = [str(g) for g in ground_truths]
    
    if not preds or not gts or len(preds) != len(gts):
        return 0.0, 0.0
    
    # Character accuracy
    total_char_errors = sum(editdistance.eval(p, g) for p, g in zip(preds, gts))
    total_gt_chars = sum(len(g) for g in gts)
    
    if total_gt_chars == 0:
        char_accuracy = 1.0 if total_char_errors == 0 else 0.0
    else:
        cer = total_char_errors / total_gt_chars
        char_accuracy = max(0.0, 1.0 - cer)
    
    # Word accuracy 
    total_word_errors = 0
    total_gt_words = 0
    for pred, gt in zip(preds, gts):
        pred_words = pred.split()
        gt_words = gt.split()
        total_word_errors += word_edit_distance(pred_words, gt_words)
        total_gt_words += len(gt_words)
    
    if total_gt_words == 0:
        word_accuracy = 1.0 if total_word_errors == 0 else 0.0
    else:
        wer = total_word_errors / total_gt_words
        word_accuracy = max(0.0, 1.0 - wer)
    
    return char_accuracy, word_accuracy

def calculate_edit_distance(predictions, ground_truths):
    """Calculate normalized edit distance."""
    total_distance = 0
    total_length = 0
    
    for pred, gt in zip(predictions, ground_truths):
        distance = editdistance.eval(pred, gt)
        total_distance += distance
        total_length += max(len(pred), len(gt), 1)
    
    return total_distance / max(total_length, 1)

# Training and validation helper functions

This section defines the training and validation routines for the CRNN model. The train_epoch function performs one full epoch of training, computing the CTC loss, performing backpropagation, and optionally decoding predictions using the beam search decoder for interim metric evaluation. The validate_epoch function runs the model in evaluation mode on the validation set, calculating loss and metrics without updating model weights. Both functions return comprehensive metrics including average loss, character-level accuracy, word-level accuracy, and normalized edit distance, allowing detailed monitoring of model performance over time.


In [None]:
import time
import os
import json
import random
from pathlib import Path
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

def train_epoch(model, train_loader, optimizer, ctc_loss, device, epoch, decoder, chars):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    all_predictions, all_ground_truths = [], []
    num_batches = len(train_loader)
    
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch}', leave=False)
    
    for batch_idx, batch in enumerate(progress_bar):
        images = batch['images'].to(device)
        targets = batch['targets'].to(device)
        target_lengths = batch['target_lengths'].to(device)
        texts = batch['texts']
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)  
        outputs_ctc = outputs.permute(1, 0, 2)  
        
        input_lengths = torch.full((images.size(0),), outputs_ctc.size(0), 
                         dtype=torch.long, device=device)
        
        # Calculate CTC loss
        loss = ctc_loss(outputs_ctc, targets, input_lengths, target_lengths)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        # Collect predictions for metrics (every 10 batches)
        if batch_idx % 10 == 0:
            with torch.no_grad():
                try:
                    preds = wbs_decode_batch(outputs, decoder, blank_is_first=True, chars=chars)
                    all_predictions.extend(preds)
                    all_ground_truths.extend(texts)
                except RuntimeError:
                    pass  # Skip this batch for metrics
        
        progress_bar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Avg Loss': f'{total_loss/(batch_idx+1):.4f}'
        })
    
    # Calculate training metrics
    avg_loss = total_loss / num_batches
    if all_predictions:
        char_acc, word_acc = calculate_accuracy(all_predictions, all_ground_truths)
        edit_dist = calculate_edit_distance(all_predictions, all_ground_truths)
    else:
        char_acc, word_acc, edit_dist = 0.0, 0.0, 0.0
    
    return avg_loss, char_acc, word_acc, edit_dist

def validate_epoch(model, val_loader, ctc_loss, device, decoder, chars):
    """Validation epoch."""
    model.eval()
    total_loss = 0
    all_predictions, all_ground_truths = [], []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validation', leave=False):
            images = batch['images'].to(device)
            targets = batch['targets'].to(device)
            target_lengths = batch['target_lengths'].to(device)
            texts = batch['texts']

            # Forward pass
            outputs = model(images)                 
            outputs_ctc = outputs.permute(1, 0, 2) 

            input_lengths = torch.full((images.size(0),), outputs_ctc.size(0), 
                                     dtype=torch.long, device=device)                         

            # Calculate loss
            loss = ctc_loss(outputs_ctc, targets, input_lengths, target_lengths)
            total_loss += loss.item()

            # Decode predictions
            try:
                preds = wbs_decode_batch(outputs, decoder, blank_is_first=True, chars=chars)
                all_predictions.extend(preds)
                all_ground_truths.extend(texts)
            except RuntimeError:
                all_predictions.extend([""] * len(texts))
                all_ground_truths.extend(texts)

    # Calculate metrics
    avg_loss = total_loss / len(val_loader)
    char_acc, word_acc = calculate_accuracy(all_predictions, all_ground_truths)
    edit_dist = calculate_edit_distance(all_predictions, all_ground_truths)

    return avg_loss, char_acc, word_acc, edit_dist, all_predictions[:10], all_ground_truths[:10]

# Checkpointing and Training History Visualization

This section provides utilities for saving model checkpoints and visualizing training progress. The save_checkpoint function stores the model state, optimizer state, epoch number, and loss metrics to disk, enabling resuming or inspecting training later. The plot_training_history function creates a comprehensive visualization of the training and validation metrics over epochs, including loss, character accuracy, word accuracy, and edit distance, helping monitor model performance and detect potential overfitting or underfitting trends.


In [None]:
def save_checkpoint(model, optimizer, epoch, train_loss, val_loss, save_dir):
    """Save model checkpoint."""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss
    }
    
    save_path = Path(save_dir) / f'checkpoint_epoch_{epoch}.pth'
    torch.save(checkpoint, save_path)
    return save_path

def plot_training_history(history, save_path=None):
    """Plot training history."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss
    axes[0,0].plot(history['train_loss'], label='Train Loss', color='blue')
    axes[0,0].plot(history['val_loss'], label='Val Loss', color='red')
    axes[0,0].set_title('Loss')
    axes[0,0].set_xlabel('Epoch')
    axes[0,0].set_ylabel('Loss')
    axes[0,0].legend()
    axes[0,0].grid(True)
    
    # Character Accuracy
    axes[0,1].plot(history['train_char_acc'], label='Train Char Acc', color='blue')
    axes[0,1].plot(history['char_acc'], label='Val Char Acc', color='red')
    axes[0,1].set_title('Character Accuracy')
    axes[0,1].set_xlabel('Epoch')
    axes[0,1].set_ylabel('Accuracy')
    axes[0,1].legend()
    axes[0,1].grid(True)
    
    # Word Accuracy
    axes[1,0].plot(history['train_word_acc'], label='Train Word Acc', color='blue')
    axes[1,0].plot(history['word_acc'], label='Val Word Acc', color='red')
    axes[1,0].set_title('Word Accuracy')
    axes[1,0].set_xlabel('Epoch')
    axes[1,0].set_ylabel('Accuracy')
    axes[1,0].legend()
    axes[1,0].grid(True)
    
    # Edit Distance
    axes[1,1].plot(history['train_edit_dist'], label='Train Edit Dist', color='blue')
    axes[1,1].plot(history['edit_dist'], label='Val Edit Dist', color='red')
    axes[1,1].set_title('Edit Distance')
    axes[1,1].set_xlabel('Epoch')
    axes[1,1].set_ylabel('Edit Distance')
    axes[1,1].legend()
    axes[1,1].grid(True)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# Main training function

This section implements the full training pipeline for the CRNN model. The train_model function orchestrates multiple epochs of training and validation, integrates the CTC loss, optimizer, and learning rate scheduler, and monitors metrics such as character accuracy, word accuracy, and edit distance. It also handles early stopping, checkpointing the best model, displaying sample predictions, and saving the training history. At the end, the function visualizes the training progress and returns both the recorded metrics and the path to the best-performing model.


In [None]:
def train_model(model, train_loader, val_loader, decoder, chars,
                num_epochs=80, learning_rate=3e-4, weight_decay=1e-4, 
                save_dir='checkpoints', patience=7, min_delta=0):
    """Complete training function."""
    
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    Path(save_dir).mkdir(exist_ok=True)
    
    ctc_loss = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=5, T_mult=2, eta_min=1e-6
    )
    
    # Training history
    history = {
        'train_loss': [], 'val_loss': [],
        'train_char_acc': [], 'train_word_acc': [], 'train_edit_dist': [],
        'char_acc': [], 'word_acc': [], 'edit_dist': []
    }
    
    # Early stopping
    best_val_loss = float('inf')
    epochs_without_improvement = 0
    best_model_path = None
    
    print(f"Starting training on {device}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    for epoch in range(1, num_epochs + 1):
        start_time = time.time()
        
        # Training
        train_loss, train_char_acc, train_word_acc, train_edit_dist = train_epoch(
            model, train_loader, optimizer, ctc_loss, device, epoch, decoder, chars
        )
        
        # Validation
        val_loss, char_acc, word_acc, edit_dist, pred_samples, gt_samples = validate_epoch(
            model, val_loader, ctc_loss, device, decoder, chars
        )
        
        scheduler.step(epoch + 1)
        
        # Record history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_char_acc'].append(train_char_acc)
        history['train_word_acc'].append(train_word_acc)
        history['train_edit_dist'].append(train_edit_dist)
        history['char_acc'].append(char_acc)
        history['word_acc'].append(word_acc)
        history['edit_dist'].append(edit_dist)
        
        # Print results
        epoch_time = time.time() - start_time
        print(f"\nEpoch {epoch}/{num_epochs} ({epoch_time:.1f}s)")
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(f"Train - Char: {train_char_acc:.3f} | Word: {train_word_acc:.3f} | Edit: {train_edit_dist:.3f}")
        print(f"Val   - Char: {char_acc:.3f} | Word: {word_acc:.3f} | Edit: {edit_dist:.3f}")
        print(f"LR: {optimizer.param_groups[0]['lr']:.2e}")
        
        # Show sample predictions
        print("\nSample Predictions:")
        for i in range(min(3, len(pred_samples))):
            print(f"GT:   '{gt_samples[i]}'")
            print(f"Pred: '{pred_samples[i]}'")
            print()
        
        # Save checkpoint
        checkpoint_path = save_checkpoint(model, optimizer, epoch, train_loss, val_loss, save_dir)
        
        # Early stopping
        if val_loss < best_val_loss - min_delta:
            best_val_loss = val_loss
            epochs_without_improvement = 0
            best_model_path = checkpoint_path
            print(f" New best model saved (val_loss: {val_loss:.4f})")
        else:
            epochs_without_improvement += 1
            print(f"No improvement ({epochs_without_improvement}/{patience})")
        
        if epochs_without_improvement >= patience:
            print(f"\nEarly stopping triggered after {epoch} epochs")
            break
        
        print("-" * 80)
    
    # Save training history
    history_path = Path(save_dir) / 'training_history.json'
    with open(history_path, 'w') as f:
        json.dump(history, f, indent=2)
    
    plot_training_history(history, save_path=Path(save_dir) / 'training_plot.png')
    
    print(f"\nTraining completed!")
    print(f"Best model: {best_model_path}")
    print(f"Training history saved: {history_path}")
    
    return history, best_model_path

# Test Sample Evaluation

This section defines a function to evaluate the trained model on test samples. The evaluate_test_samples function randomly selects a subset of examples from the dataset, runs inference using the beam search decoder, and computes the edit distance for each sample. It also calculates corpus-level metrics including character accuracy, word accuracy, and normalized edit distance, providing a quantitative assessment of the model's performance on unseen data. The function returns detailed results for individual samples along with the overall metrics.


In [None]:
def evaluate_test_samples(model, dataset, device, decoder, chars, num_samples=10):
    """Evaluate model on test samples."""
    model.eval()
    indices = random.sample(range(len(dataset)), num_samples)
    results = []
    
    all_predictions = []
    all_ground_truths = []
    
    with torch.no_grad():
        for idx in indices:
            sample = dataset[idx]
            image = sample['image'].unsqueeze(0).to(device)
            gt = sample['text']
            output = model(image)  
            pred = wbs_decode_batch(output, decoder, blank_is_first=True, chars=chars)[0]
            
            ed = editdistance.eval(pred, gt)
            
            results.append({
                'prediction': pred, 
                'ground_truth': gt, 
                'edit_distance': ed
            })
            
            all_predictions.append(pred)
            all_ground_truths.append(gt)
    
    # Calculate corpus-level metrics
    char_accuracy, word_accuracy = calculate_accuracy(all_predictions, all_ground_truths)
    edit_distance = calculate_edit_distance(all_predictions, all_ground_truths)
    
    print(f"Corpus Char Accuracy: {char_accuracy:.3f}")
    print(f"Corpus Word Accuracy: {word_accuracy:.3f}")
    print(f"Corpus Edit Distance: {edit_distance:.3f}")
    print(f"Avg Edit Distance per sample: {np.mean([r['edit_distance'] for r in results]):.3f}")
    
    # Add corpus metrics to results
    for result in results:
        result['corpus_char_accuracy'] = char_accuracy
        result['corpus_word_accuracy'] = word_accuracy
    
    return results

# Main Training and Evaluation Pipeline

This section orchestrates the entire training and evaluation workflow. The main function handles environment setup, dataset loading, exploratory data analysis, and creation of training, validation, and test datasets. It sets up the beam search decoder with an optional language model, initializes the CRNN model, and runs the complete training loop with checkpointing and early stopping. Finally, it evaluates the trained model on test samples, reporting metrics such as character accuracy, word accuracy, and edit distance, providing a comprehensive end-to-end demonstration of the handwriting recognition pipeline.


In [None]:
import warnings
warnings.filterwarnings('ignore')

from datasets import load_dataset, concatenate_datasets

def main():
    """Main training pipeline."""
    
    # Setup environment
    setup_environment()
    
    # Load dataset
    print("Loading IAM dataset...")
    dataset_all = load_dataset("Teklia/IAM-line")
    
    train_hf = dataset_all["train"]
    val_hf = dataset_all["validation"]
    test_hf = dataset_all["test"]
    full_dataset = concatenate_datasets([train_hf, val_hf, test_hf])
    
    print(f"Train split: {len(train_hf)} samples")
    print(f"Val split:   {len(val_hf)} samples") 
    print(f"Test split:  {len(test_hf)} samples")
    print(f"Total samples: {len(full_dataset)}")
    
    # Show vocabulary info
    print(f"Vocab size (incl. blank): {VOCAB_SIZE}")
    print(f"Characters: {''.join(CHARS)}")
    
    # Exploratory data analysis
    print("\n=== EXPLORATORY DATA ANALYSIS ===")
    show_random_samples(full_dataset, n=5)
    analyze_dataset(full_dataset)
    
    # Setup beam search decoder
    print("\n=== SETTING UP BEAM SEARCH DECODER ===")
    try:
        print("Loading WikiText...")
        wikitext = load_dataset("wikitext", "wikitext-103-v1", split="train")
        print(f"WikiText loaded: {len(wikitext)} samples")
    except Exception as e:
        wikitext = None
        print(f"WikiText not available: {e}")
    
    try:
        print("Setting up beam search decoder...")
        decoder, model_path = setup_beam_search_decoder(
            vocab=VOCAB, train_data=train_hf, wiki_data=wikitext, order=3
        )
        print("Beam search decoder setup complete!")
    except Exception as e:
        print(f"Setup failed: {e}")
        print("Using basic CTC decoder...")
        from pyctcdecode import build_ctcdecoder
        decoder = build_ctcdecoder(labels=CHARS)
        model_path = None
    
    # Create datasets
    print("\n=== CREATING DATASETS ===")
    train_dataset = IAMDataset(
        train_hf, CHAR_TO_IDX, IDX_TO_CHAR,
        augment=True, enhance_contrast=True, binarization="adaptive_gaussian"
    )
    
    val_dataset = IAMDataset(
        val_hf, CHAR_TO_IDX, IDX_TO_CHAR,
        augment=False, enhance_contrast=True, binarization="adaptive_gaussian"
    )
    
    test_dataset = IAMDataset(
        test_hf, CHAR_TO_IDX, IDX_TO_CHAR,
        augment=False, enhance_contrast=True, binarization="adaptive_gaussian"
    )
    
    # Create dataloaders
    train_loader, val_loader = create_dataloaders(
        train_dataset, val_dataset, batch_size=BATCH_SIZE, num_workers=2
    )
    
    # Show preprocessing examples
    show_preprocessed_samples(full_dataset, n=3)
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    
    # Test data loading
    print("\nTesting data loading...")
    batch = next(iter(train_loader))
    print(f"Batch images shape: {batch['images'].shape}")
    print(f"Sample text: '{batch['texts'][0]}'")
    
    # Initialize model
    print("\n=== INITIALIZING MODEL ===")
    model = CRNN(
        vocab_size=VOCAB_SIZE, 
        hidden_size=512, 
        num_lstm_layers=2,
        dropout=DROPOUT,
        use_attention=True
    )
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")
    
    # Train model
    print("\n=== STARTING TRAINING ===")
    history, best_model_path = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        decoder=decoder,
        chars=CHARS,
        num_epochs=EPOCHS,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        save_dir='checkpoints',
        patience=6
    )
    
    # Final evaluation
    print("\n=== FINAL EVALUATION ===")
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    print("Evaluating on test samples:")
    evaluate_test_samples(model, test_dataset, device, decoder, CHARS, num_samples=100)

if __name__ == "__main__":
    main()