 # 1. Imports

In [None]:
import torch, os, cv2, gc
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import transforms
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
from PIL import Image
import editdistance
from lipreading.pretrained_frontend.encoder_models_pretrained import Lipreading
from lipreading.optim_utils import CosineScheduler
from lipreading.transformer_decoder import ArabicTransformerDecoder
from espnet.transformer.mask import subsequent_mask
from utils import *
import logging
from datetime import datetime

# Setup logging
os.makedirs('../Logs', exist_ok=True)
log_filename = f'../Logs/training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
logging.basicConfig(
    filename=log_filename,
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)


 # 2. Initialize the seed and the device

In [None]:
# Setting the seed for reproducibility
seed = 0
def reset_seed():
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

# Setting the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


 # 3. Dataset preparation

 ## 3.1. List of Classes

In [None]:
def extract_label(file):
    label = []
    diacritics = {
        '\u064B',  # Fathatan
        '\u064C',  # Dammatan
        '\u064D',  # Kasratan
        '\u064E',  # Fatha
        '\u064F',  # Damma
        '\u0650',  # Kasra
        '\u0651',  # Shadda
        '\u0652',  # Sukun
        '\u06E2',  # Small High meem
    }

    sentence = pd.read_csv(file)
    for word in sentence.word:
        for char in word:
            if char not in diacritics:
                label.append(char)
            else:
                label[-1] += char

    return label

classes = set()
for i in os.listdir('../Dataset/Csv (with Diacritics)'):
    file = '../Dataset/Csv (with Diacritics)/' + i
    label = extract_label(file)
    classes.update(label)

mapped_classes = {}
for i, c in enumerate(sorted(classes, reverse=True), 1):
    mapped_classes[c] = i

print(mapped_classes)


 ## 3.2. Video Dataset Class

In [None]:
# Defining the video dataset class
class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, video_paths, label_paths, transform=None):
        self.video_paths = video_paths
        self.label_paths = label_paths
        self.transform = transform
        
    def __len__(self):
        return len(self.video_paths)
    
    def __getitem__(self, index):
        video_path = self.video_paths[index]
        label_path = self.label_paths[index]
        frames = self.load_frames(video_path=video_path)
        label = torch.tensor(list(map(lambda x: mapped_classes[x], extract_label(label_path))))
        input_length = torch.tensor(frames.size(1), dtype=torch.long)
        label_length = torch.tensor(len(label), dtype=torch.long)
        return frames, input_length, label, label_length
    
    def load_frames(self, video_path):
        frames = []
        video = cv2.VideoCapture(video_path)
        total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        for i in range(total_frames):
            video.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret, frame = video.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                frame_pil = Image.fromarray(frame, 'L')
                frames.append(frame_pil)

        if self.transform is not None:
            frames = [self.transform(frame) for frame in frames] 
        frames = torch.stack(frames).permute(1, 0, 2, 3)
        return frames

# Defining data augmentation transforms for train, validation, and test
data_transforms = transforms.Compose([
    # transforms.CenterCrop(88),
    transforms.ToTensor(),
    transforms.Normalize(mean=0.419232189655303955078125, std=0.133925855159759521484375),
])


 ## 3.3. Load the dataset

In [None]:
videos_dir = "../Dataset/Preprocessed_Video"
labels_dir = "../Dataset/Csv (with Diacritics)"
videos, labels = [], []
file_names = [file_name[:-4] for file_name in os.listdir(videos_dir)]
for file_name in file_names:
    videos.append(os.path.join(videos_dir, file_name + ".mp4"))
    labels.append(os.path.join(labels_dir, file_name + ".csv"))

 ## 3.4. Split the dataset

In [None]:
# Split the dataset into training, validation, test sets
X_temp, X_test, y_temp, y_test = train_test_split(videos, labels, test_size=0.1000, random_state=seed)
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.1111, random_state=seed)


 ## 3.5. DataLoaders

In [None]:
# Defining the video dataloaders (train, validation, test)
train_dataset = VideoDataset(X_train, y_train, transform=data_transforms)
val_dataset = VideoDataset(X_val, y_val, transform=data_transforms)
test_dataset = VideoDataset(X_test, y_test, transform=data_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True, collate_fn=pad_packed_collate)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, pin_memory=True, collate_fn=pad_packed_collate)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, pin_memory=True, collate_fn=pad_packed_collate)


 # 4. Model Configuration

In [None]:
# Build vocabulary setup
base_vocab_size = len(mapped_classes) + 1  # +1 for blank token (0)
sos_token_idx = base_vocab_size  # This places SOS after all normal tokens
eos_token_idx = base_vocab_size + 1  # This places EOS after SOS
full_vocab_size = base_vocab_size + 2  # +2 for SOS and EOS tokens

# Build reverse mapping for decoding
idx2char = {v: k for k, v in mapped_classes.items()}
idx2char[0] = ""  # Blank token for CTC
idx2char[sos_token_idx] = "<sos>"  # SOS token
idx2char[eos_token_idx] = "<eos>"  # EOS token
print(f"Total vocabulary size: {full_vocab_size}")
print(f"SOS token index: {sos_token_idx}")
print(f"EOS token index: {eos_token_idx}")


 ## 4.1 Temporal Encoder Options

In [None]:
# DenseTCN configuration (our default backbone)
densetcn_options = {
    'block_config': [3, 3, 3, 3],               # Number of layers in each dense block
    'growth_rate_set': [384, 384, 384, 384],    # Growth rate for each block (must be divisible by len(kernel_size_set))
    'reduced_size': 512,                        # Reduced size between blocks (must be divisible by len(kernel_size_set))
    'kernel_size_set': [3, 5, 7],               # Kernel sizes for multi-scale processing
    'dilation_size_set': [1, 2, 5],             # Dilation rates for increasing receptive field
    'squeeze_excitation': True,                 # Whether to use SE blocks for channel attention
    'dropout': 0.1                              # Dropout rate
}

# MSTCN configuration
mstcn_options = {
    'tcn_type': 'multiscale',
    'hidden_dim': 512,
    'num_channels': [171, 171, 171, 171],  # 4 layers with 171 channels each (divisible by 3)
    'kernel_size': [3, 5, 7],              # 3 kernels for multi-scale processing
    'dropout': 0.1,
    'stride': 1,
    'width_mult': 1.0,
}

# Conformer configuration
conformer_options = {
    'attention_dim': 512,            # Same as hidden_dim for consistency
    'attention_heads': 8,            # Number of attention heads
    'linear_units': 2048,           # Size of position-wise feed-forward
    'num_blocks': 6,                # Number of conformer blocks
    'dropout_rate': 0.1,            # General dropout rate
    'positional_dropout_rate': 0.1,  # Dropout rate for positional encoding
    'attention_dropout_rate': 0.0,   # Dropout rate for attention
    'cnn_module_kernel': 31         # Kernel size for convolution module
}

# Choose temporal encoder type: 'densetcn', 'mstcn', or 'conformer'
TEMPORAL_ENCODER = 'conformer'


 ## 4.2 Model Initialization and Pretrained Frontend

In [None]:
# Step 1: Initialize the model first
print(f"Initializing model with {TEMPORAL_ENCODER} temporal encoder...")
logging.info(f"Initializing model with {TEMPORAL_ENCODER} temporal encoder")

if TEMPORAL_ENCODER == 'densetcn':
    model = Lipreading(
        densetcn_options=densetcn_options,
        hidden_dim=512,
        num_classes=base_vocab_size,
        relu_type='swish'
    ).to(device)
elif TEMPORAL_ENCODER == 'mstcn':
    model = Lipreading(
        tcn_options=mstcn_options,
        hidden_dim=mstcn_options['hidden_dim'],
        num_classes=base_vocab_size,
        relu_type='swish'
    ).to(device)
elif TEMPORAL_ENCODER == 'conformer':
    model = Lipreading(
        conformer_options=conformer_options,
        hidden_dim=conformer_options['attention_dim'],
        num_classes=base_vocab_size,
        relu_type='swish'
    ).to(device)
else:
    raise ValueError(f"Unknown temporal encoder type: {TEMPORAL_ENCODER}")

print("Model initialized successfully.")

# Step 2: Load pretrained frontend weights
print("\nStep 4.2: Loading pretrained frontend weights...")
logging.info("Loading pretrained frontend weights")

pretrained_path = 'lipreading/pretrained_frontend/frontend.pth'
pretrained_weights = torch.load(pretrained_path, map_location=device)
print(f"Loaded pretrained weights from {pretrained_path}")

# Load weights into frontend
model.visual_frontend.load_state_dict(pretrained_weights['state_dict'], strict=False)
print("Successfully loaded pretrained weights")

# Freeze frontend parameters
for param in model.visual_frontend.parameters():
    param.requires_grad = False

print("Frontend frozen - parameters will not be updated during training")
logging.info("Successfully loaded and froze pretrained frontend")


 ## 4.3 Decoder and Training Setup

In [None]:
# Initialize transformer decoder
print("\nStep 4.3: Initializing transformer decoder and training components...")
transformer_decoder = ArabicTransformerDecoder(
    vocab_size=full_vocab_size,  # Use full vocab size that includes SOS/EOS
    attention_dim=512,          # Matching hidden_dim from the model
    attention_heads=8,          # 8 heads for better attention to different parts of sequence
    num_blocks=6,              # 6 transformer decoder layers
    dropout_rate=0.1           
).to(device)

# Training parameters
initial_lr = 3e-4
total_epochs = 80
scheduler = CosineScheduler(initial_lr, total_epochs)

# Loss functions
ctc_loss_fn = nn.CTCLoss(blank=0, zero_infinity=True)
ce_criterion = nn.CrossEntropyLoss(ignore_index=0)  # 0 is pad token

# Optimizer with different learning rates for encoder and decoder
optimizer = optim.Adam([
    {'params': model.parameters(), 'lr': initial_lr},
    {'params': transformer_decoder.parameters(), 'lr': initial_lr * 1.5}  # Higher LR for transformer
])

print("Selected temporal encoder:", TEMPORAL_ENCODER)
print(model)
print(transformer_decoder)


 # 5. Training and Evaluation

In [None]:
def get_rng_state():
    state = {}
    try:
        state['torch'] = torch.get_rng_state()
        state['numpy'] = np.random.get_state()
        if torch.cuda.is_available():
            state['cuda'] = torch.cuda.get_rng_state()
        else:
            state['cuda'] = None
        
        # Validate RNG state types
        if not isinstance(state['torch'], torch.Tensor):
            print("Warning: torch RNG state is not a tensor, creating a valid state")
            state['torch'] = torch.random.get_rng_state()
            
    except Exception as e:
        print(f"Warning: Error capturing RNG state: {str(e)}. Using default state.")
        logging.warning(f"Error capturing RNG state: {str(e)}. Using default state.")
        # Create minimal valid state
        state = {
            'torch': torch.random.get_rng_state(),
            'numpy': np.random.get_state(),
            'cuda': torch.cuda.get_rng_state() if torch.cuda.is_available() else None
        }
    return state

def set_rng_state(state):
    try:
        if 'torch' in state and isinstance(state['torch'], torch.Tensor):
            torch.set_rng_state(state['torch'])
        if 'numpy' in state and state['numpy'] is not None:
            np.random.set_state(state['numpy'])
        if torch.cuda.is_available() and 'cuda' in state and state['cuda'] is not None:
            if isinstance(state['cuda'], torch.Tensor):
                torch.cuda.set_rng_state(state['cuda'])
    except Exception as e:
        print(f"Warning: Failed to set RNG state: {str(e)}")
        logging.warning(f"Failed to set RNG state: {str(e)}")
        print("Continuing with current RNG state")
        logging.info("Continuing with current RNG state")

def create_transformer_inputs(labels_flat, label_lengths, device):
    target_seqs = []
    start_idx = 0
    
    for b in range(label_lengths.size(0)):
        seq_len = label_lengths[b].item()
        seq = labels_flat[start_idx:start_idx + seq_len]
        target_seq = torch.cat([
            torch.tensor([sos_token_idx], device=device),
            seq,
            torch.tensor([eos_token_idx], device=device)
        ])
        target_seqs.append(target_seq)
        start_idx += seq_len
    
    # Pad sequences to same length
    max_len = max(len(seq) for seq in target_seqs)
    padded_seqs = []
    for seq in target_seqs:
        padded = torch.cat([seq, torch.zeros(max_len - len(seq), device=device, dtype=torch.long)])
        padded_seqs.append(padded)
    
    target_tensor = torch.stack(padded_seqs)
    
    # Teacher forcing with probability 0.5
    if torch.rand(1).item() < 0.5:
        # Teacher forcing: decoder input is target shifted right (remove last token)
        decoder_input = target_tensor[:, :-1]
    else:
        # No teacher forcing: decoder input is just the start token
        decoder_input = torch.full((target_tensor.size(0), 1), sos_token_idx, device=device)
    
    # Teacher forcing: decoder target is target shifted left (remove first token)
    decoder_target = target_tensor[:, 1:]
    
    # Create dynamic causal mask based on actual sequence length
    seq_len = decoder_input.size(1)
    batch_size = decoder_input.size(0)
    
    # Create causal mask that respects auto-regressive constraints
    tgt_mask = subsequent_mask(seq_len).to(device)  # Shape [seq_len, seq_len]
    
    # Ensure mask is 3D for attention modules: [batch_size, seq_len, seq_len]
    tgt_mask = tgt_mask.unsqueeze(0).expand(batch_size, -1, -1)
    
    return decoder_input, decoder_target, tgt_mask

def train_one_epoch():
    running_loss = 0.0
    model.train()
    transformer_decoder.train()
    ctc_loss_fn = nn.CTCLoss(blank=0, zero_infinity=True)
    ce_criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding index (0)
    
    for batch_idx, (inputs, input_lengths, labels_flat, label_lengths) in enumerate(train_loader):
        # Print input shape for debugging
        logging.info(f"Batch {batch_idx+1} - Input shape: {inputs.shape}")
        
        # Move data to device
        inputs = inputs.to(device)
        input_lengths = input_lengths.to(device)
        labels_flat = labels_flat.to(device)
        label_lengths = label_lengths.to(device)
        
        optimizer.zero_grad()
        
        encoder_features = model(inputs, input_lengths)
        
        # Set output_lengths to match the actual encoder output length
        output_lengths = torch.full((encoder_features.size(0),), encoder_features.size(1), dtype=torch.long, device=device)

        # Print shape to verify sequence output
        logging.info(f"Batch {batch_idx+1} - Encoder features shape: {encoder_features.shape}")
        
        # Apply log_softmax for CTC
        log_probs = F.log_softmax(encoder_features, dim=2)  # (B, T, C)
        
        # Prepare for CTC loss - requires (T, B, C) format
        outputs_for_ctc = log_probs.transpose(0, 1)  # from (B, T, C) to (T, B, C)
        
        # Compute CTC loss
        ctc_loss_val = ctc_loss_fn(outputs_for_ctc, labels_flat, output_lengths, label_lengths)
        
        # Create proper memory mask based on actual encoder output lengths
        # This mask indicates which positions in the encoder output are valid
        batch_size = inputs.size(0)
        memory_mask = torch.zeros((batch_size, encoder_features.size(1)), device=device).bool()
        for b in range(batch_size):
            memory_mask[b, :output_lengths[b]] = True
        
        # For transformer training, we use the encoder features as memory
        # and teacher-forcing with probabilistic decision
        decoder_input, decoder_target, tgt_mask = create_transformer_inputs(labels_flat, label_lengths, device)
        logging.info(f"decoder input shape: {decoder_input.shape}")
        logging.info(f"Final decoder target shape: {decoder_target.shape}")
        logging.info(f"Mask shape: {tgt_mask.shape}")
        logging.info("Applying forward pass through transformer decoder...")
        logging.info(f"Input shapes: decoder_input={decoder_input.shape}, tgt_mask={tgt_mask.shape}")
        logging.info(f"Memory shapes: encoder_features={encoder_features.shape}, memory_mask={memory_mask.shape}")
        
        try:
            # First token prediction
            decoder_output = transformer_decoder(decoder_input, tgt_mask, encoder_features, memory_mask)
            logging.info("\n=== Debug: First Token Prediction ===")
            logging.info(f"Initial decoder_output shape: {decoder_output.shape}")
            
            first_pred = decoder_output[:, -1, :]  # Extract prediction for last position [batch_size, vocab_size]
            logging.info(f"first_pred shape: {first_pred.shape}")
            
            all_predictions = []
            all_targets = []
            
            # Add first prediction
            all_predictions.append(first_pred)
            
            # Maximum prediction length (stop at EOS or this limit)
            max_pred_len = min(24, decoder_target.size(1)-1)  # -1 because we've handled the first token already
            logging.info(f"Maximum prediction length: {max_pred_len}")
            
            # Generate rest of sequence token by token (autoregressive)
            logging.info("\n=== Debug: Starting Autoregressive Generation ===")
            for t in range(max_pred_len):
                # Add current target to targets list (shifted by 1 since we want to predict next token)
                if t < decoder_target.size(1)-1:  # -1 to account for the shift
                    all_targets.append(decoder_target[:, t+1])
                
                # Teacher forcing with 50% probability
                use_teacher_forcing = torch.rand(1).item() < 0.5
                
                if use_teacher_forcing and t < decoder_target.size(1)-1:
                    logging.info(f"Using ground truth token as next input for position {t}")
                    # Use ground truth token as next input
                    next_token = decoder_target[:, t].unsqueeze(1)
                else:
                    logging.info(f"Using most recent prediction for position {t}")
                    # Use the most recent prediction
                    next_token = decoder_output[:, -1].argmax(dim=-1).unsqueeze(1)
                
                # Concatenate with previous input
                decoder_input = torch.cat([decoder_input, next_token], dim=1)
                
                # Update mask for longer sequence
                tgt_mask = subsequent_mask(decoder_input.size(1)).to(device)
                tgt_mask = tgt_mask.expand(batch_size, -1, -1)
                
                # Predict next token
                decoder_output = transformer_decoder(decoder_input, tgt_mask, encoder_features, memory_mask)
                
                # Extract prediction for last position [batch_size, vocab_size]
                if t < max_pred_len-1:  # Only add prediction if we're not at the last step
                    current_pred = decoder_output[:, -1, :]
                    all_predictions.append(current_pred)
            
            logging.info("\n=== Debug: Final Stacking and Loss Computation ===")
            # Stack all predictions and targets
            stacked_preds = torch.stack(all_predictions, dim=1)  # [batch_size, seq_len, vocab_size]
            stacked_targets = torch.stack(all_targets, dim=1)    # [batch_size, seq_len]
            
            logging.info(f"Final stacked_preds shape: {stacked_preds.shape}")
            logging.info(f"Final stacked_targets shape: {stacked_targets.shape}")
            
            # Compute loss on the entire sequence
            decoder_output_flat = stacked_preds.reshape(-1, stacked_preds.size(-1))
            target_flat = stacked_targets.reshape(-1)
            
            # Calculate cross-entropy loss
            ce_loss = ce_criterion(decoder_output_flat, target_flat)
            logging.info(f"Cross Entropy Loss: {ce_loss.item():.6f}")
            
            ctc_weight = 0.2
            combined_loss = ctc_weight * ctc_loss_val + (1 - ctc_weight) * ce_loss
            logging.info(f"Final Combined Loss: {combined_loss.item():.6f}")
            
            combined_loss.backward()
            optimizer.step()
            
            running_loss += combined_loss.item()
            
            if batch_idx % 10 == 0:
                logging.info(f"Batch {batch_idx}, Loss: {combined_loss.item():.4f}")

            # Clean up large tensors to free memory
            del decoder_output, decoder_output_flat, target_flat
            del encoder_features, log_probs, outputs_for_ctc
            del decoder_input, decoder_target, tgt_mask, memory_mask
            del stacked_preds, stacked_targets, all_predictions, all_targets
            
            if batch_idx % 3 == 0:
                gc.collect()
                torch.cuda.empty_cache()
                logging.info(f"Memory cleared. Current GPU memory: {torch.cuda.memory_allocated()/1e6:.2f}MB")
            
        except Exception as e:
            logging.error(f"Error in transformer decoder forward pass: {str(e)}")
            logging.error(f"Error type: {type(e).__name__}")
            import traceback
            traceback_str = traceback.format_exc()
            logging.error(traceback_str)
            
            # Also print to console
            print(f"Error in batch {batch_idx}: {str(e)}")
            print(traceback_str)
            
            # Check specific tensor shapes in more detail
            logging.error(f"decoder_input dtype: {decoder_input.dtype}, device: {decoder_input.device}")
            logging.error(f"tgt_mask dtype: {tgt_mask.dtype}, device: {tgt_mask.device}")
            
            # Continue (skip this batch)
            continue
    
    return running_loss / len(train_loader)

def evaluate_model(data_loader, ctc_weight=0.2, epoch=None, print_samples=True):
    """
    Evaluate the model on the given data loader.
    
    Args:
        data_loader: DataLoader for evaluation
        ctc_weight: Weight for CTC scoring (0.0 to 1.0)
        epoch: Current epoch number (optional)
        print_samples: Whether to print sample predictions to console
    
    Returns:
        Average loss across all batches
    """
    model.eval()
    transformer_decoder.eval()
    ctc_loss_fn = nn.CTCLoss(blank=0, zero_infinity=True)

    # Track statistics
    total_cer = 0
    total_edit_distance = 0
    total_loss = 0
    sample_count = 0
    all_predictions = []
    
    # Determine if we should print samples in this epoch
    show_samples = (epoch is None or epoch == 0 or (epoch+1) % 5 == 0) and print_samples
    max_samples_to_print = 20  # Limit console output to 20 samples

    # Process all batches in the test loader
    with torch.no_grad():
        for i, (inputs, input_lengths, labels_flat, label_lengths) in enumerate(data_loader):
            # Move to device
            inputs = inputs.to(device)
            input_lengths = input_lengths.to(device)
            labels_flat = labels_flat.to(device)
            label_lengths = label_lengths.to(device)
            
            # Forward pass through visual encoder
            batch_size = inputs.size(0)
            encoder_features = model(inputs, input_lengths)  # (B, T, hidden_dim)
            
            # Set output_lengths to match the actual encoder output length
            output_lengths = torch.full((encoder_features.size(0),), encoder_features.size(1), dtype=torch.long, device=device)
            
            # Calculate CTC probabilities
            log_probs = F.log_softmax(encoder_features, dim=2)  # (B, T, C)
            log_probs_ctc = log_probs.transpose(0, 1)  # (T, B, C)
            ctc_loss = ctc_loss_fn(log_probs_ctc, labels_flat, output_lengths, label_lengths)
            
            logging.info(f"\nRunning hybrid CTC/Attention decoding for batch {i+1}...")
            if show_samples and i == 0:
                print(f"\nRunning hybrid CTC/Attention decoding for validation...")
            
            try:
                logging.info(f"Encoder features shape: {encoder_features.shape}")
                
                # Create proper memory mask based on actual encoder output lengths
                memory_mask = torch.zeros((batch_size, encoder_features.size(1)), device=device).bool()
                for b in range(batch_size):
                    memory_mask[b, :output_lengths[b]] = True
                
                # Run beam search with CTC weight
                all_nbest_hyps = transformer_decoder.batch_beam_search(
                    memory=encoder_features,
                    memory_mask=memory_mask,
                    beam_size=5,
                    maxlen=24,
                    minlen=1,
                    sos=sos_token_idx,
                    eos=eos_token_idx,
                    ctc_weight=ctc_weight
                )
                
                logging.info(f"Hybrid decoding completed for batch {i+1}")
                logging.info(f"Received {len(all_nbest_hyps)} hypotheses sets")
                
                # Process each batch item
                for b in range(batch_size):
                    logging.info(f"\nProcessing batch item {b+1}/{batch_size}")
                    sample_count += 1
                    
                    if b < len(all_nbest_hyps):
                        score, pred_indices = all_nbest_hyps[b]
                        logging.info(f"Found beam hypothesis for item {b+1} with score {score:.4f}")
                        pred_indices = np.array(pred_indices)
                        
                        if len(pred_indices) == 0:
                            logging.info("WARNING: Prediction sequence is empty!")
                    else:
                        logging.info(f"No hypotheses for batch item {b+1}")
                        pred_indices = np.array([])
                    
                    # Get target indices
                    start_idx = sum(label_lengths[:b].cpu().tolist()) if b > 0 else 0
                    end_idx = start_idx + label_lengths[b].item()
                    target_idx = labels_flat[start_idx:end_idx].cpu().numpy()

                    # Log debug information for reference and hypothesis tokens
                    logging.info(f"Debug - Reference tokens ({len(target_idx)} tokens): {target_idx}")
                    logging.info(f"Debug - Hypothesis tokens ({len(pred_indices)} tokens): {pred_indices}")
                    
                    # Convert indices to text
                    pred_text = indices_to_text(pred_indices, idx2char)
                    target_text = indices_to_text(target_idx, idx2char)
                    
                    # Calculate CER
                    cer, edit_distance = compute_cer(target_idx, pred_indices)
                    
                    # Update statistics
                    total_cer += cer
                    total_edit_distance += edit_distance
                    total_loss += ctc_loss.item() / batch_size
                    
                    # Store prediction details
                    all_predictions.append({
                        'sample_id': sample_count,
                        'pred_text': pred_text,
                        'target_text': target_text,
                        'edit_distance': edit_distance,
                        'cer': cer
                    })
                    
                    # Log complete info
                    logging.info("-" * 50)
                    logging.info(f"Sample {sample_count}:")
                    try:
                        logging.info(f"Predicted text: {pred_text}")
                        logging.info(f"Target text: {target_text}")
                    except UnicodeEncodeError:
                        logging.info("Predicted text: [Contains characters that can't be displayed in console]")
                        logging.info("Target text: [Contains characters that can't be displayed in console]")
                        logging.info(f"Predicted indices: {pred_indices}")
                        logging.info(f"Target indices: {target_idx}")
                        
                    logging.info(f"Edit distance: {edit_distance}")
                    logging.info(f"CER: {cer:.4f}")
                    logging.info("-" * 50)
                    
                    # Print to console if this is a sample we should show
                    if show_samples and sample_count <= max_samples_to_print:
                        print("-" * 50)
                        print(f"Sample {sample_count}:")
                        try:
                            print(f"Predicted text: {pred_text}")
                            print(f"Target text: {target_text}")
                        except UnicodeEncodeError:
                            print("Predicted text: [Contains characters that can't be displayed in console]")
                            print("Target text: [Contains characters that can't be displayed in console]")
                            
                        print(f"Edit distance: {edit_distance}")
                        print(f"CER: {cer:.4f}")
                        print("-" * 50)

                # Clean up tensors
                del encoder_features, log_probs, log_probs_ctc, memory_mask, all_nbest_hyps
                
                # Periodically clear cache
                if i % 3 == 0:  # Every 3 batches
                    gc.collect()
                    torch.cuda.empty_cache()
                    logging.info(f"Memory cleared. Current GPU memory: {torch.cuda.memory_allocated()/1e6:.2f}MB")
            
            except Exception as e:
                logging.error(f"Error during hybrid decoding: {str(e)}")
                logging.error(traceback.format_exc())
                print(f"Error during hybrid decoding: {str(e)}")
                continue
        
        # Write summary statistics
        n_samples = len(data_loader.dataset)
        avg_cer = total_cer / n_samples
        avg_edit_distance = total_edit_distance / n_samples
        avg_loss = total_loss / len(data_loader)
        
        # Always print summary statistics to console
        print("\n=== Summary Statistics ===")
        print(f"Total samples: {n_samples}")
        print(f"Average CER: {avg_cer:.4f}")
        print(f"Average Edit Distance: {avg_edit_distance:.2f}")
        print(f"Average Loss: {avg_loss:.4f}")
        print(f"CTC Weight used: {ctc_weight}")
        
        # Log summary statistics as well
        logging.info("\n=== Summary Statistics ===")
        logging.info(f"Total samples: {n_samples}")
        logging.info(f"Average CER: {avg_cer:.4f}")
        logging.info(f"Average Edit Distance: {avg_edit_distance:.2f}")
        logging.info(f"Average Loss: {avg_loss:.4f}")
        logging.info(f"CTC Weight used: {ctc_weight}")

    return avg_loss


In [None]:
def train_model(ctc_weight=0.2, checkpoint_path=None):
    best_val_loss = float('inf')
    start_epoch = 0
    rng_state = get_rng_state()
    
    # Load checkpoint if provided
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}...")
        logging.info(f"Loading checkpoint from {checkpoint_path}")
        
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device)
            
            # Check model architecture compatibility
            model_state_dict = model.state_dict()
            checkpoint_model_state_dict = checkpoint['model_state_dict']
            if set(model_state_dict.keys()) != set(checkpoint_model_state_dict.keys()):
                missing_keys = [k for k in model_state_dict.keys() if k not in checkpoint_model_state_dict]
                unexpected_keys = [k for k in checkpoint_model_state_dict.keys() if k not in model_state_dict]
                error_msg = "Model architecture mismatch detected!\n"
                if missing_keys:
                    error_msg += f"Missing keys in checkpoint: {missing_keys}\n"
                if unexpected_keys:
                    error_msg += f"Unexpected keys in checkpoint: {unexpected_keys}\n"
                error_msg += "Cannot proceed with training due to incompatible architecture."
                print(error_msg)
                logging.error(error_msg)
                raise RuntimeError("Model architecture mismatch. Training aborted to prevent corruption.")
            
            # Load the state dict
            model.load_state_dict(checkpoint_model_state_dict)
            
            # Check transformer decoder architecture compatibility
            decoder_state_dict = transformer_decoder.state_dict()
            checkpoint_decoder_state_dict = checkpoint['transformer_decoder_state_dict']
            
            if set(decoder_state_dict.keys()) != set(checkpoint_decoder_state_dict.keys()):
                missing_keys = [k for k in decoder_state_dict.keys() if k not in checkpoint_decoder_state_dict]
                unexpected_keys = [k for k in checkpoint_decoder_state_dict.keys() if k not in decoder_state_dict]
                error_msg = "Transformer decoder architecture mismatch detected!\n"
                if missing_keys:
                    error_msg += f"Missing keys in checkpoint: {missing_keys}\n"
                if unexpected_keys:
                    error_msg += f"Unexpected keys in checkpoint: {unexpected_keys}\n"
                error_msg += "Cannot proceed with training due to incompatible architecture."
                print(error_msg)
                logging.error(error_msg)
                raise RuntimeError("Transformer decoder architecture mismatch. Training aborted to prevent corruption.")
            
            # Load the decoder state dict
            transformer_decoder.load_state_dict(checkpoint_decoder_state_dict)
            print("Successfully loaded checkpoint")
            
            # Load optimizer state
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            
            # Update training state
            start_epoch = checkpoint['epoch'] + 1
            best_val_loss = checkpoint.get('best_val_loss', float('inf'))
            
            # Safely restore RNG state if available
            if 'rng_state' in checkpoint:
                try:
                    set_rng_state(checkpoint['rng_state'])
                    print("RNG state restored successfully")
                    logging.info("RNG state restored successfully")
                except Exception as e:
                    print(f"Warning: Could not restore RNG state: {str(e)}")
                    logging.warning(f"Could not restore RNG state: {str(e)}")
                    print("Continuing with current RNG state")
                    logging.info("Continuing with current RNG state")
                
            print(f"Checkpoint loaded successfully. Resuming from epoch {start_epoch}")
            logging.info(f"Checkpoint loaded successfully. Resuming from epoch {start_epoch}")
        
        except Exception as e:
            print(f"Error loading checkpoint: {str(e)}")
            logging.error(f"Error loading checkpoint: {str(e)}")
            print("Aborting training due to checkpoint loading failure.")
            raise  # Re-raise the exception to stop execution
        
    else:
        if checkpoint_path:
            print(f"Checkpoint file {checkpoint_path} not found. Starting training from scratch.")
            logging.info(f"Checkpoint file {checkpoint_path} not found. Starting training from scratch.")
        else:
            print("No checkpoint specified. Starting training from scratch.")
            logging.info("No checkpoint specified. Starting training from scratch.")
    
    print(f"Starting training for {total_epochs} epochs")
    print(f"Logs will be saved to {log_filename}")
    print(f"Checkpoints will be saved every 10 epochs")
    print("-" * 50)
    
    for epoch in range(start_epoch, total_epochs):
        print(f"Epoch {epoch + 1}/{total_epochs} - Training...")
        epoch_loss = train_one_epoch()
        
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            logging.info(f"GPU memory after training: {torch.cuda.memory_allocated()/1e6:.2f}MB")
        
        scheduler.adjust_lr(optimizer, epoch)
        print(f"Epoch {epoch + 1}/{total_epochs} - Evaluating...")
        val_loss = evaluate_model(val_loader, ctc_weight=ctc_weight, epoch=epoch)
        
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            logging.info(f"GPU memory after evaluation: {torch.cuda.memory_allocated()/1e6:.2f}MB")
        
        # Always log to file
        logging.info(f"Epoch {epoch + 1}/{total_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        # Print summary every epoch to console
        print(f"Epoch {epoch + 1}/{total_epochs} - Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        # Print detailed prediction samples every 5 epochs
        if (epoch + 1) % 5 == 0:
            print("\n" + "="*50)
            print(f"DETAILED RESULTS AFTER EPOCH {epoch + 1}")
            print("="*50)
            
            # Run evaluation on a small subset of validation data to show predictions
            with torch.no_grad():
                # Just evaluate first batch for demonstration
                for inputs, input_lengths, labels_flat, label_lengths in val_loader:
                    inputs = inputs.to(device)
                    input_lengths = input_lengths.to(device)
                    labels_flat = labels_flat.to(device)
                    label_lengths = label_lengths.to(device)
                    
                    batch_size = inputs.size(0)
                    encoder_features = model(inputs, input_lengths)
                    
                    # Create memory mask
                    memory_mask = torch.zeros((batch_size, encoder_features.size(1)), device=device).bool()
                    for b in range(batch_size):
                        memory_mask[b, :input_lengths[b]] = True
                    
                    # Run beam search
                    all_nbest_hyps = transformer_decoder.batch_beam_search(
                        memory=encoder_features,
                        memory_mask=memory_mask,
                        beam_size=5,
                        maxlen=24,
                        minlen=1,
                        sos=sos_token_idx,
                        eos=eos_token_idx,
                        ctc_weight=ctc_weight
                    )
                    
                    # Show predictions for a few samples
                    print(f"\nShowing predictions for {min(3, batch_size)} samples:")
                    for b in range(min(3, batch_size)):
                        # Get target indices
                        start_idx = sum(label_lengths[:b].cpu().tolist()) if b > 0 else 0
                        end_idx = start_idx + label_lengths[b].item()
                        target_idx = labels_flat[start_idx:end_idx].cpu().numpy()
                        
                        # Get prediction
                        _, pred_indices = all_nbest_hyps[b]
                        
                        # Convert to text
                        pred_text = indices_to_text(pred_indices, idx2char)
                        target_text = indices_to_text(target_idx, idx2char)
                        
                        # Calculate CER
                        cer, edit_distance = compute_cer(target_idx, pred_indices)
                        
                        print(f"\nSample {b+1}:")
                        print(f"  Prediction: {pred_text}")
                        print(f"  Target: {target_text}")
                        print(f"  CER: {cer:.4f}, Edit Distance: {edit_distance}")
                    
                    break  # Just show the first batch
            
            print("="*50 + "\n")
            
            # Current learning rate
            current_lr = optimizer.param_groups[0]['lr']
            print(f"Current learning rate: {current_lr:.6f}")
        
        # Save checkpoint every 10 epochs
        # if (epoch + 1) % 10 == 0:
        if True:
            # Update the RNG state before saving
            rng_state = get_rng_state()
            
            checkpoint_path = f'checkpoint_epoch_{epoch+1}.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'transformer_decoder_state_dict': transformer_decoder.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss,
                'rng_state': rng_state,
                'best_val_loss': best_val_loss
            }, checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path}")
            logging.info(f"Saved checkpoint to {checkpoint_path}")
        
            # Force synchronize CUDA operations and clear memory after saving
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                torch.cuda.empty_cache()
        
        # Save best model if validation loss improves
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            # Save the best model with the same pattern as checkpoint
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'transformer_decoder_state_dict': transformer_decoder.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss,
                'rng_state': rng_state,
                'best_val_loss': best_val_loss
            }, 'best_model.pth')
            print(f"New best model saved with validation loss: {val_loss:.4f}")
            logging.info(f"New best model saved with validation loss: {val_loss:.4f}")
    
    print("\nTraining completed!")
    print(f"Best validation loss: {best_val_loss:.4f}")
    print(f"Final checkpoint saved to: checkpoint_epoch_{total_epochs}.pth")
    print(f"Best model saved to: best_model.pth")

    
reset_seed()
train_model(ctc_weight=0.2, checkpoint_path="") 
