# üñäÔ∏è Handwritten Text Recognition (HTR) with CRNN + CTC
## IAM Handwriting Dataset | TensorFlow/Keras

---
### üìã Setup
1. **Add Data**: "Add Data" ‚Üí Search "iam-handwriting-word-database" ‚Üí Add
2. **Enable GPU**: Settings ‚Üí Accelerator ‚Üí GPU
3. **Run All!**

In [None]:
# Fix protobuf issue (MUST run before importing TensorFlow)
import os
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
print("‚úÖ Protobuf fix applied!")

In [None]:
import os
import numpy as np
import cv2
from pathlib import Path
from typing import List, Tuple, Optional
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping

print(f"TensorFlow: {tf.__version__}")
print(f"GPU: {len(tf.config.list_physical_devices('GPU')) > 0}")

## 1Ô∏è‚É£ Explore Dataset Structure

In [None]:
# Find the dataset
DATA_ROOT = '/kaggle/input'

print("üìÇ Available datasets:")
if os.path.exists(DATA_ROOT):
    for item in os.listdir(DATA_ROOT):
        print(f"  üìÅ {item}")
else:
    print("  (No input directory found)")

# Find the IAM dataset folder
DATA_DIR = None
if os.path.exists(DATA_ROOT):
    for item in os.listdir(DATA_ROOT):
        if 'iam' in item.lower() or 'handwriting' in item.lower():
            DATA_DIR = os.path.join(DATA_ROOT, item)
            break

if DATA_DIR is None and os.path.exists(DATA_ROOT) and len(os.listdir(DATA_ROOT)) > 0:
    DATA_DIR = os.path.join(DATA_ROOT, os.listdir(DATA_ROOT)[0])

if DATA_DIR is None:
    print("\n‚ùå NO DATASET FOUND!")
    print("   Please add the dataset using the instructions above")
    print("   1. Click '+ Add Data' in the right sidebar")
    print("   2. Search 'iam-handwriting-word-database'")
    print("   3. Click 'Add'")
    DATA_DIR = DATA_ROOT  # Prevent crash
else:
    print(f"\nüìç Using: {DATA_DIR}")

In [None]:
# Explore dataset structure
print(f"\nüìÇ Contents of {DATA_DIR}:")

def show_tree(path, prefix="", max_depth=3, current_depth=0):
    if current_depth >= max_depth:
        return
    try:
        items = sorted(os.listdir(path))[:10]  # Show max 10 items
        for item in items:
            full_path = os.path.join(path, item)
            if os.path.isdir(full_path):
                print(f"{prefix}üìÅ {item}/")
                show_tree(full_path, prefix + "   ", max_depth, current_depth + 1)
            else:
                size = os.path.getsize(full_path) / 1024
                print(f"{prefix}üìÑ {item} ({size:.1f} KB)")
    except PermissionError:
        pass
    except FileNotFoundError:
        print(f"{prefix}(Directory not found)")

if DATA_DIR and os.path.exists(DATA_DIR):
    show_tree(DATA_DIR)
else:
    print("‚ö†Ô∏è Skipping tree view (no dataset)")

In [None]:
# Find words.txt and images directory
print("üîç Searching for label files and image directories...\n")

if DATA_DIR and os.path.exists(DATA_DIR):
    # Find all .txt files
    txt_files = list(Path(DATA_DIR).rglob('*.txt'))
    print(f"Found {len(txt_files)} .txt files:")
    for f in txt_files[:10]:
        print(f"  üìÑ {f}")

    # Find directories named 'words'
    words_dirs = [d for d in Path(DATA_DIR).rglob('*') if d.is_dir() and 'word' in d.name.lower()]
    print(f"\nFound {len(words_dirs)} word-related directories:")
    for d in words_dirs[:5]:
        print(f"  üìÅ {d}")
else:
    print("‚ö†Ô∏è Skipping file search (no dataset)")

## 2Ô∏è‚É£ Configuration

In [None]:
CONFIG = {
    'img_height': 32,
    'img_width': 128,
    'batch_size': 64,
    'epochs': 50,
    'learning_rate': 0.001,
    'val_split': 0.1,
    'max_samples': None,  # Set to 10000 for quick test
}

# Character set
CHARACTERS = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.,!?'-():;\"/ ")
char_to_num = {char: idx + 1 for idx, char in enumerate(CHARACTERS)}
num_to_char = {idx + 1: char for idx, char in enumerate(CHARACTERS)}
num_to_char[0] = ''
NUM_CLASSES = len(CHARACTERS) + 1

print(f"Character set: {NUM_CLASSES} classes")

In [None]:
def encode_label(text):
    return [char_to_num.get(c, 0) for c in text if c in char_to_num]

def decode_prediction(pred):
    indices = np.argmax(pred, axis=1)
    chars = []
    prev_idx = -1
    for idx in indices:
        if idx != 0 and idx != prev_idx:
            if idx in num_to_char:
                chars.append(num_to_char[idx])
        prev_idx = idx
    return ''.join(chars)

## 3Ô∏è‚É£ Load Dataset

In [None]:
def load_iam_dataset(data_dir, max_samples=None):
    """Load IAM dataset - handles various folder structures."""
    if not data_dir or not os.path.exists(data_dir):
        raise FileNotFoundError("No dataset directory provided. Please add data first!")

    data_path = Path(data_dir)
    samples = []
    
    # Find label file
    words_file = None
    for pattern in ['**/words*.txt', '**/labels*.txt', '**/*.txt']:
        matches = list(data_path.glob(pattern))
        for m in matches:
            if m.stat().st_size > 10000:  # Must be reasonably large
                words_file = m
                break
        if words_file:
            break
    
    if not words_file:
        raise FileNotFoundError(f"No label file found in {data_dir}")
    
    print(f"üìÑ Using label file: {words_file}")
    
    # Find images directory
    images_dir = None
    for d in data_path.rglob('*'):
        if d.is_dir() and ('word' in d.name.lower() or 'image' in d.name.lower()):
            # Check if it contains images
            if list(d.rglob('*.png'))[:1]:
                images_dir = d
                break
    
    if not images_dir:
        # Try to find any directory with PNG files
        for d in data_path.rglob('*'):
            if d.is_dir():
                pngs = list(d.glob('*.png'))[:1]
                if pngs:
                    images_dir = d
                    break
    
    print(f"üìÅ Using images dir: {images_dir}")
    
    # Parse label file
    skipped = 0
    with open(words_file, 'r', encoding='utf-8', errors='ignore') as f:
        for line in f:
            if line.startswith('#') or line.strip() == '':
                continue
            
            parts = line.strip().split(' ')
            if len(parts) < 9:
                continue
            
            word_id = parts[0]
            status = parts[1]
            transcription = parts[-1]
            
            if status == 'err':
                continue
            
            # Filter out samples with zero-length encoded labels
            encoded = encode_label(transcription)
            if len(encoded) == 0:
                skipped += 1
                continue
            
            # Try to find the image
            id_parts = word_id.split('-')
            if len(id_parts) >= 3:
                # Standard IAM structure
                folder1 = id_parts[0]
                folder2 = f"{id_parts[0]}-{id_parts[1]}"
                
                # Try various paths
                possible_paths = [
                    images_dir / folder1 / folder2 / f"{word_id}.png",
                    images_dir / folder2 / f"{word_id}.png",
                    images_dir / f"{word_id}.png",
                ]
                
                for img_path in possible_paths:
                    if img_path.exists():
                        samples.append((str(img_path), transcription))
                        break
    
    # If no samples found with standard parsing, try finding images directly
    if len(samples) == 0:
        print("‚ö†Ô∏è Standard parsing failed, trying direct image search...")
        all_pngs = list(data_path.rglob('*.png'))
        print(f"Found {len(all_pngs)} PNG files")
        
        # For each image, try to find its label
        with open(words_file, 'r', encoding='utf-8', errors='ignore') as f:
            label_dict = {}
            for line in f:
                if line.startswith('#') or line.strip() == '':
                    continue
                parts = line.strip().split(' ')
                if len(parts) >= 9 and parts[1] != 'err':
                    transcription = parts[-1]
                    # Filter zero-length labels
                    if len(encode_label(transcription)) > 0:
                        label_dict[parts[0]] = transcription
        
        for img_path in all_pngs:
            word_id = img_path.stem
            if word_id in label_dict:
                samples.append((str(img_path), label_dict[word_id]))
    
    if max_samples:
        samples = samples[:max_samples]
    
    print(f"‚úÖ Loaded {len(samples):,} samples (skipped {skipped:,} zero-length labels)")
    return samples

In [None]:
# Load dataset
print("üìÇ Loading IAM dataset...")
if DATA_DIR and os.path.exists(DATA_DIR):
    all_samples = load_iam_dataset(DATA_DIR, max_samples=CONFIG['max_samples'])

    # Shuffle and split
    np.random.seed(42)
    np.random.shuffle(all_samples)

    n_val = int(len(all_samples) * CONFIG['val_split'])
    train_samples = all_samples[n_val:]
    val_samples = all_samples[:n_val]

    print(f"\nüìä Train: {len(train_samples):,} | Val: {len(val_samples):,}")
else:
    print("‚ùå No dataset to load. Please fix the error in Section 1.")
    train_samples = []
    val_samples = []

## 4Ô∏è‚É£ Preprocessing

In [None]:
def preprocess_image(image_path, img_height=32, img_width=128):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        return np.zeros((img_height, img_width, 1), dtype=np.float32)
    
    h, w = img.shape
    new_width = min(int(w * (img_height / h)), img_width)
    img = cv2.resize(img, (new_width, img_height))
    
    if new_width < img_width:
        img = np.pad(img, ((0, 0), (0, img_width - new_width)), constant_values=255)
    
    img = 1.0 - (img.astype(np.float32) / 255.0)
    return np.expand_dims(img, axis=-1)

In [None]:
# Visualize samples
if len(train_samples) > 0:
    fig, axes = plt.subplots(2, 4, figsize=(14, 5))
    for ax, (img_path, label) in zip(axes.flat, train_samples[:8]):
        img = preprocess_image(img_path, CONFIG['img_height'], CONFIG['img_width'])
        ax.imshow(img.squeeze(), cmap='gray')
        ax.set_title(f'"{label}"', fontsize=10)
        ax.axis('off')
    plt.tight_layout()
    plt.show()
else:
    print("‚ö†Ô∏è No samples to visualize")

## 5Ô∏è‚É£ Data Generator

In [None]:
class IAMDataGenerator(keras.utils.Sequence):
    def __init__(self, samples, img_height, img_width, batch_size, max_label_len=32, shuffle=True):
        self.samples = samples
        self.img_height = img_height
        self.img_width = img_width
        self.batch_size = batch_size
        self.max_label_len = max_label_len
        self.shuffle = shuffle
        self.indices = np.arange(len(samples))
        if shuffle:
            np.random.shuffle(self.indices)
    
    def __len__(self):
        if self.batch_size <= 0: return 0
        return len(self.samples) // self.batch_size
    
    def __getitem__(self, idx):
        batch_idx = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]
        
        images = np.zeros((self.batch_size, self.img_height, self.img_width, 1), dtype=np.float32)
        labels = np.zeros((self.batch_size, self.max_label_len), dtype=np.int32)
        # CTC input length: width / 2 based on pooling structure (64 time steps for 128 width)
        input_lengths = np.full((self.batch_size, 1), self.img_width // 2, dtype=np.int32)
        label_lengths = np.zeros((self.batch_size, 1), dtype=np.int32)
        
        for i, si in enumerate(batch_idx):
            img_path, text = self.samples[si]
            images[i] = preprocess_image(img_path, self.img_height, self.img_width)
            encoded = encode_label(text)
            label_len = min(len(encoded), self.max_label_len)
            labels[i, :label_len] = encoded[:label_len]
            label_lengths[i] = label_len
        
        return {'image': images, 'label': labels, 'input_length': input_lengths, 'label_length': label_lengths}, np.zeros(self.batch_size)
    
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

if len(train_samples) > 0:
    train_gen = IAMDataGenerator(train_samples, CONFIG['img_height'], CONFIG['img_width'], CONFIG['batch_size'])
    val_gen = IAMDataGenerator(val_samples, CONFIG['img_height'], CONFIG['img_width'], CONFIG['batch_size'], shuffle=False)
    print(f"‚úÖ Generators ready: {len(train_gen)} train batches, {len(val_gen)} val batches")
else:
    train_gen = None
    val_gen = None
    print("‚ö†Ô∏è Generators not created (no data)")

## 6Ô∏è‚É£ CRNN Model

In [None]:
def build_model(img_height, img_width, num_classes):
    input_img = layers.Input(shape=(img_height, img_width, 1), name='image')
    labels = layers.Input(shape=(None,), dtype='int32', name='label')
    input_length = layers.Input(shape=(1,), dtype='int32', name='input_length')
    label_length = layers.Input(shape=(1,), dtype='int32', name='label_length')
    
    # CNN
    x = layers.Conv2D(64, 3, activation='relu', padding='same')(input_img)
    x = layers.MaxPooling2D((2, 2))(x)
    
    x = layers.Conv2D(128, 3, activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 1))(x)  # Keep more width (64)
    
    x = layers.Conv2D(256, 3, activation='relu', padding='same')(x)
    x = layers.Conv2D(256, 3, activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 1))(x)
    
    x = layers.Conv2D(512, 3, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(512, 3, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 1))(x)
    
    # Last conv reduces height from 2 to 1, keeps width at 64
    x = layers.Conv2D(512, (2, 1), activation='relu', padding='valid')(x)
    
    # Reshape: (Batch, 1, 64, 512) -> (Batch, 64, 512)
    target_shape = (x.shape[2], x.shape[3])
    x = layers.Reshape(target_shape)(x)
    
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.25)(x)
    
    # BiLSTM
    x = layers.Bidirectional(layers.LSTM(256, return_sequences=True, dropout=0.2))(x)
    x = layers.Bidirectional(layers.LSTM(256, return_sequences=True, dropout=0.2))(x)
    x = layers.Dense(num_classes, activation='softmax', name='output')(x)
    
    # CTC
    ctc = layers.Lambda(lambda args: keras.backend.ctc_batch_cost(*args))([labels, x, input_length, label_length])
    
    train_model = Model([input_img, labels, input_length, label_length], ctc)
    pred_model = Model(input_img, x)
    return train_model, pred_model

training_model, prediction_model = build_model(CONFIG['img_height'], CONFIG['img_width'], NUM_CLASSES)
training_model.compile(optimizer=keras.optimizers.Adam(CONFIG['learning_rate']), loss=lambda y, p: p)
print(f"‚úÖ Model: {training_model.count_params():,} params")
training_model.summary()

## 7Ô∏è‚É£ Train!

In [None]:
if train_gen and len(train_gen) > 0:
    callbacks = [
        ModelCheckpoint('best_model.weights.h5', save_best_only=True, save_weights_only=True),
        ReduceLROnPlateau(factor=0.5, patience=3, min_lr=1e-6),
        EarlyStopping(patience=10, restore_best_weights=True)
    ]

    print("üöÄ Starting training...")
    history = training_model.fit(train_gen, validation_data=val_gen, epochs=CONFIG['epochs'], callbacks=callbacks)
else:
    print("‚ùå Skipping training (no data).")

In [None]:
# Plot
if 'history' in locals():
    plt.plot(history.history['loss'], label='Train')
    plt.plot(history.history['val_loss'], label='Val')
    plt.legend()
    plt.title('Loss')
    plt.savefig('loss.png')
    plt.show()
else:
    print("No history to plot")

## 8Ô∏è‚É£ Inference

In [None]:
def predict(image_path):
    img = preprocess_image(image_path, CONFIG['img_height'], CONFIG['img_width'])
    pred = prediction_model.predict(np.expand_dims(img, 0), verbose=0)
    return decode_prediction(pred[0])

# Test
if len(val_samples) > 0:
    fig, axes = plt.subplots(3, 4, figsize=(14, 8))
    for ax, (path, label) in zip(axes.flat, val_samples[:12]):
        pred = predict(path)
        ax.imshow(preprocess_image(path).squeeze(), cmap='gray')
        ax.set_title(f'P:"{pred}"\nL:"{label}"', fontsize=9, color='green' if pred==label else 'red')
        ax.axis('off')
    plt.tight_layout()
    plt.savefig('predictions.png')
    plt.show()
else:
    print("No validation samples to test")

In [None]:
# Save models and outputs
import shutil

print("üíæ Saving outputs...")
prediction_model.save('htr_model.keras')

# Copy to /kaggle/working/ if not already there
output_dir = '/kaggle/working'
current_dir = os.getcwd()

for filename in ['htr_model.keras', 'best_model.weights.h5', 'loss.png', 'predictions.png']:
    if os.path.exists(filename):
        src = os.path.abspath(filename)
        dst = os.path.join(output_dir, filename)
        
        # Only copy if source and destination are different
        if os.path.abspath(src) != os.path.abspath(dst):
            shutil.copy(src, dst)
            print(f"  ‚úÖ Copied {filename}")
        else:
            print(f"  ‚úÖ {filename} (already in output dir)")
    else:
        print(f"  ‚ö†Ô∏è {filename} not found")

print(f"\n‚úÖ All outputs saved to {output_dir}")