# 1. Setup Environment

This section clones the Arabic-Lip-Reading repository, installs required dependencies, and configures the model path for subsequent steps.


## 1.1 Clone Repository

In this step, we clone the Arabic-Lip-Reading repository from GitHub to access the dataset and model code required for training.


In [1]:
import subprocess
import os
import sys

# Clone the repository using subprocess
repo_url = "https://github.com/Essa-Ramzy/Arabic-Lip-Reading"
repo_name = "Arabic-Lip-Reading"

# Check if the repository already exists
if not os.path.exists(repo_name):
    try:
        print(f"🔄 Executing Command: git clone --progress {repo_url}")
        print("-" * 92)
        
        # Use subprocess.call so Git writes inline progress and we capture exit code
        return_code = subprocess.call(
            ["git", "clone", "--progress", repo_url],
            stderr=subprocess.STDOUT
        )
        
        print("-" * 92)
        if return_code == 0:
            print(f"✅ Successfully cloned {repo_url}")
            print(f"📁 Repository saved to: {os.path.abspath(repo_name)}")
        else:
            print(f"❌ Git clone failed with return code: {return_code}")
            
    except FileNotFoundError:
        print("❌ Error: Git is not installed or not in PATH")
        print("💡 Please install Git first: https://git-scm.com/downloads")
    except Exception as e:
        print(f"❌ Unexpected error: {e}")
        import traceback
        traceback.print_exc()
else:
    print(f"ℹ️  Repository {repo_name} already exists - skipping clone")
    print(f"📁 Location: {os.path.abspath(repo_name)}")

# Add the model directory to the system path
model_path = os.path.join(repo_name, 'model')
if model_path not in sys.path:
    sys.path.append(model_path)
    print(f"📁 Added model path to system path: {model_path}")

🔄 Executing Command: git clone --progress https://github.com/Essa-Ramzy/Arabic-Lip-Reading
--------------------------------------------------------------------------------------------
Cloning into 'Arabic-Lip-Reading'...
remote: Enumerating objects: 119232, done.        
remote: Counting objects: 100% (98/98), done.        
remote: Compressing objects: 100% (51/51), done.        
remote: Total 119232 (delta 69), reused 65 (delta 47), pack-reused 119134 (from 5)        
Receiving objects: 100% (119232/119232), 2.05 GiB | 37.33 MiB/s, done.
Resolving deltas: 100% (6678/6678), done.
Updating files: 100% (61602/61602), done.
--------------------------------------------------------------------------------------------
✅ Successfully cloned https://github.com/Essa-Ramzy/Arabic-Lip-Reading
📁 Repository saved to: /kaggle/working/Arabic-Lip-Reading
📁 Added model path to system path: Arabic-Lip-Reading/model


## 1.2 Install Dependencies

In this step, we install the required Python libraries (e.g., kornia, editdistance) needed for data preprocessing, model training, and evaluation using pip.


In [None]:
import subprocess

# Install Python dependencies
try:
    print(f"🔄 Executing Command: pip install kornia editdistance")
    print("-" * 90)
    return_code = subprocess.call(
        ["pip", "install", "kornia", "editdistance"],
        stderr=subprocess.STDOUT
    )
    print("-" * 90)
    if return_code == 0:
        print("✅ Successfully installed kornia and editdistance")
    else:
        print(f"❌ pip install failed with return code: {return_code}")
except Exception as e:
    print(f"❌ Unexpected error during pip install: {e}")
    import traceback
    traceback.print_exc()

🔄 Executing Command: pip install kornia editdistance
------------------------------------------------------------------------------------------
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.9.1->kornia)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.9.1->kornia)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.9.1->kornia)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=1.9.1->kornia)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=1.9.1->kornia)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.17

# 2. Data Preparation

This section outlines steps to download and extract the dataset, extract tokens and labels, and prepare PyTorch DataLoaders to serve batches during training and validation.


## 2.1 Configure Environment and Logging

In this subsection, we set dataset normalization parameters, configure logging to record training progress and debug information, and initialize random seeds for reproducibility.


In [3]:
import os
import gc
import shutil
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler
import numpy as np
from utils import *
import logging
from datetime import datetime
import traceback
from e2e_vsr import E2EVSR
import random
from math import ceil
from torch.amp import GradScaler, autocast
import subprocess
import zipfile

# Dataset configuration
DATASET_ROOT = 'Arabic-Lip-Reading/dataset'
DATASET_NAME = 'LRC-AR'
WITH_SPACES = True
WITH_DIARITICS = True
MANUAL_ONLY = False

set_normalization_params(mean=0.40589704064965376 if MANUAL_ONLY else 0.40947135433671134,
                         std=0.14899824732759526 if MANUAL_ONLY else 0.15003469454968454)

os.makedirs('Logs', exist_ok=True)
log_filename = f'Logs/training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'

for h in logging.root.handlers[:]:
    logging.root.removeHandler(h)

logging.basicConfig(
    filename=log_filename,
    level=logging.INFO,
    format='%(message)s',
    encoding='utf-8',
    force=True 
)

# Helper to print and log in one call
def log_print(msg):
    print(msg)
    logging.info(msg)

# Setting the seed for reproducibility
seed = 0
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
generator = torch.Generator()
generator.manual_seed(seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set the start method to 'spawn' for CUDA safety.
torch.multiprocessing.set_start_method('spawn', force=True)

Updated normalization parameters: MEAN=0.40947135433671134, STD=0.15003469454968454


## 2.2 Download and Extract Dataset

In this subsection, we download the LRC-AR dataset from a Google Drive link if it is not already present, extract the archive into the dataset directory, and clean up temporary files to prepare for token extraction and loader setup.


In [None]:
# Ensure dataset is available
dataset_dir = os.path.join(DATASET_ROOT, DATASET_NAME)
dataset_zip = os.path.join(DATASET_ROOT, f'{DATASET_NAME}.zip')
os.makedirs(DATASET_ROOT, exist_ok=True)
if not os.path.exists(dataset_dir):
    try:
        print(f"🔄 Executing Command: gdown https://drive.google.com/uc?id=1tX5YYTPbpWnOmj5zYEt8vG76iiq0kO8e")
        print("-" * 93)
        return_code = subprocess.call(
            ["gdown", "https://drive.google.com/uc?id=1tX5YYTPbpWnOmj5zYEt8vG76iiq0kO8e", "-O", dataset_zip],
            stderr=subprocess.STDOUT
        )
        print("-" * 93)
        if return_code == 0:
            print("✅ Downloaded dataset")
        else:
            print(f"❌ Download failed with return code: {return_code}")
        with zipfile.ZipFile(dataset_zip, 'r') as zip_ref:
            zip_ref.extractall(DATASET_ROOT)
        os.remove(dataset_zip)
    except Exception as e:
        print(f"❌ Unexpected error during dataset download: {e}")
        import traceback
        traceback.print_exc()

tokens = set()
# Collect tokens from all CSVs in new dataset structure
for split in ['Train/Manually_Verified', 'Train/Gemini_Transcribed', 'Val/Manually_Verified']:
    csv_dir = os.path.join(DATASET_ROOT, DATASET_NAME, split, 'Csv')
    for fname in os.listdir(csv_dir):
        path = os.path.join(csv_dir, fname)
        label = extract_label(path, with_spaces=WITH_SPACES, with_diaritics=WITH_DIARITICS)
        tokens.update(label)

mapped_tokens = {}
for i, c in enumerate(sorted(tokens, reverse=True), 1):
    mapped_tokens[c] = i

log_print(mapped_tokens)

🔄 Executing Command: gdown https://drive.google.com/uc?id=1tX5YYTPbpWnOmj5zYEt8vG76iiq0kO8e
---------------------------------------------------------------------------------------------
Downloading...
From (original): https://drive.google.com/uc?id=1tX5YYTPbpWnOmj5zYEt8vG76iiq0kO8e
From (redirected): https://drive.google.com/uc?id=1tX5YYTPbpWnOmj5zYEt8vG76iiq0kO8e&confirm=t&uuid=e931e166-7e70-40c5-9242-ad22a7c02168
To: /kaggle/working/Arabic-Lip-Reading/dataset/LRC-AR.zip
100%|██████████| 512M/512M [00:03<00:00, 147MB/s]  
---------------------------------------------------------------------------------------------
✅ Downloaded dataset
{'ٱ': 1, 'يْ': 2, 'يّْ': 3, 'يِّ': 4, 'يُّ': 5, 'يَّ': 6, 'يٌّ': 7, 'يِ': 8, 'يُ': 9, 'يَ': 10, 'يٌ': 11, 'ي': 12, 'ى': 13, 'وْ': 14, 'وِّ': 15, 'وُّ': 16, 'وَّ': 17, 'وِ': 18, 'وُ': 19, 'وَ': 20, 'وً': 21, 'و': 22, 'هْ': 23, 'هُّ': 24, 'هِ': 25, 'هُ': 26, 'هَ': 27, 'نۢ': 28, 'نْ': 29, 'نِّ': 30, 'نُّ': 31, 'نَّ': 32, 'نِ': 33, 'نُ': 34, 'نَ': 35, 'مْ': 

## 2.3 Prepare Data Splits and Loaders

This subsection handles the creation of training, validation, and test splits, including manual and auto-labeled data integration, and sets up PyTorch DataLoaders for efficient batch loading.


In [27]:
def load_dataloaders(manual_only=MANUAL_ONLY):
    # Helper to gather video and csv file paths from directories
    def load_paths_from_dir(video_dir, csv_dir):
        videos, labels = [], []
        for fname in sorted(os.listdir(video_dir)):
            if not fname.endswith('.mp4'):
                continue
            base = os.path.splitext(fname)[0]
            videos.append(os.path.join(video_dir, fname))
            labels.append(os.path.join(csv_dir, base + ".csv"))
        return videos, labels
    
    # --- Step 1: Prepare paths using existing dataset directory ---
    X_train, y_train, is_manual = [], [], []
    # Select subsets based on manual_only flag
    subsets = [('Manually_Verified', True)]
    if not manual_only:
        subsets.append(('Gemini_Transcribed', False))
    for subset, manual_flag in subsets:
        vid_dir = os.path.join(DATASET_ROOT, DATASET_NAME, 'Train', subset, 'Video')
        csv_dir = os.path.join(DATASET_ROOT, DATASET_NAME, 'Train', subset, 'Csv')
        v_paths, l_paths = load_paths_from_dir(vid_dir, csv_dir)
        X_train.extend(v_paths)
        y_train.extend(l_paths)
        is_manual.extend([manual_flag] * len(v_paths))
    # Validation from Val/Manually_Verified
    val_vid_dir = os.path.join(DATASET_ROOT, DATASET_NAME, 'Val', 'Manually_Verified', 'Video')
    val_csv_dir = os.path.join(DATASET_ROOT, DATASET_NAME, 'Val', 'Manually_Verified', 'Csv')
    X_val, y_val = load_paths_from_dir(val_vid_dir, val_csv_dir)

    # --- Step 2: Loader Setup ---
    train_transform = VideoAugmentation(is_train=True)
    train_dataset = VideoDataset(X_train, y_train, mapped_tokens, with_spaces=WITH_SPACES, with_diaritics=WITH_DIARITICS, transform=train_transform)
    val_transform = VideoAugmentation(is_train=False)
    val_dataset = VideoDataset(X_val, y_val, mapped_tokens, with_spaces=WITH_SPACES, with_diaritics=WITH_DIARITICS, transform=val_transform)

    if manual_only:
        train_loader = DataLoader(
            train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True,
            collate_fn=pad_packed_collate, generator=generator, worker_init_fn=seed_worker
        )
    else:
        weights = [5.0 if manual else 1.0 for manual in is_manual]
        sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True, generator=generator)
        train_loader = DataLoader(
            train_dataset, batch_size=64, sampler=sampler, num_workers=4, pin_memory=True,
            collate_fn=pad_packed_collate, generator=generator, worker_init_fn=seed_worker
        )

    val_loader = DataLoader(
        val_dataset, batch_size=64, shuffle=False, num_workers=4,
        pin_memory=True, collate_fn=pad_packed_collate,
        worker_init_fn=seed_worker
    )

    return train_loader, val_loader

# 3. Model Configuration

This section defines the vocabulary indices, sets up different temporal encoder architectures (DenseTCN, MSTCN, Conformer), and initializes the end-to-end E2EVSR model.


## 3.1 Build Token Mappings

In this subsection, we construct the reverse index-to-character mapping (including blank, SOS, and EOS tokens) which is essential for decoding the model's output sequences back into readable text.


In [28]:
# Build reverse mapping for decoding
idx2char = [""]  # Blank token for CTC
idx2char.extend(mapped_tokens.keys())
sos_token_idx = len(idx2char)
idx2char.append("<sos>")  # SOS token
eos_token_idx = sos_token_idx + 1
idx2char.append("<eos>")  # EOS token
full_vocab_size = eos_token_idx + 1
accumulation_steps = 4
log_print(f"Total vocabulary size: {full_vocab_size}")
log_print(f"SOS token index: {sos_token_idx}")
log_print(f"EOS token index: {eos_token_idx}")

Total vocabulary size: 230
SOS token index: 228
EOS token index: 229


## 3.2 Temporal Encoder Options

In this subsection, we define and configure the temporal encoder modules—including DenseTCN, MSTCN, and Conformer—by specifying layer blocks, growth rates, kernel sizes, and other hyperparameters.


In [29]:
# DenseTCN configuration
densetcn_options = {
    'block_config': [4, 4, 4, 4],
    'growth_rate_set': [512, 512, 512, 512],
    'reduced_size': 768,
    'kernel_size_set': [3, 5, 7, 9],
    'dilation_size_set': [1, 2, 4, 8],
    'squeeze_excitation': True,
    'dropout': 0.1,
    'hidden_dim': 768,
}

# MSTCN configuration
mstcn_options = {
    'tcn_type': 'multiscale',
    'hidden_dim': 768,
    'num_channels': [512, 512, 512, 512],
    'kernel_size': [3, 5, 7, 9],                   
    'dropout': 0.1,
    'stride': 1,
    'width_mult': 1.0,
}

# Conformer configuration (our default backbone)
conformer_options = {
    'attention_dim': 768,
    'attention_heads': 12,
    'linear_units': 3072,
    'num_blocks': 12,
    'dropout_rate': 0.1,
    'positional_dropout_rate': 0.1,
    'attention_dropout_rate': 0.1,
    'cnn_module_kernel': 31
}


# Choose temporal encoder type: 'densetcn', 'mstcn', or 'conformer'
TEMPORAL_ENCODER = 'conformer'

# Determine hidden_dim for E2EVSR based on the chosen temporal encoder
if TEMPORAL_ENCODER == 'densetcn':
    e2e_hidden_dim = densetcn_options['hidden_dim']
elif TEMPORAL_ENCODER == 'mstcn':
    e2e_hidden_dim = mstcn_options['hidden_dim']
elif TEMPORAL_ENCODER == 'conformer':
    e2e_hidden_dim = conformer_options['attention_dim']
else:
    raise ValueError(f"Unknown TEMPORAL_ENCODER: {TEMPORAL_ENCODER}")

log_print(f"Selected temporal encoder: {TEMPORAL_ENCODER} with hidden dimension {e2e_hidden_dim}")

Selected temporal encoder: conformer with hidden dimension 768


## 3.3 Initialize Model

Here we initialize the E2EVSR model with chosen encoder, decoder, and set CTC weight and label smoothing. The model is then moved to the appropriate compute device.


In [30]:
# Initialize the E2EVSR end-to-end model
log_print("Initializing E2EVSR end-to-end model...")

e2e_model = E2EVSR(
    encoder_type=TEMPORAL_ENCODER,
    vocab_size=full_vocab_size,
    token_list=idx2char,
    sos=sos_token_idx,
    eos=eos_token_idx,
    pad=0,
    enc_options={
        'densetcn_options': densetcn_options,
        'mstcn_options': mstcn_options,
        'conformer_options': conformer_options,
        'hidden_dim': e2e_hidden_dim,
        'frontend3d_dropout_rate': 0.0,
        'resnet_dropout_rate': 0.0
    },
    dec_options={
        'attention_dim': e2e_hidden_dim,
        'attention_heads': 12,
        'linear_units': 3072,
        'num_blocks': 6,
        'dropout_rate': 0.1,
        'positional_dropout_rate': 0.1,
        'self_attention_dropout_rate': 0.1,
        'src_attention_dropout_rate': 0.1,
        'normalize_before': True,
    },
    ctc_weight=0.5,
    label_smoothing=0.1,
).to(device)

log_print(repr(e2e_model))

Initializing E2EVSR end-to-end model...
E2EVSR(
  (frontend): VisualFrontend(
    (frontend3D): Sequential(
      (0): Conv3d(1, 64, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False)
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
      (3): Dropout3d(p=0.0, inplace=False)
      (4): MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), dilation=1, ceil_mode=False)
    )
    (resnet_trunk): ResNet(
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): SiLU()
          (relu2): SiLU()
          (dropout): Identity()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, a

## 3.4 Load Pretrained Weights

We download and load pretrained frontend weights into the model to leverage pre-trained visual features before training the full end-to-end network.


In [None]:
# Load pretrained frontend weights
log_print("Loading pretrained frontend weights...")

pretrained_path = 'vsr_trlrs2lrs3vox2avsp_base.pth'
if not os.path.exists(pretrained_path):
    try:
        print(f"🔄 Executing Command: gdown https://drive.google.com/uc?id=1r1kx7l9sWnDOCnaFHIGvOtzuhFyFA88_")
        print("-" * 93)
        return_code = subprocess.call(
            ["gdown", "https://drive.google.com/uc?id=1r1kx7l9sWnDOCnaFHIGvOtzuhFyFA88_"],
            stderr=subprocess.STDOUT
        )
        print("-" * 93)
        if return_code == 0:
            print("✅ Successfully downloaded pretrained weights")
        else:
            print(f"❌ Download failed with return code: {return_code}")
    except Exception as e:
        print(f"❌ Unexpected error during pretrained weights download: {e}")
        import traceback
        traceback.print_exc()

pretrained_weights = torch.load(pretrained_path, map_location='cpu', weights_only=False)#['e2e_model_state_dict']
pretrained_weights = {
    k.replace('trunk', 'resnet_trunk'): v for k, v in pretrained_weights.items()
    if 'decoder.embed.0.weight' not in k and
       'decoder.output_layer.weight' not in k and
       'decoder.output_layer.bias' not in k and
       'ctc.ctc_lo.weight' not in k and
       'ctc.ctc_lo.bias' not in k
}

# Load weights into frontend
e2e_model.load_state_dict(pretrained_weights, strict=False)
log_print(f"Loaded pretrained weights from {pretrained_path}")

Loading pretrained frontend weights...
🔄 Executing Command: gdown https://drive.google.com/uc?id=1r1kx7l9sWnDOCnaFHIGvOtzuhFyFA88_
---------------------------------------------------------------------------------------------
Downloading...
From (original): https://drive.google.com/uc?id=1r1kx7l9sWnDOCnaFHIGvOtzuhFyFA88_
From (redirected): https://drive.google.com/uc?id=1r1kx7l9sWnDOCnaFHIGvOtzuhFyFA88_&confirm=t&uuid=0b19b0e5-9608-4df4-8ec5-33cb2aee9d8c
To: /kaggle/working/vsr_trlrs2lrs3vox2avsp_base.pth
100%|██████████| 1.00G/1.00G [00:09<00:00, 109MB/s] 
---------------------------------------------------------------------------------------------
✅ Successfully downloaded pretrained weights
Loaded pretrained weights from vsr_trlrs2lrs3vox2avsp_base.pth


## 3.5 Training Hyperparameters and Scheduler

Here we configure the training hyperparameters, including learning rate setup, optimizer weight decay, warmup-cosine scheduler, and gradient clipping to ensure stable and effective model convergence.


In [33]:
suggested_lr = 0.00010340380070341652
total_epochs = 60
warmup_epochs = 5
alpha = 0.0
overshoot = 1 + alpha * (warmup_epochs / total_epochs)
initial_lr = suggested_lr * overshoot

# Use a smaller weight decay, and disable decay on normalization layers
decay_params = []
no_decay_params = []
for name, param in e2e_model.named_parameters():
    if "ln" in name.lower() or "norm" in name.lower() or "bias" in name.lower():
        no_decay_params.append(param)
    else:
        decay_params.append(param)

optimizer = optim.AdamW(
    [
      {"params": decay_params,    "weight_decay": 0.02},
      {"params": no_decay_params, "weight_decay": 0.0},
    ],
    lr=initial_lr,
    betas=(0.9, 0.98),
    eps=1e-9
)

train_loader, val_loader = load_dataloaders()
steps_per_epoch = ceil(len(train_loader) / accumulation_steps)
scheduler = WarmupCosineScheduler(optimizer, warmup_epochs, total_epochs, steps_per_epoch)

# Add gradient clipping during training
grad_clip_value = 1.0

log_print(f"Initial learning rate: {initial_lr}")
log_print(f"Total epochs: {total_epochs}, Warmup epochs: {warmup_epochs}")
log_print(f"Steps per epoch: {steps_per_epoch}")

Initial learning rate: 0.00010340380070341652
Total epochs: 60, Warmup epochs: 5
Steps per epoch: 51


# 4. Training and Evaluation Helpers

This section provides utility functions for seed management, single-epoch training, model evaluation, and an orchestrated training routine with checkpointing and early stopping mechanisms.


## 4.1 RNG State Management

These functions save and restore the random number generator states across PyTorch, NumPy, Python's random module, and CUDA to ensure reproducible training and evaluation runs.


In [34]:
def get_rng_state():
    return {
        'torch': torch.get_rng_state(),
        'numpy': np.random.get_state(),
        'random': random.getstate(),
        'generator': generator.get_state(),
        'cuda': torch.cuda.get_rng_state() if torch.cuda.is_available() else None
    }

def set_rng_state(state):
    torch.set_rng_state(state['torch'].cpu())
    np.random.set_state(state['numpy'])
    random.setstate(state['random'])
    generator.set_state(state['generator'])
    if torch.cuda.is_available() and 'cuda' in state and state['cuda'] is not None:
        torch.cuda.set_rng_state(state['cuda'].cpu())

## 4.2 Single Epoch Training Loop

This function handles one full epoch of training, utilizing mixed-precision forward and backward passes, gradient accumulation, gradient clipping, and detailed logging of batch metrics and learning rate.


In [35]:
def train_one_epoch(data_loader, scaler):
    running_loss = 0.0
    equal_loss = 0.0
    e2e_model.train()
    
    # Settings for gradient accumulation
    processed_batches = 0

    for batch_idx, (inputs, input_lengths, labels_flat, label_lengths) in enumerate(data_loader):
        # Print input shape for debugging
        logging.info(f"\nBatch {batch_idx+1} - Input shape: {inputs.shape}")

        inputs = inputs.to(device)
        input_lengths = input_lengths.to(device)
        labels_flat = labels_flat.to(device)
        label_lengths = label_lengths.to(device)

        # Only zero gradients at the start of accumulation cycle
        if processed_batches % accumulation_steps == 0:
            optimizer.zero_grad(set_to_none=True)

        try:
            # Use mixed precision for forward pass
            with autocast(device.type):
                # End-to-end forward (CTC+Attention)
                out = e2e_model(inputs, input_lengths, ys=labels_flat, ys_lengths=label_lengths)
                if batch_idx % 10 == 0:
                    log_print(f"decoder_loss: {out['att_loss']}, ctc_loss: {out['ctc_loss']}")
                # Scale loss by accumulation steps
                loss = out['loss'] / accumulation_steps
            
            # Backward with scaled gradients
            scaler.scale(loss).backward()
            
            # Update weights if we've completed an accumulation cycle
            if (processed_batches + 1) % accumulation_steps == 0:
                # Apply gradient clipping before stepping
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(e2e_model.parameters(), max_norm=grad_clip_value)
                
                # Step optimizer and update scaler
                scaler.step(optimizer)
                scaler.update()

                scheduler.step()
            
            processed_batches += 1
            running_loss += loss.item() * accumulation_steps  # Re-scale loss for reporting
            equal_loss += 0.5 * out['ctc_loss'] + 0.5 * out['att_loss']

            if batch_idx % 10 == 0:
                lr = optimizer.param_groups[0]['lr']
                logging.info(f"Batch {batch_idx+1}, Loss: {loss.item()*accumulation_steps:.4f}, LR: {lr:.7f}")

            del out, loss
                
        except Exception as e:
            logging.info(f"Error in training loop for batch {batch_idx}: {str(e)}") 
            logging.info(f"Error type: {type(e).__name__}")
            import traceback
            traceback_str = traceback.format_exc()
            logging.info(traceback_str)

            log_print(f"Error in batch {batch_idx}: {str(e)}")
            raise e

        del inputs, input_lengths, labels_flat, label_lengths
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            logging.info(f"Memory cleared. Current GPU memory: {torch.cuda.memory_allocated()/1e6:.2f}MB")

    print(f"Final LR: {optimizer.param_groups[0]['lr']}")
    return running_loss / len(train_loader) if len(train_loader) > 0 else 0.0, equal_loss / len(train_loader)

## 4.3 Model Evaluation

This function evaluates the trained model using greedy and beam-search decoding on validation or test sets, computes character error rate (CER) and edit distance metrics, and logs both aggregate statistics and individual sample predictions for analysis.


In [36]:
def evaluate_model(data_loader, epoch=None, print_samples=True, mode='hybrid'):
    """
    Evaluate the model on the given data loader using greedy decoding.
    Returns the average CER.
    """
    e2e_model.eval()

    # Track statistics
    total_cer = 0
    sample_count = 0
    all_predictions = []

    # Determine if we should print samples in this epoch
    show_samples = (epoch is None or epoch == 0 or (epoch+1) % 5 == 0) and print_samples
    max_samples_to_print = 10

    # Process all batches in the test loader
    with torch.no_grad():
        for i, (inputs, input_lengths, labels_flat, label_lengths) in enumerate(data_loader):
            inputs = inputs.to(device)
            input_lengths = input_lengths.to(device)
            labels_flat = labels_flat.to(device)
            label_lengths = label_lengths.to(device)
            
            if show_samples and i == 0:
                log_print(f"\nRunning {mode} beam decoding for validation...")
            
            try:
                with autocast(device.type, enabled=False):
                    if mode == 'transformer':
                        # Pure attention decoding
                        all_results = e2e_model.beam_search(inputs, input_lengths, ctc_weight=0.0)
                    elif mode == 'ctc':
                        # Pure CTC decoding
                        all_results = e2e_model.beam_search(inputs, input_lengths, ctc_weight=1.0)
                    elif mode == 'hybrid':
                        # Hybrid decoding with a specific weight
                        all_results = e2e_model.beam_search(inputs, input_lengths, ctc_weight=0.3)
                    else:
                        raise ValueError(f"Unknown decoding mode: {mode}")
                
                logging.info(f"Beam search decoding completed for batch {i+1}")
                logging.info(f"Received {len(all_results)} result sequences using mode {mode} beam search")
                
                # Process each batch item
                for b in range(label_lengths.size(0)):
                    logging.info(f"\nProcessing batch item {b+1}/{label_lengths.size(0)}")
                    sample_count += 1
                    
                    if b < len(all_results):
                        # Get predicted token indices
                        pred_indices = all_results[b]
                    
                    if len(pred_indices) == 0:
                        log_print("WARNING: Prediction sequence is empty!")
                    
                    # Get target indices
                    start_idx = sum(label_lengths[:b].cpu().tolist()) if b > 0 else 0
                    end_idx = start_idx + label_lengths[b].item()
                    target_idx = labels_flat[start_idx:end_idx].cpu().numpy()

                    # Log debug information for reference and hypothesis tokens
                    logging.info(f"Reference tokens ({len(target_idx)} tokens): {target_idx}")
                    logging.info(f"Hypothesis tokens ({len(pred_indices)} tokens): {pred_indices}")
                    
                    # Reference sequence
                    ref_seq = target_idx.tolist()
                    # Direct greedy output without cleaning
                    cleaned_seq = list(pred_indices)
                    
                    # compute CER and edit distance on cleaned sequence
                    cer, edit_dist = compute_cer(ref_seq, cleaned_seq)
                    pred_text = indices_to_text(cleaned_seq, idx2char)
                    
                    target_text = indices_to_text(target_idx, idx2char)
                    
                    # Log using the filtered best sequence
                    # Update statistics
                    total_cer += cer
                    
                    # Store prediction details
                    all_predictions.append({
                        'sample_id': sample_count,
                        'pred_text': pred_text,
                        'target_text': target_text,
                        'cer': cer,
                        'edit_distance': edit_dist,
                    })
                    
                    # Log complete info
                    logging.info("-" * 50)
                    logging.info(f"Sample {sample_count}:")
                    try:
                        logging.info(f"Predicted text: {pred_text}")
                        logging.info(f"Target text: {target_text}")
                    except UnicodeEncodeError:
                        logging.info("Predicted text: [Contains characters that can't be displayed in console]")
                        logging.info("Target text: [Contains characters that can't be displayed in console]")
                        logging.info(f"Predicted indices: {pred_indices}")
                        logging.info(f"Target indices: {target_idx}")
                        
                    logging.info(f"Edit distance: {edit_dist}")
                    logging.info(f"CER: {cer:.4f}")
                    logging.info("-" * 50)
                    
                    # Print to console if this is a sample we should show
                    if show_samples and sample_count <= max_samples_to_print:
                        print("-" * 50)
                        print(f"Sample {sample_count}:")
                        try:
                            print(f"Predicted text: {pred_text}")
                            print(f"Target text: {target_text}")
                        except UnicodeEncodeError:
                            print("Predicted text: [Contains characters that can't be displayed in console]")
                            print("Target text: [Contains characters that can't be displayed in console]")
                            
                        print(f"Edit distance: {edit_dist}")
                        print(f"CER: {cer:.4f}")
                        print("-" * 50)

                # Clean up tensors
                del all_results
                
                # Periodically clear cache
                if i % 3 == 0:  # Every 3 batches
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                        logging.info(f"Memory cleared. Current GPU memory: {torch.cuda.memory_allocated()/1e6:.2f}MB")
            
            except Exception as e:
                log_print(f"Error during greedy decoding: {str(e)}")
                log_print(traceback.format_exc())
                raise

            del inputs, input_lengths, labels_flat, label_lengths
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                logging.info(f"Memory cleared. Current GPU memory: {torch.cuda.memory_allocated()/1e6:.2f}MB")
        
        # Calculate average CER
        n_samples = len(data_loader.dataset)
        avg_cer = total_cer / n_samples if n_samples > 0 else float('inf')
        
        # Always print summary statistics to console
        log_print("\n=== Summary Statistics ===")
        log_print(f"Total samples: {n_samples}")
        log_print(f"Average CER: {avg_cer:.4f}\n")
        
        return avg_cer


def evaluate_loss(data_loader):
    """
    Compute average CTC+Attention loss on dev set with teacher forcing.
    """
    e2e_model.eval()
    running_loss = 0.0
    # Save original weight
    original_ctc_weight = e2e_model.ctc_weight
    # Use a fixed weight for validation (e.g., 0.5)
    e2e_model.ctc_weight = 0.5
    with torch.no_grad():
        for inputs, input_lengths, labels_flat, label_lengths in data_loader:
            inputs = inputs.to(device)
            input_lengths = input_lengths.to(device)
            labels_flat = labels_flat.to(device)
            label_lengths = label_lengths.to(device)
            out = e2e_model(
                inputs, input_lengths,
                ys=labels_flat, ys_lengths=label_lengths
            )
            running_loss += out['loss'].item()
            del inputs, input_lengths, labels_flat, label_lengths, out
            gc.collect()
            torch.cuda.empty_cache()
    # Restore original weight
    e2e_model.ctc_weight = original_ctc_weight
    return running_loss / len(data_loader) if len(data_loader) > 0 else 0.0

## 4.4 Training Pipeline and Checkpointing

In this subsection, we define the complete training loop with curriculum learning stages, early stopping, model checkpoint saving, and best-model tracking to manage long-running experiments efficiently.


In [39]:
def train_model(checkpoint_path=None):
    best_val_loss = float('inf')
    best_val_cer = float('inf')
    best_epoch = -1
    start_epoch = 0
    patience_counter = 0
    max_patience = 10  # Increased early stopping patience
    
    # For mixed precision training
    scaler = GradScaler(device.type)
    
    # Curriculum learning - track stage
    curriculum_stage = 1  # Start with stage 1 (focus on CTC)
    stage_transitions = [total_epochs // 3, 2 * total_epochs // 3]  # Transition at epochs 20 and 40
    
    # Load checkpoint if provided
    if checkpoint_path and os.path.exists(checkpoint_path):
        log_print(f"Loading checkpoint from {checkpoint_path}...")
        
        try:
            checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
            
            # Load E2E model checkpoint non-strictly (ignoring mismatched keys)
            dec_res = e2e_model.load_state_dict(checkpoint['e2e_model_state_dict'], strict=False)
            log_print(f"Loaded e2e_model checkpoint (non-strict): missing {dec_res.missing_keys}, unexpected {dec_res.unexpected_keys}")
            
            # Load optimizer state
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
            
            # Update training state
            start_epoch = checkpoint['epoch'] + 1
            best_val_loss = checkpoint.get('best_val_loss', float('inf'))
            best_val_cer = checkpoint.get('best_val_cer', float('inf'))
            best_epoch = checkpoint.get('best_epoch', -1)
            
            # Determine curriculum stage based on loaded epoch
            if start_epoch >= stage_transitions[1]:
                curriculum_stage = 3
            elif start_epoch >= stage_transitions[0]:
                curriculum_stage = 2
            else:
                curriculum_stage = 1
                
            log_print(f"Resuming at curriculum stage {curriculum_stage}")
            
            # Restore RNG state if available
            if 'rng_state' in checkpoint:
                try:
                    set_rng_state(checkpoint['rng_state'])
                    # Success
                    log_print("RNG state restored successfully")
                except Exception as e:
                    log_print(f"Warning: Could not restore RNG state: {e}. Continuing with current RNG state.")
            
            log_print(f"Checkpoint loaded successfully. Resuming from epoch {start_epoch + 1}")
        
        except Exception as e:
            log_print(f"Error loading checkpoint: {str(e)}")
            log_print("Aborting training due to checkpoint loading failure.")
            raise
        
    else:
        if checkpoint_path:
            log_print(f"Checkpoint file {checkpoint_path} not found. Starting training from scratch.")
        else:
            log_print("No checkpoint specified. Starting training from scratch.")
    
    train_loader, val_loader = load_dataloaders()
    print(f"Starting training for {total_epochs} epochs")
    print(f"Logs will be saved to {log_filename}")
    print(f"Checkpoints will be saved every 5 epochs")
    print("-" * 50)
    
    for epoch in range(start_epoch, total_epochs):
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            logging.info(f"GPU memory before training: {torch.cuda.memory_allocated()/1e6:.2f}MB")
        
        # Update curriculum stage if needed
        if epoch == stage_transitions[0]:
            curriculum_stage = 2
            log_print(f"Moving to curriculum stage 2: Balanced CTC-Attention learning")
        elif epoch == stage_transitions[1]:
            curriculum_stage = 3
            log_print(f"Moving to curriculum stage 3: Focused attention learning")
        
        # Dynamically adjust CTC weight based on curriculum stage
        if curriculum_stage == 1:
            # Stage 1: CTC-focused learning (helps establish alignment)
            e2e_model.ctc_weight = 0.7
        elif curriculum_stage == 2:
            # Stage 2: Balanced CTC-attention learning
            e2e_model.ctc_weight = 0.5
        else:
            # Stage 3: Attention-focused learning
            e2e_model.ctc_weight = max(0.2, 0.4 - 0.01 * (epoch - stage_transitions[1]))
        
        print(f"Epoch {epoch + 1}/{total_epochs} - Training (Stage {curriculum_stage}, CTC weight: {e2e_model.ctc_weight:.3f})...")
        epoch_loss, equal_loss = train_one_epoch(train_loader, scaler)
    
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            logging.info(f"GPU memory after training: {torch.cuda.memory_allocated()/1e6:.2f}MB")
        
        print(f"Epoch {epoch + 1}/{total_epochs} - Evaluating...")
        # First compute validation loss under teacher forcing
        val_loss = evaluate_loss(val_loader)
        
        # Compute CER more frequently in later stages
        eval_cer = (epoch + 1) % 5 == 0 or curriculum_stage >= 2
        
        if eval_cer:
            # Then compute decoding metrics (CER) via greedy decoding
            val_cer = {
                'Hybrid': evaluate_model(val_loader, epoch=epoch, print_samples=True, mode='hybrid'),
                # 'Transformer': evaluate_model(val_loader, epoch=epoch, print_samples=True, mode='transformer'),
                # 'CTC': evaluate_model(val_loader, epoch=epoch, print_samples=True, mode='ctc'),
            }
        else:
            val_cer = None  # Skip CER evaluation this epoch
        
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            logging.info(f"GPU memory after evaluation: {torch.cuda.memory_allocated()/1e6:.2f}MB")
        
        log_print(
            f"Epoch {epoch + 1}/{total_epochs}"
            f", Train Loss: {epoch_loss:.4f}"
            f", Equal Loss: {equal_loss:.4f}"
            f", Val Loss: {val_loss:.4f}"
            + (f", Val Hybrid CER: {val_cer['Hybrid']:.4f}" if isinstance(val_cer, dict) else "")
        )#, Val Transformer CER: {val_cer['Transformer']:.4f}, Val CTC CER: {val_cer['CTC']:.4f}

        # Early stopping check
        improved = False
        
        # First priority - improve CER
        if isinstance(val_cer, dict) and val_cer['Hybrid'] < best_val_cer:
            best_val_cer = val_cer['Hybrid']
            best_epoch = epoch
            improved = True
            log_print(f"New best validation CER: {val_cer['Hybrid']:.4f}")

        # if isinstance(val_cer, dict) and val_cer['Transformer'] < best_val_cer:
        #     best_val_cer = val_cer['Transformer']
        #     if not improved:
        #         best_epoch = epoch
        #         improved = True
        #         log_print(f"New best validation CER: {val_cer['Transformer']:.4f}")

        # if isinstance(val_cer, dict) and val_cer['CTC'] < best_val_cer:
        #     best_val_cer = val_cer['CTC']
        #     if not improved:
        #         best_epoch = epoch
        #         improved = True
        #         log_print(f"New best validation CER: {val_cer['CTC']:.4f}")
        
        # Second priority - improve loss if CER wasn't better
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            if not improved:
                improved = True
                log_print(f"New best validation loss: {val_loss:.4f}")
            
        # Save checkpoint every 5 epochs or when validation improves
        if (epoch + 1) % 5 == 0 or improved:
            # Clean up old checkpoints, keeping only the last 2
            keep_last_n = 2
            checkpoints = sorted([f for f in os.listdir('.') if f.startswith('checkpoint_epoch_')])
            for old_ckpt in checkpoints[:-keep_last_n]:
                try:
                    os.remove(old_ckpt)
                    log_print(f"Removed old checkpoint: {old_ckpt}")
                except:
                    pass

            checkpoint_path = f'checkpoint_epoch_{epoch+1:02d}.pth'
            torch.save({
                'epoch': epoch,
                'e2e_model_state_dict': e2e_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
                'loss': val_loss,
                'hybrid_cer': val_cer['Hybrid'] if isinstance(val_cer, dict) else None,
                # 'transformer_cer': val_cer['Transformer'] if isinstance(val_cer, dict) else None,
                # 'ctc_cer': val_cer['CTC'] if isinstance(val_cer, dict) else None,
                'rng_state': get_rng_state(),
                'best_val_loss': best_val_loss,
                'best_val_cer': best_val_cer,
                'best_epoch': best_epoch,
                'curriculum_stage': curriculum_stage
            }, checkpoint_path)
            log_print(f"Checkpoint saved to {checkpoint_path}")
        
            # Force synchronize CUDA operations and clear memory after saving
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                torch.cuda.empty_cache()
        
        # Save best model if validation improves
        if improved:
            torch.save({
                'epoch': epoch,
                'e2e_model_state_dict': e2e_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
                'loss': val_loss,
                'hybrid_cer': val_cer['Hybrid'] if isinstance(val_cer, dict) else None,
                # 'transformer_cer': val_cer['Transformer'] if isinstance(val_cer, dict) else None,
                # 'ctc_cer': val_cer['CTC'] if isinstance(val_cer, dict) else None,
                'rng_state': get_rng_state(),
                'best_val_loss': best_val_loss,
                'best_val_cer': best_val_cer,
                'best_epoch': best_epoch,
                'curriculum_stage': curriculum_stage
            }, 'best_model.pth')
            log_print(f"New best model saved at epoch {epoch+1}")
            patience_counter = 0
        else:
            patience_counter += 1
            log_print(f"No improvement for {patience_counter} epochs (patience: {max_patience})")
            
            if patience_counter > 0 and patience_counter % 5 == 0:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = param_group['lr'] * 0.6
                scheduler.base_lrs = [base_lr * 0.6 for base_lr in scheduler.base_lrs]
                log_print(f"Reducing learning rate to {optimizer.param_groups[0]['lr']:.7f}")

            # Early stopping
            if patience_counter >= max_patience:
                log_print(f"Early stopping triggered after {patience_counter} epochs without improvement")
                break
            
        del epoch_loss, val_loss
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    log_print("\nTraining completed!")
    log_print(f"Best validation loss: {best_val_loss:.4f}")
    log_print(f"Best validation CER: {best_val_cer:.4f} (epoch {best_epoch+1})")
    log_print(f"Best model saved to: best_model.pth")

# 5. Launch Training

This final step calls the `train_model` function to start the full training pipeline using the configured environment, datasets, and model settings.


In [None]:
train_model()

No checkpoint specified. Starting training from scratch.
Starting training for 60 epochs
Logs will be saved to Logs/training_20250703_134728.log
Checkpoints will be saved every 5 epochs
--------------------------------------------------
Epoch 1/60 - Training (Stage 1, CTC weight: 0.700)...
