# SECTION 1: INSTALL PACKAGES AND IMPORTS

## Installations

In [None]:
!pip install python-bidi arabic-reshaper
!pip install python-Levenshtein

## Importing Pachages

In [None]:
import os
import cv2
import re
import math
import time
import numpy as np
import random
import matplotlib.pyplot as plt
from IPython.display import display
from tqdm.auto import tqdm
from PIL import Image
from threading import local
from typing import Optional, Tuple, Dict
from __future__ import annotations



# Deep learning imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CTCLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.transforms as T
# from torch.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torchinfo import summary


import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
import multiprocessing as mp

import editdistance
import Levenshtein



# Import libraries for handling Arabic text
from bidi.algorithm import get_display
import arabic_reshaper

# SECTION 2: CONFIGURATION DICTIONARY

In [None]:
config = {
    # Dataset parameters
    'on_colab': False,   # else Kaggle
    'use_drive': False, # else Kaggle
    # 'dataset_dir': '/content/Arabic_English_OCR_Dataset', # Colab - OLD DATASET
    # 'dataset_dir': '/kaggle/input/arabic-english-ocr-dataset/Arabic_English_OCR_Dataset', # Kaggle - OLD DATASET
    # 'dataset_dir': '/content/output', # Colab - NEW DATASET
    'dataset_dir': '/kaggle/input/arabic-english-ocr-synthatic-dataset-v2/output', # Kaggle - NEW DATASET
    'ar_dir': 'ar',
    'en_dir': 'en',
    'mixed_dir': 'mixed', # New directory for mixed language examples
    'labels_file': 'labels.txt',
    'max_samples': 100000,  # Total desired samples (e.g., 50k from each dir)
    'max_text_length': 32,              # Maximum text sequence length
    'permissible_chars': set(
                    " !\"#$%&'()*+,-./:;=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]_`abcdefghijklmnopqrstuvwxyz{|}،؛؟٫٬٭"
                    "0123456789ءآأؤإئابةتثجحخدذرزسشصضطظعغفقكلمنهوىيًٌٍَُِّْ٠١٢٣٤٥٦٧٨٩"
                    "‘’“”"
                ),
    'vocab_size': 160,                      # Will be updated after dataset analysis
    'train_split_percentage': 0.95,         # Percentage for training set 0.95
    'val_split_percentage': 0.025,          # Percentage for validation set 0.025
    'test_split_percentage': 0.025,         # Percentage for test set 0.025
    'remove_bad_examples': True,            # remove examples with not permissible chars, else remove these characters instead in the text
    'do_analysis': True,                   # To do Analysis on the dataset (text --> lengths & chars, images --> shape)

    # Image parameters - SVTR standard
    'img_height': 64,  # SVTR uses 64 height
    'img_width': 256,  # SVTR uses 256 width, (mean width of the dataset 512) (256,384,512)
    'channels': 3,

    # Model Parameters - SVTR Large
    'embed_dims': [128, 256, 384],
    'd3': 512,
    'heads': [4, 8, 12],                # heads chosen such that embed_dim / num_heads == 32 (nice head dim)
    'mlp_ratio': 2,
    'dropout_rate': 0.1,
    'n_points': 9,
    'offset_scale': 4.0,
    # ## SMALL ##
    'num_blocks': [3, 6, 3],
    'pattern': ['L'] * 6 + ['G'] * 6,
    'local_type': ['non_overlapping', 'non_overlapping', 'deformable'] * 2 + ['conv'] * 6 ,
    'window_sizes': [(7,11)] * 6 + [(3,3)] * 6,
    # ## LARGE ##
    # 'num_blocks': [3, 12, 3],
    # 'pattern': ['L'] * 9 + ['G'] * 9,
    # 'local_type': ['non_overlapping', 'non_overlapping', 'deformable'] * 2 + ['deformable'] * 3 + ['conv'] * 9 ,
    # 'window_sizes': [(7,11)] * 9 + [(3,3)] * 9,


    # Training parameters
    'num_epochs': 28,         
    'early_stopping_patience': 5,
    'learning_rate': 3e-4,
    'weight_decay': 1e-5,
    # TODO: implement warmup
    'warmup_epochs': 5,
    'gradient_clip': 1.0,

    # Augmentation parameters
    'augmentation_strength': 0.8, #(0 to 1)
    'aug_prob': 0.7,
    'rotation_limit': 5,
    'blur_limit': 3,
    'brightness_limit': 0.2,
    'contrast_limit': 0.2,
    # 'dataset_mean': [0.485, 0.456, 0.406],       #--> ImageNet Values
    # 'dataset_std': [0.229, 0.224, 0.225],        #--> ImageNet Values
    # 'dataset_mean': [0.615, 0.617, 0.616],       #--> ar,en dataset Values
    # 'dataset_std': [0.271, 0.276, 0.273],        #--> ar,en dataset Values
    'dataset_mean': [0.683, 0.691, 0.694],        #--> ar,en,mixed dataset Values 
    'dataset_std': [0.320, 0.309, 0.301],         #--> ar,en,mixed dataset


    # Other parameters
    'save_path_directory': './arabic_ocr_checkpoints',
    'load_model_path': "/kaggle/input/svtr_deformable_epoch_26/pytorch/default/1/best_SVTR_deformable_model.pth", # Old model path
    # 'load_model_path': None, # Set to None or update with a new path if you have one
    'exclude_head_on_load': False,
    'beam_size': 4,  # For beam search during inference
    'SEED': 42,
    'best_model_filename': 'best_SVTR_deformable_model.pth', # Added: Filename for the best model checkpoint
    'regular_checkpoint_frequency': 7,                       # Added: Frequency (in epochs) to save regular checkpoints


    # DataLoader parameters
    'dataloader_params': {
        'batch_size': 32, # (kaggle (LARGE 16, SMALE 32), colab (LARGE 4, SMALE 8)) Increased batch size for visualization
        'num_workers': 4 if torch.cuda.is_available() else 0,      # works on kaggle, not colab with gpu
        'pin_memory': torch.cuda.is_available(),
        'persistent_workers': torch.cuda.is_available()
    },
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

# SEED
# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# set_seed(config['SEED'])

# Set the start method to 'spawn' or 'forkserver', for num_workers error
# if torch.cuda.is_available():
#     mp.set_start_method('spawn', force=True)
# else:
#     mp.set_start_method('forkserver', force=True)

# Check GPU availability
print(f"Using device: {config['device']}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Available GPUs: {torch.cuda.device_count()}")

# SECTION 3: DOWNLOAD AND ANALYZE DATASET

## Download

In [None]:
if config['on_colab']:
    print("====================== On Colab Environment ======================")
    # Check if the dataset directory already exists
    if not os.path.exists(config['dataset_dir']) and not config['use_drive']:
        from google.colab import files
        files.upload()  # Choose the kaggle.json file you downloaded

        print("Download Dataset From Kaggle: ")
        # # Create kaggle directory and move the file
        !mkdir -p ~/.kaggle
        !cp kaggle.json ~/.kaggle/
        !chmod 600 ~/.kaggle/kaggle.json

        # # Download your dataset
        # !kaggle datasets download -d ahmedkamal75/arabic-english-ocr-dataset
        !kaggle datasets download -d ahmedkamal75/arabic-english-ocr-synthatic-dataset-v2

        # # Unzip the dataset (Kaggle will provide it as a zip)
        # !unzip arabic-english-ocr-dataset.zip
        !unzip arabic-english-ocr-synthatic-dataset-v2.zip

        # !rm -rf arabic-english-ocr-dataset.zip
        !rm -rf arabic-english-ocr-synthatic-dataset-v2.zip
        !rm -rf kaggle.json

        # config['dataset_dir'] = '/content/Arabic_English_OCR_Dataset'
        config['dataset_dir'] = '/content/output'

    elif config['use_drive']:
        print("load Dataset From Drive: ")
        from google.colab import drive
        drive.mount('/content/drive')

        # config['dataset_dir'] = '/content/drive/MyDrive/Arabic_English_OCR_Dataset'
        config['dataset_dir'] = '/content/drive/MyDrive/output'
    else:
        print("Dataset already exists, skipping download.")
else:
    print("====================== Load From Local ======================")
    # config['dataset_dir'] = '/kaggle/input/arabic-english-ocr-dataset/Arabic_English_OCR_Dataset'
    config['dataset_dir'] = '/kaggle/input/arabic-english-ocr-synthatic-dataset-v2/output'

## Load Text (labels) & Paths of Images & Visualize sample Before Preprocessig

In [None]:
ar_dir_path = os.path.join(config['dataset_dir'], config['ar_dir'])
en_dir_path = os.path.join(config['dataset_dir'], config['en_dir'])
mixed_dir_path = os.path.join(config['dataset_dir'], config['mixed_dir'])

ar_labels = []
ar_labels_file_path = os.path.join(ar_dir_path, config['labels_file'])
if os.path.exists(ar_labels_file_path):
    with open(ar_labels_file_path, 'r', encoding='utf8') as f:
        for line in f:
            try:
                image_name, text = line.strip().split(' ', 1)
                ar_labels.append((os.path.join(ar_dir_path, image_name), text))
            except ValueError:
                print(f"Skipping malformed line in {ar_labels_file_path}: {line.strip()}")
else:
    print(f"Warning: Arabic labels file not found at {ar_labels_file_path}")


en_labels = []
en_labels_file_path = os.path.join(en_dir_path, config['labels_file'])
if os.path.exists(en_labels_file_path):
    with open(en_labels_file_path, 'r', encoding='utf8') as f:
        for line in f:
            try:
                image_name, text = line.strip().split(' ', 1)
                en_labels.append((os.path.join(en_dir_path, image_name), text))
            except ValueError:
                 print(f"Skipping malformed line in {en_labels_file_path}: {line.strip()}")
else:
    print(f"Warning: English labels file not found at {en_labels_file_path}")


mixed_labels = []
mixed_labels_file_path = os.path.join(mixed_dir_path, config['labels_file'])
if os.path.exists(mixed_labels_file_path):
    with open(mixed_labels_file_path, 'r', encoding='utf8') as f:
        for line in f:
             try:
                image_name, text = line.strip().split(' ', 1)
                mixed_labels.append((os.path.join(mixed_dir_path, image_name), text))
             except ValueError:
                 print(f"Skipping malformed line in {mixed_labels_file_path}: {line.strip()}")
else:
    print(f"Warning: Mixed labels file not found at {mixed_labels_file_path}")


print("After loading: ")
print(f"Total number of examples in ar: {len(ar_labels)}")
print(f"Total number of examples in en: {len(en_labels)}")
print(f"Total number of examples in mixed: {len(mixed_labels)}")


# Filter labels to only include permissible characters
def filter_lebels(text):
    return ''.join(c for c in text if c in config['permissible_chars'])

def is_permissible_ar(text):
    return all(c in config['permissible_chars'] for c in text)


def is_permissible_en(text):
    # Define Arabic digits and characters
    arabic_digits = set("٠١٢٣٤٥٦٧٨٩")
    arabic_chars = set("ءآأؤإئابةتثجحخدذرزسشصضطظعغفقكلمنهوىيًٌٍَُِّْ")
    # Check if all characters are permissible AND none of them are Arabic digits or characters
    return all(c in config['permissible_chars'] for c in text) and not any(c in arabic_digits or c in arabic_chars for c in text)

def is_permissible_mixed(text):
    return is_permissible_ar(text) or is_permissible_en(text)

if config['remove_bad_examples']:
    ar_labels = [(name, text) for name, text in ar_labels if is_permissible_ar(text)]
    en_labels = [(name, text) for name, text in en_labels if is_permissible_en(text)]
    mixed_labels = [(name, text) for name, text in mixed_labels if is_permissible_mixed(text)]
else:
    ar_labels = [(name, filter_lebels(text)) for name, text in ar_labels]
    en_labels = [(name, filter_lebels(text)) for name, text in en_labels]
    mixed_labels = [(name, filter_lebels(text)) for name, text in mixed_labels]


print("After filtering: ")
print(f"Total number of examples in ar: {len(ar_labels)}")
print(f"Total number of examples in en: {len(en_labels)}")
print(f"Total number of examples in mixed: {len(mixed_labels)}")


# Take the chosen sample size from each category
# Distribute max_samples equally among ar, en, and mixed
samples_per_category = config['max_samples'] // 3

ar_labels = ar_labels[:min(samples_per_category, len(ar_labels))]
en_labels = en_labels[:min(samples_per_category, len(en_labels))]
mixed_labels = mixed_labels[:min(samples_per_category, len(mixed_labels))]

print("After taking the choosen sample size: ")
print(f"Total number of examples in ar: {len(ar_labels)}")
print(f"Total number of examples in en: {len(en_labels)}")
print(f"Total number of examples in mixed: {len(mixed_labels)}")


# Visualize a sample image and its label
def visualize(labels, lang='en', n=4):
    if labels:
        # Added check to ensure there are enough samples to visualize
        num_samples_to_viz = min(n, len(labels))
        sample = labels[:num_samples_to_viz]

        # Determine figure title based on language/type
        if lang == 'ar':
            fig_title = "Sample Arabic Images"
        elif lang == 'en':
            fig_title = "Sample English Images"
        elif lang == 'mixed':
             fig_title = "Sample Mixed Images"
        else:
             fig_title = "Sample Images"

        # Determine grid size (e.g., 4 columns, rows based on num_samples)
        cols = 4
        rows = (num_samples_to_viz + cols - 1) // cols

        fig, axes = plt.subplots(rows, cols, figsize=(20, rows * 5))
        axes = axes.ravel() if isinstance(axes, np.ndarray) else [axes] # Handle single subplot case


        print(f"====================== {fig_title} ======================")
        for i, (sample_image_path, sample_label) in enumerate(sample):
            try:
                img = cv2.imread(sample_image_path)
                if img is not None:
                    # Convert BGR to RGB for displaying
                    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

                    if lang == 'ar' or lang == 'mixed':
                        # Reshape and reorder Arabic text for correct display
                        # Check if the text contains Arabic characters before reshaping/reordering
                        arabic_chars_set = set("ءآأؤإئابةتثجحخدذرزسشصضطظعغفقكلمنهوىيًٌٍَُِّْ")
                        if any(c in arabic_chars_set for c in sample_label):
                            reshaped_text = arabic_reshaper.reshape(sample_label)
                            display_label = get_display(reshaped_text)
                        else:
                            # If no Arabic characters, display as is (might be numbers/symbols/English in mixed)
                            display_label = sample_label
                    else: # English
                        display_label = sample_label

                    axes[i].imshow(img_rgb)
                    axes[i].set_title(f"Label: {display_label}", fontsize=10)
                    axes[i].axis('off') # Hide axes

                else:
                    print(f"Error loading image: {sample_image_path}")
                    # Hide subplot if image fails to load
                    axes[i].axis('off')

            except Exception as e:
                print(f"An error occurred while loading or displaying the image: {e}")
                 # Hide subplot if error occurs
                axes[i].axis('off')


        # Hide any unused subplots
        for j in range(num_samples_to_viz, len(axes)):
            axes[j].axis('off')

        plt.tight_layout()
        plt.show()
    else:
        print(f"No examples available for {lang} to visualize.")


print("====================== Arabic ======================")
visualize(ar_labels, lang='ar')
print("====================== English ======================")
visualize(en_labels, lang='en')
print("====================== Mixed ======================")
visualize(mixed_labels, lang='mixed')

## Analyze

In [None]:
# 1. Unified character set
if config['do_analysis']:
    # Calculate the unified character set
    all_chars = set()
    for _, text in ar_labels:
        all_chars.update(text)
    for _, text in en_labels:
        all_chars.update(text)
    for _, text in mixed_labels:
        all_chars.update(text)

    unified_charset = sorted(list(all_chars))
    config['vocab_size'] = len(unified_charset) # This will be updated by TextProcessor

    print(f"Unified Character Set (Length): {len(unified_charset)}")
    print("Unified Character Set (Characters):", "".join(unified_charset)) # Commented out as it can be very long

In [None]:
if config['do_analysis']:
    # Analyze filtered dataset - Text length statistics
    ar_text_lengths_filtered = [len(text) for _, text in ar_labels]
    en_text_lengths_filtered = [len(text) for _, text in en_labels]
    mixed_text_lengths_filtered = [len(text) for _, text in mixed_labels]

    print("\nFiltered Arabic Text Lengths:")
    print(f"  Max: {max(ar_text_lengths_filtered) if ar_text_lengths_filtered else 'N/A'}")
    print(f"  Mean: {np.mean(ar_text_lengths_filtered):.2f}")
    print(f"  Min: {min(ar_text_lengths_filtered) if ar_text_lengths_filtered else 'N/A'}")

    print("\nFiltered English Text Lengths:")
    print(f"  Max: {max(en_text_lengths_filtered) if en_text_lengths_filtered else 'N/A'}")
    print(f"  Mean: {np.mean(en_text_lengths_filtered):.2f}")
    print(f"  Min: {min(en_text_lengths_filtered) if en_text_lengths_filtered else 'N/A'}")

    # New: Print mixed text lengths
    print("\nFiltered Mixed Text Lengths:")
    print(f"  Max: {max(mixed_text_lengths_filtered) if mixed_text_lengths_filtered else 'N/A'}")
    print(f"  Mean: {np.mean(mixed_text_lengths_filtered):.2f}")
    print(f"  Min: {min(mixed_text_lengths_filtered) if mixed_text_lengths_filtered else 'N/A'}")

In [None]:
if config['do_analysis']:
    # Analyze filtered dataset - Image shape statistics
    ar_image_shapes_filtered = []
    for image_path, _ in ar_labels:
        try:
            img = cv2.imread(image_path)
            if img is not None:
                ar_image_shapes_filtered.append(img.shape)
        except Exception as e:
            print(f"Error loading image for shape analysis: {image_path} - {e}")

    en_image_shapes_filtered = []
    for image_path, _ in en_labels:
        try:
            img = cv2.imread(image_path)
            if img is not None:
                en_image_shapes_filtered.append(img.shape)
        except Exception as e:
            print(f"Error loading image for shape analysis: {image_path} - {e}")

    mixed_image_shapes_filtered = []
    for image_path, _ in mixed_labels:
        try:
            img = cv2.imread(image_path)
            if img is not None:
                mixed_image_shapes_filtered.append(img.shape)
        except Exception as e:
            print(f"Error loading image for shape analysis: {image_path} - {e}")


    # Assuming all images are color (height, width, channels)
    ar_heights_filtered = [shape[0] for shape in ar_image_shapes_filtered]
    ar_widths_filtered = [shape[1] for shape in ar_image_shapes_filtered]
    en_heights_filtered = [shape[0] for shape in en_image_shapes_filtered]
    en_widths_filtered = [shape[1] for shape in en_image_shapes_filtered]
    mixed_heights_filtered = [shape[0] for shape in mixed_image_shapes_filtered]
    mixed_widths_filtered = [shape[1] for shape in mixed_image_shapes_filtered]


    print("\nFiltered Arabic Image Shapes (Height, Width):")
    print(f"  Max Height: {max(ar_heights_filtered) if ar_heights_filtered else 'N/A'}, Max Width: {max(ar_widths_filtered) if ar_widths_filtered else 'N/A'}")
    print(f"  Mean Height: {np.mean(ar_heights_filtered):.2f} , Mean Width: {np.mean(ar_widths_filtered):.2f}")
    print(f"  Min Height: {min(ar_heights_filtered) if ar_heights_filtered else 'N/A'}, Min Width: {min(ar_widths_filtered) if ar_widths_filtered else 'N/A'}")

    print("\nFiltered English Image Shapes (Height, Width):")
    print(f"  Max Height: {max(en_heights_filtered) if en_heights_filtered else 'N/A'}, Max Width: {max(en_widths_filtered) if en_widths_filtered else 'N/A'}")
    print(f"  Mean Height: {np.mean(en_heights_filtered):.2f}, Mean Width: {np.mean(en_widths_filtered):.2f}")
    print(f"  Min Height: {min(en_heights_filtered) if en_heights_filtered else 'N/A'}, Min Width: {min(en_widths_filtered) if en_widths_filtered else 'N/A'}")

    # New: Print mixed image shapes
    print("\nFiltered Mixed Image Shapes (Height, Width):")
    print(f"  Max Height: {max(mixed_heights_filtered) if mixed_heights_filtered else 'N/A'}, Max Width: {max(mixed_widths_filtered) if mixed_widths_filtered else 'N/A'}")
    print(f"  Mean Height: {np.mean(mixed_heights_filtered):.2f}, Mean Width: {np.mean(mixed_widths_filtered):.2f}")
    print(f"  Min Height: {min(mixed_heights_filtered) if mixed_heights_filtered else 'N/A'}, Min Width: {min(mixed_widths_filtered) if mixed_widths_filtered else 'N/A'}")

# SECTION 4: TEXT PREPROCESSING AND VOCABULARY CREATION


In [None]:
class TextProcessor:
    def __init__(self, unified_charset, max_length=64):
        self.max_length = max_length
        self.char2idx = {}
        self.idx2char = {}

        # Special tokens
        self.PAD_TOKEN = '<PAD>'
        self.START_TOKEN = '<START>'
        self.END_TOKEN = '<END>'
        self.UNK_TOKEN = '<UNK>'
        self.BLANK_TOKEN = '<BLANK>'

        # Build vocabulary using the provided unified_charset
        vocab = [self.PAD_TOKEN, self.START_TOKEN, self.END_TOKEN, self.UNK_TOKEN, self.BLANK_TOKEN] + sorted(list(unified_charset))

        # Create mappings
        self.char2idx = {char: idx for idx, char in enumerate(vocab)}
        self.idx2char = {idx: char for idx, char in enumerate(vocab)}
        self.vocab_size = len(vocab)

        print(f"Vocabulary size: {self.vocab_size}")
        # Print sample characters and special tokens for verification
        print(f"Characters: {sorted(list(unified_charset))[:]}") # Print only first 10 chars
        print(f"Special tokens and their indices: PAD={self.char2idx[self.PAD_TOKEN]}, START={self.char2idx[self.START_TOKEN]}, END={self.char2idx[self.END_TOKEN]}, UNK={self.char2idx[self.UNK_TOKEN]}, BLANK={self.char2idx[self.BLANK_TOKEN]}")


    def __len__(self):
        return self.vocab_size

    def encode_text(self, text):
        """Convert text to sequence of indices and pad/truncate"""
        if not isinstance(text, str):
            text = str(text)

        # For CTC targets, we only need the characters + UNK and then pad with PAD.
        sequence = [self.char2idx.get(char, self.char2idx.get(self.UNK_TOKEN)) for char in text]

        # Truncate or pad to max_length
        if len(sequence) > self.max_length:
             sequence = sequence[:self.max_length]
        elif len(sequence) < self.max_length:
            sequence.extend([self.char2idx.get(self.PAD_TOKEN)] * (self.max_length - len(sequence)))

        return sequence

    def decode_sequence(self, sequence):
        """Convert sequence of indices back to text"""
        text = ""
        for idx in sequence:
            char = self.idx2char.get(idx, self.UNK_TOKEN)
            # Exclude PAD, START, END, UNK, and BLANK tokens from the decoded text
            if char not in [self.PAD_TOKEN, self.START_TOKEN, self.END_TOKEN, self.UNK_TOKEN, self.BLANK_TOKEN]:
                text += char
        return text

In [None]:
# Instantiate the TextProcessor
all_chars_for_processor = set()
for _, text in ar_labels:
    all_chars_for_processor.update(text)
for _, text in en_labels:
    all_chars_for_processor.update(text)
for _, text in mixed_labels:
     all_chars_for_processor.update(text)

unified_charset_for_processor = sorted(list(all_chars_for_processor))


text_processor = TextProcessor(unified_charset_for_processor, max_length=config['max_text_length'])
config['vocab_size'] = text_processor.vocab_size

# SECTION 5: DATASET CLASS AND DATALOADER

## Dataset Class

In [None]:
class OCRDataset(Dataset):
    def __init__(self, labels, text_processor, transform=None):
        self.labels = labels
        self.text_processor = text_processor
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_path, text = self.labels[idx]

        # Process image
        image = cv2.imread(image_path)
        if image is None:
            # Handle cases where image loading fails, e.g., return None or a placeholder
            print(f"Warning: Could not load image {image_path}. Skipping.")
            return None # Or handle as appropriate

        # Convert to RGB if needed
        if len(image.shape) == 3 and image.shape[2] == 4:  # RGBA
            image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
        elif len(image.shape) == 2:  # Grayscale
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

        # Apply transforms
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        # Process text
        encoded_text = self.text_processor.encode_text(text)

        return {
            'image': image,
            'text': text,
            'encoded_text': torch.tensor(encoded_text, dtype=torch.long),
            'text_length': len(text)
        }

## Augmentations

In [None]:
def get_train_transforms(dataset_mean, dataset_std, augmentation_strength=0.25):
    """
    Returns augmentation transforms for training based on augmentation_strength (0 to 1).
    augmentation_strength = 0 -> no augmentation (similar to validation transforms)
    augmentation_strength = 1 -> maximum configured augmentation
    """
    # Define maximum probabilities and limits for augmentation_strength = 1.0
    max_aug_prob = 0.7 # Max overall augmentation probability
    max_rotation_limit = 10 # Increased limit for stronger rotation at max strength
    max_blur_limit = 5 # Increased limit for stronger blur at max strength
    max_brightness_limit = 0.3 # Increased limit
    max_contrast_limit = 0.3 # Increased limit
    max_distortion_limit = 0.5 # Increased limit
    max_elastic_alpha = 10
    max_elastic_sigma = 50


    # Scale probabilities and limits linearly based on the augmentation strength
    aug_prob = max_aug_prob * augmentation_strength
    rotation_limit = max_rotation_limit * augmentation_strength
    blur_limit = int(max_blur_limit * augmentation_strength)
    brightness_limit = max_brightness_limit * augmentation_strength
    contrast_limit = max_contrast_limit * augmentation_strength
    distortion_limit = max_distortion_limit * augmentation_strength
    elastic_alpha = max_elastic_alpha * augmentation_strength
    elastic_sigma = max_elastic_sigma * augmentation_strength


    # Ensure blur_limit is at least 1 if greater than 0 after scaling
    blur_limit = max(1, blur_limit) if blur_limit > 0 else 0


    return A.Compose([
        A.Resize(config['img_height'], config['img_width']),
        A.OneOf([
            A.GaussianBlur(blur_limit=blur_limit, p=0.5 * augmentation_strength), # Scale internal probabilities as well
            A.MotionBlur(blur_limit=blur_limit, p=0.5 * augmentation_strength),
        ], p=aug_prob),
        A.RandomBrightnessContrast(
            brightness_limit=brightness_limit,
            contrast_limit=contrast_limit,
            p=aug_prob
        ),
        # Use linear interpolation and reflect border mode for transformations
        A.Rotate(limit=rotation_limit, p=aug_prob * 0.6, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT_101), # Rotate prob can be slightly lower
        A.OneOf([
            A.GridDistortion(num_steps=5, distort_limit=distortion_limit, p=0.5 * augmentation_strength, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT_101),
            A.ElasticTransform(alpha=elastic_alpha, sigma=elastic_sigma, p=0.8 * augmentation_strength, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT_101),
        ], p=aug_prob),
        A.Normalize(mean=dataset_mean, std=dataset_std),
        ToTensorV2(),
    ])

def get_val_transforms(dataset_mean, dataset_std, augmentation_strength=0.0):
    """
    Returns augmentation transforms for validation (typically no augmentation).
    augmentation_strength = 0 -> no augmentation
    """
    # For validation, strength 0 means no augmentations beyond resizing and normalization.
    # We still pass the strength parameter for consistency, but it effectively disables augs.
    return A.Compose([
        A.Resize(config['img_height'], config['img_width']),
        # No other augmentations are applied when strength is 0
        A.Normalize(mean=dataset_mean, std=dataset_std),
        ToTensorV2(),
    ])

## Get Mean & STD for Aug

In [None]:
def get_mean_std(labels):
    # Initialize variables for sum of pixel values and total pixel count
    sum_pixels = np.zeros(3, dtype=np.float64)
    total_pixels = 0

    # Iterate through the labels to calculate the sum of pixel values
    for image_path, _ in tqdm(labels, desc="Summing pixels"):
        try:
            image = cv2.imread(image_path)
            if image is None:
                print(f"Warning: Could not load image {image_path}. Skipping for mean/std calculation.")
                continue

            # Convert to RGB if needed
            if len(image.shape) == 3 and image.shape[2] == 4:  # RGBA
                image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
            elif len(image.shape) == 2:  # Grayscale
                image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

            # Ensure image is float64 for accurate summation
            image = image.astype(np.float64)

            sum_pixels += np.sum(image, axis=(0, 1))
            total_pixels += image.shape[0] * image.shape[1]
        except Exception as e:
            print(f"Error processing image {image_path} for mean/std calculation: {e}")


    # Calculate the mean for each channel
    mean_pixels = sum_pixels / total_pixels if total_pixels > 0 else np.zeros(3)

    # Initialize variables for sum of squared differences
    sum_sq_diff = np.zeros(3, dtype=np.float64)

    # Iterate through the labels again to calculate the sum of squared differences
    for image_path, _ in tqdm(labels, desc="Summing squared diffs"):
        try:
            image = cv2.imread(image_path)
            if image is None:
                continue # Already warned and skipped in the first pass

            # Convert to RGB if needed
            if len(image.shape) == 3 and image.shape[2] == 4:  # RGBA
                image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
            elif len(image.shape) == 2:  # Grayscale
                image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

            # Ensure image is float64 for accurate calculation
            image = image.astype(np.float64)

            # Calculate the squared difference from the mean
            sq_diff = (image - mean_pixels)**2
            sum_sq_diff += np.sum(sq_diff, axis=(0, 1))
        except Exception as e:
             print(f"Error processing image {image_path} for mean/std calculation (squared diffs): {e}")


    # Calculate the variance and standard deviation for each channel
    variance_pixels = sum_sq_diff / total_pixels if total_pixels > 0 else np.zeros(3)
    std_pixels = np.sqrt(variance_pixels)

    mean_pixels = mean_pixels / 255.0
    std_pixels = std_pixels / 255.0

    return mean_pixels.tolist(), std_pixels.tolist()

## Dataloader & Collate_fn

In [None]:
# Combine Arabic, English, and Mixed labels
all_labels = ar_labels + en_labels + mixed_labels

# Shuffle the combined dataset
random.shuffle(all_labels)


# Calculate sizes for train, val, and test
total_size = len(all_labels)
# Use the split percentages from the config
train_size = int(config['train_split_percentage'] * total_size)
val_size = int(config['val_split_percentage'] * total_size)
test_size = total_size - train_size - val_size # Ensure all samples are included


# First split: train and temp (val + test)
train_labels, temp_labels = train_test_split(
    all_labels,
    train_size=train_size,
    random_state=config['SEED'] # Use a fixed random state for reproducibility
)

# Second split: val and test from temp
val_labels, test_labels = train_test_split(
    temp_labels,
    train_size=(val_size / (val_size + test_size) if (val_size + test_size) > 0 else 0), # Ensure non-zero division
    random_state=config['SEED'] # Use a fixed random state for reproducibility
)


print(f"Total number of samples: {len(all_labels)}")
print(f"Training set size: {len(train_labels)}")
print(f"Validation set size: {len(val_labels)}")
print(f"Test set size: {len(test_labels)}")

def collate_fn(batch):
    # Filter out None values from the batch (in case image loading failed in the dataset)
    batch = [item for item in batch if item is not None]
    if not batch:
        return None # Return None if batch is empty after filtering

    # Use pad_sequence for images if they have variable widths/heights
    # However, the current PatchEmbed and SVTR architecture expects fixed input size,
    # so Resize in transforms should handle this. We can keep stack if Resize is applied.
    # If using variable size input later, this would need pad_sequence or custom padding.
    try:
      images = torch.stack([item['image'] for item in batch])
    except RuntimeError as e:
        if "stack expects each tensor to be of the same size" in str(e):
            print("Error: Images in the batch have inconsistent sizes after transformations.")
            print("This might be due to incorrect resizing or padding in the dataset or transforms.")
            # You might want to inspect batch[0]['image'].shape etc. for debugging
            raise e # Re-raise the error after printing info
        else:
             raise e


    texts = [item['text'] for item in batch]
    # encoded_texts are now padded by TextProcessor encode_text, so we can stack them.
    # Ensure encoded_texts are consistent length (config['max_text_length'])
    try:
        encoded_texts = torch.stack([item['encoded_text'] for item in batch]) # Stack already padded tensors
    except RuntimeError as e:
         if "stack expects each tensor to be of the same size" in str(e):
              print("Error: Encoded texts in the batch have inconsistent sizes after TextProcessor.")
              print("This should not happen if max_text_length is applied correctly in encode_text.")
              # You might want to inspect [item['encoded_text'].shape for item in batch] for debugging
              raise e # Re-raise the error after printing info
         else:
              raise e


    # text_lengths from item['text_length'] should be the ORIGINAL lengths, not padded length
    text_lengths = torch.tensor([item['text_length'] for item in batch], dtype=torch.long) # Use original text lengths


    return {
        'images': images,
        'texts': texts,
        'encoded_texts': encoded_texts, # Use padded sequences from TextProcessor
        'text_lengths': text_lengths # Use original text lengths
    }


# Calculate mean and std of the dataset for normalization
# Only calculate if dataset_mean is not already provided in config
if config['dataset_mean'] is None or config['dataset_std'] is None:
    print("\nCalculating dataset mean and standard deviation for normalization...")
    # Use the full combined dataset for calculation
    dataset_mean, dataset_std = get_mean_std(all_labels)
    config['dataset_mean'] = dataset_mean # Update config with calculated values
    config['dataset_std'] = dataset_std
    print("Calculation complete.")
    print(f"Dataset Mean: {dataset_mean}")
    print(f"Dataset Std: {dataset_std}")
else:
    dataset_mean = config['dataset_mean']
    dataset_std = config['dataset_std']
    print(f"\nUsing pre-configured Dataset Mean: {dataset_mean}")
    print(f"Using pre-configured Dataset Std: {dataset_std}")


# Create datasets
train_dataset = OCRDataset(
    train_labels,
    text_processor,
    transform=get_train_transforms(config['dataset_mean'], config['dataset_std'], config['augmentation_strength'])
)

val_dataset = OCRDataset(
    val_labels,
    text_processor,
    transform=get_val_transforms(config['dataset_mean'], config['dataset_std'], 0.0) # Validation uses strength 0
)

test_dataset = OCRDataset(
    test_labels,
    text_processor,
    transform=get_val_transforms(config['dataset_mean'], config['dataset_std'], 0.0) # Test uses strength 0
)

# DataLoader parameters from config
dataloader_params = config['dataloader_params']

# Create dataloaders
# Add error handling for DataLoader creation if num_workers causes issues
try:
    train_loader = DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_fn,
        **dataloader_params
    )

    val_loader = DataLoader(
        val_dataset,
        shuffle=False,
        collate_fn=collate_fn,
        **dataloader_params
    )

    test_loader = DataLoader(
        test_dataset,
        shuffle=False,
        collate_fn=collate_fn,
        **dataloader_params
    )
except RuntimeError as e:
     if "DataLoader worker (pid(s)" in str(e) and "KeyboardInterrupt" not in str(e):
         print("\nError: DataLoader failed, possibly due to num_workers issues in this environment.")
         print("Trying again with num_workers=0...")
         # Retry with num_workers=0
         dataloader_params_no_workers = config['dataloader_params'].copy()
         dataloader_params_no_workers['num_workers'] = 0
         if 'pin_memory' in dataloader_params_no_workers:
             dataloader_params_no_workers['pin_memory'] = False
         if 'persistent_workers' in dataloader_params_no_workers:
              dataloader_params_no_workers['persistent_workers'] = False

         train_loader = DataLoader(
            train_dataset,
            shuffle=True,
            collate_fn=collate_fn,
            **dataloader_params_no_workers
        )

         val_loader = DataLoader(
            val_dataset,
            shuffle=False,
            collate_fn=collate_fn,
            **dataloader_params_no_workers
        )

         test_loader = DataLoader(
            test_dataset,
            shuffle=False,
            collate_fn=collate_fn,
            **dataloader_params_no_workers
        )
         print("DataLoaders created with num_workers=0.")
     else:
         # Re-raise other RuntimeError
         raise e


print(f"Train batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

# SECTION 6: VISUALIZE SAMPLES FROM DATALOADER

In [None]:
def visualize_batch(dataloader, text_processor, num_samples=8):
    batch = next(iter(dataloader))

    # Ensure num_samples does not exceed the actual batch size
    actual_batch_size = batch['images'].size(0)
    num_samples_to_show = min(num_samples, actual_batch_size)

    images = batch['images'][:num_samples_to_show]
    texts = batch['texts'][:num_samples_to_show]

    # Determine grid size (e.g., 4 columns, rows based on num_samples)
    cols = 4
    rows = (num_samples_to_show + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(20, rows * 5))
    axes = axes.ravel()

    # Get the dataset mean and std from the config or a global variable if available
    # Assuming dataset_mean and dataset_std are available in the global scope or config
    # If not, you might need to pass them to this function or retrieve them differently
    if 'dataset_mean' in globals() and 'dataset_std' in globals():
        mean_val = np.array(dataset_mean)
        std_val = np.array(dataset_std)
    elif 'dataset_mean' in config and 'dataset_std' in config:
         mean_val = np.array(config['dataset_mean'])
         std_val = np.array(config['dataset_std'])
    else:
        # Fallback to default values if not found (though it's better to use calculated values)
        print("Warning: Dataset mean and std not found. Using default ImageNet values for visualization denormalization.")
        mean_val = np.array([0.485, 0.456, 0.406])
        std_val = np.array([0.229, 0.224, 0.225])


    for i in range(num_samples_to_show):
        image = images[i].permute(1, 2, 0).cpu().numpy()
        # Denormalize using the calculated mean and std
        image = image * std_val + mean_val
        image = np.clip(image, 0, 1)

        # Reshape and get display for Arabic text
        reshaped_text = arabic_reshaper.reshape(texts[i])
        display_text = get_display(reshaped_text)

        axes[i].imshow(image)
        axes[i].set_title(f"Text: {display_text}", fontsize=12)
        axes[i].axis('off')

    # Hide any unused subplots
    for j in range(num_samples_to_show, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

print("Visualizing training samples...")
visualize_batch(train_loader, text_processor, num_samples=8)

# SECTION 7: SVTR MODEL ARCHITECTURE

## Utilities: Window Partitioning / Reverse (B[N]CHW)

In [None]:
def window_partition_nchw(x, window_size):
    """
    Partition NCHW tensor into windows.
    Args:
      x: (B, C, H, W)
      window_size: (wh, ww)
    Returns:
      windows: (num_windows_total, wh*ww, C) i.e. (B * n_h, n_w, wh * ww, C) later the num_window_total will be treaded as
      normal batches and the attension will work on each window independently, and then recombined via reverse operation.
      (Hp, Wp, pad_h, pad_w, n_h, n_w)
    """
    B, C, H, W = x.shape
    wh, ww = window_size
    # Calculate padding size for height and width to be divisible by window size
    pad_h = (wh - H % wh) % wh
    pad_w = (ww - W % ww) % ww
    # Apply padding if needed
    if pad_h > 0 or pad_w > 0:
        # F.pad expects (left, right, top, bottom) for the last dimension,
        # then the second to last, and so on.
        # So for (B, C, H, W), padding is (pad_w_left, pad_w_right, pad_h_top, pad_h_bottom, pad_c_front, pad_c_back, pad_b_front, pad_b_back)
        # We only need to pad W and H, so it's (0, pad_w, 0, pad_h) applied to the last two spatial dimensions (W and H)
        x = F.pad(x, (0, pad_w, 0, pad_h))
    # Calculate padded height and width
    Hp, Wp = H + pad_h, W + pad_w
    # Calculate number of windows along height and width
    n_h = Hp // wh
    n_w = Wp // ww
    # reshape to (B, C, n_h, wh, n_w, ww) - arrange windows within the image
    x = x.view(B, C, n_h, wh, n_w, ww)
    # permute to (B, n_h, n_w, wh, ww, C) - move channels to the end
    x = x.permute(0, 2, 4, 3, 5, 1).contiguous()
    # reshape to (num_windows_total, wh*ww, C) - flatten windows and combine batch and window dimensions
    windows = x.view(-1, wh * ww, C)  # (B*n_h*n_w, wh*ww, C)
    # Return windows and information needed for reverse operation
    return windows, (Hp, Wp, pad_h, pad_w, n_h, n_w)

def window_reverse_nchw(windows, window_size, Hp, Wp, pad_h, pad_w, n_h, n_w, B):
    """
    Reverse windows back to NCHW tensor.
    Args:
      windows: (B*n_h*n_w, wh*ww, C) - flattened windows
      window_size: (wh, ww) - height and width of each window
      Hp, Wp: padded H and W - height and width after padding
      pad_h, pad_w: paddings - amount of padding applied to height and width
      n_h, n_w: #windows per dim - number of windows along height and width
      B: original batch size
    Returns:
      x: (B, C, H, W) restored (unpadded) tensor
    """
    wh, ww = window_size
    C = windows.shape[-1] # Get the number of channels from the windows tensor
    # Reshape the windows back to their spatial arrangement within each image
    x = windows.view(B, n_h, n_w, wh, ww, C)
    # Permute the dimensions back to the original NCHW format (B, C, H, W)
    # This reverses the permutation done in window_partition_nchw
    x = x.permute(0, 5, 1, 3, 2, 4).contiguous()  # (B, C, n_h, wh, n_w, ww)
    # Reshape to combine the window dimensions with the number of windows dimensions
    x = x.view(B, C, Hp, Wp) # (B, C, Hp, Wp) - padded tensor

    # Remove padding if it was applied
    if pad_h > 0:
        h = Hp - pad_h
    else:
        h = Hp
    if pad_w > 0:
        w = Wp - pad_w
    else:
        w = Wp
    # Slice the tensor to remove the padding and get the original height and width
    x = x[:, :, :h, :w].contiguous()
    return x

dummy_tensor_nchw = torch.randn(2, 64, 32, 32) # Example shape
window_size_nchw = (7, 7) # Example window size

print(f"Original NCHW tensor shape: {dummy_tensor_nchw.shape}")

# Apply window partition
windows_nchw, (Hp_nchw, Wp_nchw, pad_h_nchw, pad_w_nchw, n_h_nchw, n_w_nchw) = window_partition_nchw(dummy_tensor_nchw, window_size_nchw)

print(f"Output windows shape after partition: {windows_nchw.shape}")
print(f"Padding and window info: Hp={Hp_nchw}, Wp={Wp_nchw}, pad_h={pad_h_nchw}, pad_w={pad_w_nchw}, n_h={n_h_nchw}, n_w={n_w_nchw}")


# Apply window reverse
reversed_tensor_nchw = window_reverse_nchw(windows_nchw, window_size_nchw, Hp_nchw, Wp_nchw, pad_h_nchw, pad_w_nchw, n_h_nchw, n_w_nchw, dummy_tensor_nchw.shape[0])

print(f"Output NCHW tensor shape after reverse: {reversed_tensor_nchw.shape}")

# Verify that the reversed tensor is close to the original (before padding)
# We need to remove the padding from the original tensor if any was applied
original_unpadded = dummy_tensor_nchw[:, :, :dummy_tensor_nchw.shape[2], :dummy_tensor_nchw.shape[3]]

# Check if shapes match after reversing and removing padding from original
print(f"Shape of original unpadded tensor: {original_unpadded.shape}")
print(f"Shape of reversed tensor: {reversed_tensor_nchw.shape}")

# Check if values are close (due to potential floating point inaccuracies)
are_tensors_close = torch.allclose(original_unpadded, reversed_tensor_nchw, atol=1e-6)
print(f"Are the original unpadded and reversed tensors close? {are_tensors_close}")

# You can also print the tensors or their difference for detailed inspection if needed
print("Difference:", (original_unpadded - reversed_tensor_nchw).abs().max())

## Basic modules

In [None]:
class PatchEmbedSVTR(nn.Module):
    """Overlapping Patch Embedding (NCHW) similar to SVTR paper."""
    def __init__(self, img_size=(64, 256), in_chans=3, embed_dim=64):
        super().__init__()
        H, W = img_size
        self.img_size = img_size
        # two conv layers with stride 2 each -> downsample by (4,4)
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim // 2, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(embed_dim // 2),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(embed_dim),
            nn.ReLU(inplace=True)
        )
        # store patches resolution
        self.patches_resolution = (H // 4, W // 4)
        self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # x: (B, C, H, W)
        B, C, H, W = x.shape
        assert (H, W) == self.img_size, f"Input size {(H,W)} != expected {self.img_size}"
        x = self.conv1(x)
        x = self.conv2(x)
        # now x: (B, embed_dim, H/4, W/4)
        B, D, h, w = x.shape
        x = x.flatten(2).transpose(1, 2)  # (B, h*w, D)
        x = self.norm(x)
        return x  # (B, N, D)

class MLP(nn.Module):
    def __init__(self, in_dim, ratio=2.0, drop=0.0):
        super().__init__()
        hidden = int(in_dim * ratio)
        self.fc1 = nn.Linear(in_dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, in_dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return self.drop(x)

# Test the PatchEmbedSVTR
patch_embed = PatchEmbedSVTR(img_size=(config['img_height'], config['img_width']), in_chans=config['channels'], embed_dim=64)
dummy_input_patch = torch.randn(1, config['channels'], config['img_height'], config['img_width'])
output_patch = patch_embed(dummy_input_patch)

print(f"PatchEmbedSVTR Input shape: {dummy_input_patch.shape}")
print(f"PatchEmbedSVTR Output shape: {output_patch.shape}")

# Test the MLP
# The input to MLP should match the output dimension of the PatchEmbed,
# which is (B, num_patches, embed_dim)
dummy_input_mlp = torch.randn(output_patch.shape) # Using the output shape of PatchEmbed as input
mlp = MLP(in_dim=output_patch.shape[-1]) # Initialize MLP with the embedding dimension
output_mlp = mlp(dummy_input_mlp)

print(f"\nMLP Input shape: {dummy_input_mlp.shape}")
print(f"MLP Output shape: {output_mlp.shape}")

## Global mixing (Transformer-style)


In [None]:
class GlobalMixing(nn.Module):
    """Global mixing (Transformer-style) block."""
    def __init__(self, dim, num_heads, mlp_ratio=2.0, drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim) # Layer normalization before attention
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True, dropout=drop) # Multi-head self-attention
        self.norm2 = nn.LayerNorm(dim) # Layer normalization after attention and before MLP
        self.mlp = MLP(dim, ratio=mlp_ratio, drop=drop) # MLP block (Feed-forward network)

    def forward(self, x):
        # x: (B, N, D) where B is batch size, N is sequence length, D is dimension
        res = x # Residual connection
        x = self.norm1(x) # Apply LayerNorm Improved Training Stability
        # Apply Multi-head self-attention. Query, Key, and Value are all the same (self-attention).
        attn_out, _ = self.attn(x, x, x)
        x = res + attn_out # Add attention output to the residual connection

        res = x # Second residual connection
        x = self.norm2(x) # Apply LayerNorm
        x = res + self.mlp(x) # Add MLP output to the residual connection (including dropout within MLP)
        return x # Output tensor with the same shape as input (B, N, D)

# Test the GlobalMixing module
# The input to GlobalMixing should be in the shape (B, N, D)
# where D is the embedding dimension. This could be the output of PatchEmbed or previous blocks.
# Using a dummy input similar to the output of PatchEmbed for testing.
B, N, D = 2, 2048, 64 # Example Batch size, Sequence length (num_patches), Dimension
dummy_input_global = torch.randn(B, N, D)

# Instantiate the GlobalMixing module
num_heads = 8 # Example number of attention heads
global_mixing = GlobalMixing(dim=D, num_heads=num_heads)

# Pass the dummy input through the module
output_global = global_mixing(dummy_input_global)

print(f"GlobalMixing Input shape: {dummy_input_global.shape}")
print(f"GlobalMixing Output shape: {output_global.shape}")

# Verify that the output shape is the same as the input shape
assert dummy_input_global.shape == output_global.shape, "Output shape does not match input shape!"
print("Test passed: Output shape matches input shape.")

## Local mixing using Non-overlapping windows + attention


### Non-overlapping local attention

In [None]:
# Non-Overlapping Local Window Attention mixing
class NOLMWA(nn.Module):
    """
    Non-overlapping local window attention:
      - partition to windows
      - apply MultiheadAttention on each window independently by batching them
      - reverse
    Inputs to forward(): x is (B, N, D) with h,w passed separately.
    """
    def __init__(self, dim, num_heads, mlp_ratio=2.0, window_size=(7,11), drop=0.0):
        super().__init__()
        self.dim = dim # Input dimension
        self.num_heads = num_heads # Number of attention heads
        self.wh, self.ww = window_size # Window height and width
        self.norm1 = nn.LayerNorm(dim) # Layer normalization before attention
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True, dropout=drop) # Multi-head self-attention
        self.norm2 = nn.LayerNorm(dim) # Layer normalization after attention and before MLP
        self.mlp = MLP(dim, ratio=mlp_ratio, drop=drop) # MLP block (Feed-forward network)

    def forward(self, x, h, w, shift=(0,0)):
        # x: (B, N, D) where B is batch size, N is sequence length (h*w), D is dimension
        B, N, D = x.shape
        assert N == h * w, "N must equal h*w" # Assert that sequence length matches spatial dimensions

        res = x # Residual connection

        # Pre-normalization
        x = self.norm1(x)

        # to NCHW for windowing - reshape to (B, C, h, w)
        x2 = x.transpose(1, 2).reshape(B, D, h, w)

        # Partition into windows - windows shape (B*n_h*n_w, wh*ww, C)
        windows, (Hp, Wp, pad_h, pad_w, n_h, n_w) = window_partition_nchw(x2, (self.wh, self.ww))

        # Apply Multi-head self-attention on windows
        # windows shape is (num_windows_total, wh*ww, C) but C==D
        attn_out, _ = self.attn(windows, windows, windows)

        # Reverse window partitioning - reconstruct to (B, C, h, w)
        x2 = window_reverse_nchw(attn_out, (self.wh, self.ww), Hp, Wp, pad_h, pad_w, n_h, n_w, B)

        # Reshape back to (B, N, D)
        x = x2.view(B, D, h * w).transpose(1, 2).contiguous()

        # Add residual connection after attention
        x = res + x

        # Second residual connection
        res = x

        # Pre-normalization before MLP
        x = self.norm2(x)

        # Apply MLP and add residual connection
        x = res + self.mlp(x)

        return x # Output tensor with the same shape as input (B, N, D)

h, w = 32, 64 # Example spatial dimensions (after patch embedding)
dummy_input_local = torch.randn(2, h * w, 64)

output_local = NOLMWA(dim=64, num_heads=8, window_size=(7,11))(dummy_input_local, h, w)

print(f"LocalMixingWindow Input shape: {dummy_input_local.shape}")
print(f"LocalMixingWindow Output shape: {output_local.shape}")

assert dummy_input_local.shape == output_local.shape, "Output shape does not match input shape!"
print("Test passed: Output shape matches input shape.")

### Shifted Window Local Mixing Window Attention

In [None]:
class SWLMWA(nn.Module):
    """
    Local windowed attention with optional cyclic shift and per-window attention mask.
    Implements manual multi-head attention so we can apply a (num_windows, ws, ws) mask.
    """

    def __init__(self, dim, num_heads, mlp_ratio=2.0, window_size=(7,11), drop=0.0, shift_size=(0,0)):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        self.dim = dim
        self.num_heads = num_heads
        self.wh, self.ww = window_size
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.shift_size = shift_size if shift_size is not None else (0,0)

        # Layer norm before attention and QKV projection
        self.norm1 = nn.LayerNorm(dim)
        # QKV projection: linearly transform input to Query, Key, and Value
        # Output size is 3 * dim for Q, K, and V combined
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        # Output projection after attention
        self.proj = nn.Linear(dim, dim, bias=True)

        # Layer norm before MLP
        self.norm2 = nn.LayerNorm(dim)
        # MLP (Feed-forward network) - Reusing the MLP class defined earlier
        self.mlp = MLP(dim, ratio=mlp_ratio, drop=drop)

        # The attention mask will be generated in the forward pass based on input dimensions.
        # We can register an empty buffer or None here, or simply generate it directly in forward.
        # We will generate it directly in forward.


    def forward(self, x, h, w):
        """
        x: (B, N, D) with N == h*w
        h,w: spatial dims (tokens grid) before windowing
        """
        B, N, D = x.shape
        assert N == h * w, f"N({N}) must equal h({h})*w({w})"
        device = x.device
        res = x # Residual connection

        # pre-norm
        x = self.norm1(x)

        # to NCHW for windowing
        # Reshape to (B, C, H, W) where C=D, H=h, W=w
        x2 = x.transpose(1, 2).reshape(B, D, h, w)  # (B, D, H, W)

        # apply cyclic shift if requested (negative roll to shift content BEFORE partitioning,
        # which matches Swin's order)
        if self.shift_size[0] != 0 or self.shift_size[1] != 0:
            shift_h, shift_w = self.shift_size
            # Roll the tensor along height and width dimensions
            x2 = torch.roll(x2, shifts=(-shift_h, -shift_w), dims=(2, 3))

        # partition into windows
        # Output windows shape: (B * n_h * n_w, ws, C) where ws = wh*ww
        windows, (Hp, Wp, pad_h, pad_w, n_h, n_w) = window_partition_nchw(x2, (self.wh, self.ww))
        # windows: (num_windows_total, ws, C) where num_windows_total = B * n_h * n_w

        num_windows_total = windows.shape[0] # Total number of windows across the batch
        ws = windows.shape[1]  # seq length inside a window
        num_windows_per_image = n_h * n_w # Number of windows per image

        # Generate attention mask dynamically based on current input dimensions
        # Create a mask grid that labels each (Hp x Wp) token with its window index
        img_mask = torch.zeros((1, 1, Hp, Wp), device=device, dtype=torch.int32)
        cnt = 0
        for i in range(0, Hp, self.wh):
            for j in range(0, Wp, self.ww):
                img_mask[:, :, i:i + self.wh, j:j + self.ww] = cnt
                cnt += 1

        # Partition that mask the same way we partition x2 (but on a single image)
        # mask_windows shape: (1 * n_h * n_w, ws, 1)
        mask_windows, _ = window_partition_nchw(img_mask, (self.wh, self.ww))

        # Reshape mask_windows to (num_windows_per_image, ws)
        # Remove the singleton channel dimension from mask_windows
        mask_windows = mask_windows.squeeze(-1) # Shape becomes (num_windows_per_image, ws)


        # Create per-window boolean mask where True indicates "forbid attention"
        # (n_h * n_w, ws, ws) of booleans
        attn_mask = (mask_windows.unsqueeze(1) != mask_windows.unsqueeze(2))

        # Repeat the mask for the batch dimension
        # attn_mask shape: (n_h * n_w, ws, ws) -> repeat B times -> (B, n_h * n_w, ws, ws)
        repeated_attn_mask = attn_mask.unsqueeze(0).repeat(B, 1, 1, 1)
        # Reshape to match the batched windows dimension: (B * n_h * n_w, ws, ws)
        repeated_attn_mask = repeated_attn_mask.view(num_windows_total, ws, ws)


        # Manual multi-head attention over windows, with mask applied per-window
        # windows: (num_windows_total, ws, C)
        qkv = self.qkv(windows)  # (num_windows_total, ws, 3*D)
        # Split QKV into Query, Key, Value along the last dimension
        q, k, v = qkv.chunk(3, dim=-1)

        # reshape for heads: (num_windows_total, num_heads, ws, head_dim)
        def reshape_to_heads(t):
            # Reshape to (num_windows_total, ws, num_heads, head_dim)
            t = t.view(num_windows_total, ws, self.num_heads, self.head_dim)
            # Permute to (num_windows_total, num_heads, ws, head_dim) for batch matrix multiplication
            return t.permute(0, 2, 1, 3).contiguous()

        q = reshape_to_heads(q)
        k = reshape_to_heads(k)
        v = reshape_to_heads(v)

        # scaled dot-product: (num_windows_total, num_heads, ws, ws)
        # Calculate attention scores: Q @ K_transpose / sqrt(head_dim)
        attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # apply mask: set to -inf where forbidden
        # Use a large negative number for masked positions so that softmax makes them effectively zero
        inf_mask = torch.tensor(-1e9, device=device, dtype=attn.dtype)
        attn = attn.masked_fill(repeated_attn_mask.unsqueeze(1), inf_mask)

        # softmax & dropout (dropout omitted here but can be added)
        attn = torch.softmax(attn, dim=-1)
        # Optionally add dropout after softmax: attn = F.dropout(attn, p=drop, training=self.training)

        # attention output (num_windows_total, num_heads, ws, head_dim)
        # Compute the output by multiplying attention probabilities with Value tensor
        out = torch.matmul(attn, v)

        # merge heads -> (num_windows_total, ws, D)
        # Permute back to (num_windows_total, ws, num_heads, head_dim)
        out = out.permute(0, 2, 1, 3).contiguous()
        # Reshape to (num_windows_total, ws, D) by combining num_heads and head_dim
        out = out.view(num_windows_total, ws, D)
        # Apply output projection
        out = self.proj(out)

        # reverse windows -> reconstruct (B, C, Hp, Wp)
        # Use window_reverse_nchw to bring the windows back to the padded image shape
        x2 = window_reverse_nchw(out, (self.wh, self.ww), Hp, Wp, pad_h, pad_w, n_h, n_w, B)

        # reverse cyclic shift (if applied) by positive shift
        if self.shift_size[0] != 0 or self.shift_size[1] != 0:
            shift_h, shift_w = self.shift_size
            # Roll back the tensor to reverse the cyclic shift
            x2 = torch.roll(x2, shifts=(shift_h, shift_w), dims=(2, 3))

        # back to (B, N, D)
        # Reshape from (B, D, H, W) to (B, N, D) where N=H*W
        x = x2.view(B, D, h * w).transpose(1, 2).contiguous()

        # residual + MLP
        x = res + x # Add residual connection after attention
        res2 = x # Second residual connection before MLP
        x = self.norm2(x) # Pre-normalization before MLP
        x = res2 + self.mlp(x) # Apply MLP and add residual connection

        return x # Output tensor with the same shape as input (B, N, D)

h, w = 32, 64 # Example spatial dimensions (after patch embedding)
dummy_input_local = torch.randn(2, h * w, D)
output_local = SWLMWA(dim=64, num_heads=8, window_size=(7,11), shift_size=(3,5))(dummy_input_local, h, w)

print(f"LocalMixingWindow Input shape: {dummy_input_local.shape}")
print(f"LocalMixingWindow Output shape: {output_local.shape}")

# Verify that the output shape is the same as the input shape
assert dummy_input_local.shape == output_local.shape, "Output shape does not match input shape!"
print("Test passed: Output shape matches input shape.")

### Deformable Local Mixing Window Attention

##### V2

In [None]:
class OffsetPredictor(nn.Module):
    """
    Predict offsets for each query token.

    Produces offsets shaped (B, N, H, P, 2) in pixel units (dx, dy).
    The output of the linear layer is squashed with tanh and multiplied by offset_scale.
    """
    def __init__(self, dim: int, num_heads: int, n_points: int, offset_scale: float = 4.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.n_points = n_points
        self.offset_scale = float(offset_scale)
        self.linear = nn.Linear(dim, num_heads * n_points * 2, bias=True)

        # initialize offset predictor to small values so offsets start near zero
        nn.init.constant_(self.linear.weight, 0.0)
        nn.init.constant_(self.linear.bias, 0.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, N, D)
        B, N, D = x.shape
        out = self.linear(x)  # (B, N, H * P * 2)
        out = out.view(B, N, self.num_heads, self.n_points, 2)  # (B, N, H, P, 2)
        out = out.tanh() * self.offset_scale  # scaled pixel offsets (dx, dy)
        return out  # (B, N, H, P, 2)

In [None]:
def make_base_grid(h: int, w: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    """
    Construct base pixel coordinates for each spatial location.
    Returned shape: (N, 2) with coordinate order (x, y).
    N == h * w
    """
    # coords_y: [0, 1, ..., h-1], coords_x: [0, 1, ..., w-1]
    coords_y = torch.arange(h, device=device, dtype=dtype)
    coords_x = torch.arange(w, device=device, dtype=dtype)
    # meshgrid with indexing='ij' gives grid_y shape (h, w) and grid_x shape (h, w)
    grid_y, grid_x = torch.meshgrid(coords_y, coords_x, indexing='ij')
    # stack as (x, y) per pixel and flatten to (N, 2)
    base_xy = torch.stack((grid_x, grid_y), dim=-1).reshape(-1, 2)  # (N, 2), columns: (x, y)
    return base_xy

In [None]:
class FeatureMapProducer(nn.Module):
    """
    Produce key and value feature maps from flattened tokens (B, N, D).
    Produces k_map and v_map each shaped (B, D, h, w).
    """
    def __init__(self, dim: int):
        super().__init__()
        self.kv_conv = nn.Conv2d(dim, dim * 2, kernel_size=1, bias=True)

    def forward(self, tokens: torch.Tensor, h: int, w: int) -> Tuple[torch.Tensor, torch.Tensor]:
        # tokens: (B, N, D) where N == h * w
        B, N, D = tokens.shape
        feat = tokens.transpose(1, 2).reshape(B, D, h, w)  # (B, D, h, w)
        kv = self.kv_conv(feat)  # (B, 2*D, h, w)
        k_map, v_map = kv.chunk(2, dim=1)  # each (B, D, h, w)
        return k_map, v_map

In [None]:
class FeatureSampler:
    """
    Utility functions (not a nn.Module) to sample features at fractional positions using F.grid_sample.

    - Sample per-head feature maps: we reshape feature maps into (B*H, C_head, h, w)
    - Build grid and call F.grid_sample with align_corners=True to match normalization convention used.
    """
    @staticmethod
    def sample_maps_at_points(
        feature_map: torch.Tensor,
        sample_norm: torch.Tensor,
    ) -> torch.Tensor:
        """
        Sample feature_map at normalized points.

        feature_map: (B * H, C_head, h, w)
        sample_norm: (B * H, N * P, 2) in normalized coords [-1, 1], ordering (x, y) per point.
        Returns:
            sampled: (B, H, N, P, C_head)
        """
        B_H, C_head, h, w = feature_map.shape
        device = feature_map.device
        dtype = feature_map.dtype

        # grid requires shape (N_batch, H_out, W_out, 2). We'll set H_out = N*P, W_out = 1
        grid = sample_norm.view(B_H, -1, 1, 2)  # (B*H, N*P, 1, 2)

        # grid_sample -> output (B*H, C_head, N*P, 1)
        sampled = F.grid_sample(feature_map, grid, mode='bilinear', padding_mode='zeros', align_corners=True)

        # reshape to (B, H, C_head, N, P, 1) -> drop last dim -> (B, H, C_head, N, P)
        # we know B_H = B * H, so recover B and H later
        # first reshape to (B, H, C_head, N, P)
        # But we need N and P; they can be derived from sample_norm original leading dim: sample_norm.view(B_H, N*P, 2)
        # So compute N*P:
        _, NP, _ = sample_norm.shape  # NP = N * P
        # We'll infer N by dividing with known P (passed via context in caller) — but to keep method generic,
        # the caller should have arranged sample_norm with the correct ordering and will later reshape.
        # Here we produce (B_H, C_head, NP, 1) -> reshape to (B, H, C_head, N, P) by caller's knowledge.
        sampled = sampled.squeeze(-1)  # (B*H, C_head, NP)

        return sampled  # (B*H, C_head, N*P)

In [None]:
class DeformableAttention(nn.Module):
    """
    Readable, modular implementation of deformable attention.

    Key features:
    - Per-query, per-head learnable offsets (n_points per query per head)
    - Bilinear sampling of K and V feature maps using F.grid_sample
    - Dot-product attention between query and sampled keys (softmax over P)
    - Output projected back to embedding dimension

    Parameters:
        dim: input & output embedding dimension
        num_heads: number of attention heads
        n_points: sampling points per query per head
        offset_scale: scale multiplier for tanh-squashed offsets (in pixels)
    """
    def __init__(self, dim: int, num_heads: int, n_points: int = 9, offset_scale: float = 4.0, debug: bool = False):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        self.dim = dim
        self.num_heads = num_heads
        self.n_points = n_points
        self.head_dim = dim // num_heads
        self.offset_scale = float(offset_scale)
        self.debug = debug

        # modules
        self.to_q = nn.Linear(dim, dim, bias=True)
        self.offset_predictor = OffsetPredictor(dim, num_heads, n_points, offset_scale=offset_scale)
        self.feature_producer = FeatureMapProducer(dim)
        self.out_proj = nn.Linear(dim, dim, bias=True)

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, N, D) -> (B, H, N, head_dim)
        B, N, D = x.shape
        return x.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, H, N, head_dim) -> (B, N, D)
        B, H, N, hd = x.shape
        return x.permute(0, 2, 1, 3).contiguous().view(B, N, H * hd)

    def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
        """
        x: (B, N, D), where N == h * w
        returns: (B, N, D)
        """
        B, N, D = x.shape
        device = x.device
        dtype = x.dtype
        assert N == h * w, "N must equal h*w"

        if self.debug:
            print(f"[DeformableAttention] input: {x.shape}, h={h}, w={w}")

        # 1) Query projection -> (B, H, N, head_dim)
        q = self.to_q(x)
        q = self._split_heads(q)  # (B, H, N, hd)

        # 2) Offsets -> (B, N, H, P, 2) in pixel units (dx, dy)
        offsets = self.offset_predictor(x)  # (B, N, H, P, 2)

        # 3) Feature maps (k_map, v_map) each (B, D, h, w)
        k_map, v_map = self.feature_producer(x, h, w)  # (B, D, h, w)
        if self.debug:
            print(f"[DeformableAttention] k_map shape: {k_map.shape}, v_map shape: {v_map.shape}")

        # 4) Split K and V per head: -> (B*H, head_dim, h, w)
        B, D, _, _ = k_map.shape
        k_map_heads = k_map.view(B, self.num_heads, self.head_dim, h, w).reshape(B * self.num_heads, self.head_dim, h, w)
        v_map_heads = v_map.view(B, self.num_heads, self.head_dim, h, w).reshape(B * self.num_heads, self.head_dim, h, w)

        # 5) Base grid (pixel coords) and absolute sample positions
        base_xy = make_base_grid(h, w, device=device, dtype=dtype)  # (N, 2) with (x, y)
        base_xy = base_xy.unsqueeze(0).expand(B, -1, -1)  # (B, N, 2)

        # offsets currently (B, N, H, P, 2) in (dx, dy) ordering; grid_sample needs (x, y)
        # swap to (dx, dy) -> (x,y) means ordering (dx, dy) corresponds to adding to base (x,y)
        # we have offsets as (dx, dy) ordering already if OffsetPredictor produced (dx,dy)
        sample_xy = base_xy.unsqueeze(2).unsqueeze(3) + offsets  # (B, N, H, P, 2)

        # normalize sample coordinates to [-1, 1] for grid_sample (x normalized by w-1, y by h-1)
        norm = torch.tensor([(w - 1), (h - 1)], device=device, dtype=dtype).view(1, 1, 1, 1, 2)
        sample_norm = (sample_xy / norm) * 2.0 - 1.0  # (B, N, H, P, 2)

        # 6) Rearrange sample_norm for grid_sample:
        # desired ordering for sampling is (B, H, N*P, 2) -> then view to (B*H, N*P, 2)
        sample_norm = sample_norm.permute(0, 2, 1, 3, 4).contiguous()  # (B, H, N, P, 2)
        B_H = B * self.num_heads
        sample_norm = sample_norm.view(B_H, N * self.n_points, 2)  # (B*H, N*P, 2)

        if self.debug:
            print(f"[DeformableAttention] sample_norm shape (for grid_sample): {sample_norm.shape}")

        # 7) Use grid_sample to fetch features at sampled points (bilinear interpolation)
        # grid_sample expects (N_batch, C, h, w) and grid shaped (N_batch, H_out, W_out, 2),
        # and returns (N_batch, C, H_out, W_out).
        # We'll set H_out = N * P, W_out = 1
        grid_for_gs = sample_norm.view(B_H, N * self.n_points, 1, 2)  # (B*H, N*P, 1, 2)
        sampled_k = F.grid_sample(k_map_heads, grid_for_gs, mode='bilinear', padding_mode='zeros', align_corners=True)  # (B*H, hd, N*P, 1)
        sampled_v = F.grid_sample(v_map_heads, grid_for_gs, mode='bilinear', padding_mode='zeros', align_corners=True)

        # remove last dim and reshape to (B, H, hd, N, P)
        sampled_k = sampled_k.squeeze(-1).view(B, self.num_heads, self.head_dim, N, self.n_points)
        sampled_v = sampled_v.squeeze(-1).view(B, self.num_heads, self.head_dim, N, self.n_points)

        # permute to (B, H, N, P, hd)
        sampled_k = sampled_k.permute(0, 1, 3, 4, 2).contiguous()
        sampled_v = sampled_v.permute(0, 1, 3, 4, 2).contiguous()

        if self.debug:
            print(f"[DeformableAttention] sampled_k shape: {sampled_k.shape}, sampled_v shape: {sampled_v.shape}")

        # 8) Compute attention logits and weights
        # q: (B, H, N, hd) ; sampled_k: (B, H, N, P, hd)
        # compute dot product along hd -> (B, H, N, P)
        # use einsum for clarity
        attn_logits = torch.einsum('bhnd,bhnpd->bhnp', q, sampled_k) * (self.head_dim ** -0.5)
        attn_weights = F.softmax(attn_logits, dim=-1)  # (B, H, N, P)

        # 9) Weighted sum of sampled_v -> (B, H, N, hd)
        out_heads = torch.einsum('bhnp,bhnpd->bhnd', attn_weights, sampled_v)  # (B, H, N, hd)

        # 10) Merge heads and project
        out = self._merge_heads(out_heads)  # (B, N, D)
        out = self.out_proj(out)  # (B, N, D)

        if self.debug:
            print(f"[DeformableAttention] out shape: {out.shape}")

        return out

In [None]:
class DLMWA(nn.Module):
    """
    Local mixing block wrapping DeformableAttention with LayerNorm and MLP (residuals).
    """
    def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 2.0, n_points: int = 9, offset_scale: float = 4.0, debug: bool = False):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = DeformableAttention(dim, num_heads, n_points=n_points, offset_scale=offset_scale, debug=debug)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, ratio=mlp_ratio)
        self.debug = debug

    def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
        if self.debug:
            print(f"[DLMWA] input: {x.shape}")
        res = x
        x = self.norm1(x)
        x = self.attn(x, h, w)
        x = res + x
        res2 = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = res2 + x
        if self.debug:
            print(f"[DLMWA] output: {x.shape}")
        return x

# example usage / test
h, w = 32, 64
B = 2
dim = 64
num_heads = 8
n_points = 9

dummy = torch.randn(B, h * w, dim)
block = DLMWA(dim=dim, num_heads=num_heads, n_points=n_points, debug=False)
out = block(dummy, h, w)
print("Input shape:", dummy.shape)
print("Output shape:", out.shape)
assert out.shape == dummy.shape, "Output shape mismatch!"
print("Smoke test passed.")

## Local mixing using a small depthwise conv (efficient alternative)


In [None]:
class LocalMixingConv(nn.Module):
    """
    A fast local mixing implemented by depthwise convolution followed by pointwise projection.
    Works as a local mixing alternative to window attention (useful for speed).
    """
    def __init__(self, dim, kernel_size=(7,11), mlp_ratio=2.0, drop=0.0):
        super().__init__()
        kh, kw = kernel_size
        self.norm1 = nn.LayerNorm(dim) # Layer normalization before depthwise convolution
        # Depthwise convolution: applies a separate convolution to each input channel
        self.dw = nn.Conv2d(dim, dim, kernel_size=(kh, kw), padding=(kh//2, kw//2), groups=dim)
        # Pointwise convolution: 1x1 convolution to mix information across channels
        self.pw = nn.Conv2d(dim, dim, kernel_size=1)
        self.norm2 = nn.LayerNorm(dim) # Layer normalization before MLP
        self.mlp = MLP(dim, ratio=mlp_ratio, drop=drop) # MLP block

    def forward(self, x, h, w):
        # x: (B, N, D) where B is batch size, N is sequence length (h*w), D is dimension
        B, N, D = x.shape
        assert N == h * w, "N must equal h*w" # Assert sequence length matches spatial dimensions
        res = x # Residual connection

        # Pre-normalization
        x = self.norm1(x)

        # Reshape to NCHW for convolution (B, D, h, w)
        x2 = x.transpose(1, 2).reshape(B, D, h, w)

        # Apply depthwise and pointwise convolutions
        x2 = self.dw(x2)
        x2 = self.pw(x2)

        # Reshape back to (B, N, D)
        x = x2.view(B, D, h * w).transpose(1, 2).contiguous()

        # Add residual connection after convolution
        x = res + x

        # Second residual connection
        res = x

        # Pre-normalization before MLP
        x = self.norm2(x)

        # Apply MLP and add residual connection
        x = res + self.mlp(x)

        return x # Output tensor with the same shape as input (B, N, D)

# Test the LocalMixingConv module
# The input to LocalMixingConv is (B, N, D), where N = h * w
B, D = 2, 64 # Example Batch size and Dimension
h, w = 32, 64 # Example spatial dimensions (after patch embedding)
N = h * w # Sequence length
dummy_input_conv = torch.randn(B, N, D)

# Instantiate the LocalMixingConv module
kernel_size = (7, 11) # Example kernel size
local_mixing_conv = LocalMixingConv(dim=D, kernel_size=kernel_size)

# Pass the dummy input through the module
output_conv = local_mixing_conv(dummy_input_conv, h, w)

print(f"LocalMixingConv Input shape: {dummy_input_conv.shape}")
print(f"LocalMixingConv Output shape: {output_conv.shape}")

# Verify that the output shape is the same as the input shape
assert dummy_input_conv.shape == output_conv.shape, "Output shape does not match input shape!"
print("Test passed: Output shape matches input shape.")

## Merge / reduce resolution between stages


In [None]:
class Merging(nn.Module):
    """
    Merge reduces height by 2 (stride (2,1)) similar to SVTR merge layer.
    Input x: (B, N, D) with h,w provided.
    Returns x_new (B, N_new, D_out), and new (h, w).
    """
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=(2,1), padding=1)
        self.norm = nn.LayerNorm(out_dim)

    def forward(self, x, h, w):
        B, N, D = x.shape
        x2 = x.transpose(1, 2).reshape(B, D, h, w)
        x2 = self.conv(x2)  # (B, out_dim, h//2, w)
        _, _, hp, wp = x2.shape
        x = x2.flatten(2).transpose(1, 2).contiguous()  # (B, hp*wp, out_dim)
        x = self.norm(x)
        return x, hp, wp

# Test the Merging module
B, in_dim = 2, 12 # Example Batch size and Input Dimension
out_dim = 123 # Example Output Dimension
h, w = 32, 64 # Example spatial dimensions before merging
N = h * w # Sequence length
dummy_input_merge = torch.randn(B, N, in_dim)

# Instantiate the Merging module
merging_layer = Merging(in_dim=in_dim, out_dim=out_dim)

# Pass the dummy input through the module
output_merge, new_h, new_w = merging_layer(dummy_input_merge, h, w)

print(f"Merging Input shape: {dummy_input_merge.shape}")
print(f"Merging Output shape: {output_merge.shape}")
print(f"New spatial dimensions (h, w): ({new_h}, {new_w})")

## Combine -> collapse height to 1 and produce sequence (B, W, D_out)


In [None]:
class Combining(nn.Module):
    """
    Combine across height: collapse height via mean and project channels to out_dim.
    """
    def __init__(self, in_dim, out_dim, drop=0.0):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim)
        self.act = nn.GELU()
        self.drop = nn.Dropout(drop)

    def forward(self, x, h, w):
        # x: (B, N, D) where N = h*w
        B, N, D = x.shape
        x2 = x.transpose(1, 2).reshape(B, D, h, w)  # (B, D, h, w)
        # collapse height by mean -> (B, D, 1, w)
        x2 = x2.mean(dim=2, keepdim=True)  # average over height
        x2 = x2.squeeze(2)  # (B, D, w)
        x2 = x2.transpose(1, 2).contiguous()  # (B, w, D)
        x2 = self.fc(x2)
        x2 = self.act(x2)
        x2 = self.drop(x2)
        return x2  # (B, w, out_dim)

# Test the Combining module
B, in_dim = 2, 12 # Example Batch size and Input Dimension (e.g., output of Merging)
out_dim = 123 # Example Output Dimension
# Assume input comes from a Merging layer that reduced height to 16 and kept width at 64
h, w = 16, 64 # Example spatial dimensions before combining
N = h * w # Sequence length
dummy_input_combine = torch.randn(B, N, in_dim)

# Instantiate the Combining module
combining_layer = Combining(in_dim=in_dim, out_dim=out_dim)

# Pass the dummy input through the module
output_combine = combining_layer(dummy_input_combine, h, w)

print(f"Combining Input shape: {dummy_input_combine.shape}")
print(f"Combining Output shape: {output_combine.shape}")

# Verify the output shape
expected_output_shape = torch.Size([B, w, out_dim])
assert output_combine.shape == expected_output_shape, "Output shape does not match expected shape!"
print("Test passed: Output shape matches expected shape.")

## SVTR full model


In [None]:
class SVTR(nn.Module):
    """
    SVTR-like model architecture (PyTorch implementation).

    Args:
        img_size (tuple): Input image size (height, width). Defaults to (64, 256).
        in_chans (int): Number of input image channels. Defaults to 3.
        vocab_size (int): Size of the output vocabulary (number of characters). Defaults to 100.
        embed_dims (tuple): Embedding dimensions for each stage. Defaults to (64, 128, 256).
        d3 (int): Output dimension of the combining layer before the head. Defaults to 192.
        heads (tuple): Number of attention heads for each stage. Defaults to (2, 4, 8).
        mlp_ratio (float): Ratio to determine hidden dimension in MLP. Defaults to 2.0.
        window_sizes (list(tuple)): Window sizes for window-based local attention. length must match the number of 'L' blocks in the pattern.
                             Also used as kernel size for LocalMixingConv.
        num_blocks (tuple): Number of blocks in each stage. Defaults to (3, 6, 3).
        pattern (list, optional): List of 'L' (local) or 'G' (global) specifying block types.
                                  If None, a default pattern is generated.
        local_type (list): List of local mixing types for each local block ('non_overlapping', 'swin', 'deformable', 'conv').
                           Length must match the number of 'L' blocks in the pattern.
        drop (float): Dropout rate. Defaults to 0.0.
        n_points (int): Number of sampling points for Deformable Attention. Defaults to 9.
        offset_scale (float): Scaling factor for Deformable Attention offsets. Defaults to 4.0.
    """
    def __init__(self,
                 img_size=(64, 256),
                 in_chans=3,
                 vocab_size=100,
                 embed_dims=(64, 128, 256),
                 d3=192,
                 heads=(2, 4, 8),
                 mlp_ratio=2.0,
                 window_sizes=[(7, 11)] * 12 + [(3,3)] * 6, # Default window sizes if not provided
                 num_blocks=(3, 6, 3),
                 pattern=None,
                 local_type=None,
                 drop=0.0,
                 n_points=9,
                 offset_scale=4.0):
        super().__init__()

        # pattern length must cover sum(num_blocks). L = local, G = global
        total_blocks = sum(num_blocks)
        assert pattern is not None and len(pattern) == total_blocks, f"Pattern must be a list of length {total_blocks} specifying 'L' or 'G' for each block."

        # Ensure local_type is a list and matches the number of local blocks in the pattern
        assert isinstance(local_type, list), "local_type must be a list specifying block types."
        num_local_blocks = pattern.count('L')
        assert len(local_type) >= num_local_blocks, f"Length of local_type list ({len(local_type)}) must match the number of local blocks in pattern ({num_local_blocks})."
        self.local_type_list = local_type

        # Ensure window_sizes list length matches the number of local blocks
        assert isinstance(window_sizes, list), "window_sizes must be a list of tuples."
        assert len(window_sizes) >= num_local_blocks, f"Length of window_sizes list ({len(window_sizes)}) must match the number of local blocks in pattern ({num_local_blocks})."
        self.window_sizes_list = window_sizes


        self.patch_embed = PatchEmbedSVTR(img_size, in_chans, embed_dim=embed_dims[0])
        self.patches_resolution = self.patch_embed.patches_resolution
        dims = list(embed_dims) # Convert to list
        self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches, dims[0]))  # learnable positional embedding
        self.n_points = n_points # Store for Deformable Attention
        self.offset_scale = offset_scale # Store for Deformable Attention
        self.drop_rate = drop # Store dropout rate

        # Build the model stages based on the pattern and local_type
        cur_pattern_idx = 0
        cur_local_type_idx = 0
        cur_window_size_idx = 0

        # Stage 1
        self.stage1 = nn.ModuleList()
        for i in range(num_blocks[0]):
            tp = pattern[cur_pattern_idx]
            if tp == 'L':
                current_local_type = self.local_type_list[cur_local_type_idx]
                current_window_size = self.window_sizes_list[cur_window_size_idx]
                if current_local_type == 'non_overlapping':
                    blk = NOLMWA(dims[0], heads[0], mlp_ratio, current_window_size, self.drop_rate)
                elif current_local_type == 'swin':
                    # Alternate shift for SWIN
                    do_shift = (i % 2 == 1)
                    shift_size = (current_window_size[0] // 2, current_window_size[1] // 2) if do_shift else (0, 0)
                    blk = SWLMWA(dims[0], heads[0], mlp_ratio, current_window_size, self.drop_rate, shift_size)
                elif current_local_type == 'deformable':
                    blk = DLMWA(dims[0], heads[0], mlp_ratio, self.n_points, self.offset_scale, debug=False) # debug=False by default
                elif current_local_type == 'conv':
                    blk = LocalMixingConv(dims[0], kernel_size=current_window_size, mlp_ratio=mlp_ratio, drop=self.drop_rate)
                else:
                    raise ValueError(f"Unknown local_type '{current_local_type}' for block {cur_pattern_idx}")
                cur_local_type_idx += 1
                cur_window_size_idx += 1
            else: # Global mixing
                blk = GlobalMixing(dims[0], heads[0], mlp_ratio, self.drop_rate)
            self.stage1.append(blk)
            cur_pattern_idx += 1

        self.merge1 = Merging(dims[0], dims[1])

        # Stage 2
        self.stage2 = nn.ModuleList()
        for i in range(num_blocks[1]):
            tp = pattern[cur_pattern_idx]
            if tp == 'L':
                current_local_type = self.local_type_list[cur_local_type_idx]
                current_window_size = self.window_sizes_list[cur_window_size_idx]
                if current_local_type == 'non_overlapping':
                    blk = NOLMWA(dims[1], heads[1], mlp_ratio, current_window_size, self.drop_rate)
                elif current_local_type == 'swin':
                    do_shift = (i % 2 == 1)
                    shift_size = (current_window_size[0] // 2, current_window_size[1] // 2) if do_shift else (0, 0)
                    blk = SWLMWA(dims[1], heads[1], mlp_ratio, current_window_size, self.drop_rate, shift_size)
                elif current_local_type == 'deformable':
                    blk = DLMWA(dims[1], heads[1], mlp_ratio, self.n_points, self.offset_scale, debug=False)
                elif current_local_type == 'conv':
                    blk = LocalMixingConv(dims[1], kernel_size=current_window_size, mlp_ratio=mlp_ratio, drop=self.drop_rate)
                else:
                     raise ValueError(f"Unknown local_type '{current_local_type}' for block {cur_pattern_idx}")
                cur_local_type_idx += 1
                cur_window_size_idx += 1
            else: # Global mixing
                blk = GlobalMixing(dims[1], heads[1], mlp_ratio, self.drop_rate)
            self.stage2.append(blk)
            cur_pattern_idx += 1

        self.merge2 = Merging(dims[1], dims[2])

        # Stage 3
        self.stage3 = nn.ModuleList()
        for i in range(num_blocks[2]):
            tp = pattern[cur_pattern_idx]
            if tp == 'L':
                current_local_type = self.local_type_list[cur_local_type_idx]
                current_window_size = self.window_sizes_list[cur_window_size_idx]
                if current_local_type == 'non_overlapping':
                    blk = NOLMWA(dims[2], heads[2], mlp_ratio, current_window_size, self.drop_rate)
                elif current_local_type == 'swin':
                    do_shift = (i % 2 == 1)
                    shift_size = (current_window_size[0] // 2, current_window_size[1] // 2) if do_shift else (0, 0)
                    blk = SWLMWA(dims[2], heads[2], mlp_ratio, current_window_size, self.drop_rate, shift_size)
                elif current_local_type == 'deformable':
                    blk = DLMWA(dims[2], heads[2], mlp_ratio, self.n_points, self.offset_scale, debug=False)
                elif current_local_type == 'conv':
                    blk = LocalMixingConv(dims[2], kernel_size=current_window_size, mlp_ratio=mlp_ratio, drop=self.drop_rate)
                else:
                    raise ValueError(f"Unknown local_type '{current_local_type}' for block {cur_pattern_idx}")
                cur_local_type_idx += 1
                cur_window_size_idx += 1
            else: # Global mixing
                blk = GlobalMixing(dims[2], heads[2], mlp_ratio, self.drop_rate)
            self.stage3.append(blk)
            cur_pattern_idx += 1

        self.combine = Combining(dims[2], d3, drop=self.drop_rate)
        self.head = nn.Linear(d3, vocab_size)

        # Initialize weights
        self.init_weights()

    def init_weights(self):
        """
        Initialize the weights of the model.
        Uses Kaiming uniform initialization for convolutional and linear layers,
        constant initialization for batch norm and layer norm,
        and truncated normal for positional embeddings.
        """
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, PatchEmbedSVTR):
                        # Initialize convolution layers within PatchEmbedSVTR
                        for conv in [m.conv1, m.conv2]:  # Adjust based on actual structure
                            if isinstance(conv, nn.Conv2d):
                                nn.init.kaiming_uniform_(conv.weight, mode='fan_out', nonlinearity='relu')
                                if conv.bias is not None:
                                    nn.init.constant_(conv.bias, 0)

        # Initialize positional embedding with truncated normal distribution
        if self.pos_embed is not None:
            nn.init.trunc_normal_(self.pos_embed, std=.02)


    def forward(self, x):
        """
        Forward pass of the SVTR model.

        Args:
            x (torch.Tensor): Input image tensor (B, C, H, W).

        Returns:
            torch.Tensor: Output logits for each character at each timestep (B, W, vocab_size).
        """
        # x: Input image tensor (B, C, H, W)
        B = x.shape[0]

        # Patch Embedding
        # Input shape: (B, C, H, W) e.g., (B, 3, 64, 512)
        x = self.patch_embed(x)
        # Add positional embedding
        x = x + self.pos_embed
        # Output shape: (B, N, D) where N is number of patches, D is embed_dim[0]
        # N = (H/4) * (W/4)
        h, w = self.patch_embed.patches_resolution # Spatial dimensions after patch embedding (H/4, W/4)

        # Build the model stages based on the pattern and local_type
        cur_pattern_idx = 0
        cur_local_type_idx = 0
        cur_window_size_idx = 0

        # Stage 1
        # Input shape: (B, N, embed_dims[0]) e.g., (B, 2048, 64) from patches_resolution (16, 128)
        for i in range(len(self.stage1)):
            blk = self.stage1[i]
            tp = config['pattern'][cur_pattern_idx] # Get pattern from config
            if tp == 'L':
                # Local mixing requires spatial dimensions h, w
                x = blk(x, h, w)
                cur_local_type_idx += 1
                cur_window_size_idx += 1
            else:
                # Global mixing operates on the sequence (B, N, D)
                x = blk(x)
            cur_pattern_idx += 1
            # Output shape after each block: (B, N, embed_dims[0])

        # Merge 1
        # Input shape: (B, N, embed_dims[0])
        x, h, w = self.merge1(x, h, w)
        # Output shape: (B, N_new, embed_dims[1]) and new spatial dimensions (h, w)
        # Merging reduces height by 2, keeps width (approximately due to convolution stride)
        # e.g., (B, (h//2)*w, embed_dims[1])

        # Stage 2
        # Input shape: (B, N_new, embed_dims[1])
        for i in range(len(self.stage2)):
            blk = self.stage2[i]
            tp = config['pattern'][cur_pattern_idx] # Get pattern from config
            if tp == 'L':
                x = blk(x, h, w)
                cur_local_type_idx += 1
                cur_window_size_idx += 1
            else:
                x = blk(x)
            cur_pattern_idx += 1
            # Output shape after each block: (B, N_new, embed_dims[1])

        # Merge 2
        # Input shape: (B, N_new, embed_dims[1])
        x, h, w = self.merge2(x, h, w)
        # Output shape: (B, N_new_new, embed_dims[2]) and new spatial dimensions (h, w)
        # Merging reduces height by 2, keeps width

        # Stage 3
        # Input shape: (B, N_new_new, embed_dims[2])
        for i in range(len(self.stage3)):
            blk = self.stage3[i]
            tp = config['pattern'][cur_pattern_idx] # Get pattern from config
            if tp == 'L':
                x = blk(x, h, w)
                cur_local_type_idx += 1
                cur_window_size_idx += 1
            else:
                x = blk(x)
            cur_pattern_idx += 1
            # Output shape after each block: (B, N_new_new, embed_dims[2])

        # Combine -> collapse height to sequence of width length
        # Input shape: (B, N_final, embed_dims[2]) where N_final is h*w after last merge
        x = self.combine(x, h, w)
        # Output shape: (B, w, d3)
        # After combining, sequence length is the final width (w), dimension is d3.

        # Final Linear Head (Classifier/Decoder)
        # Input shape: (B, w, d3)
        x = self.head(x)
        # Output shape: (B, w, vocab_size)
        # This is the output sequence of logits for each character at each timestep (width position).

        return x

model = SVTR(
    img_size=(config['img_height'], config['img_width']),
    in_chans=config['channels'],
    vocab_size=config['vocab_size'],
    local_type=config['local_type'], # Use the list from config
    embed_dims=config['embed_dims'], # Use embed_dims from config
    heads=config['heads'], # Use heads from config
    mlp_ratio=config['mlp_ratio'], # Use mlp_ratio from config
    window_sizes=config['window_sizes'], # Use window_sizes from config
    num_blocks=config['num_blocks'], # Use num_blocks from config
    pattern=config['pattern'], # Use pattern from config
    drop=config['dropout_rate'], # Use dropout_rate from config
    n_points=config['n_points'], # Use n_points from config
    offset_scale=config['offset_scale'], # Use offset_scale from config
).to(config['device']).eval()

dummy = torch.randn(2, config['channels'], config['img_height'], config['img_width'], device=config['device'])
with torch.no_grad():
    out = model(dummy)

print(f"Input shape: {dummy.shape}")
print(f"Output shape: {out.shape}")
# The expected output shape is (batch_size, final_width, vocab_size)
# From the summary, the final width is 128.
assert out.shape == torch.zeros((2, config['img_width'] // 4, config['vocab_size'])).shape, "Output shape does not match expected shape!"
print("Test passed: Output shape matches expected shape.")

summary(model, input_size=dummy.shape)

# SECTION 8: COMPLETE TRAINING & VALIDATION & TESTING WORKFLOW

## Definition

### Utilities

In [None]:
class OCRMetrics:
    """Comprehensive evaluation metrics for OCR"""

    def __init__(self, text_processor):
        self.text_processor = text_processor
        self.reset()

    def reset(self):
        self.total_samples = 0
        self.correct_samples = 0
        self.total_chars = 0
        self.correct_chars = 0
        self.total_edit_distance = 0
        self.total_word_distance = 0
        self.total_words = 0
        self.correct_words = 0
        self.char_error_rates = []
        self.word_error_rates = []
        self.sequence_accuracies = []

    def update(self, predictions, targets):
        """Update metrics with batch predictions and targets"""
        batch_size = len(predictions)
        self.total_samples += batch_size

        for pred_text, target_text in zip(predictions, targets):
            # Character-level metrics
            pred_chars = list(pred_text)
            target_chars = list(target_text)

            self.total_chars += len(target_chars)

            # Character accuracy
            correct_chars = sum(1 for p, t in zip(pred_chars, target_chars) if p == t)
            self.correct_chars += correct_chars

            # Character Error Rate (CER)
            edit_dist = editdistance.eval(pred_text, target_text)
            self.total_edit_distance += edit_dist
            cer = edit_dist / max(len(target_text), 1)
            self.char_error_rates.append(cer)

            # Word-level metrics
            pred_words = pred_text.split()
            target_words = target_text.split()

            self.total_words += len(target_words)

            # Word accuracy
            correct_words = sum(1 for p, t in zip(pred_words, target_words) if p == t)
            self.correct_words += correct_words

            # Word Error Rate (WER)
            word_edit_dist = editdistance.eval(pred_words, target_words)
            self.total_word_distance += word_edit_dist
            wer = word_edit_dist / max(len(target_words), 1)
            self.word_error_rates.append(wer)

            # Sequence accuracy
            if pred_text == target_text:
                self.correct_samples += 1
                self.sequence_accuracies.append(1.0)
            else:
                self.sequence_accuracies.append(0.0)

    def get_metrics(self):
        """Calculate and return all metrics"""
        if self.total_samples == 0:
            return {}

        metrics = {
            'sequence_accuracy': self.correct_samples / self.total_samples,
            'character_accuracy': self.correct_chars / self.total_chars,
            'word_accuracy': self.correct_words / self.total_words,
            'character_error_rate': self.total_edit_distance / self.total_chars,
            'word_error_rate': self.total_word_distance / self.total_words,
            'avg_cer': np.mean(self.char_error_rates),
            'avg_wer': np.mean(self.word_error_rates),
            'avg_sequence_accuracy': np.mean(self.sequence_accuracies),
            'total_samples': self.total_samples
        }

        return metrics

In [None]:
def ctc_decode(logits, text_processor, beam_size=4):
    """Decode CTC output to text with beam search"""
    # Get the most likely character at each time step
    if beam_size == 1:  # Greedy decoding
        pred_indices = torch.argmax(logits, dim=-1)  # [T, B]
        pred_indices = pred_indices.transpose(0, 1)  # [B, T]

        decoded_texts = []
        for seq in pred_indices:
            # Remove duplicates and blank tokens
            decoded = []
            prev_token = None
            for token in seq:
                token = token.item()
                if token != text_processor.char2idx[text_processor.BLANK_TOKEN] and token != prev_token:
                    decoded.append(token)
                prev_token = token

            # Convert to text
            text = text_processor.decode_sequence(decoded)
            decoded_texts.append(text)

        return decoded_texts
    else:
        # Beam search decoding (simplified version)
        # For full beam search, you'd need a more complex implementation
        return ctc_decode(logits, text_processor, beam_size=1)  # Fallback to greedy

In [None]:
def get_gpu_memory_usage():
    """Get current GPU memory usage in GB"""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / (1024**3)
    return 0

In [None]:
class ArabicOCRInference:
    """Robust inference pipeline for Arabic OCR"""

    def __init__(self, model, text_processor, device, transform=None):
        self.model = model
        self.text_processor = text_processor
        self.device = device
        self.transform = transform or get_val_transforms(config['dataset_mean'], config['dataset_std'])
        self.model.eval()

    def preprocess_image(self, image):
        """Preprocess single image for inference"""
        # Handle different input types: file path (str), PIL Image, numpy array, or torch.Tensor
        if isinstance(image, str):
            image = Image.open(image)

        if isinstance(image, Image.Image):
            image = np.array(image) # Convert PIL Image to numpy array

        # At this point, image should be either a numpy array or a torch.Tensor
        if isinstance(image, np.ndarray):
             # Apply transforms if it's a numpy array (e.g., from file or PIL)
            if self.transform:
                augmented = self.transform(image=image)
                image = augmented['image'] # Output of ToTensorV2 is a torch.Tensor

            # Ensure it's a PyTorch tensor and add batch dimension
            if not isinstance(image, torch.Tensor):
                 image = torch.from_numpy(image)
            image = image.unsqueeze(0) # Add batch dimension

        elif isinstance(image, torch.Tensor):
            # If the input was already a tensor, just ensure correct dimensions and dtype
            # Assuming the input tensor is already in (C, H, W) format and correct dtype (float32)
            # If not, additional checks/conversions might be needed
            if len(image.shape) == 3: # C, H, W
                image = image.unsqueeze(0) # Add batch dimension B, C, H, W
            elif len(image.shape) == 4: # B, C, H, W
                 pass # Already has batch dimension
            else:
                raise ValueError(f"Unexpected tensor shape: {image.shape}. Expected (C, H, W) or (B, C, H, W).")

        else:
            raise TypeError(f"Unsupported image type: {type(image)}. Supported types are str (file path), PIL.Image.Image, np.ndarray, or torch.Tensor.")


        return image


    def predict(self, image, beam_size=4):
        """Predict text from single image"""
        with torch.no_grad():
            # Preprocess
            input_tensor = self.preprocess_image(image).to(self.device)

            # Forward pass
            logits = self.model(input_tensor)

            # Decode
            # logits shape is (B, W, vocab_size), ctc_decode expects (T, B, C)
            # T is sequence length (width W), B is batch size, C is vocab_size
            decoded_texts = ctc_decode(logits.transpose(0, 1), self.text_processor, beam_size)

            return decoded_texts[0] if decoded_texts else "" # Return the first (and only) prediction

    def predict_batch(self, images, beam_size=4):
        """Predict text from batch of images"""
        with torch.no_grad():
            # Preprocess batch
            batch_tensors = []
            for image in images:
                tensor = self.preprocess_image(image)
                batch_tensors.append(tensor)

            # Ensure all tensors have the same shape before stacking (except batch dimension)
            # This might require padding or resizing images to a consistent size
            # For now, assuming images are already of compatible size after individual preprocessing.
            # A more robust implementation would handle variable sizes.
            if not batch_tensors:
                 return [] # Return empty list if no images were successfully preprocessed

            # Check if all tensors have the same shape (excluding batch dim)
            first_shape = batch_tensors[0].shape[1:]
            if not all(t.shape[1:] == first_shape for t in batch_tensors):
                # If shapes are inconsistent, pad or resize them
                # For simplicity here, let's assume they are consistent or raise an error
                # A better approach would involve padding the batch
                 raise ValueError("Images in batch have inconsistent shapes after preprocessing.")


            batch_tensor = torch.cat(batch_tensors, dim=0).to(self.device)

            # Forward pass
            logits = self.model(batch_tensor) # Shape: (B, W, vocab_size)

            # Decode
            # logits shape is (B, W, vocab_size), ctc_decode expects (T, B, C)
            decoded_texts = ctc_decode(logits.transpose(0, 1), self.text_processor, beam_size)

            return decoded_texts

    def predict_with_confidence(self, image, beam_size=4):
        """Predict text with confidence score"""
        with torch.no_grad():
            input_tensor = self.preprocess_image(image).to(self.device)
            logits = self.model(input_tensor)

            # Get probabilities
            probs = F.softmax(logits, dim=-1) # Shape: (B, W, vocab_size)
            # Get max probability for each time step
            max_probs, _ = torch.max(probs, dim=-1) # Shape: (B, W)

            # Average confidence across sequence for the first sample in the batch
            # Assuming batch size is 1 for single image inference
            confidence = torch.mean(max_probs[0]).item() if max_probs.size(0) > 0 else 0.0

            # Decode
            decoded_text = ctc_decode(logits.transpose(0, 1), self.text_processor, beam_size)

            return decoded_text[0] if decoded_text else "", confidence

In [None]:
def plot_training_history(history):
    """Plots training and validation loss and metrics over epochs."""
    print("\nVisualizing training history...")

    # Plot Loss
    plt.figure(figsize=(12, 6))
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss over Epochs')
    plt.legend()
    plt.grid(True)
    plt.show()

    # Get list of all metric names from history (assuming consistent metrics)
    if history['train_metrics']:
        metric_names = history['train_metrics'][0].keys()
        # Exclude metrics that don't make sense to plot over time (like total_samples)
        metrics_to_plot = [name for name in metric_names if name not in ['total_samples']]

        for metric_name in metrics_to_plot:
            plt.figure(figsize=(12, 6))
            train_values = [m[metric_name] for m in history['train_metrics']]
            val_values = [m[metric_name] for m in history['val_metrics']]

            plt.plot(train_values, label=f'Train {metric_name}')
            plt.plot(val_values, label=f'Validation {metric_name}')
            plt.xlabel('Epoch')
            plt.ylabel(metric_name)
            plt.title(f'Training and Validation {metric_name} over Epochs')
            plt.legend()
            plt.grid(True)
            plt.show()

In [None]:
def display_error_analysis(error_analysis, dataset_name="Test Set", num_examples=10):
    """Displays summary and sample error predictions."""
    print(f"\n{'='*60}")
    print(f"{dataset_name} Detailed Analysis:")
    print(f"{'='*60}")
    print(f"Total incorrect samples in {dataset_name}: {len(error_analysis['error_samples'])}")

    num_error_examples_to_show = min(num_examples, len(error_analysis['error_samples']))
    print(f"First {num_error_examples_to_show} incorrect predictions:")
    for i, (pred, target) in enumerate(error_analysis['error_samples'][:num_error_examples_to_show]):
        # Prepare Arabic text for RTL display in print output
        # display_pred_text_print = get_display(arabic_reshaper.reshape(pred))
        # display_target_text_print = get_display(arabic_reshaper.reshape(target))
        display_pred_text_print = pred
        display_target_text_print = target
        print(f"  Sample {i+1}: Predicted='{display_pred_text_print}', Actual='{display_target_text_print}'")

def plot_error_distributions(error_analysis, dataset_name="Test Set"):
    """Plots distribution of CER and WER."""
    print(f"\nVisualizing {dataset_name} Error Distributions:")
    plt.figure(figsize=(12, 6))
    plt.hist(error_analysis['cer_distribution'], bins=20, alpha=0.7, label='CER Distribution')
    plt.hist(error_analysis['wer_distribution'], bins=20, alpha=0.7, label='WER Distribution')
    plt.xlabel('Error Rate')
    plt.ylabel('Frequency')
    plt.title(f'Distribution of CER and WER on {dataset_name}')
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
def visualize_predictions(dataloader, text_processor, inference_pipeline, num_samples=12, dataset_mean=None, dataset_std=None):
    """Visualizes samples with actual and predicted text."""
    print("\nVisualizing samples with predictions...")
    batch = next(iter(dataloader))
    images = batch['images'][:num_samples]
    texts = batch['texts'][:num_samples]

    # Prepare images for the inference pipeline
    inference_images = []
    for img_tensor in images:
        # Denormalize, permute, convert to numpy, convert to uint8, then to PIL Image
        img_np = img_tensor.permute(1, 2, 0).cpu().numpy()

        # Use provided mean and std for denormalization, fallback to config
        mean_val = np.array(dataset_mean if dataset_mean is not None else config['dataset_mean'])
        std_val = np.array(dataset_std if dataset_std is not None else config['dataset_std'])
        img_np = img_np * std_val + mean_val

        img_np = np.clip(img_np, 0, 1) * 255 # Scale to 0-255
        img_np = img_np.astype(np.uint8) # Convert to uint8
        inference_images.append(Image.fromarray(img_np))

    # Predict text for the batch using the inference pipeline
    predicted_texts = inference_pipeline.predict_batch(inference_images)

    # Ensure num_samples does not exceed the number of actual predictions
    num_samples_to_show = min(num_samples, len(predicted_texts))


    # Determine grid size (e.g., 4 columns, rows based on num_samples)
    cols = 4
    rows = (num_samples_to_show + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(20, rows * 5))
    axes = axes.ravel() if isinstance(axes, np.ndarray) else [axes] # Handle single subplot case

    for i in range(num_samples_to_show):
        # Use the PIL Image for imshow
        image = inference_images[i]
        actual_text = texts[i]
        predicted_text = predicted_texts[i]

        # Prepare Arabic text for RTL display in plot titles
        display_actual_text_plot = get_display(arabic_reshaper.reshape(actual_text))
        display_predicted_text_plot = get_display(arabic_reshaper.reshape(predicted_text))

        axes[i].imshow(image)
        # Ensure text is displayed RTL using the prepared strings
        axes[i].set_title(f"Actual: {display_actual_text_plot}\nPredicted: {display_predicted_text_plot}", fontsize=10)
        axes[i].axis('off')

    # Hide any unused subplots
    for j in range(num_samples_to_show, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

### Epoch functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, text_processor,
                epoch, total_epochs, use_amp=True, grad_clip=1.0):
    """Complete training epoch with progress tracking and AMP support"""
    model.train()
    total_loss = 0
    num_batches = len(dataloader)

    # Clear CUDA cache before starting
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Initialize GradScaler for mixed precision training, and use mixed precision for memory efficiency
    use_amp = torch.cuda.is_available() # Only enable AMP if CUDA is available
    scaler = torch.amp.GradScaler('cuda', enabled=use_amp)

    # Progress bar with leave=True for better visualization
    progress_bar = tqdm(dataloader, desc=f"Training Epoch {epoch}/{total_epochs}",
                       leave=True, dynamic_ncols=True)

    batch_times = []
    metrics = OCRMetrics(text_processor)

    for batch_idx, batch in enumerate(progress_bar):
        try:
            start_time = time.time()

            # Get batch data
            images = batch['images'].to(device)
            encoded_texts = batch['encoded_texts'].to(device)
            text_lengths = batch['text_lengths'].to(device)
            original_texts = batch['texts']

            # Explicitly delete the batch from CPU memory after moving to GPU
            del batch

            # Zero gradients
            optimizer.zero_grad()

            # Initialize loss for the current batch
            loss = torch.tensor(0.0, device=device)

            # Forward pass with mixed precision
            with torch.amp.autocast('cuda', enabled=use_amp):
                logits = model(images)  # Shape: (B, W, vocab_size)

                # Prepare CTC targets
                batch_size = images.size(0)
                # seq_length = logits.size(0) # Incorrect for CTC loss
                seq_length = logits.size(1) # Correct sequence length (width W)

                # Create mask for non-padding tokens
                max_target_length = encoded_texts.size(1)
                mask = torch.arange(max_target_length).expand(batch_size, max_target_length).to(device) < text_lengths.unsqueeze(1)
                targets = encoded_texts[mask]
                target_lengths = text_lengths
                input_lengths = torch.full(size=(batch_size,), fill_value=seq_length, dtype=torch.long).to(device)

                # Calculate CTC loss
                # log_probs = F.log_softmax(logits, dim=2) # Shape: (B, W, vocab_size)
                # Transpose log_probs to (W, B, vocab_size) for CTCLoss
                log_probs = F.log_softmax(logits, dim=2).transpose(0, 1)
                # Ensure target_lengths match the effective batch size after masking
                # and input_lengths matches the batch size of images.
                # Also, ensure targets are long and on the correct device.
                loss = criterion(log_probs, targets.long(), input_lengths, target_lengths.long())


            # Backward pass with gradient scaling
            scaler.scale(loss).backward()

            # Gradient clipping
            if grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            # Optimizer step
            scaler.step(optimizer)
            scaler.update()

            # Update learning rate
            scheduler.step()

            # Update metrics
            total_loss += loss.item()
            batch_time = time.time() - start_time
            batch_times.append(batch_time)

            # Decode predictions for metrics - logits shape expected by ctc_decode is (T, B, C)
            # Need to transpose logits back to (W, B, vocab_size) before decoding
            decoded_texts = ctc_decode(logits.transpose(0, 1), text_processor)
            metrics.update(decoded_texts, original_texts)

            # Explicitly delete intermediate tensors
            del logits, log_probs, targets, mask, input_lengths
            # Also delete images and encoded_texts as they are not needed anymore
            del images, encoded_texts, text_lengths, original_texts


            # Update progress bar
            current_metrics = metrics.get_metrics()
            progress_bar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Avg Loss': f'{total_loss/(batch_idx+1):.4f}',
                'CER': f'{current_metrics.get("avg_cer", 0):.4f}',
                'Acc': f'{current_metrics.get("sequence_accuracy", 0):.4f}',
                'LR': f'{optimizer.param_groups[0]["lr"]:.6f}',
                'Time': f'{batch_time:.3f}s',
                'Mem': f'{get_gpu_memory_usage():.1f}GB' if torch.cuda.is_available() else 'N/A'
            })

            # Clear some memory periodically
            if batch_idx % 50 == 0:
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        except RuntimeError as e:
            if 'out of memory' in str(e).lower():
                print(f"OOM error at batch {batch_idx}. Skipping batch...")
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                # Continue to the next batch
                continue
            else:
                # Re-raise other RuntimeErrors
                raise e
        except Exception as e:
             # Catch any other unexpected errors during batch processing
             print(f"Error processing batch {batch_idx}: {e}. Skipping batch...")
             if torch.cuda.is_available():
                torch.cuda.empty_cache()
             continue


    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    avg_batch_time = np.mean(batch_times) if batch_times else 0.0
    final_metrics = metrics.get_metrics()

    return avg_loss, final_metrics, avg_batch_time

In [None]:
def validate_epoch(model, dataloader, criterion, device, text_processor, epoch, total_epochs):
    """Complete validation epoch with comprehensive metrics"""
    model.eval()
    total_loss = 0
    num_batches = len(dataloader)

    metrics = OCRMetrics(text_processor)
    batch_times = []

    # Progress bar for validation
    progress_bar = tqdm(dataloader, desc=f"Validation Epoch {epoch}/{total_epochs}",
                       leave=True, dynamic_ncols=True)

    with torch.no_grad():
        for batch_idx, batch in enumerate(progress_bar):
            start_time = time.time()

            # Get batch data
            images = batch['images'].to(device)
            encoded_texts = batch['encoded_texts'].to(device)
            text_lengths = batch['text_lengths'].to(device)
            original_texts = batch['texts']

            # Forward pass
            logits = model(images) # Shape: (B, W, vocab_size)


            # Calculate loss
            batch_size = images.size(0)
            # seq_length = logits.size(0) # Incorrect for CTC loss
            seq_length = logits.size(1) # Correct sequence length (width W)

            max_target_length = encoded_texts.size(1)
            mask = torch.arange(max_target_length).expand(batch_size, max_target_length).to(device) < text_lengths.unsqueeze(1)
            targets = encoded_texts[mask]
            target_lengths = text_lengths
            input_lengths = torch.full(size=(batch_size,), fill_value=seq_length, dtype=torch.long).to(device)

            # log_probs = F.log_softmax(logits, dim=2) # Shape: (B, W, vocab_size)
            # Transpose log_probs to (W, B, vocab_size) for CTCLoss
            log_probs = F.log_softmax(logits, dim=2).transpose(0, 1)

            loss = criterion(log_probs, targets.long(), input_lengths, target_lengths.long())

            total_loss += loss.item()
            batch_time = time.time() - start_time
            batch_times.append(batch_time)

            # Decode predictions - logits shape expected by ctc_decode is (T, B, C)
            # Need to transpose logits back to (W, B, vocab_size) before decoding
            decoded_texts = ctc_decode(logits.transpose(0, 1), text_processor)
            metrics.update(decoded_texts, original_texts)

            # Update progress bar
            current_metrics = metrics.get_metrics()
            progress_bar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Avg Loss': f'{total_loss/(batch_idx+1):.4f}',
                'CER': f'{current_metrics.get("avg_cer", 0):.4f}',
                'Acc': f'{current_metrics.get("sequence_accuracy", 0):.4f}',
                'Time': f'{batch_time:.3f}s'
            })

    avg_loss = total_loss / num_batches
    avg_batch_time = np.mean(batch_times)
    final_metrics = metrics.get_metrics()

    return avg_loss, final_metrics, avg_batch_time

In [None]:
def test_model(model, test_loader, criterion, device, text_processor):
    """Comprehensive testing with detailed metrics"""
    print("Starting comprehensive testing...")
    model.eval()

    metrics = OCRMetrics(text_processor)
    total_loss = 0
    num_batches = len(test_loader)

    # Detailed analysis storage
    error_analysis = {
        'correct_samples': [],
        'error_samples': [],
        'cer_distribution': [],
        'wer_distribution': []
    }

    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc="Testing", leave=True, dynamic_ncols=True)

        for batch_idx, batch in enumerate(progress_bar):
            images = batch['images'].to(device)
            encoded_texts = batch['encoded_texts'].to(device)
            text_lengths = batch['text_lengths'].to(device)
            original_texts = batch['texts']

            # Forward pass
            logits = model(images)

            # Calculate loss
            batch_size = images.size(0)
            # seq_length = logits.size(0) # Incorrect for CTC loss
            seq_length = logits.size(1) # Correct sequence length (width W)

            max_target_length = encoded_texts.size(1)
            mask = torch.arange(max_target_length).expand(batch_size, max_target_length).to(device) < text_lengths.unsqueeze(1)
            targets = encoded_texts[mask]
            target_lengths = text_lengths
            input_lengths = torch.full(size=(batch_size,), fill_value=seq_length, dtype=torch.long).to(device)

            # log_probs = F.log_softmax(logits, dim=2) # Shape: (B, W, vocab_size)
            # Transpose log_probs to (W, B, vocab_size) for CTCLoss
            log_probs = F.log_softmax(logits, dim=2).transpose(0, 1)

            loss = criterion(log_probs, targets.long(), input_lengths, target_lengths.long())

            total_loss += loss.item()

            # Decode predictions - logits shape expected by ctc_decode is (T, B, C)
            # Need to transpose logits back to (W, B, vocab_size) before decoding
            decoded_texts = ctc_decode(logits.transpose(0, 1), text_processor)
            metrics.update(decoded_texts, original_texts)

            # Error analysis
            for pred, target in zip(decoded_texts, original_texts):
                cer = editdistance.eval(pred, target) / max(len(target), 1)
                wer = editdistance.eval(pred.split(), target.split()) / max(len(target.split()), 1)

                error_analysis['cer_distribution'].append(cer)
                error_analysis['wer_distribution'].append(wer)

                if pred == target:
                    error_analysis['correct_samples'].append((pred, target))
                else:
                    error_analysis['error_samples'].append((pred, target))

    avg_loss = total_loss / num_batches
    final_metrics = metrics.get_metrics()

    return avg_loss, final_metrics, error_analysis

In [None]:
def train_complete_pipeline(model, train_loader, val_loader, test_loader, text_processor,
                          config, device, save_dir='./checkpoints'):
    """Complete training pipeline with all features and visualizations"""

    # Create save directory
    os.makedirs(save_dir, exist_ok=True)

    # Initialize components
    criterion = CTCLoss(blank=text_processor.char2idx[text_processor.BLANK_TOKEN],
                       reduction='mean', zero_infinity=True)

    optimizer = AdamW(model.parameters(), lr=config['learning_rate'],
                     weight_decay=config['weight_decay'])

    scheduler = CosineAnnealingLR(optimizer, T_max=config['num_epochs'],
                                 eta_min=config['learning_rate'] * 0.01)

    # Training history
    history = {
        'train_loss': [], 'train_metrics': [],
        'val_loss': [], 'val_metrics': [],
        'best_val_cer': float('inf'), 'best_epoch': 0,
        'epochs_no_improve': 0 # Counter for early stopping
    }

    print(f"Starting training for {config['num_epochs']} epochs...")
    print(f"Using device: {device}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Get model saving parameters from config
    best_model_filename = config.get('best_model_filename', 'best_model.pth')
    regular_checkpoint_frequency = config.get('regular_checkpoint_frequency', 5)


    # Training loop
    for epoch in range(1, config['num_epochs'] + 1):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch}/{config['num_epochs']}")
        print(f"{'='*60}")

        # Training
        train_loss, train_metrics, train_time = train_epoch(
            model, train_loader, criterion, optimizer, scheduler,
            device, text_processor, epoch, config['num_epochs'],
            use_amp=True, grad_clip=config['gradient_clip']
        )

        # Validation
        val_loss, val_metrics, val_time = validate_epoch(
            model, val_loader, criterion, device, text_processor,
            epoch, config['num_epochs']
        )

        # Update history
        history['train_loss'].append(train_loss)
        history['train_metrics'].append(train_metrics)
        history['val_loss'].append(val_loss)
        history['val_metrics'].append(val_metrics)

        # Print epoch summary
        print(f"\nEpoch {epoch} Summary:")
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(f"Train CER: {train_metrics['avg_cer']:.4f} | Val CER: {val_metrics['avg_cer']:.4f}")
        print(f"Train Acc: {train_metrics['sequence_accuracy']:.4f} | Val Acc: {val_metrics['sequence_accuracy']:.4f}")
        print(f"Train Time: {train_time:.2f}s | Val Time: {val_time:.2f}s")


        # Save best model and check for early stopping
        current_val_cer = val_metrics['avg_cer']
        if current_val_cer < history['best_val_cer']:
            history['best_val_cer'] = current_val_cer
            history['best_epoch'] = epoch
            history['epochs_no_improve'] = 0 # Reset counter

            # Save model
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'config': config,
                'val_cer': current_val_cer,
                'text_processor': text_processor
            }

            torch.save(checkpoint, os.path.join(save_dir, best_model_filename))
            print(f"✓ New best model saved with CER: {current_val_cer:.4f}")
        else:
            history['epochs_no_improve'] += 1
            print(f"Validation CER did not improve. Epochs with no improvement: {history['epochs_no_improve']}/{config['early_stopping_patience']}")


        # Save regular checkpoint based on configurable frequency
        if epoch % regular_checkpoint_frequency == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'config': config,
            }, os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pth'))

        # Check for early stopping
        if history['epochs_no_improve'] >= config['early_stopping_patience']:
            print(f"\nEarly stopping triggered after {config['early_stopping_patience']} epochs with no improvement on validation CER.")
            break # Exit the training loop

    # --- Post-training Analysis and Visualization ---
    print(f"\n{'='*60}")
    print("Training completed!")
    print(f"{'='*60}")

    # Load best model for final evaluation and inference
    best_checkpoint_path = os.path.join(save_dir, best_model_filename)
    if os.path.exists(best_checkpoint_path):
        try:
            # Load the checkpoint with weights_only=False to access other saved data
            best_checkpoint = torch.load(best_checkpoint_path, map_location=device, weights_only=False)
            model.load_state_dict(best_checkpoint['model_state_dict'])
            print(f"Loaded best model from epoch {best_checkpoint['epoch']} for final evaluation.")
        except Exception as e:
            print(f"Error loading best model checkpoint from {best_checkpoint_path}: {e}")
            print("Proceeding with the model from the last training epoch.")
    else:
        print(f"Best model checkpoint not found at {best_checkpoint_path}.")
        print("Proceeding with the model from the last training epoch.")


    # Evaluate on test set
    test_loss, test_metrics, error_analysis_test = test_model(
        model, test_loader, criterion, device, text_processor
    )

    # Evaluate on validation set for detailed analysis
    val_loss_detailed, val_metrics_detailed, error_analysis_val = test_model(
         model, val_loader, criterion, device, text_processor
    )


    # Create inference pipeline with the potentially best loaded model
    inference_pipeline = ArabicOCRInference(model, text_processor, device)


    return history, test_metrics, error_analysis_test, val_metrics_detailed, error_analysis_val, inference_pipeline

## Execution

#### RUN

In [None]:
# Initialize model
model = SVTR(
    img_size=(config['img_height'], config['img_width']),
    in_chans=config['channels'],
    vocab_size=config['vocab_size'],
    local_type=config['local_type'],  # Use the list from config
    embed_dims=config['embed_dims'],  # Use embed_dims from config
    heads=config['heads'],  # Use heads from config
    mlp_ratio=config['mlp_ratio'],  # Use mlp_ratio from config
    window_sizes=config['window_sizes'],  # Use window_sizes from config
    num_blocks=config['num_blocks'],  # Use num_blocks from config
    pattern=config['pattern'],  # Use pattern from config
    drop=config['dropout_rate'],  # Use dropout_rate from config
    n_points=config['n_points'],  # Use n_points from config
    offset_scale=config['offset_scale'],  # Use offset_scale from config
).to(config['device'])


# Add option to load a previously trained model
if 'load_model_path' in config and config['load_model_path']:
    model_path = config['load_model_path']
    exclude_head = config.get('exclude_head_on_load', True)  # default True for safety
    print(f"Loading model from: {model_path} (exclude_head={exclude_head})")
    try:
        # Check if the path is a directory and look for a model file
        if os.path.isdir(model_path):
            print(f"'{model_path}' is a directory. Searching for a model file inside...")
            # Assuming the model file is a .pth file within the directory
            model_file = None
            for f in os.listdir(model_path):
                if f.endswith('.pth') or f.endswith('.pt'):
                    model_file = os.path.join(model_path, f)
                    break
            if model_file:
                print(f"Found model file: {model_file}")
                # Load the checkpoint
                checkpoint = torch.load(model_file, map_location=config['device'], weights_only=False)

                # Get the state dict from the checkpoint
                pretrained_state_dict = checkpoint['model_state_dict']
                model_state_dict = model.state_dict()

                # Filter out unwanted keys if exclude_head=True
                if exclude_head:
                    new_state_dict = {k: v for k, v in pretrained_state_dict.items()
                                      if k in model_state_dict and not k.startswith('head.')}
                else:
                    new_state_dict = {k: v for k, v in pretrained_state_dict.items()
                                      if k in model_state_dict}

                # Update and load
                model_state_dict.update(new_state_dict)
                model.load_state_dict(model_state_dict)

                print("Model loaded successfully!")
            else:
                print(f"Error: No .pth or .pt model file found in directory {model_path}")
        else:
            # If it's not a directory, try loading it directly
            checkpoint = torch.load(model_path, map_location=config['device'], weights_only=False)
            pretrained_state_dict = checkpoint['model_state_dict']
            model_state_dict = model.state_dict()

            if exclude_head:
                new_state_dict = {k: v for k, v in pretrained_state_dict.items()
                                  if k in model_state_dict and not k.startswith('head.')}
            else:
                new_state_dict = {k: v for k, v in pretrained_state_dict.items()
                                  if k in model_state_dict}

            model_state_dict.update(new_state_dict)
            model.load_state_dict(model_state_dict)
            print("Model loaded successfully!")

    except FileNotFoundError:
        print(f"Error: Model file or directory not found at {model_path}")
    except Exception as e:
        print(f"Error loading model: {e}")


# Train the complete pipeline
history, test_metrics, error_analysis_test, val_metrics_detailed, error_analysis_val, inference_pipeline = train_complete_pipeline(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    text_processor=text_processor,
    config=config,
    device=config['device'],
    save_dir=config['save_path_directory']
)

# --- Utilization of pipeline outputs ---

# Visualize training history (loss and metrics)
plot_training_history(history)

# Display final test metrics
print("\nFinal Test Metrics:")
for metric_name, value in test_metrics.items():
    print(f"{metric_name}: {value:.4f}")

# Display detailed error analysis for validation set
display_error_analysis(error_analysis_val, dataset_name="Validation Set")
plot_error_distributions(error_analysis_val, dataset_name="Validation Set")

# Display detailed error analysis for test set
display_error_analysis(error_analysis_test, dataset_name="Test Set")
plot_error_distributions(error_analysis_test, dataset_name="Test Set")

# Visualize samples with predictions from the test loader using the inference pipeline
visualize_predictions(test_loader, text_processor, inference_pipeline, num_samples=12,
                      dataset_mean=config['dataset_mean'], dataset_std=config['dataset_std'])


# Example inference usage with the returned inference_pipeline
print("\nExample Inference using the returned pipeline:")
all_inference_labels = val_dataset.labels + test_dataset.labels
num_inference_samples = 32  # Adjust this number as needed
random.shuffle(all_inference_labels)
inference_samples = all_inference_labels[:min(num_inference_samples, len(all_inference_labels))]

if inference_samples:
    print(f"Performing inference on {len(inference_samples)} random samples from validation and test sets:")
    inference_images_paths = [img_path for img_path, _ in inference_samples]
    actual_texts = [text for _, text in inference_samples]

    predicted_texts = inference_pipeline.predict_batch(inference_images_paths)

    print("\nInference Results:")
    for i in range(len(inference_samples)):
        actual_text = actual_texts[i]
        predicted_text = predicted_texts[i]
        display_actual_text = actual_text
        display_predicted_text = predicted_text
        print(f"Sample {i+1}:")
        print(f"  Actual:    '{display_actual_text}'")
        print(f"  Predicted: '{display_predicted_text}'")
        print(f"  Correct:   {actual_text == predicted_text}")
        print("-" * 20)
else:
    print("No samples available in validation or test dataset for inference example.")
