In [None]:
# ============================================================
# HANDWRITTEN TEXT RECOGNITION (CTC-based CNN + BiLSTM model)
# ============================================================

!pip -q install datasets tensorflow==2.15.0 opencv-python pillow --quiet

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from datasets import load_dataset
import itertools
import re

# ========================
# 0) GLOBALS
# ========================
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)

IMG_H = 64
IMG_W = 800
BATCH_SIZE = 16
MAX_LABEL_LEN_CAP = 80

# ========================
# 1) LOAD DATASET (IAM-line only)
# ========================
print("Loading teklia/IAM-line ...")
ds = load_dataset("teklia/IAM-line")
train_split = ds["train"]
test_split  = ds["test"]

print("Train samples:", len(train_split), " Test samples:", len(test_split))

# ========================
# 2) BUILD VOCAB FROM TRAIN TEXTS
# ========================
def collect_texts(split):
    texts = []
    for s in split:
        t = s["text"]
        if t is None: t = ""
        texts.append(t)
    return texts

train_texts = collect_texts(train_split)
test_texts  = collect_texts(test_split)

# Build vocab (letters, digits, punctuation)
all_text = "".join(train_texts)
vocab = sorted(set(all_text))
vocab = [c for c in vocab if c not in ["\n", "\r", "\t"]]
print("Vocab size (from train):", len(vocab))

num_chars = len(vocab)

char_to_num = layers.StringLookup(
    vocabulary=vocab,
    mask_token=None,
    oov_token="[UNK]"
)
blank_index = len(char_to_num.get_vocabulary())
num_classes = blank_index + 1

num_to_char = layers.StringLookup(
    vocabulary=char_to_num.get_vocabulary(),
    invert=True,
    mask_token=None
)

# ========================
# 3) FIXED IMAGE PREPROCESSING (NO CLIPPING)
# ========================
def preprocess_image(pil_img: Image.Image):
    """
    Resizes image keeping aspect ratio and ensures both height <= IMG_H and width <= IMG_W.
    Pads with white if smaller; no clipping ever occurs.
    """
    img = pil_img.convert("L")  # grayscale
    orig_w, orig_h = img.size

    # scale so both dimensions fit within limits
    scale_h = IMG_H / orig_h
    scale_w = IMG_W / orig_w
    scale = min(scale_h, scale_w)

    new_w = max(1, int(round(orig_w * scale)))
    new_h = max(1, int(round(orig_h * scale)))
    img = img.resize((new_w, new_h), Image.BILINEAR)

    arr = np.array(img).astype("float32") / 255.0

    # vertical pad (centered)
    if new_h < IMG_H:
        pad_top = (IMG_H - new_h) // 2
        pad_bottom = IMG_H - new_h - pad_top
        arr = np.pad(arr, ((pad_top, pad_bottom), (0, 0)), constant_values=1.0)

    # horizontal pad (right side)
    if new_w < IMG_W:
        pad_right = IMG_W - new_w
        arr = np.pad(arr, ((0, 0), (0, pad_right)), constant_values=1.0)

    arr = np.expand_dims(arr, axis=-1)
    return arr, new_w


# ========================
# 4) PREPARE DATA SPLITS
# ========================
def prepare_split(split):
    images, widths, labels = [], [], []
    for s in split:
        img, new_w = preprocess_image(s["image"])
        images.append(img)
        widths.append(min(new_w, IMG_W))
        labels.append(s["text"] if s["text"] is not None else "")
    return np.array(images, dtype=np.float32), np.array(widths, dtype=np.int32), labels

train_images, train_widths, train_labels = prepare_split(train_split)
test_images,  test_widths,  test_labels  = prepare_split(test_split)

# Split into train/val
num_train = int(0.9 * len(train_images))
perm = np.random.permutation(len(train_images))
tr_idx, va_idx = perm[:num_train], perm[num_train:]

X_tr, W_tr, y_tr = train_images[tr_idx], train_widths[tr_idx], [train_labels[i] for i in tr_idx]
X_va, W_va, y_va = train_images[va_idx], train_widths[va_idx], [train_labels[i] for i in va_idx]

print("Shapes  Train:", X_tr.shape, " Val:", X_va.shape, " Test:", test_images.shape)

# ========================
# 5) ENCODE LABELS FOR CTC
# ========================
def encode_label(txt, max_len):
    seq = char_to_num(tf.strings.unicode_split(txt, "UTF-8"))
    seq = seq[:max_len]
    pad_len = max_len - tf.shape(seq)[0]
    seq = tf.pad(seq, [[0, pad_len]], constant_values=-1)
    return tf.cast(seq, tf.int32)

max_label_len = min(max(len(t) for t in y_tr), MAX_LABEL_LEN_CAP)
print("Max label length used:", max_label_len)

def batch_encode(texts):
    return tf.stack([encode_label(t, max_label_len) for t in texts], axis=0)

# ========================
# 6) DATA GENERATOR
# ========================
class DataGenerator(keras.utils.Sequence):
    def __init__(self, images, widths, labels, batch_size=BATCH_SIZE):
        self.images = images
        self.widths = widths
        self.labels = labels
        self.batch_size = batch_size
        self.indexes = np.arange(len(images))

    def __len__(self):
        return len(self.images) // self.batch_size

    def on_epoch_end(self):
        np.random.shuffle(self.indexes)

    def __getitem__(self, idx):
        ids = self.indexes[idx*self.batch_size:(idx+1)*self.batch_size]
        batch_imgs = self.images[ids]
        batch_txt  = [self.labels[i] for i in ids]
        batch_lbl  = batch_encode(batch_txt)
        return batch_imgs, batch_lbl

train_gen = DataGenerator(X_tr, W_tr, y_tr)
val_gen   = DataGenerator(X_va, W_va, y_va)

# ========================
# 7) MODEL (CNN + BiLSTM + CTC)
# ========================
def build_model():
    inp = layers.Input(shape=(IMG_H, IMG_W, 1), name="image")

    x = layers.Conv2D(64, 3, padding="same", activation="relu")(inp)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)

    x = layers.Conv2D(128, 3, padding="same", activation="relu")(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)

    x = layers.Conv2D(256, 3, padding="same", activation="relu")(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(pool_size=(2,1))(x)

    x = layers.Conv2D(512, 3, padding="same", activation="relu")(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(pool_size=(2,1))(x)

    H_prime = IMG_H // 16
    W_prime = IMG_W // 4
    C = 512

    x = layers.Permute((2,1,3))(x)
    x = layers.Reshape((W_prime, H_prime*C))(x)

    x = layers.Dense(256, activation="relu")(x)
    x = layers.Bidirectional(layers.LSTM(256, return_sequences=True))(x)
    x = layers.Bidirectional(layers.LSTM(256, return_sequences=True))(x)

    logits = layers.Dense(num_classes, activation="softmax", name="softmax")(x)
    return keras.Model(inp, logits, name="htr_cnn_bilstm_ctc")

model = build_model()
model.summary()

# ========================
# 8) CTC LOSS
# ========================
@tf.function
def ctc_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.int32)
    label_len = tf.reduce_sum(tf.cast(y_true != -1, tf.int32), axis=1)
    label_len = tf.expand_dims(label_len, 1)
    blank = tf.cast(blank_index, tf.int32)
    y_true = tf.where(tf.equal(y_true, -1), blank, y_true)
    batch_size = tf.shape(y_pred)[0]
    time_steps = tf.shape(y_pred)[1]
    input_len = tf.fill([batch_size, 1], time_steps)
    return keras.backend.ctc_batch_cost(y_true, y_pred, input_len, label_len)

model.compile(optimizer=keras.optimizers.Adam(1e-3), loss=ctc_loss)

# ========================
# 9) TRAIN
# ========================
callbacks = [
    keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True, monitor="val_loss"),
    keras.callbacks.ReduceLROnPlateau(patience=5, factor=0.5)
]

history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=50,
    callbacks=callbacks,
    verbose=1
)

# ========================
# 10) INFERENCE + GREEDY DECODE
# ========================
pred_model = keras.Model(model.input, model.get_layer("softmax").output)

def decode_batch(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    results, _ = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)
    decoded = []
    for r in results[0]:
        r = tf.gather(r, tf.where(r != -1))
        txt = tf.strings.reduce_join(num_to_char(r)).numpy().decode("utf-8")
        decoded.append(txt)
    return decoded

def collapse_repeats(text):
    return ''.join(ch for ch, _ in itertools.groupby(text))

# Test few samples
sample_imgs = test_images[:5]
preds = pred_model.predict(sample_imgs)
decoded_raw = decode_batch(preds)
decoded = [collapse_repeats(t) for t in decoded_raw]

for i in range(len(sample_imgs)):
    plt.imshow(sample_imgs[i].squeeze(), cmap="gray")
    plt.title(f"GT: {test_labels[i]}\nPred: {decoded[i]}")
    plt.axis("off")
    plt.show()
