# Prepare the dataset

In [None]:
import kagglehub
import shutil

# Download latest version
path = kagglehub.dataset_download("allahhitler/ocr-synthetic-dataset")

print("Path to dataset files:", path)

shutil.copytree(path, "/content/data", dirs_exist_ok=True)

Path to dataset files: /root/.cache/kagglehub/datasets/allahhitler/ocr-synthetic-dataset/versions/1


'/content/data'

# Split the data

In [None]:
import os
from pathlib import Path

%cd data
def split_dataset(image_dir, labels_file, train_ratio=0.8, val_ratio=0.1):
    """
    Split dataset into training, validation, and testing sets.

    Args:
        image_dir (str): Path to directory containing images
        labels_file (str): Path to labels.txt file
        train_ratio (float): Ratio of training data (default: 0.8 = 80%)
        val_ratio (float): Ratio of validation data (default: 0.1 = 10%)
    """

    # Read labels file
    with open(labels_file, 'r') as f:
        lines = f.readlines()

    # Parse labels
    image_labels = {}
    for line in lines:
        if line.strip():
            filename, label = line.strip().split(' ', 1)
            image_labels[filename] = label

    # Get sorted list of image files
    image_files = sorted([f for f in os.listdir(image_dir) if f.endswith('.jpg')])
    total_images = len(image_files)

    # Calculate split points
    train_end = int(total_images * train_ratio)
    val_end = int(total_images * (train_ratio + val_ratio))

    # Create directories
    base_dirs = {
        "train": Path("train"),
        "val": Path("val"),
        "test": Path("test")
    }

    for d in base_dirs.values():
        (d / "images").mkdir(parents=True, exist_ok=True)

    # Initialize label lists
    label_files = {
        "train": [],
        "val": [],
        "test": []
    }

    # Split the dataset
    for i, image_file in enumerate(image_files):
        src_path = Path(image_dir) / image_file

        if i < train_end:
            split = "train"
        elif i < val_end:
            split = "val"
        else:
            split = "test"

        dest_path = base_dirs[split] / "images" / image_file
        shutil.copy2(src_path, dest_path)
        label_files[split].append(f"{image_file} {image_labels[image_file]}")

    # Write label files
    for split, lines in label_files.items():
        with open(base_dirs[split] / "labels.txt", "w") as f:
            f.write("\n".join(lines))

    # Print summary
    print("Dataset split completed!")
    print(f"Total images: {total_images}")
    print(f"Training set: {train_end} images ({train_ratio:.0%})")
    print(f"Validation set: {val_end - train_end} images ({val_ratio:.0%})")
    print(f"Testing set: {total_images - val_end} images ({1 - train_ratio - val_ratio:.0%})")
    print("\nDirectory structure:")
    for split in ["train", "val", "test"]:
        print(f"{split}/")
        print(f"├── labels.txt")
        print(f"└── images/")
    print("\nAll done!")


# Configuration
IMAGE_DIR = "images"       # Directory containing your images
LABELS_FILE = "labels.txt" # Path to labels.txt
TRAIN_RATIO = 0.8          # 80% for training
VAL_RATIO = 0.1            # 10% for validation (remaining 10% for testing)

split_dataset(IMAGE_DIR, LABELS_FILE, TRAIN_RATIO, VAL_RATIO)

%cd ..

!mkdir checkpoints/

/content/data
Dataset split completed!
Total images: 100000
Training set: 80000 images (80%)
Validation set: 10000 images (10%)
Testing set: 10000 images (10%)

Directory structure:
train/
├── labels.txt
└── images/
val/
├── labels.txt
└── images/
test/
├── labels.txt
└── images/

All done!
/content
mkdir: cannot create directory ‘checkpoints/’: File exists


# Configurations

In [None]:
config = {
    'data_dir': 'data',
    'img_width': 100,
    'img_height': 32,
    'map_to_seq_hidden': 64,
    'rnn_hidden': 256,
    'leaky_relu': False,
    'max_label_len': 32,
}

train_config = {
    'epochs': 2,
    'train_batch_size': 32,
    'eval_batch_size': 512,
    'lr': 0.0005,
    'show_interval': 10,
    'valid_interval': 500,
    'save_interval': 500,
    'cpu_workers': 4,
    'valid_max_iter': 100,
    'decode_method': 'beam_search',
    'beam_size': 10,
    'checkpoints_dir': 'checkpoints',
    # New checkpoint-related settings
    'max_checkpoints_to_keep': 3,
    'resume_training': True,           # Auto-resume from latest checkpoint
    'checkpoint_path': None,           # Specific checkpoint to load (optional)
    'save_every_epoch': False,         # Save at end of each epoch
    'save_h5_weights': True,           # Also save H5 for compatibility
}
train_config.update(config)

evaluate_config = {
    'eval_batch_size': 512,
    'cpu_workers': 4,
    'decode_method': 'beam_search',
    'beam_size': 10,
    # Updated checkpoint loading for evaluation
    'checkpoint_path': 'checkpoints',  # Can be specific file or directory
    'resume_training': False,           # Set to False for evaluation
}
evaluate_config.update(config)

In [None]:
config

{'data_dir': 'data',
 'img_width': 100,
 'img_height': 32,
 'map_to_seq_hidden': 64,
 'rnn_hidden': 256,
 'leaky_relu': False,
 'max_label_len': 32}

In [None]:
train_config

{'epochs': 2,
 'train_batch_size': 32,
 'eval_batch_size': 512,
 'lr': 0.0005,
 'show_interval': 10,
 'valid_interval': 500,
 'save_interval': 500,
 'cpu_workers': 4,
 'valid_max_iter': 100,
 'decode_method': 'beam_search',
 'beam_size': 10,
 'checkpoints_dir': 'checkpoints',
 'max_checkpoints_to_keep': 3,
 'resume_training': True,
 'checkpoint_path': None,
 'save_every_epoch': False,
 'save_h5_weights': True,
 'data_dir': 'data',
 'img_width': 100,
 'img_height': 32,
 'map_to_seq_hidden': 64,
 'rnn_hidden': 256,
 'leaky_relu': False,
 'max_label_len': 32}

In [None]:
evaluate_config

{'eval_batch_size': 512,
 'cpu_workers': 4,
 'decode_method': 'beam_search',
 'beam_size': 10,
 'checkpoint_path': 'checkpoints',
 'resume_training': False,
 'data_dir': 'data',
 'img_width': 100,
 'img_height': 32,
 'map_to_seq_hidden': 64,
 'rnn_hidden': 256,
 'leaky_relu': False,
 'max_label_len': 32}

# `Dataset` class

In [None]:
import os
import tensorflow as tf
import numpy as np
from PIL import Image
from enum import Enum
from typing import Optional, List, Tuple, Union

class DatasetMode(Enum):
    TRAIN = 'train'
    VAL = 'val'
    TEST = 'test'

class Synth90kDataset:
    # 0–9, a–z, A–Z  (indices start at 1; 0 is reserved for CTC blank)
    CHARS = ''.join([chr(i) for i in range(ord('0'), ord('9')+1)]) + \
            ''.join([chr(i) for i in range(ord('a'), ord('z')+1)]) + \
            ''.join([chr(i) for i in range(ord('A'), ord('Z')+1)])
    CHAR2LABEL = {c: i+1 for i, c in enumerate(CHARS)}
    LABEL2CHAR = {i+1: c for i, c in enumerate(CHARS)}

    def __init__(
        self,
        root_dir: Optional[str] = None,
        mode: Optional[DatasetMode] = None,
        paths: Optional[List[str]] = None,
        img_height: int = 32,
        img_width: int = 100
    ):
        if root_dir and mode and not paths:
            self.paths, self.texts = self._load_from_raw_files(root_dir, mode)
        elif not root_dir and not mode and paths:
            self.paths = paths
            self.texts = None

        self.img_height = img_height
        self.img_width = img_width

    def _load_from_raw_files(self, root_dir: str, mode: DatasetMode) -> Tuple[List[str], List[str]]:
        """
        Load image file paths and their corresponding text labels from label files.

        This method reads a label file containing filename–text pairs, constructs
        the full image paths, and extracts the associated text labels. Each line
        in the label file must follow the format:

            filename text_label

        Args:
            root_dir (str): Root directory where the dataset is stored.
            mode (DatasetMode): Dataset split to load (e.g., TRAIN, VAL, TEST).

        Returns:
            tuple[list[str], list[str]]:
                - paths: List of full paths to image files in the format
                  "data/{mode}/images/{filename}".
                - texts: List of text labels corresponding to each image path.

        Example:
            If root_dir = "/dataset", mode = TRAIN, and the file
            "/dataset/train/labels.txt" contains:
                img1.jpg hello
                img2.jpg world

            The method will return:
                paths = ["data/train/images/img1.jpg", "data/train/images/img2.jpg"]
                texts = ["hello", "world"]
        """

        paths_file = f"{mode.value}/labels.txt"
        paths: List[str] = []
        texts: List[str] = []

        # Read the labels file and parse each line
        with open(os.path.join(root_dir, paths_file), 'r') as fr:
            for line in fr.readlines():
                # Path contains filenames and text are the labels
                line = line.strip()
                if not line:
                    continue
                # Split at first whitespace only (labels can be single token; this keeps it robust)
                parts = line.split(maxsplit=1)
                if len(parts) != 2:
                    continue
                filename, text = parts[0], parts[1]
                img_path = os.path.join(root_dir, mode.value, 'images', filename)
                paths.append(img_path)
                texts.append(text)
        return paths, texts

    def __len__(self) -> int:
        return len(self.paths)

    def create_tf_dataset(self, batch_size: int = 32, training: bool = True) -> tf.data.Dataset:
      """
      Create a TensorFlow Dataset for efficient data loading and preprocessing.

      This method generates a tf.data.Dataset pipeline that can be used for training
      or validation. It handles both cases with and without text labels, applying
      appropriate batching and preprocessing optimizations.
      """
      def generator():
          """
          Generator function that yields individual preprocessed samples.

          Yields:
              tuple or ndarray: Single processed sample from _process_single_item
          """
          for i in range(len(self.paths)):
              yield self._process_single_item(i)

      # Define output signature using tf.TensorSpec based on whether text labels are available
      if self.texts:
          output_signature = (
              tf.TensorSpec(shape=(self.img_height, self.img_width, 1), dtype=tf.float32),
              tf.TensorSpec(shape=(None,), dtype=tf.int32),
              tf.TensorSpec(shape=(), dtype=tf.int32),
          )
      else:
          output_signature = tf.TensorSpec(
              shape=(self.img_height, self.img_width, 1), dtype=tf.float32
          )

      # Create TensorFlow dataset from generator using tf.data.Dataset.from_generator()
      dataset = tf.data.Dataset.from_generator(generator, output_signature=output_signature)


      # Apply appropriate batching strategy based on text presence
      if self.texts:
          # Call dataset.padded_batch for variable-length text sequences
          # padded batch for variable-length label sequences
          dataset = dataset.padded_batch(
              batch_size,
              padded_shapes=(
                  (self.img_height, self.img_width, 1),  # image
                  (None,),  # target sequence (pad on right)
                  (),      # length (scalar)
              ),
              padding_values=(
                  tf.constant(0.0, dtype=tf.float32), # image padding (0 maps to mid-range after [-1,1])
                  tf.constant(0, dtype=tf.int32), # label padding
                  tf.constant(0, dtype=tf.int32), # length padding (won't be used)
              ),
              drop_remainder=training # only drop last incomplete batch during training
          )
          # Apply training-specific optimizations
          if training:
              # Prefetch batches for improved pipeline performance
              dataset = dataset.prefetch(tf.data.AUTOTUNE)
      else:
          # Simple batching for datasets without text labels
          dataset = dataset.batch(batch_size)

      return dataset

    def _process_single_item(self, index: int) -> Union[Tuple[np.ndarray, np.ndarray, int], np.ndarray]:
      """
      Process a single data item by loading, preprocessing image and text.

      This method handles the complete preprocessing pipeline for a single data sample,
      including image loading, resizing, normalization, and text label conversion.
      It includes error handling for corrupted images and recursive fallback.
      """
      path = self.paths[index]

      try:
          # Load image and convert to grayscale
          image = Image.open(path).convert('L')
      except IOError:
          # Handle corrupted images by recursively trying the next sample
          print('Corrupted image for %d' % index)
          return self._process_single_item((index + 1) % len(self.paths))

      # Expect image with shape (img_height, img_width, 1) and pixel values in [-1, 1] (use bilinear interpolation to resize)
      # Resize with bilinear interpolation to (W, H) in PIL, then normalize to [-1, 1]
      image = image.resize((self.img_width, self.img_height), resample=Image.BILINEAR)
      image = np.asarray(image, dtype=np.float32) / 255.0            # [0,1]
      image = image * 2.0 - 1.0                                      # [-1,1]
      image = np.expand_dims(image, axis=-1)                         # (H, W, 1)

      # Process text labels if available
      if self.texts:
          text = self.texts[index]
          # Expect target as a list of integer labels mapped from each character in text (target_length is the length of the target list)
          # Map characters to indices (skip unknown chars)
          target = [self.CHAR2LABEL[c] for c in text if c in self.CHAR2LABEL]
          target_length = len(target)
          # Return image with text labels and sequence length
          return image, np.array(target, dtype=np.int32), target_length
      else:
          # Return only the preprocessed image when no text labels are available
          return image


In [None]:
# For checking
Synth90kDataset(root_dir="/content/data",
                mode=DatasetMode.TRAIN,
                img_height=config["img_width"],
                img_width=config["img_height"]).create_tf_dataset(batch_size=32).element_spec

(TensorSpec(shape=(32, 100, 32, 1), dtype=tf.float32, name=None),
 TensorSpec(shape=(32, None), dtype=tf.int32, name=None),
 TensorSpec(shape=(32,), dtype=tf.int32, name=None))

# `CRNN` Class

In [None]:
from tensorflow.keras import layers

class CRNN(tf.keras.Model):
    """
    Convolutional Recurrent Neural Network (CRNN) for sequence recognition.

    This model combines CNN feature extraction with RNN sequence modeling,
    commonly used for tasks like text recognition in images. The architecture
    follows the CRNN design with CNN backbone followed by bidirectional LSTMs.

    Args:
        img_channel (int): Number of channels in input images (1 for grayscale, 3 for RGB)
        img_height (int): Height of input images (must be divisible by 16)
        img_width (int): Width of input images (must be divisible by 4)
        num_class (int): Number of output classes (typically characters + 1 for blank)
        map_to_seq_hidden (int, optional): Hidden units in the mapping layer between CNN and RNN.
                                          Defaults to 64.
        rnn_hidden (int, optional): Number of hidden units in each LSTM layer. Defaults to 256.
        leaky_relu (bool, optional): Whether to use LeakyReLU instead of ReLU. Defaults to False.
        **kwargs: Additional keyword arguments passed to the parent class.

    Architecture:
        - CNN Backbone: 7 convolutional layers with pooling for feature extraction
        - Map to Sequence: Dense layer to transform CNN features to sequence format
        - RNN: Two bidirectional LSTM layers for sequence modeling
        - Output: Dense layer for logits

    Input Shape:
        (batch_size, img_height, img_width, img_channel)

    Output Shape:
        (width, batch_size, num_class) - Transposed for CTC loss compatibility

    Note:
        - Input image dimensions must satisfy: height % 16 == 0 and width % 4 == 0
        - Uses bidirectional LSTMs for better context understanding
    """

    def __init__(self, img_channel, img_height, img_width, num_class,
                 map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False, **kwargs):
        super().__init__(**kwargs)

        # Store model configuration parameters
        self.img_channel = img_channel
        self.img_height = img_height
        self.img_width = img_width
        self.num_class = num_class
        self.map_to_seq_hidden = map_to_seq_hidden
        self.rnn_hidden = rnn_hidden
        self.leaky_relu = leaky_relu

        # Build CNN backbone for feature extraction
        self.cnn = self._build_cnn_backbone(img_channel, leaky_relu)

        # Dense layer to map CNN features to sequence format for RNN
        # Map CNN features (flattened per time step) → hidden
        self.fc_map = layers.Dense(self.map_to_seq_hidden, name="map_to_seq")

        # Two bidirectional LSTM layers for sequence modeling
        self.rnn1 = layers.Bidirectional(
            layers.LSTM(self.rnn_hidden, return_sequences=True), name="bilstm_1")
        self.rnn2 = layers.Bidirectional(
            layers.LSTM(self.rnn_hidden, return_sequences=True), name="bilstm_2")

        # Output layer for class probability prediction
        # (no activation; CTC expects logits)
        self.classifier = layers.Dense(self.num_class, name="logits")


    def _act(self):
        # Returns the correct activation layer depending on whether leaky_relu=True
        return layers.LeakyReLU(alpha=0.2) if self.leaky_relu else layers.ReLU()

    def _build_cnn_backbone(self, img_channel, leaky_relu):
        """
        Build the CNN backbone for feature extraction.

        The CNN consists of 7 convolutional layers with pooling and batch normalization.
        It progressively reduces spatial dimensions while increasing feature depth.

        Args:
            img_channel (int): Number of input image channels
            leaky_relu (bool): Whether to use LeakyReLU activation

        Returns:
            tf.keras.Sequential: CNN backbone model

        Architecture Details:
            - Layers 0-1: 64 and 128 filters with 2x2 max pooling
            - Layers 2-3: 256 filters with special (2,1) pooling
            - Layers 4-5: 512 filters with batch normalization and (2,1) pooling
            - Layer 6: 512 filters with valid padding (reduces width by 1)

        Note:
            - Input height must be divisible by 16 due to pooling strides
            - Input width must be divisible by 4 due to pooling strides
        """
        # Validate input dimensions meet architectural requirements
        assert self.img_height % 16 == 0
        assert self.img_width % 4 == 0

        model = tf.keras.Sequential()

        # Define input shape
        model.add(layers.InputLayer(input_shape=(self.img_height, self.img_width, img_channel)))

        # Layer 0: 64 filters, 3x3 convolution, ReLU/LeakyReLU, 2x2 max pooling
        model.add(layers.Conv2D(64, 3, padding='same', use_bias=True))
        model.add(self._act())
        model.add(layers.MaxPool2D(pool_size=(2,2), strides=(2,2))) # (H/2, W/2, 64)

        # Layer 1: 128 filters, 3x3 convolution, ReLU/LeakyReLU, 2x2 max pooling
        model.add(layers.Conv2D(128, 3, padding='same', use_bias=True))
        model.add(self._act())
        model.add(layers.MaxPool2D(pool_size=(2,2), strides=(2,2))) # (H/4, W/4, 128)

        # Layer 2: 256 filters, 3x3 convolution, ReLU/LeakyReLU (no pooling)
        model.add(layers.Conv2D(256, 3, padding='same', use_bias=True))
        model.add(self._act()) # (H/4, W/4, 256)

        # Layer 3: 256 filters, 3x3 convolution, ReLU/LeakyReLU, (2,1) max pooling
        model.add(layers.Conv2D(256, 3, padding='same', use_bias=True))
        model.add(self._act())
        model.add(layers.MaxPool2D(pool_size=(2,1), strides=(2,1))) # (H/8, W/4, 256)

        # Layer 4: 512 filters with batch normalization, ReLU/LeakyReLU
        model.add(layers.Conv2D(512, 3, padding='same', use_bias=True))
        model.add(layers.BatchNormalization())
        model.add(self._act()) # (H/8, W/4, 512)

        # Layer 5: 512 filters with batch normalization, ReLU/LeakyReLU, (2,1) max pooling
        model.add(layers.Conv2D(512, 3, padding='same', use_bias=True))
        model.add(layers.BatchNormalization())
        model.add(self._act())
        model.add(layers.MaxPool2D(pool_size=(2,1), strides=(2,1))) # (H/16, W/4, 512)

        # Layer 6: 512 filters, 2x2 convolution with valid padding, ReLU/LeakyReLU
        model.add(layers.Conv2D(512, 2, padding='valid', use_bias=True))
        model.add(self._act()) # (H/16 - 1, W/4 - 1, 512)

        return model

    def build(self, input_shape):
        """
        Build the model by explicitly constructing all layers.

        This method ensures all layers are properly built with the correct
        input shapes before the first forward pass.

        Args:
            input_shape: Shape of the input tensor
        """
        # Explicitly build each component to ensure proper weight initialization
        super(CRNN, self).build(input_shape)

    def call(self, images, training=False):
        """
        Forward pass of the CRNN model.

        Processes input images through CNN feature extraction, sequence mapping,
        RNN sequence modeling, and final classification layers.

        Args:
            images (tf.Tensor): Input image tensor of shape
                               (batch_size, height, width, channels)
            training (bool, optional): Whether the model is in training mode.
                                      Defaults to False.

        Returns:
            tf.Tensor: Output tensor of shape (width, batch_size, num_class)
                       Transposed for compatibility with CTC loss functions

        Processing Steps:
            1. CNN feature extraction
            2. Map to sequence with dense layer
            3. Process through two bidirectional LSTM layers
            4. Apply dense layer for classification

        Note: You may need to reshape to match dimension
        """
        # CNN feature extraction: (batch, height, width, channels)
        output = self.cnn(images, training=training)
        # Reshape for sequence processing along width:
        # (B, W', H'*C) — time dimension is width
        B, Hp, Wp, C = tf.unstack(tf.shape(output))

        # transpose to (B, W', H', C) then flatten H' and C
        output = tf.transpose(output, [0, 2, 1, 3]) # (B, W', H', C)
        output = tf.reshape(output, [B, Wp, -1]) # (B, W', H'*C)

        # Map CNN features to sequence format
        output = self.fc_map(output) # (B, W', map_to_seq_hidden)

        # Process through bidirectional LSTM layers
        # RNN layers (keep sequence)
        output = self.rnn1(output, training=training) # (B, W', 2*rnn_hidden)
        output = self.rnn2(output, training=training) # (B, W', 2*rnn_hidden)

        # Final classification layer: (batch, width, num_class)
        # Classifier to logits
        output = self.classifier(output) # (B, W', num_class)
        # Transpose to time-major for CTC: (T, B, C)
        output = tf.transpose(output, [1, 0, 2])
        return output


In [None]:

crnn = CRNN(1, config["img_height"], config["img_width"], len(Synth90kDataset.LABEL2CHAR) + 1,
                map_to_seq_hidden=config['map_to_seq_hidden'],
                rnn_hidden=config['rnn_hidden'],
                leaky_relu=config['leaky_relu'])

crnn.build(input_shape=(None, config["img_height"], config["img_height"], 1))
crnn.summary()

crnn.cnn.summary()

# `CTCLoss` Class

In [None]:
class CTCLoss(tf.keras.losses.Loss):
    """
    Connectionist Temporal Classification (CTC) Loss for sequence recognition tasks.

    This loss function is designed for sequence-to-sequence tasks where the alignment
    between inputs and labels is unknown, such as speech recognition or handwritten
    text recognition. It computes the negative log probability of the correct sequence.

    The CTC loss handles variable-length sequences by considering all possible alignments
    between the input sequence (model predictions) and the target sequence (labels).

    Inherits from:
        tf.keras.losses.Loss

    Typical Use Cases:
        - Optical Character Recognition (OCR)
        - Speech Recognition
        - Handwriting Recognition
        - Any sequence labeling task with unaligned sequences

    Note:
        - Assumes blank index is 0 (standard for CTC)
        - Handles variable-length sequences automatically
        - Uses logits (unnormalized scores) rather than probabilities
        - Computes mean loss across the batch
    """

    def call(self, y_true, y_pred):
        """
        Compute the CTC loss between true labels and predicted logits.

        This method prepares the inputs for TensorFlow's CTC loss function by
        calculating the necessary sequence lengths and formatting the data
        appropriately for the CTC algorithm.

        Args:
            y_true (tf.Tensor): True labels tensor of shape (batch_size, max_label_length)
                               containing integer class labels. Padded with zeros for
                               shorter sequences.
            y_pred (tf.Tensor): Predicted logits tensor of shape
                               (batch_size, sequence_length, num_classes) containing
                               unnormalized scores for each class at each time step.

        Returns:
            tf.Tensor: Scalar tensor representing the mean CTC loss across the batch.

        Note:
            - The blank index (for CTC transitions) is set to 0
            - Label length is computed by counting non-zero elements in y_true
            - Input length is assumed to be the full sequence length of predictions
            - Logits are expected in batch-major format (batch, time, classes)
        """

        # Compute CTC loss using TensorFlow's built-in function `tf.nn.ctc_loss`
        # lengths
        y_true = tf.cast(y_true, tf.int32)
        label_length = tf.math.count_nonzero(y_true, axis=1, dtype=tf.int32) # (B,)
        B = tf.shape(y_pred)[0]
        T = tf.shape(y_pred)[1]
        logit_length = tf.fill([B], T) # (B,)

        # Convert dense padded labels -> SparseTensor (drop pads)
        # (only keep the first label_length tokens of each row)
        ragged = tf.RaggedTensor.from_tensor(y_true, lengths=label_length) # [B, None]
        labels_sparse = ragged.to_sparse() # SparseTensor

        # CTC loss expects logits (NOT softmax). We pass batch-major logits.
        loss = tf.nn.ctc_loss(
            labels=labels_sparse,
            logits=y_pred,
            label_length=label_length,
            logit_length=logit_length,
            logits_time_major=False, # because y_pred is (B, T, C)
            blank_index=0
        )

        # Return mean loss across all samples in the batch
        return tf.reduce_mean(loss)

# CTC decoders

In [None]:
from collections import defaultdict
from scipy.special import logsumexp

# Constants
NINF = -1 * float('inf')  # Negative infinity for log probability calculations
DEFAULT_EMISSION_THRESHOLD = 0.01  # Default threshold for emission probabilities

def _reconstruct(labels, blank=0):
    """
    Reconstruct labels by removing consecutive duplicates and blank tokens.

    This function applies the CTC collapse rules to remove repeated characters
    and blank tokens from a sequence, producing the final decoded output.

    Args:
        labels (list): Sequence of integer labels containing duplicates and blanks
        blank (int, optional): Index of the blank token. Defaults to 0.

    Returns:
        list: Reconstructed sequence with consecutive duplicates removed and blanks filtered out

    Example:
        >>> _reconstruct([0, 1, 1, 0, 2, 2, 2, 0, 3], blank=0)
        [1, 2, 3]
        >>> _reconstruct([1, 1, 2, 2, 1, 1], blank=0)
        [1, 2, 1]

    Note:
        - Implements the standard CTC collapse rule: remove duplicates and blanks
        - Only consecutive duplicates are removed, non-consecutive duplicates remain
    """
    new_labels = []
    previous = None
    # Remove consecutive duplicates
    for l in labels:
        if l != previous:
            new_labels.append(l)
            previous = l
    # Remove blank tokens
    new_labels = [l for l in new_labels if l != blank]
    return new_labels

def greedy_decode(emission_log_prob, blank=0, **kwargs):
    """
    Greedy decoding for CTC sequences.

    This decoder selects the most probable character at each time step independently,
    then applies CTC reconstruction rules. It's fast but may not find the optimal sequence.

    Args:
        emission_log_prob (np.ndarray): Log probability matrix of shape (length, num_classes)
        blank (int, optional): Index of the blank token. Defaults to 0.
        **kwargs: Additional arguments (ignored for greedy decoding)

    Returns:
        list: Decoded sequence of labels

    Time Complexity: O(T × C) where T is sequence length, C is number of classes
    Space Complexity: O(T)

    Example:
        >>> emission = np.log([[0.1, 0.7, 0.2], [0.3, 0.4, 0.3], [0.8, 0.1, 0.1]])
        >>> greedy_decode(emission, blank=0)
        [1, 1, 0]  # After reconstruction: [1]

    Note:
        - Fastest decoding method but suboptimal for difficult sequences
        - Does not consider multiple possible alignments
        - Suitable for when the probability distribution is peaked
    """

    # Select most probable character at each time step
    best = np.argmax(emission_log_prob, axis=1).tolist() #length T
    # Apply CTC reconstruction rules using _reconstruct()
    labels = _reconstruct(best, blank=blank)

    return labels

def beam_search_decode(emission_log_prob, blank=0, **kwargs):
    """
    Beam search decoding for CTC sequences.

    This decoder maintains multiple candidate sequences (beams) and expands them
    at each time step, keeping only the top-k most probable sequences.

    Args:
        emission_log_prob (np.ndarray): Log probability matrix of shape (length, num_classes)
        blank (int, optional): Index of the blank token. Defaults to 0.
        **kwargs: Additional arguments including:
            beam_size (int): Number of beams to maintain
            emission_threshold (float, optional): Log probability threshold for pruning

    Returns:
        list: Most probable decoded sequence

    Time Complexity: O(T × C × B) where T=length, C=classes, B=beam_size
    Space Complexity: O(B × T)

    Example:
        >>> emission = np.log([[0.1, 0.7, 0.2], [0.3, 0.4, 0.3], [0.8, 0.1, 0.1]])
        >>> beam_search_decode(emission, blank=0, beam_size=5)
        [1, 2]  # Best sequence found by beam search

    Note:
        - More accurate than greedy decoding but slower
        - beam_size controls the trade-off between accuracy and speed
        - emission_threshold prunes low-probability paths to improve efficiency
    """
    beam_size = kwargs['beam_size']
    emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD))

    length, class_count = emission_log_prob.shape

    # Initialize with empty sequence and zero log probability
    beams = [([], 0)]  # Each beam is (sequence, accumulated_log_prob)

    # Iterate through each time step
    for t in range(length):
        new_beams = []
        # Expand each existing beam with all possible next characters
        for prefix, accumulated_log_prob in beams:
            for c in range(class_count):
                log_prob = emission_log_prob[t, c]
                # Prune paths with very low emission probability
                if log_prob < emission_threshold:
                    continue
                # Create new sequence and update probability
                new_prefix = prefix + [c]
                new_accu_log_prob = accumulated_log_prob + log_prob
                new_beams.append((new_prefix, new_accu_log_prob))

        # Keep only the top-k beams by probability
        new_beams.sort(key=lambda x: x[1], reverse=True)
        beams = new_beams[:beam_size]

    # Merge beams that produce the same reconstructed sequence
    total_accu_log_prob = {}
    for prefix, accu_log_prob in beams:
        labels = tuple(_reconstruct(prefix, blank))
        # Use logsumexp to combine probabilities of different alignments to same sequence
        total_accu_log_prob[labels] = \
            logsumexp([accu_log_prob, total_accu_log_prob.get(labels, NINF)])

    # Select the sequence with highest probability
    labels_beams = [(list(labels), accu_log_prob)
                    for labels, accu_log_prob in total_accu_log_prob.items()]
    labels_beams.sort(key=lambda x: x[1], reverse=True)
    labels = labels_beams[0][0]

    return labels

def prefix_beam_decode(emission_log_prob, blank=0, **kwargs):
    """
    Prefix beam search decoding for CTC sequences.

    Args:
        emission_log_prob (np.ndarray): Log probability matrix of shape (length, num_classes)
        blank (int, optional): Index of the blank token. Defaults to 0.
        **kwargs: Additional arguments including:
            beam_size (int): Number of beams to maintain
            emission_threshold (float, optional): Log probability threshold for pruning

    Returns:
        list: Most probable decoded sequence

    ------------------------------------------------------------------------
    HOW THIS ALGORITHM WORKS:

    1. Each prefix (partial output sequence) keeps TWO log-probabilities:
         - p_b(prefix): probability of all paths that end with a BLANK
         - p_nb(prefix): probability of all paths that end with a NON-BLANK

    2. At each time step t, for every active prefix:
         (a) Emit BLANK:
             - Stay on the same prefix (no new label appended).
             - p_b'(prefix) += log_p[blank] + logsumexp(p_b, p_nb)

         (b) Emit NON-BLANK label c:
             - If c == last(prefix):
                 # same label as last
                 - repeat without appending:
                     p_nb'(prefix) += log_p[c] + p_nb(prefix)
                 - append same label again (requires blank before):
                     p_nb'(prefix + c) += log_p[c] + p_b(prefix)
             - Else (c != last(prefix)):
                 - append new label:
                     p_nb'(prefix + c) += log_p[c] + logsumexp(p_b, p_nb)

    3. All updates are accumulated with logsumexp because multiple
       paths can land on the same (prefix, ending state).

    4. After processing all labels for time t:
         - Keep only the top `beam_size` prefixes ranked by
           logsumexp(p_b, p_nb).

    5. After the final time step:
         - Choose the prefix with the largest total probability
           logsumexp(p_b, p_nb) as the best decoded sequence.

    Key points:
      - Initialization: p_b(()) = 0, p_nb(()) = -inf
      - Must separate blank vs non-blank ending states
      - Handle repeated labels carefully
      - Use logsumexp for combining probabilities
      - Prune to beam_size at each time step

    Complexity:
      - O(T × C × B) time where T=time steps, C=classes, B=beam size
      - O(B) space (top beams at each step)
    """

    # Initialize beams with empty prefix and (p_b=0, p_nb=-inf)
    # For each time step:
    #   - create defaultdict for new beams
    #   - loop over current beams
    #   - update new beams for blank and each label
    #   - prune to top beam_size
    # Return the prefix with max total probability

    beam_size = kwargs['beam_size']
    emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD))

    length, class_count = emission_log_prob.shape

    # Beams store (prefix, (prob_blank, prob_non_blank))
    beams = [(tuple(), (0, NINF))]  # Start with empty prefix

    for t in range(length):
        new_beams_dict = defaultdict(lambda: (NINF, NINF))

        for prefix, (lp_b, lp_nb) in beams:
            for c in range(class_count):
                log_prob = emission_log_prob[t, c]
                if log_prob < emission_threshold:
                    continue

                end_t = prefix[-1] if prefix else None
                new_lp_b, new_lp_nb = new_beams_dict[prefix]

                if c == blank:
                    # Extending with blank: both blank and non-blank can transition to blank
                    new_lp_b = logsumexp([new_lp_b, log_prob + logsumexp([lp_b, lp_nb])])
                    new_beams_dict[prefix] = (new_lp_b, new_lp_nb)
                    
                    continue

                if c == end_t:
                    # Extending with same character as last: only non-blank can extend
                    new_lp_nb = logsumexp([new_lp_nb, log_prob + lp_nb])
                    new_beams_dict[prefix] = (new_lp_b, new_lp_nb)

                new_prefix = prefix + (c,)
                new_lp_b, new_lp_nb = new_beams_dict[new_prefix]

                if c != end_t:
                    # New character: both blank and non-blank can extend to new character
                    new_lp_nb = logsumexp([new_lp_nb, log_prob + logsumexp([lp_b, lp_nb])])
                    new_beams_dict[new_prefix] = (new_lp_b, new_lp_nb)

                else:
                    # Same character but not consecutive: only blank can extend
                    new_lp_nb = logsumexp([new_lp_nb, log_prob + lp_b])
                    new_beams_dict[new_prefix] = (new_lp_b, new_lp_nb)

        # Select top beams by total probability (blank + non-blank)
        beams = sorted(new_beams_dict.items(), key=lambda x: logsumexp(x[1]), reverse=True)
        beams = beams[:beam_size]

    # Return the best prefix
    labels = list(beams[0][0])
    return labels

def ctc_decode(
    log_probs: np.ndarray,
    label2char: dict | None = None,
    blank: int = 0,
    method: str = "beam_search",
    beam_size: int = 10,
    input_lengths: np.ndarray | None = None,
    return_strings: bool = False,
):
    """
    Decode a batch of log-prob sequences (B, T, C) into label sequences.
    - Iterates over the batch dimension (B), passing (Ti, C) to the decoder.
    - Supports optional per-sample input_lengths (Ti). If None, uses full T.
    - Normalizes different decoder return shapes (tuple, beams, single seq).
    - If label2char and return_strings=True, also returns strings.

    Returns:
        preds  (List[np.ndarray[int]] or List[str] if return_strings=True)
        (If your decoders expose scores and you need them, extend this to also return scores.)
    """
    # Expected shape: (B, T, C)
    emissions = np.asarray(log_probs)
    assert emissions.ndim == 3, f"log_probs must be 3D (B,T,C), got {emissions.shape}"
    B, T, C = emissions.shape

    # Choose decoder
    decoders = {
        "greedy": greedy_decode,
        "beam_search": beam_search_decode,
        "prefix_beam_search": prefix_beam_decode,
    }
    if method not in decoders:
        raise KeyError(f"Unknown decoding method: {method}. Available: {list(decoders.keys())}")
    decoder = decoders[method]

    # Per-sample lengths
    if input_lengths is None:
        input_lengths = np.full((B,), T, dtype=int)
    else:
        input_lengths = np.asarray(input_lengths).astype(int)
        assert input_lengths.shape == (B,), f"input_lengths must be shape (B,), got {input_lengths.shape}"

    preds = []
    for b in range(B):
        Ti = int(input_lengths[b])
        Ti = max(0, min(Ti, T))  # clamp
        emission_bt = emissions[b, :Ti, :]  # (Ti, C)

        # Call decoder. Different implementations may return:
        #  - sequence only,
        #  - (sequence, score),
        #  - [beam1, beam2, ...], possibly with scores elsewhere.
        out = decoder(emission_bt, blank=blank, beam_size=beam_size)

        # Normalize to a single best sequence of ints
        if isinstance(out, tuple) and len(out) >= 1:
            seq = out[0]
        else:
            seq = out

        # If beams (list of sequences), take top-1
        if isinstance(seq, list):
            if len(seq) == 0:
                seq = np.asarray([], dtype=int)
            else:
                # If it's a list of ints -> good; if it's a list of lists -> take first beam
                if isinstance(seq[0], (list, np.ndarray)):
                    seq = np.asarray(seq[0], dtype=int)
                else:
                    seq = np.asarray(seq, dtype=int)
        elif isinstance(seq, np.ndarray):
            seq = seq.astype(int)
        else:
            # Scalar or unknown -> coerce
            seq = np.asarray(seq, dtype=int)

        preds.append(seq)

    if label2char is not None and return_strings:
        preds_str = []
        for seq in preds:
            # only map known labels (skip blank and unseen ids)
            s = "".join(label2char.get(int(i), "") for i in seq if int(i) != blank)
            preds_str.append(s)
        return preds_str

    return preds

# Training

In [None]:
from typing import Any, Dict, List, Iterable, Tuple, Optional
from tqdm import tqdm

def evaluate(
    model: tf.keras.Model,
    dataloader,
    criterion,
    max_iter: int | None = None,
    decode_method: str = "beam_search",
    beam_size: int = 10,
) -> Dict[str, float | List[Tuple[str, str]]]:
    """
    Evaluate the model on a given dataloader.

    Returns:
        - 'loss': float
            Average loss over all evaluated samples.
        - 'acc': float
            Sequence-level (exact match) accuracy.
        - 'wrong_cases': List[Tuple[str, str]]
            Each element is a tuple (ground_truth_text, predicted_text)
            for a sample where the prediction did not match exactly.
    """
    total_count = 0
    total_loss = 0.0
    total_correct = 0
    wrong_cases: List[Tuple[str, str]] = []

    total = max_iter if max_iter < len(list(dataloader)) else len(list(dataloader))
    pbar = tqdm(total=total, desc="Evaluate")

    for i, batch in enumerate(dataloader):
        if max_iter is not None and i >= max_iter:
            break
        # Forward pass: CRNN returns time-major (T,B,C) → make batch-major (B,T,C)
        images, targets, target_lengths = batch
        B = int(images.shape[0])
        logits_TBC = model(images, training=False)
        logits_BTC = tf.transpose(logits_TBC, [1, 0, 2]) # [B,T,C]

        # Compute CTCLoss (batch-major)
        batch_loss = criterion(y_true=targets, y_pred=logits_BTC)

        # Decode predictions (ctc_decode expects log-probs)
        log_probs = tf.nn.log_softmax(logits_BTC, axis=-1).numpy()  # (B,T,C)
        input_lengths = np.full((B,), log_probs.shape[1], dtype=int)
        preds_str = ctc_decode(
            log_probs,
            label2char=Synth90kDataset.LABEL2CHAR,
            blank=0,
            method=decode_method,
            beam_size=beam_size,
            input_lengths=input_lengths,
            return_strings=True,
        )

        # Ground-truth (strip padding via target_lengths)
        gts_str = []
        for b in range(B):
            L = int(target_lengths[b].numpy() if isinstance(target_lengths, tf.Tensor) else target_lengths[b])
            tgt_row = targets[b].numpy() if isinstance(targets, tf.Tensor) else targets[b]
            s = "".join(Synth90kDataset.LABEL2CHAR.get(int(x), "") for x in tgt_row[:L] if int(x) != 0)
            gts_str.append(s)

        # Compare predictions with ground truth & update stats (sequence-level accuracy, only count as correct if the whole sequence is correct)
        for gt, pred in zip(gts_str, preds_str):
            total_count += 1
            if gt == pred:
                total_correct += 1
            else:
                wrong_cases.append((gt, pred))

        # Accumulate loss as sum over samples (so we can divide by total_count later)
        total_loss += float(batch_loss.numpy()) * B

        pbar.update(1)

    pbar.close()

    accuracy = total_correct / total_count if total_count > 0 else 0.0
    avg_loss = total_loss / total_count if total_count > 0 else 0.0
    return {"loss": avg_loss, "acc": accuracy, "wrong_cases": wrong_cases}

In [None]:
from __future__ import annotations
import abc, os
from typing import Any, Dict, Optional, Iterable, Tuple
import tensorflow as tf

class BaseTrainer(abc.ABC):
    """Minimal reusable trainer skeleton (SRP-friendly).

    Subclasses override:
      - build_datasets()
      - build_model()
      - build_optimizer()
      - build_loss()
      - train_step(batch)
      - evaluate_model()

    Optional hooks:
      - on_epoch_start(epoch), on_epoch_end(epoch), on_step_end(step, loss)
    """

    def __init__(self, cfg: Dict[str, Any]):
        self.cfg = cfg
        self.global_step = 1

        self.model: Optional[tf.keras.Model] = None
        self.optimizer: Optional[tf.keras.optimizers.Optimizer] = None
        self.criterion = None

        self.train_ds: Optional[Iterable] = None
        self.valid_ds: Optional[Iterable] = None
        self.metrics = {"train_loss": tf.keras.metrics.Mean(name="train_loss")}

        # Checkpoint management
        self.ckpt = None
        self.ckpt_manager = None

    # ----------------- build phase -----------------
    @abc.abstractmethod
    def build_datasets(self) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
        """Return (train_ds, valid_ds)."""
        raise NotImplementedError

    @abc.abstractmethod
    def build_model(self) -> tf.keras.Model:
        raise NotImplementedError

    @abc.abstractmethod
    def build_optimizer(self) -> tf.keras.optimizers.Optimizer:
        raise NotImplementedError

    @abc.abstractmethod
    def build_loss(self):
        raise NotImplementedError

    # ----------------- train/eval steps -----------------
    @abc.abstractmethod
    @tf.function
    def train_step(self, batch) -> tf.Tensor:
        raise NotImplementedError

    @abc.abstractmethod
    def evaluate_model(self) -> Dict[str, Any]:
        raise NotImplementedError

    # ----------------- optional helpers/hooks -----------------
    def forward(self, images: tf.Tensor, training: bool) -> tf.Tensor:
        return self.model(images, training=training)

    def compute_loss(self, logits: tf.Tensor, batch) -> tf.Tensor:
        raise NotImplementedError

    def on_epoch_start(self, epoch: int) -> None:
        pass

    def on_epoch_end(self, epoch: int) -> None:
        pass

    def on_step_end(self, step: int, loss_value: float) -> None:
        pass

    # ----------------- environment/setup -----------------
    def seed_and_determinism(self) -> None:
        tf.random.set_seed(self.cfg.get("seed", 42))
        try:
            tf.config.experimental.enable_op_determinism()
        except Exception:
            pass

    def setup_checkpoint(self) -> None:
        os.makedirs(self.cfg['checkpoints_dir'], exist_ok=True)
        self.ckpt = tf.train.Checkpoint(
            step=tf.Variable(1, dtype=tf.int64),
            model=self.model,
            optimizer=self.optimizer,
            global_step=tf.Variable(1, dtype=tf.int64),
        )
        self.ckpt_manager = tf.train.CheckpointManager(
            checkpoint=self.ckpt,
            directory=self.cfg['checkpoints_dir'],
            max_to_keep=self.cfg.get('max_checkpoints_to_keep', 3),
        )

    def save_checkpoint(self, step: int, eval_stats: Optional[Dict[str, Any]] = None) -> str:
        if self.ckpt_manager is None:
            raise RuntimeError("Checkpoint manager not initialized. Call setup_checkpoint() first.")
        self.ckpt.step.assign(step)
        self.ckpt.global_step.assign(self.global_step)
        save_path = self.ckpt_manager.save(checkpoint_number=step)
        print(f"[Checkpoint] Saved: {save_path} (global_step: {self.global_step})")

        if self.cfg.get('save_h5_weights', False):
            h5_path = f"{self.cfg['checkpoints_dir']}/model_{step:06d}.weights.h5"
            self.model.save_weights(h5_path)
            print(f"[Checkpoint] H5 weights saved: {h5_path}")
        return save_path

    def restore_checkpoint(self, checkpoint_path: Optional[str] = None) -> bool:
        if self.ckpt is None:
            raise RuntimeError("Checkpoint not initialized. Call setup_checkpoint() first.")
        try:
            if checkpoint_path:
                self.ckpt.restore(checkpoint_path)
                print(f"[Checkpoint] Restored from specific checkpoint: {checkpoint_path}")
            else:
                if self.ckpt_manager.latest_checkpoint:
                    self.ckpt.restore(self.ckpt_manager.latest_checkpoint)
                    self.global_step = int(self.ckpt.global_step.numpy())
                    print(f"[Checkpoint] Restored from latest: {self.ckpt_manager.latest_checkpoint}")
                    print(f"[Checkpoint] Resuming from global_step: {self.global_step}")
                else:
                    print("[Checkpoint] No checkpoint found. Starting from scratch.")
                    return False
            return True
        except Exception as e:
            print(f"[Checkpoint] Failed to restore checkpoint: {e}")
            return False

    def get_checkpoint_info(self) -> Dict[str, Any]:
        if self.ckpt_manager is None:
            return {"available": False}
        return {
            "available": bool(self.ckpt_manager.latest_checkpoint),
            "latest_checkpoint": self.ckpt_manager.latest_checkpoint,
            "checkpoints": self.ckpt_manager.checkpoints if hasattr(self.ckpt_manager, 'checkpoints') else [],
        }

    # ----------------- explicit lifecycle steps -----------------
    def prepare(self) -> None:
        """Seed & build all components; setup checkpoint manager."""
        self.seed_and_determinism()
        self.train_ds, self.valid_ds = self.build_datasets()
        self.model = self.build_model()
        self.optimizer = self.build_optimizer()
        self.criterion = self.build_loss()
        self.setup_checkpoint()

    def restore_from_cfg(self) -> bool:
        """Read restoration policy from cfg and restore."""
        if self.ckpt is None or self.ckpt_manager is None:
            raise RuntimeError("Call prepare() before restore.")
        if self.cfg.get("resume_training", False):
            return self.restore_checkpoint()
        if self.cfg.get("checkpoint_path"):
            return self.restore_checkpoint(self.cfg["checkpoint_path"])
        return False

    def restore_from_path(self, checkpoint_path: str) -> bool:
        """Programmatic restore entrypoint (SRP)."""
        if self.ckpt is None or self.ckpt_manager is None:
            raise RuntimeError("Call prepare() before restore.")
        return self.restore_checkpoint(checkpoint_path)

    # ----------------- training only -----------------
    def train_loop(self) -> None:
        """Pure training loop. Assumes prepare() (and optional restore) already ran."""
        if any(x is None for x in [self.train_ds, self.valid_ds, self.model, self.optimizer, self.criterion]):
            raise RuntimeError("Trainer not prepared. Call prepare() first.")
        epochs = self.cfg["epochs"]
        show_interval = self.cfg["show_interval"]
        valid_interval = self.cfg["valid_interval"]
        save_interval = self.cfg["save_interval"]

        for epoch in range(1, epochs + 1):
            print(f"\n===== Epoch {epoch} =====")
            self.metrics["train_loss"].reset_state()
            self.on_epoch_start(epoch)

            for batch in self.train_ds:
                loss = self.train_step(batch)
                self.metrics["train_loss"](loss)

                if self.global_step % show_interval == 0:
                    avg = float(self.metrics["train_loss"].result().numpy())
                    print(f"[Step {self.global_step}] batch loss: {float(loss.numpy()):.4f} | running avg loss: {avg:.4f}")

                if self.global_step % valid_interval == 0:
                    eval_stats = self.evaluate_model()
                    if self.global_step % save_interval == 0:
                        self.save_checkpoint(self.global_step, eval_stats)

                self.on_step_end(self.global_step, float(loss.numpy()))
                self.global_step += 1

            print(f"Epoch {epoch} train loss: {float(self.metrics['train_loss'].result().numpy()):.4f}")
            self.on_epoch_end(epoch)

            if self.cfg.get('save_every_epoch', False):
                self.save_checkpoint(self.global_step)

    # ----------------- thin facade (optional) -----------------
    def fit(self, *, auto_prepare: bool = True) -> None:
        """
        SRP-compliant: training only.
        - If auto_prepare=True: runs prepare() (no restore) before training.
        """
        if auto_prepare:
            self.prepare()
        self.train_loop()


In [None]:
class CRNNTrainer(BaseTrainer):
    """
    CRNN + CTC trainer with small functions.

    Key TODOs:
      1) build_datasets()
      2) build_model()
      3) build_optimizer()
      4) build_loss()
      5) compute_loss()
      6) train_step()
      7) evaluate_model()
    """

    def __init__(self, cfg: Dict[str, Any]):
        super().__init__(cfg)
        self.img_h = cfg["img_height"]
        self.img_w = cfg["img_width"]
        self.reload_checkpoint = cfg.get("reload_checkpoint", "")

    # ----------------- BUILD PHASE -----------------
    def build_datasets(self) -> Tuple[tf.Data.Dataset, tf.Data.Dataset]:
        """
        Create train/valid tf.data.Datasets.
        Must yield tuples (images, targets, target_lengths).
          images: float32 [B, H, W, 1]
          targets: int32  [B, T_label]
          target_lengths: int32 [B]
        """
        data_dir = self.cfg["data_dir"]
        train_batch = self.cfg["train_batch_size"]
        eval_batch = self.cfg["eval_batch_size"]

        # Prepare datasets with correct shapes/dtypes
        data_dir   = self.cfg["data_dir"]
        train_batch = int(self.cfg["train_batch_size"])
        eval_batch  = int(self.cfg["eval_batch_size"])

        train_set = Synth90kDataset(
            root_dir=data_dir,
            mode=DatasetMode.TRAIN,
            img_height=self.img_h,
            img_width=self.img_w,
        )
        valid_set = Synth90kDataset(
            root_dir=data_dir,
            mode=DatasetMode.VAL,
            img_height=self.img_h,
            img_width=self.img_w,
        )

        train_ds = train_set.create_tf_dataset(batch_size=train_batch, training=True)
        valid_ds = valid_set.create_tf_dataset(batch_size=eval_batch,  training=False)

        return train_ds, valid_ds

    def build_model(self) -> tf.keras.Model:
        """
        Instantiate CRNN with correct num_class and hyperparams.
        Sanity check by running one batch and printing shapes.
        """
        num_class = int(self.cfg.get("num_class",
            self.cfg.get("num_classes", len(Synth90kDataset.CHARS) + 1)))
        print(f"[Build] num_classes={num_class}")
        model = CRNN(
            img_channel=1,
            img_height=self.img_h,
            img_width=self.img_w,
            num_class=num_class,
            map_to_seq_hidden=int(self.cfg.get("map_to_seq_hidden", 64)),
            rnn_hidden=int(self.cfg.get("rnn_hidden", 256)),
            leaky_relu=bool(self.cfg.get("leaky_relu", False)),
        )

        # eager sanity pass on a single batch
        sample_images, _, _ = next(iter(self.train_ds))
        out = model(sample_images)
        print(f"[Sanity] input: {sample_images.shape} -> logits: {out.shape}")

        # model summaries
        model.build(input_shape=(None, self.img_h, self.img_w, 1))
        model.summary()
        if hasattr(model, "cnn"):
            print("\n[Backbone] CNN")
            model.cnn.summary()

        # load weights if specified
        if self.reload_checkpoint:
            weights_path = self.reload_checkpoint.replace(".pt", ".h5")
            try:
                model.load_weights(weights_path)
                print(f"[Init] Loaded weights from {weights_path}")
            except Exception as e:
                print(f"[Init] Failed to load weights from {weights_path}: {e}")

        return model

    def build_optimizer(self) -> tf.keras.optimizers.Optimizer:
        """
        Choose and configure optimizer.
        """
        lr = float(self.cfg["lr"])
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
        return optimizer

    def build_loss(self):
        """
        Return CTCLoss() instance or a thin wrapper.
        """
        criterion = CTCLoss()
        return criterion

    # ----------------- HELPER OVERRIDES -----------------

    def compute_loss(self, logits: tf.Tensor, batch) -> tf.Tensor:
        """
        Implement CTC loss call.

        Inputs:
          logits: float32 [B, T, C] (batch-major)
          batch: (images, targets, target_lengths)
            targets: int32 [B, U] with -1 or padding tokens (depends on dataset)
            target_lengths: int32 [B]

        CTCLoss in this project expects batch-major logits [B,T,C].
        """
        _, targets, _ = batch
        loss = self.criterion(y_true=targets, y_pred=logits)
        return loss

    # ----------------- TRAIN / EVAL -----------------
    @tf.function
    def train_step(self, batch) -> tf.Tensor:
        """
        TODO:
        - run forward pass (training=True)
        - compute loss
        - backprop: grads = tape.gradient(...)
        - clip gradients
        - optimizer.apply_gradients(...)
        - return loss
        """
        images, _, _ = batch
        with tf.GradientTape() as tape:
            # Model returns time-major (T,B,C); transpose to batch-major (B,T,C)
            logits_time_major = self.model(images, training=True)
            logits = tf.transpose(logits_time_major, [1, 0, 2])
            loss = self.compute_loss(logits, batch)

        grads = tape.gradient(loss, self.model.trainable_variables)
        clipped_grads, _ = tf.clip_by_global_norm(grads, 5.0)
        self.optimizer.apply_gradients(zip(clipped_grads, self.model.trainable_variables))
        return loss

    def evaluate_model(self) -> Dict[str, Any]:
        """
        Use provided evaluate() helper OR write a small loop that:
          - runs forward on valid_ds
          - accumulates CTCLoss
          - optionally decodes for accuracy
        """
        stats = evaluate(
            self.model,
            self.valid_ds,
            self.criterion,
            max_iter=self.cfg["valid_max_iter"],
            decode_method=self.cfg["decode_method"],
            beam_size=self.cfg["beam_size"],
        )
        # Keep logs compact
        pstats = {k: (float(v) if isinstance(v, (int, float)) else v) for k, v in stats.items()}
        print(f"[Eval @ step {self.global_step}] {pstats}")
        return stats

In [65]:
trainer = CRNNTrainer(train_config)
trainer.prepare()
trainer.restore_from_cfg()
trainer.fit(auto_prepare=False)

[Build] num_classes=63
[Sanity] input: (32, 32, 100, 1) -> logits: (24, 32, 63)



[Backbone] CNN


[Checkpoint] Restored from latest: checkpoints/ckpt-5000
[Checkpoint] Resuming from global_step: 5000

===== Epoch 1 =====


Evaluate:  75%|███████▌  | 15/20 [06:03<02:01, 24.21s/it]


[Step 5000] batch loss: 7.1063 | running avg loss: 7.1063


Evaluate: 100%|██████████| 20/20 [01:21<00:00,  4.10s/it]


[Eval @ step 5000] {'loss': 5.172192630767822, 'acc': 0.5255, 'wrong_cases': [('INTERROGATOR', 'INTERROCATOY'), ('Multi', 'Miuuler'), ('pissoir', 'Pissoir'), ('25TH', 'PSE'), ('disproofs', 'diswroafs'), ('Mcdaniel', 'Medantel'), ('plowing', 'Plowing'), ('THEREOF', 'THENEOF'), ('Stamen', 'Stamer'), ('PRESUMPTION', 'PAESUMPTHON'), ('lassoed', 'Lassoed'), ('tableware', 'TANLLWARY'), ('BERATED', 'BERFTED'), ('BOOTED', 'BooTeD'), ('smock', 'SMOCK'), ('harmonically', 'harmonieally'), ('passed', 'zussed'), ('JACOBS', 'VACOBS'), ('UPBRAIDED', 'UPERAIDED'), ('kerfuffles', 'Lerfuffles'), ('Hybridism', 'Myiriass'), ('ambassadorway', 'uesudowy'), ('Relinquish', 'Relinguish'), ('UNHURT', 'MTILY'), ('SEMIPRIVATE', 'SEmIPRIWATE'), ('gazillions', 'quiillions'), ('VIBRATES', 'VERATES'), ('PORTRAYING', 'PORTRAVING'), ('legalism', 'tigatiene'), ('Gussying', 'Gussviny'), ('Thrombosis', 'Thrombosts'), ('fillet', 'fiiler'), ('Internalization', 'Intermaization'), ('playoffs', 'Playoffs'), ('Ebert', 'Fbert'),

Evaluate: 100%|██████████| 20/20 [01:21<00:00,  4.10s/it]


[Eval @ step 5500] {'loss': 4.864372613525391, 'acc': 0.5656, 'wrong_cases': [('INTERROGATOR', 'INTERROCATOP'), ('Multi', 'Niuler'), ('25TH', 'DSTEL'), ('Mcdaniel', 'Modantel'), ('plowing', 'Plowing'), ('THEREOF', 'THENEOE'), ('Palpable', 'Pelpable'), ('PRESUMPTION', 'PRESUNPTHON'), ('lassoed', 'Lassoed'), ('Vazquez', 'Vazquer'), ('tableware', 'TARLSVArY'), ('BERATED', 'BERFITED'), ('smock', 'SMOCK'), ('harmonically', 'harmonially'), ('UPBRAIDED', 'UPERAIDED'), ('kerfuffles', 'lerfrffles'), ('Hybridism', 'MyDriaiSD'), ('ambassadorway', 'ULRaLovEY'), ('UNHURT', 'UENINY'), ('SEMIPRIVATE', 'SEIMIPRIVATE'), ('gazillions', 'padillions'), ('VIBRATES', 'VERATES'), ('TANTALIZERS', 'INNITALIZERS'), ('napoleonic', 'mapoleonic'), ('legalism', 'tegaliene'), ('Gussying', 'Guassviny'), ('Thrombosis', 'Thrombosts'), ('fillet', 'fiiler'), ('Internalization', 'Interalitation'), ('Penuche', 'Denuche'), ('Ebert', 'Fbert'), ('DISMISSIVE', 'DiSMTISSTVE'), ('Millrace', 'WAOUrALe'), ('pasteurizers', 'pacteur

Evaluate: 100%|██████████| 20/20 [01:11<00:00,  3.56s/it]


[Eval @ step 6000] {'loss': 4.67182710647583, 'acc': 0.5669, 'wrong_cases': [('INTERROGATOR', 'INTERROGAION'), ('Multi', 'Muktt'), ('25TH', 'BSTE'), ('Mcdaniel', 'Modaniel'), ('plowing', 'Plowing'), ('THEREOF', 'THEREOT'), ('tableware', 'TABLSYARS'), ('BOOTED', 'BOOTeD'), ('smock', 'SMOCK'), ('harmonically', 'harmonieally'), ('passed', 'zassed'), ('JACOBS', 'JacoBs'), ('kerfuffles', 'ferfuffles'), ('Hybridism', 'Hybpidiss'), ('ambassadorway', 'absneudorg'), ('Relinquish', 'Relinguish'), ('UNHURT', 'LNLY'), ('gazillions', 'gagillions'), ('VIBRATES', 'VBRATES'), ('TANTALIZERS', 'TAMDALIZERS'), ('legalism', 'tagahisns'), ('Gussying', 'Gusszing'), ('ISUZU', 'ISUZ'), ('MISSTEPS', 'MISSTepS'), ('Thrombosis', 'Thrombosls'), ('fillet', 'filler'), ('Internalization', 'Internaliation'), ('Penuche', 'Benuche'), ('DISMISSIVE', 'DISISSIYE'), ('Millrace', 'Hsbirbbs'), ('wand', 'ccey'), ('redeems', 'redeens'), ('Subcontracting', 'Sobemtraching'), ('FURROW', 'Cuno'), ('Metamorphoses', 'Metamorphoss'),

Evaluate: 100%|██████████| 20/20 [01:05<00:00,  3.29s/it]


[Eval @ step 6500] {'loss': 4.650373094940186, 'acc': 0.5722, 'wrong_cases': [('INTERROGATOR', 'INTERROGATON'), ('Multi', 'sukt'), ('25TH', 'aSIE'), ('Mcdaniel', 'Modanial'), ('plowing', 'Plowing'), ('THEREOF', 'THENEOF'), ('Stamen', 'Stamenn'), ('PRESUMPTION', 'FNESUMTTHON'), ('tableware', 'TABLsYArS'), ('Cobwebby', 'cobwebby'), ('BOOTED', 'BooTed'), ('smock', 'SMOCK'), ('thinness', 'tninness'), ('passed', 'passeL'), ('kerfuffles', 'Lerfuffles'), ('Dvd', 'Ivd'), ('Hybridism', 'Mybriaism'), ('ambassadorway', 'cnlenudorer'), ('Relinquish', 'Relinguish'), ('UNHURT', 'WENUAS'), ('SEMIPRIVATE', 'SEMUIPRIVATE'), ('gazillions', 'pusillions'), ('VIBRATES', 'VBRATIES'), ('TANTALIZERS', 'TAMTALIZERS'), ('legalism', 'lagalism'), ('fillet', 'filler'), ('Internalization', 'Intermelization'), ('Ebert', 'Fbert'), ('Millrace', 'Msdirabe'), ('pasteurizers', 'pastousizere'), ('wand', 'aued'), ('almshouse', 'slmshouse'), ('redeems', 'redeenas'), ('Subcontracting', 'Suleantrachng'), ('FURROW', 'RUnnous')

Evaluate: 100%|██████████| 20/20 [01:04<00:00,  3.24s/it]


[Eval @ step 7000] {'loss': 4.637065321350097, 'acc': 0.5672, 'wrong_cases': [('INTERROGATOR', 'INTERROGATON'), ('Multi', 'Niuket'), ('25TH', 'OSH'), ('disproofs', 'Aisproofs'), ('Mcdaniel', 'Modantal'), ('plowing', 'Plowing'), ('THEREOF', 'THEREDE'), ('PRESUMPTION', 'FRESUMPTHON'), ('tableware', 'TARLenAre'), ('BERATED', 'BERAITED'), ('Cobwebby', 'cobwvebby'), ('BOOTED', 'BooTeD'), ('smock', 'SMOCK'), ('thinness', 'Ehtnness'), ('harmonically', 'harmonielly'), ('passed', 'eassed'), ('JACOBS', 'JAcOBS'), ('kerfuffles', 'kerfrffles'), ('Hybridism', 'Hybridiess'), ('ambassadorway', 'clereudorey'), ('UNHURT', 'RENRAE'), ('gazillions', 'gagillions'), ('VIBRATES', 'VBRATES'), ('CANTILEVERS', 'GANTILEVERS'), ('legalism', 'lagaliens'), ('Gussying', 'Gussving'), ('Thrombosis', 'Thrombosts'), ('fillet', 'fEilet'), ('Internalization', 'Intermalization'), ('DISMISSIVE', 'DISMISSIME'), ('Millrace', 'MsdIrALs'), ('pasteurizers', 'pastcusizers'), ('wand', 'wed'), ('Subcontracting', 'Sobermtrachng'), 

Evaluate: 100%|██████████| 20/20 [01:34<00:00,  4.72s/it]


[Eval @ step 7500] {'loss': 4.689970115661621, 'acc': 0.5615, 'wrong_cases': [('INTERROGATOR', 'INTERROCATON'), ('Multi', 'Niult'), ('25TH', '33B'), ('Mcdaniel', 'Moclantel'), ('plowing', 'Plowing'), ('THEREOF', 'THEREOY'), ('PRESUMPTION', 'PRCSUMPTION'), ('persian', 'Persian'), ('tableware', 'TAaLeWArY'), ('BERATED', 'BERAITED'), ('BOOTED', 'BOoTeD'), ('smock', 'SMOCK'), ('harmonically', 'harmonieally'), ('passed', 'zussed'), ('UPBRAIDED', 'UPERAIDED'), ('kerfuffles', 'ferfuffles'), ('Hybridism', 'Mybridlem'), ('ambassadorway', 'cnadind'), ('UNHURT', 'WRILL'), ('SEMIPRIVATE', 'SEmIPRIWATE'), ('gazillions', 'gagieltens'), ('VIBRATES', 'VERATES'), ('legalism', 'tigaliem'), ('Gussying', 'Gussving'), ('ISUZU', 'ISUzU'), ('Thrombosis', 'Thrombosts'), ('fillet', 'filles'), ('Ebert', 'Fbert'), ('DISMISSIVE', 'DISMTSSTVE'), ('Millrace', 'MAOrALS'), ('pasteurizers', 'peeteurleers'), ('wand', 'aved'), ('ruthenium', 'ruthentum'), ('Subcontracting', 'Sutentrachng'), ('FURROW', 'Curow'), ('Metamor

Evaluate: 100%|██████████| 20/20 [01:17<00:00,  3.89s/it]


[Eval @ step 8000] {'loss': 4.314776892089844, 'acc': 0.612, 'wrong_cases': [('INTERROGATOR', 'INTERROGATON'), ('Multi', 'Miultr'), ('25TH', 'DSTA'), ('Mcdaniel', 'Modanial'), ('plowing', 'Plowing'), ('THEREOF', 'THEREOR'), ('Stamen', 'Stamet'), ('PRESUMPTION', 'PRCSUMPTION'), ('tableware', 'TARLENAre'), ('Cobwebby', 'cobwebby'), ('smock', 'SMOCK'), ('passed', 'wussed'), ('kerfuffles', 'kerfrffles'), ('Hybridism', 'Myiriaism'), ('ambassadorway', 'cndenedower'), ('UNHURT', 'MINUIT'), ('gazillions', 'pajillions'), ('VIBRATES', 'VIBRATIES'), ('PORTRAYING', 'PORTRAVING'), ('legalism', 'tegalisne'), ('Gussying', 'Gusswing'), ('Thrombosis', 'Thrombosts'), ('fillet', 'filler'), ('Ebert', 'Fbert'), ('DISMISSIVE', 'DISMTSSIVE'), ('Millrace', 'MsdIribe'), ('pasteurizers', 'pacteurizers'), ('wand', 'cvnd'), ('ruthenium', 'ruthentum'), ('almshouse', 'aimshouse'), ('redeems', 'redeens'), ('Subcontracting', 'Sulerntracting'), ('FURROW', 'GURROD'), ('Metamorphoses', 'Metamoroioses'), ('PALMETTOS', 'C

Evaluate: 100%|██████████| 20/20 [01:03<00:00,  3.19s/it]


[Eval @ step 8500] {'loss': 4.159563464355469, 'acc': 0.6137, 'wrong_cases': [('INTERROGATOR', 'INTERROGAIOR'), ('Multi', 'Multt'), ('25TH', 'DSIH'), ('Mcdaniel', 'Modaniel'), ('plowing', 'Plowing'), ('THEREOF', 'THEREO'), ('Stamen', 'Stamer'), ('lassoed', 'Lassoed'), ('tableware', 'TABLEMARE'), ('BOOTED', 'BOOTeD'), ('smock', 'SMOCK'), ('harmonically', 'harmonieally'), ('kerfuffles', 'lerfuffles'), ('Hybridism', 'Hybriaism'), ('ambassadorway', 'cemedorey'), ('UNHURT', '4IN44S'), ('SEMIPRIVATE', 'SEMUPRIATE'), ('gazillions', 'gagillions'), ('VIBRATES', 'VERATES'), ('PORTRAYING', 'PORTRIVING'), ('Dulling', 'Oulling'), ('CANTILEVERS', 'GANTILEVERS'), ('legalism', 'Lagalisn'), ('ISUZU', 'IsUZU'), ('fillet', 'filles'), ('Penuche', 'Jenuche'), ('DISMISSIVE', 'DISMUSSIVE'), ('Millrace', 'Hsdbrils'), ('pasteurizers', 'pastourlzers'), ('PREDILECTION', 'PREDLECTION'), ('wand', 'ced'), ('Subcontracting', 'Soserntracting'), ('FURROW', 'Rumnos'), ('Metamorphoses', 'Metamorptoses'), ('PALMETTOS', '

Evaluate: 100%|██████████| 20/20 [01:04<00:00,  3.24s/it]


[Eval @ step 9000] {'loss': 4.12561001739502, 'acc': 0.6149, 'wrong_cases': [('Multi', 'Miulr'), ('25TH', 'DSIE'), ('Mcdaniel', 'Modanied'), ('plowing', 'Plowing'), ('tableware', 'TABLEVARE'), ('BOOTED', 'BOoTED'), ('smock', 'SMOCK'), ('kerfuffles', 'lerfuffles'), ('Dvd', 'bvd'), ('Hybridism', 'Hsbriaism'), ('ambassadorway', 'conudorer'), ('Relinquish', 'Relinguish'), ('UNHURT', 'UINUAI'), ('gazillions', 'jagillions'), ('VIBRATES', 'VBRATES'), ('Ineradicable', 'Incradicable'), ('legalism', 'lagalism'), ('fillet', 'filler'), ('Ebert', 'Fbert'), ('DISMISSIVE', 'DISMTISSIVE'), ('Millrace', 'Msiimale'), ('pasteurizers', 'pactourizers'), ('wand', 'cool'), ('almshouse', 'aimshouse'), ('Subcontracting', 'Suseamtrachng'), ('FURROW', 'Pumbos'), ('Metamorphoses', 'Metamoroloses'), ('PALMETTOS', 'LACMEMIOS'), ('RETIREES', 'RRIETURBES'), ('LANDWARD', 'HANDWaRD'), ('DVISORY', 'DUISORY'), ('Amphibian', 'Amphibion'), ('RETRIEVER', 'RRETRIEVER'), ('Khoisan', 'Knoisan'), ('Cupcake', 'Gupcake'), ('Mahar

Evaluate: 100%|██████████| 20/20 [01:02<00:00,  3.15s/it]


[Eval @ step 9500] {'loss': 4.067565302276611, 'acc': 0.6141, 'wrong_cases': [('INTERROGATOR', 'INTERROGATOHR'), ('Multi', 'Muker'), ('25TH', '1SH'), ('disproofs', 'disproafs'), ('Mcdaniel', 'Medianiel'), ('plowing', 'Plowing'), ('THEREOF', 'THEREOE'), ('PRESUMPTION', 'FRESUMPTIOH'), ('tableware', 'TARLEvARE'), ('Cobwebby', 'cobwebby'), ('smock', 'SMOCK'), ('Properly', 'Praperly'), ('passed', 'rassed'), ('JACOBS', 'Jacobs'), ('kerfuffles', 'kerfnffles'), ('Hybridism', 'Hybridisins'), ('ambassadorway', 'cnenedowey'), ('Relinquish', 'Relinguish'), ('UNHURT', 'HINIAY'), ('gazillions', 'gagillions'), ('VIBRATES', 'VBRATES'), ('CANTILEVERS', 'GANTILEVERS'), ('legalism', 'legalisn'), ('fillet', 'filler'), ('Millrace', 'MsdIrAls'), ('wand', 'Cced'), ('Subcontracting', 'Suderntraching'), ('FURROW', 'Putrows'), ('Metamorphoses', 'Metamorpioses'), ('PALMETTOS', 'ACMETBOS'), ('RETIREES', 'RETURBES'), ('LANDWARD', 'IANDWARd'), ('DVISORY', 'DUISORY'), ('Amphibian', 'Anchibien'), ('Khoisan', 'Knoisa

In [64]:
new_config = train_config.copy()
new_config['checkpoints_dir'] = 'checkpoints/'
new_trainer = CRNNTrainer(new_config)
new_trainer.prepare()
new_trainer.restore_from_cfg()


[Build] num_classes=63
[Sanity] input: (32, 32, 100, 1) -> logits: (24, 32, 63)



[Backbone] CNN


[Checkpoint] Restored from latest: checkpoints/ckpt-5000
[Checkpoint] Resuming from global_step: 5000


True

# Evaluation on test set

In [66]:
crnn_test = new_trainer.model

test_dataset = Synth90kDataset(root_dir=config['data_dir'], mode=DatasetMode.TEST,
                                    img_height=config["img_height"], img_width=config["img_width"])

test_ds = test_dataset.create_tf_dataset(batch_size=evaluate_config["eval_batch_size"], training=False)

criterion = CTCLoss()

prediction = evaluate(
                      crnn_test,
                      test_ds,
                      criterion,
                      train_config['valid_max_iter'],
                      evaluate_config['decode_method'],
                      evaluate_config['beam_size'])


Evaluate: 100%|██████████| 20/20 [01:06<00:00,  3.32s/it]


In [67]:
prediction['wrong_cases'][:10]

[('DECRIES', 'DELRIES'),
 ('Logier', 'agie'),
 ('modeled', 'modaled'),
 ('Faintly', 'Tuinaly'),
 ('UNIMPOSING', 'UNMPOSING'),
 ('aliens', 'ALIENS'),
 ('Exponents', 'Expunents'),
 ('Mollycoddle', 'Mallycoddle'),
 ('dave', 'davee'),
 ('Pettifog', 'Pettifos')]

In [68]:
prediction['acc']

0.5173

In [72]:
prediction['loss']

5.171909332275391

# Clustering

## Logits extraction on test set

In [79]:
def extract_test_logits(model, test_ds):
    all_Z = []
    all_y = []

    for batch in test_ds:
        # batch may be (images) or (images, targets, target_lengths)
        images = batch[0] if isinstance(batch, (tuple, list)) else batch

        # Model outputs time-major: (T, B, C)
        logits_TBC = model(images, training=False)
        logits_BTC = tf.transpose(logits_TBC, [1, 0, 2])  # (B, T, C)

        # Greedy labels per step from logits (no softmax)
        greedy_B_T = tf.argmax(logits_BTC, axis=-1).numpy()  #(B, T)

        # Flatten across batch/time
        Z_bt = logits_BTC.numpy().reshape(-1, logits_BTC.shape[-1])  # (B*T, C)
        y_bt = greedy_B_T.reshape(-1)  # (B*T,)

        all_Z.append(Z_bt)
        all_y.append(y_bt)

    Z_flat = np.vstack(all_Z) #(M, C)
    y_flat = np.concatenate(all_y)# (M,)
    return Z_flat, y_flat

# run
Z_flat, y_flat = extract_test_logits(trainer.model, test_ds)

## Standardization and Clustering on Logits.

In [78]:
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans

def standardize_and_cluster(Z_flat, K=63, sample_size=50000, seed=42):
    scaler = StandardScaler()
    Z_std = scaler.fit_transform(Z_flat)

    if Z_std.shape[0] > sample_size:
        idx = np.random.default_rng(seed).choice(Z_std.shape[0], sample_size, replace=False)
        Z_used = Z_std[idx]
        y_used = y_flat[idx]
    else:
        Z_used = Z_std
        y_used = y_flat

    km = KMeans(n_clusters=K, n_init=10, random_state=seed)
    cl = km.fit_predict(Z_used)
    return Z_used, y_used, cl, km

Z_used, y_used, cluster_labels, km = standardize_and_cluster(Z_flat, K=63)

## Silhouette and Purity on Test Logits

In [77]:
from sklearn.metrics import silhouette_score

def silhouette_and_purity(Z_used, cluster_labels, y_used, K=63):
    sil = silhouette_score(Z_used, cluster_labels)

    purity_sum = 0
    M = len(y_used)
    for k in range(K):
        mask = (cluster_labels == k)
        if not np.any(mask):
            continue
        vals, counts = np.unique(y_used[mask], return_counts=True)
        purity_sum += counts.max()
    purity = purity_sum / M
    return sil, purity

sil, purity = silhouette_and_purity(Z_used, cluster_labels, y_used, K=63)
print(f"Silhouette: {sil:.4f}, Purity: {purity:.4f}")

Silhouette: 0.1377, Purity: 0.8262
