  # 1. Imports

In [1]:
# %%

import torch, os, cv2
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import transforms
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
from PIL import Image
import editdistance
from lipreading.model import Lipreading
from lipreading.optim_utils import CosineScheduler
# Import the Transformer decoder instead of TCN decoder
from lipreading.transformer_decoder import ArabicTransformerDecoder
# We need the mask utility for transformer
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask



  # 2. Initialize the seed and the device

In [2]:
# %%

# Setting the seed for reproducibility
seed = 0
def reset_seed():
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

# Setting the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



  # 3. Dataset preparation

  ## 3.1. List of Classes

In [4]:
# %%
def extract_label(file):
    label = []
    diacritics = {
        '\u064B',  # Fathatan
        '\u064C',  # Dammatan
        '\u064D',  # Kasratan
        '\u064E',  # Fatha
        '\u064F',  # Damma
        '\u0650',  # Kasra
        '\u0651',  # Shadda
        '\u0652',  # Sukun
        '\u06E2',  # Small High meem
    }

    sentence = pd.read_csv(file)
    for word in sentence.word:
        for char in word:
            if char not in diacritics:
                label.append(char)
            else:
                label[-1] += char

    return label

classes = set()
for i in os.listdir('../Dataset/Csv (with Diacritics)'):
    file = '../Dataset/Csv (with Diacritics)/' + i
    label = extract_label(file)
    classes.update(label)

mapped_classes = {}
for i, c in enumerate(sorted(classes, reverse=True), 1):
    mapped_classes[c] = i

print(mapped_classes)



{'ٱ': 1, 'يْ': 2, 'يّْ': 3, 'يِّ': 4, 'يُّ': 5, 'يَّ': 6, 'يٌّ': 7, 'يِ': 8, 'يُ': 9, 'يَ': 10, 'يٌ': 11, 'ي': 12, 'ى': 13, 'وْ': 14, 'وِّ': 15, 'وُّ': 16, 'وَّ': 17, 'وِ': 18, 'وُ': 19, 'وَ': 20, 'وً': 21, 'و': 22, 'هْ': 23, 'هُّ': 24, 'هِ': 25, 'هُ': 26, 'هَ': 27, 'نۢ': 28, 'نْ': 29, 'نِّ': 30, 'نُّ': 31, 'نَّ': 32, 'نِ': 33, 'نُ': 34, 'نَ': 35, 'مْ': 36, 'مّْ': 37, 'مِّ': 38, 'مُّ': 39, 'مَّ': 40, 'مِ': 41, 'مُ': 42, 'مَ': 43, 'مٍ': 44, 'مٌ': 45, 'مً': 46, 'لْ': 47, 'لّْ': 48, 'لِّ': 49, 'لُّ': 50, 'لَّ': 51, 'لِ': 52, 'لُ': 53, 'لَ': 54, 'لٍ': 55, 'لٌ': 56, 'لً': 57, 'كْ': 58, 'كِّ': 59, 'كَّ': 60, 'كِ': 61, 'كُ': 62, 'كَ': 63, 'قْ': 64, 'قَّ': 65, 'قِ': 66, 'قُ': 67, 'قَ': 68, 'قٍ': 69, 'قً': 70, 'فْ': 71, 'فِّ': 72, 'فَّ': 73, 'فِ': 74, 'فُ': 75, 'فَ': 76, 'غْ': 77, 'غِ': 78, 'غَ': 79, 'عْ': 80, 'عَّ': 81, 'عِ': 82, 'عُ': 83, 'عَ': 84, 'عٍ': 85, 'ظْ': 86, 'ظِّ': 87, 'ظَّ': 88, 'ظِ': 89, 'ظُ': 90, 'ظَ': 91, 'طْ': 92, 'طِّ': 93, 'طَّ': 94, 'طِ': 95, 'طُ': 96, 'طَ': 97, 'ضْ': 98, 'ض

  ## 3.2. Video Dataset Class

In [None]:
# %%
# Defining the video dataset class
class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, video_paths, label_paths, transform=None):
        self.video_paths = video_paths
        self.label_paths = label_paths
        self.transform = transform
        
    def __len__(self):
        return len(self.video_paths)
    
    def __getitem__(self, index):
        video_path = self.video_paths[index]
        label_path = self.label_paths[index]
        frames = self.load_frames(video_path=video_path)
        label = torch.tensor(list(map(lambda x: mapped_classes[x], extract_label(label_path))))
        input_length = torch.tensor(len(frames), dtype=torch.long)
        label_length = torch.tensor(len(label), dtype=torch.long)
        return frames, input_length, label, label_length
    
    def load_frames(self, video_path):
        frames = []
        video = cv2.VideoCapture(video_path)
        total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        for i in range(total_frames):
            video.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret, frame = video.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                frame_pil = Image.fromarray(frame, 'L')
                frames.append(frame_pil)

        if self.transform is not None:
            frames = [self.transform(frame) for frame in frames] 
        frames = torch.stack(frames).permute(1, 0, 2, 3)
        return frames

# Defining the video transform
transforms = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
    transforms.Normalize(mean=0.421, std=0.165),
])



  ## 3.2. Load the dataset

In [None]:
# %%
videos_dir = "Dataset/Preprocessed_Video"
labels_dir = "Dataset/Csv (with Diacritics)"
videos, labels = [], []
file_names = [file_name[:-4] for file_name in os.listdir(videos_dir)]
for file_name in file_names:
    videos.append(os.path.join(videos_dir, file_name + ".mp4"))
    labels.append(os.path.join(labels_dir, file_name + ".csv"))



  ## 3.3. Split the dataset

In [None]:
# %%
# Split the dataset into training, validation, test sets
X_temp, X_test, y_temp, y_test = train_test_split(videos, labels, test_size=0.1000, random_state=seed)
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.1111, random_state=seed)



  ## 3.4. DataLoaders

In [None]:
# %%
def pad_packed_collate(batch):
    """Pads data and labels with different lengths in the same batch
    """
    data_list, input_lengths, labels_list, label_lengths = zip(*batch)
    c, max_len, h, w = max(data_list, key=lambda x: x.shape[1]).shape

    data = torch.zeros((len(data_list), c, max_len, h, w))
    
    # Only copy up to the actual sequence length
    for idx in range(len(data)):
        data[idx, :, :input_lengths[idx], :, :] = data_list[idx][:, :input_lengths[idx], :, :]
    
    # Flatten labels for CTC loss
    labels_flat = []
    for label_seq in labels_list:
        labels_flat.extend(label_seq)
    labels_flat = torch.LongTensor(labels_flat)
    
    # Convert lengths to tensor
    input_lengths = torch.LongTensor(input_lengths)
    label_lengths = torch.LongTensor(label_lengths)
    return data, input_lengths, labels_flat, label_lengths


# Defining the video dataloaders (train, validation, test)
train_dataset = VideoDataset(X_train, y_train, transform=transforms)
val_dataset = VideoDataset(X_val, y_val, transform=transforms)
test_dataset = VideoDataset(X_test, y_test, transform=transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True, collate_fn=pad_packed_collate)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, pin_memory=True, collate_fn=pad_packed_collate)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, pin_memory=True, collate_fn=pad_packed_collate)



  # 4. Model

In [None]:
# %%
def indices_to_text(indices, idx2char):
    """
    Converts a list of indices to text using the reverse vocabulary mapping.
    """
    try:
        return ''.join([idx2char.get(i, '') for i in indices])
    except UnicodeEncodeError:
        # Handle encoding issues in Windows console
        # Return a safe representation that won't cause encoding errors
        safe_text = []
        for i in indices:
            char = idx2char.get(i, '')
            try:
                # Test if character can be encoded
                char.encode('cp1252')
                safe_text.append(char)
            except UnicodeEncodeError:
                # Replace with a placeholder for characters that can't be displayed
                safe_text.append(f"[{i}]")
        return ''.join(safe_text)

def compute_cer(reference_indices, hypothesis_indices):
    """
    Computes Character Error Rate (CER) directly using token indices.
    Takes raw token indices from our vocabulary (class_mapping.txt) rather than Unicode text.
    
    Returns a tuple of (CER, edit_distance)
    """
    # Use the indices directly - each index is one token in our vocabulary
    ref_tokens = reference_indices
    hyp_tokens = hypothesis_indices
    
    try:
        print(f"Debug - Reference tokens ({len(ref_tokens)} tokens): {ref_tokens}")
        print(f"Debug - Hypothesis tokens ({len(hyp_tokens)} tokens): {hyp_tokens}")
    except UnicodeEncodeError:
        # Handle encoding issues in Windows console
        print(f"Debug - Reference tokens ({len(ref_tokens)} tokens): [Token indices omitted due to encoding issues]")
        print(f"Debug - Hypothesis tokens ({len(hyp_tokens)} tokens): [Token indices omitted due to encoding issues]")
    
    # Calculate edit distance using the editdistance library
    edit_distance = editdistance.eval(ref_tokens, hyp_tokens)
    
    # Calculate CER
    cer = edit_distance / max(len(ref_tokens), 1)  # Avoid division by zero
    
    return cer, edit_distance

# Initializing the hyper-parameters
densetcn_options = {
    'block_config': [3, 3, 3, 3],               # Number of layers in each dense block
    'growth_rate_set': [384, 384, 384, 384],    # Growth rate for each block (must be divisible by len(kernel_size_set))
    'reduced_size': 512,                        # Reduced size between blocks (must be divisible by len(kernel_size_set))
    'kernel_size_set': [3, 5, 7],               # Kernel sizes for multi-scale processing
    'dilation_size_set': [1, 2, 5],             # Dilation rates for increasing receptive field
    'squeeze_excitation': True,                 # Whether to use SE blocks for channel attention
    'dropout': 0.2                              # Dropout rate
}
initial_lr = 3e-4
total_epochs = 80
scheduler = CosineScheduler(initial_lr, total_epochs)

# Build reverse mapping for decoding
idx2char = {v: k for k, v in mapped_classes.items()}
idx2char[0] = ""  # Blank token for CTC

# Initializing the model
model = Lipreading(densetcn_options=densetcn_options, hidden_dim=512, num_classes=len(mapped_classes) + 1, relu_type='prelu').to(device)

# Add a Transformer decoder on top of the visual encoder
transformer_decoder = ArabicTransformerDecoder(
    vocab_size=len(mapped_classes) + 1,  # +1 for blank token
    attention_dim=512,  # Matching hidden_dim from the model
    attention_heads=8,  # 8 heads for better attention to different parts of sequence
    num_blocks=6,  # 6 transformer decoder layers
    dropout_rate=0.1  # Dropout rate
).to(device)

print(model)
print(transformer_decoder)

# Setting up loss functions
ctc_loss_fn = nn.CTCLoss(blank=0, zero_infinity=True)
ce_criterion = nn.CrossEntropyLoss(ignore_index=0)  # 0 is pad token

# Defining the optimizer with different learning rates for encoder and decoder
optimizer = optim.Adam([
    {'params': model.parameters(), 'lr': initial_lr},
    {'params': transformer_decoder.parameters(), 'lr': initial_lr * 1.5}  # Higher LR for transformer
])



  # 5. Training and Evaluation

In [None]:
# %%
# Replace beam search with Transformer decoder inference
def transformer_decode(log_probs, beam_size=8, maxlen=24, blank_index=0):
    """
    Perform transformer-based decoding on log probabilities with additional debugging.
    
    Args:
        log_probs: Log probabilities from the encoder, shape (B, T, C)
        beam_size: Beam width for search
        maxlen: Maximum length of the decoded sequence
        blank_index: Index of the blank token
        
    Returns:
        List of hypotheses, each with 'yseq' and 'score' keys
    """
    batch_size = log_probs.size(0)
    print(f"Starting transformer_decode with batch size: {batch_size}")
    
    # Create memory from encoder features
    memory = log_probs
    
    # Create memory mask (indicates valid encoder positions)
    memory_mask = torch.ones((batch_size, memory.size(1)), device=device)
    
    print(f"Memory shape: {memory.shape}")
    print(f"Memory mask shape: {memory_mask.shape}")
    
    # Debug memory stats
    print(f"Memory stats: mean={memory.mean().item():.4f}, std={memory.std().item():.4f}")
    
    all_results = []
    
    for b in range(batch_size):
        print(f"\nProcessing batch item {b+1}/{batch_size} in beam search")
        # Get this item's memory
        single_memory = memory[b:b+1]  # Keep batch dimension
        single_memory_mask = memory_mask[b:b+1]  # Keep batch dimension
        
        try:
            print("Starting beam search...")
            # Beam search params
            char_list = list(idx2char.keys())  # List of valid character indices
            
            # Beam search initialization
            # Initial state with start token
            y = torch.tensor([1], dtype=torch.long, device=device).reshape(1, 1)  # Start token
            
            # Initialize beam with single hypothesis
            beam = [{'score': 0.0, 'yseq': [1], 'cache': None}]
            
            # Set up length normalization parameters
            length_penalty = 0.6  # Adjust this parameter for better results
            diversity_penalty = 0.1  # Penalty for repeated tokens
            
            for i in range(maxlen):
                print(f"Beam search step {i+1}/{maxlen}")
                if len(beam) == 0:
                    print("Empty beam, breaking")
                    break
                
                # Collect candidates from all beam hypotheses
                new_beam = []
                
                for hyp in beam:
                    # Convert yseq to tensor
                    vy = torch.tensor(hyp['yseq'], dtype=torch.long, device=device).reshape(1, -1)
                    
                    # Create proper causal mask for autoregressive property
                    vy_mask = subsequent_mask(vy.size(1)).to(device)
                    
                    # Forward through transformer decoder
                    try:
                        decoder_out = transformer_decoder(
                            vy,                # Input token sequence
                            vy_mask,           # Self-attention causal mask 
                            single_memory,     # Memory from encoder
                            single_memory_mask  # Memory mask (valid positions)
                        )
                        
                        # Get the last prediction (most recent token)
                        y_logits = decoder_out[:, -1]
                        
                        # Convert to log probs
                        local_scores = F.log_softmax(y_logits, dim=-1)
                        
                        # Add to beam for every possible next token
                        for c in char_list:
                            # Skip blank token
                            if c == blank_index:
                                continue
                                
                            # Apply length normalization to scores
                            normalized_score = (hyp['score'] + local_scores[0, c].item()) / \
                                              ((len(hyp['yseq']) + 1) ** length_penalty)
                            
                            # Apply diversity penalty for repeated tokens
                            if c in hyp['yseq']:
                                normalized_score -= diversity_penalty
                            
                            # Create new hypothesis
                            new_hyp = {
                                'score': normalized_score,
                                'yseq': hyp['yseq'] + [c],
                                'cache': None
                            }
                            
                            new_beam.append(new_hyp)
                    except Exception as e:
                        print(f"Error in decoder forward pass: {str(e)}")
                        continue
                
                # No candidates found
                if len(new_beam) == 0:
                    print("No candidates in new beam, breaking")
                    break
                
                # Sort and keep top beam_size hypotheses
                new_beam.sort(key=lambda x: x['score'], reverse=True)
                beam = new_beam[:beam_size]
                
                # Debug beam status
                print(f"Top beam after step {i+1}:")
                for j, top_hyp in enumerate(beam[:3]):  # Just show top 3
                    print(f"  {j+1}: score={top_hyp['score']:.4f}, seq={top_hyp['yseq']}")
                
                # Check if all beam hypotheses end with EOS
                if all(hyp['yseq'][-1] == 2 for hyp in beam):
                    print("All hypotheses end with EOS, breaking")
                    break
            
            print(f"Beam search complete for batch item {b+1}")
            print(f"Final beam size: {len(beam)}")
            
            # Sort final beam
            beam.sort(key=lambda x: x['score'], reverse=True)
            all_results.append(beam)
            
        except Exception as e:
            print(f"Error during beam search for batch item {b+1}: {str(e)}")
            import traceback
            traceback.print_exc()
            
            # Add an empty result for this batch
            all_results.append([])
    
    # Return the best hypothesis for the first batch item (simplified)
    if len(all_results) > 0 and len(all_results[0]) > 0:
        return all_results[0]
    else:
        return []

def create_transformer_inputs(labels_flat, label_lengths, device):
    """
    Creates input and target tensors for transformer decoder training.
    
    Args:
        labels_flat: Flattened label tensor
        label_lengths: Length of each label sequence
        device: Device to create tensors on
        
    Returns:
        decoder_input: Input tensor for transformer decoder
        decoder_target: Target tensor for transformer decoder
        tgt_mask: Causal attention mask for decoder
    """
    # Prepare target sequences for transformer training (teacher forcing)
    target_seqs = []
    start_idx = 0
    
    for b in range(label_lengths.size(0)):
        seq_len = label_lengths[b].item()
        # Get this sequence
        seq = labels_flat[start_idx:start_idx + seq_len]
        # Add start-of-sequence token (for decoder input)
        target_seq = torch.cat([torch.tensor([1], device=device), seq])
        # Add end-of-sequence token
        target_seq = torch.cat([target_seq, torch.tensor([2], device=device)])
        # Add to list
        target_seqs.append(target_seq)
        # Update start index
        start_idx += seq_len
    
    # Pad sequences to same length
    max_len = max(len(seq) for seq in target_seqs)
    padded_seqs = []
    for seq in target_seqs:
        padded = torch.cat([seq, torch.zeros(max_len - len(seq), device=device, dtype=torch.long)])
        padded_seqs.append(padded)
    
    # Stack to tensor
    target_tensor = torch.stack(padded_seqs)
    
    # Teacher forcing: decoder input is target shifted right (remove last token)
    decoder_input = target_tensor[:, :-1]
    # Teacher forcing: decoder target is target shifted left (remove first token)
    decoder_target = target_tensor[:, 1:]
    
    # Create dynamic causal mask based on actual sequence length
    seq_len = decoder_input.size(1)
    batch_size = decoder_input.size(0)
    
    # Create causal mask that respects auto-regressive constraints
    tgt_mask = subsequent_mask(seq_len).to(device)  # Shape [seq_len, seq_len]
    
    # Ensure mask is 3D for attention modules: [batch_size, seq_len, seq_len]
    tgt_mask = tgt_mask.unsqueeze(0).expand(batch_size, -1, -1)
    
    return decoder_input, decoder_target, tgt_mask

# Replace the train_one_epoch function with the transformer approach
def train_one_epoch():
    running_loss = 0.0
    model.train()
    transformer_decoder.train()
    ctc_loss_fn = nn.CTCLoss(blank=0, zero_infinity=True)
    ce_criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding index (0)
    
    for batch_idx, (inputs, input_lengths, labels_flat, label_lengths) in enumerate(train_loader):
        # Print input shape for debugging
        print(f"Batch {batch_idx+1} - Input shape: {inputs.shape}")
        
        # Move data to device
        inputs = inputs.to(device)
        
        # Fix input_lengths to reflect actual video lengths
        # The input shape is [batch_size, channels, frames, height, width]
        actual_input_lengths = torch.full((inputs.size(0),), inputs.size(2), dtype=torch.long, device=device)
        print(f"Input lengths: {input_lengths}")
        print(f"Corrected input lengths: {actual_input_lengths}")
        
        input_lengths = actual_input_lengths  # Use corrected lengths
        
        labels_flat = labels_flat.to(device)
        label_lengths = label_lengths.to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass through the visual encoder
        encoder_features = model(inputs, input_lengths)
        
        # Set output_lengths to match the actual encoder output length
        output_lengths = torch.full((encoder_features.size(0),), encoder_features.size(1), dtype=torch.long, device=device)

        # Print shape to verify sequence output
        print(f"Batch {batch_idx+1} - Encoder features shape: {encoder_features.shape}")
        
        # Apply log_softmax for CTC
        log_probs = F.log_softmax(encoder_features, dim=2)  # (B, T, C)
        
        # Prepare for CTC loss - requires (T, B, C) format
        outputs_for_ctc = log_probs.transpose(0, 1)  # from (B, T, C) to (T, B, C)
        
        # Compute CTC loss
        ctc_loss_val = ctc_loss_fn(outputs_for_ctc, labels_flat, output_lengths, label_lengths)
        
        # Prepare target sequences for transformer training
        # First, reconstruct the target sequences from the flattened labels
        # Create a list of target sequences for each batch item
        target_seqs = []
        
        start_idx = 0
        batch_size = inputs.size(0)
        
        for b in range(batch_size):
            seq_len = label_lengths[b].item()
            # Extract the sequence for this batch item
            target_seq = labels_flat[start_idx:start_idx + seq_len]
            # Add start-of-sequence token (1) at the beginning
            target_seq = torch.cat([torch.tensor([1], device=device), target_seq])
            # Add end-of-sequence token (2) at the end
            target_seq = torch.cat([target_seq, torch.tensor([2], device=device)])
            
            # Add to lists
            target_seqs.append(target_seq)
            
            # Update start index
            start_idx += seq_len
        
        # Pad sequences to max length
        max_len = max(len(seq) for seq in target_seqs)
        padded_seqs = []
        for seq in target_seqs:
            padded = torch.cat([seq, torch.zeros(max_len - len(seq), device=device, dtype=torch.long)])
            padded_seqs.append(padded)
        
        # Stack sequences
        target_tensor = torch.stack(padded_seqs)
        
        # Create proper memory mask based on actual encoder output lengths
        # This mask indicates which positions in the encoder output are valid
        memory_mask = torch.zeros((batch_size, encoder_features.size(1)), device=device).bool()
        for b in range(batch_size):
            memory_mask[b, :output_lengths[b]] = True
        
        # For transformer training, we use the encoder features as memory
        # and teacher-forcing with target sequences as input
        # The input to the transformer decoder is the target sequence shifted right
        decoder_input, decoder_target, tgt_mask = create_transformer_inputs(labels_flat, label_lengths, device)
        print(f"Original decoder input shape: {decoder_input.shape}")
        
        # No need to truncate or pad to exactly 8 tokens anymore - use dynamic mask
        print(f"Final decoder input shape: {decoder_input.shape}")
        print(f"Final decoder target shape: {decoder_target.shape}")
        print(f"Mask shape: {tgt_mask.shape}")
        
        # Create a causal mask for the target
        print("Applying forward pass through transformer decoder...")
        
        print(f"Input shapes: decoder_input={decoder_input.shape}, tgt_mask={tgt_mask.shape}")
        print(f"Memory shapes: encoder_features={encoder_features.shape}, memory_mask={memory_mask.shape}")
        
        try:
            # Use try-except to capture details of any error
            # Forward pass through the transformer decoder
            decoder_output = transformer_decoder(
                decoder_input,  # (batch_size, seq_len)
                tgt_mask,       # (batch_size, seq_len, seq_len)
                encoder_features,  # (batch_size, seq_len, dim)
                memory_mask     # (batch_size, seq_len)
            )
            print(f"Decoder output shape: {decoder_output.shape}")
            
            # Calculate statistics of the decoder output for debugging
            print(f"Decoder output stats: mean={decoder_output.float().mean():.4f}, std={decoder_output.float().std():.4f}")
            
            # Calculate cross-entropy loss
            decoder_output_flat = decoder_output.reshape(-1, decoder_output.size(-1))
            decoder_target_flat = decoder_target.reshape(-1)
            ce_loss = ce_criterion(decoder_output_flat, decoder_target_flat)
            print(f"CE Loss: {ce_loss.item():.6f}")
            
            # Calculate combined loss (weighted sum of CTC and CE)
            ctc_weight = 0.7  # Adjust this weight as needed
            combined_loss = ctc_weight * ctc_loss_val + (1 - ctc_weight) * ce_loss
            print(f"Combined Loss: {combined_loss.item():.6f}")
            
            # Backward pass and optimize
            combined_loss.backward()
            
            # Gradient clipping to prevent exploding gradients (important for transformers)
            torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(transformer_decoder.parameters()), 1.0)
            
            # Update weights
            optimizer.step()
            
            running_loss += combined_loss.item()
            
        except Exception as e:
            print(f"Error in transformer decoder forward pass: {str(e)}")
            print(f"Error type: {type(e).__name__}")
            import traceback
            traceback.print_exc()
            
            # Check specific tensor shapes in more detail
            print(f"decoder_input dtype: {decoder_input.dtype}, device: {decoder_input.device}")
            print(f"tgt_mask dtype: {tgt_mask.dtype}, device: {tgt_mask.device}")
            
            # Continue with penalty loss if error occurs
            ce_loss = torch.tensor(5.0, device=device)  # Default penalty
            combined_loss = ctc_weight * ctc_loss_val + (1 - ctc_weight) * ce_loss
            
            # Backward pass and optimize
            combined_loss.backward()
            
            # Update weights
            optimizer.step()
            
            running_loss += combined_loss.item()
    
    return running_loss / len(train_loader)

def evaluate_model(data_loader):
    model.eval()
    transformer_decoder.eval()
    ctc_loss_fn = nn.CTCLoss(blank=0, zero_infinity=True)

    # Track statistics
    total_cer = 0
    total_edit_distance = 0
    total_loss = 0
    
    # Process all batches in the test loader
    with torch.no_grad():
        for i, (inputs, input_lengths, labels_flat, label_lengths) in enumerate(data_loader):
            # Move to device
            inputs = inputs.to(device)
            
            # Fix input_lengths to reflect actual video lengths
            actual_input_lengths = torch.full((inputs.size(0),), inputs.size(2), dtype=torch.long, device=device)
            print(f"Input lengths: {input_lengths}")
            print(f"Corrected input lengths: {actual_input_lengths}")
            
            input_lengths = actual_input_lengths  # Use corrected lengths
            
            labels_flat = labels_flat.to(device)
            label_lengths = label_lengths.to(device)
            
            # Forward pass through visual encoder
            batch_size = inputs.size(0)
            encoder_features = model(inputs, input_lengths)  # (B, T, hidden_dim)
            
            # Set output_lengths to match the actual encoder output length
            output_lengths = torch.full((encoder_features.size(0),), encoder_features.size(1), dtype=torch.long, device=device)
            
            # Calculate CTC loss as before (for monitoring only)
            log_probs = F.log_softmax(encoder_features, dim=2)  # (B, T, C)
            log_probs_ctc = log_probs.transpose(0, 1)  # (T, B, C)
            loss = ctc_loss_fn(log_probs_ctc, labels_flat, output_lengths, label_lengths)
            
            print(f"\nRunning Transformer decoding for batch {i+1}...")
            
            try:
                # Batch beam search returns a list of lists - each batch item has its own list of beams
                print(f"Encoder features shape: {encoder_features.shape}")
                print(f"Beam size: 8, Max length: 24")
                
                # Create proper memory mask based on actual encoder output lengths
                memory_mask = torch.zeros((batch_size, encoder_features.size(1)), device=device).bool()
                for b in range(batch_size):
                    memory_mask[b, :output_lengths[b]] = True
                
                all_nbest_hyps = transformer_decoder.batch_beam_search(
                    memory=encoder_features,
                    memory_mask=memory_mask,
                    beam_size=8,
                    maxlen=24,
                    minlen=1,  # Minimum length requirement
                    sos=1,     # Start of sequence token
                    eos=2      # End of sequence token
                )
                
                print(f"Transformer decoding completed for batch {i+1}")
                print(f"Received {len(all_nbest_hyps)} hypotheses sets")
                
                # Process each batch item
                for b in range(batch_size):
                    print(f"\nProcessing batch item {b+1}/{batch_size}")
                    
                    # Get best hypothesis for this batch item
                    if b < len(all_nbest_hyps):
                        # With the updated batch_beam_search, results are (score, hyp) tuples
                        score, pred_indices = all_nbest_hyps[b]
                        print(f"Found beam hypothesis for item {b+1} with score {score:.4f}")
                        
                        # Convert to numpy array
                        pred_indices = np.array(pred_indices)
                        
                        # Print warning if pred_indices is empty
                        if len(pred_indices) == 0:
                            print("WARNING: Prediction sequence is empty!")
                    else:
                        # No hypotheses for this batch item - use empty prediction
                        print(f"No hypotheses for batch item {b+1}")
                        pred_indices = np.array([])
                    
                    # Get target indices
                    start_idx = sum(label_lengths[:b].cpu().tolist()) if b > 0 else 0
                    end_idx = start_idx + label_lengths[b].item()
                    target_idx = labels_flat[start_idx:end_idx].cpu().numpy()
                    
                    # Convert indices to text
                    pred_text = indices_to_text(pred_indices, idx2char)
                    target_text = indices_to_text(target_idx, idx2char)
                    
                    # Calculate CER using custom function
                    cer, edit_distance = compute_cer(target_idx, pred_indices)
                    
                    # Update statistics
                    total_cer += cer
                    total_edit_distance += edit_distance
                    total_loss += loss.item() / batch_size
                    
                    # Print info
                    print("-" * 50)
                    print(f"Sample {i * batch_size + b + 1}:")
                    try:
                        print(f"Predicted text: {pred_text}")
                        print(f"Target text: {target_text}")
                    except UnicodeEncodeError:
                        print("Predicted text: [Contains characters that can't be displayed in console]")
                        print("Target text: [Contains characters that can't be displayed in console]")
                        print(f"Predicted indices: {pred_indices}")
                        print(f"Target indices: {target_idx}")
                        
                    print(f"Edit distance: {edit_distance}")
                    print(f"CER: {cer:.4f}")
                    print("-" * 50)
            
            except Exception as e:
                print(f"Error during Transformer decoding: {str(e)}")
                import traceback
                traceback.print_exc()
                # Fall back to CTC greedy decoding
                print("Falling back to CTC greedy decoding")
                
                # Process batch items with greedy decoding
                for b in range(batch_size):
                    batch_logits = log_probs[b].cpu().numpy()
                    pred_indices = np.argmax(batch_logits, axis=1)
                    # Remove duplicates and blanks
                    filtered_indices = []
                    prev_idx = -1
                    for idx in pred_indices:
                        if idx != 0 and idx != prev_idx:  # Skip blanks and duplicates
                            filtered_indices.append(idx)
                        prev_idx = idx
                    pred_indices = np.array(filtered_indices)
                    
                    # Get target indices
                    start_idx = sum(label_lengths[:b].cpu().tolist()) if b > 0 else 0
                    end_idx = start_idx + label_lengths[b].item()
                    target_idx = labels_flat[start_idx:end_idx].cpu().numpy()
                    
                    # Convert indices to text
                    pred_text = indices_to_text(pred_indices, idx2char)
                    target_text = indices_to_text(target_idx, idx2char)
                    
                    # Calculate CER
                    cer, edit_distance = compute_cer(target_idx, pred_indices)
                    
                    # Update statistics
                    total_cer += cer
                    total_edit_distance += edit_distance
                    total_loss += loss.item() / batch_size
                    
                    # Print info
                    print("-" * 50)
                    print(f"Sample {i * batch_size + b + 1} (Greedy CTC):")
                    print(f"Predicted text: {pred_text}")
                    print(f"Target text: {target_text}")
                    print(f"Edit distance: {edit_distance}")
                    print(f"CER: {cer:.4f}")
                    print("-" * 50)
        
        # Write summary statistics
        n_samples = len(data_loader.dataset)
        avg_cer = total_cer / n_samples
        avg_edit_distance = total_edit_distance / n_samples
        avg_loss = total_loss / len(data_loader)
        
        print("=== Summary Statistics ===")
        print(f"Total samples: {n_samples}")
        print(f"Average CER: {avg_cer:.4f}")
        print(f"Average Edit Distance: {avg_edit_distance:.2f}")
        print(f"Average Loss: {avg_loss:.4f}")

    return total_loss / len(data_loader)




In [None]:
# %%
def test_batch(transformer_decoder, encoder_features, memory_mask, beam_size=8):
    """Test a batch with beam search using the transformer decoder"""
    batch_size = encoder_features.size(0)
    
    print(f"Testing batch with beam search (batch_size={batch_size})")
    
    # Run batch beam search
    try:
        results = transformer_decoder.batch_beam_search(
            memory=encoder_features,
            memory_mask=memory_mask,
            beam_size=beam_size,
            maxlen=24,
            minlen=1,
            sos=1,  # Start of sequence token
            eos=2   # End of sequence token
        )
        
        # Print results
        for b, (score, hyp) in enumerate(results):
            print(f"Batch item {b+1}:")
            print(f"  Score: {score:.4f}")
            print(f"  Hypothesis: {hyp}")
            
    except Exception as e:
        print(f"Error during beam search: {str(e)}")
        import traceback
        traceback.print_exc()
    
def train_model():
    """Train the model on the full dataset"""
    for epoch in range(total_epochs):
        # Train for one epoch
        epoch_loss = train_one_epoch()
        
        # Adjust learning rate
        scheduler.adjust_lr(optimizer, epoch)
        
        # Evaluate on validation set
        val_loss = evaluate_model(val_loader)
        
        print(f"Epoch {epoch + 1}/{total_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}")




In [None]:
# %%
def quick_experiment(model, transformer_decoder, full_dataset, num_epochs=5, num_samples=50):
    """Run a quick experiment with a small subset of the data.
    
    Args:
        model: The visual encoder model
        transformer_decoder: The transformer decoder model
        full_dataset: The complete training dataset
        num_epochs: Number of epochs to train (default: 5)
        num_samples: Number of samples to use (default: 50)
    """
    print(f"Running quick experiment with {num_samples} samples for {num_epochs} epochs")
    
    try:
        # Create small dataset for quick testing
        indices = torch.randperm(len(full_dataset))[:num_samples]
        small_dataset = torch.utils.data.Subset(full_dataset, indices)
        
        # Create dataloader with the small dataset
        small_loader = DataLoader(small_dataset, batch_size=8, shuffle=True, 
                                pin_memory=True, collate_fn=pad_packed_collate)
        
        # Initialize loss functions
        ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
        ce_loss = nn.CrossEntropyLoss(ignore_index=0)
        
        # Training loop
        for epoch in range(num_epochs):
            print(f"\n===== Epoch {epoch+1}/{num_epochs} Training =====")
            
            # Initialize tracking variables
            running_loss = 0.0
            batch_count = 0
            
            # Set models to training mode
            model.train()
            transformer_decoder.train()
            
            # Process batches
            for batch_idx, (inputs, input_lengths, labels_flat, label_lengths) in enumerate(small_loader):
                try:
                    print(f"Training batch {batch_idx+1} of {len(small_loader)}")
                    
                    # Print shapes for debugging
                    print(f"Input shape: {inputs.shape}")
                    print(f"Input lengths: {input_lengths}")
                    print(f"Label flat shape: {labels_flat.shape}")
                    print(f"Label lengths: {label_lengths}")
                    
                    # Move data to device
                    inputs = inputs.to(device)
                    input_lengths = input_lengths.to(device)
                    labels_flat = labels_flat.to(device)
                    label_lengths = label_lengths.to(device)
                    
                    # Create actual input lengths tensor based on encoder output size
                    actual_input_lengths = torch.full((inputs.size(0),), inputs.size(2), 
                                                    dtype=torch.long, device=device)
                    print(f"Corrected input lengths: {actual_input_lengths}")
                    
                    # Zero the gradients
                    optimizer.zero_grad()
                    
                    # Forward pass through visual encoder
                    encoder_features = model(inputs, actual_input_lengths)
                    print(f"Encoder features shape: {encoder_features.shape}")
                    print(f"Encoder features stats: mean={encoder_features.mean():.4f}, std={encoder_features.std():.4f}")
                    
                    # Calculate CTC loss
                    log_probs = F.log_softmax(encoder_features, dim=2)
                    outputs_for_ctc = log_probs.transpose(0, 1)
                    ctc_loss_val = ctc_loss(outputs_for_ctc, labels_flat, actual_input_lengths, label_lengths)
                    print(f"CTC Loss: {ctc_loss_val.item():.6f}")
                    
                    # Prepare target sequences for transformer training
                    target_seqs = []
                    start_idx = 0
                    batch_size = inputs.size(0)
                    
                    for b in range(batch_size):
                        seq_len = label_lengths[b].item()
                        target_seq = labels_flat[start_idx:start_idx + seq_len]
                        # Add SOS (1) and EOS (2) tokens
                        target_seq = torch.cat([torch.tensor([1], device=device), 
                                             target_seq, 
                                             torch.tensor([2], device=device)])
                        target_seqs.append(target_seq)
                        start_idx += seq_len
                    
                    # Pad sequences to same length
                    max_len = max(len(seq) for seq in target_seqs)
                    padded_seqs = []
                    for seq in target_seqs:
                        padded = torch.cat([seq, torch.zeros(max_len - len(seq), 
                                                           device=device, dtype=torch.long)])
                        padded_seqs.append(padded)
                    
                    # Stack sequences
                    target_tensor = torch.stack(padded_seqs)
                    
                    # Prepare decoder input/output
                    decoder_input = target_tensor[:, :-1]  # Remove last token
                    decoder_output = target_tensor[:, 1:]  # Remove first token
                    
                    print(f"Original decoder input shape: {decoder_input.shape}")
                    
                    # Create masks
                    tgt_mask = subsequent_mask(decoder_input.size(1)).to(device)
                    tgt_mask = tgt_mask.expand(batch_size, -1, -1)
                    
                    # Create memory mask based on actual encoder output lengths
                    memory_mask = torch.ones((batch_size, encoder_features.size(1)), device=device)
                    
                    print(f"Final decoder input shape: {decoder_input.shape}")
                    print(f"Final decoder target shape: {decoder_output.shape}")
                    print(f"Mask shape: {tgt_mask.shape}")
                    
                    # Forward through transformer decoder
                    print("Applying forward pass through transformer decoder...")
                    print(f"Input shapes: decoder_input={decoder_input.shape}, tgt_mask={tgt_mask.shape}")
                    print(f"Memory shapes: encoder_features={encoder_features.shape}, memory_mask={memory_mask.shape}")
                    
                    decoder_out = transformer_decoder(
                        decoder_input, tgt_mask, encoder_features, memory_mask
                    )
                    
                    print(f"Decoder output shape: {decoder_out.shape}")
                    print(f"Decoder output stats: mean={decoder_out.mean():.4f}, std={decoder_out.std():.4f}")
                    
                    # Calculate transformer loss
                    decoder_out_flat = decoder_out.reshape(-1, decoder_out.size(-1))
                    decoder_output_flat = decoder_output.reshape(-1)
                    transformer_loss = ce_loss(decoder_out_flat, decoder_output_flat)
                    print(f"CE Loss: {transformer_loss.item():.6f}")
                    
                    # Combined loss
                    alpha = 0.7  # Weight for CTC loss
                    loss = alpha * ctc_loss_val + (1 - alpha) * transformer_loss
                    print(f"Combined Loss: {loss.item():.6f}")
                    
                    # Backprop
                    loss.backward()
                    optimizer.step()
                    
                    running_loss += loss.item()
                    batch_count += 1
                    
                except Exception as e:
                    print(f"Error in batch {batch_idx+1}: {str(e)}")
                    import traceback
                    traceback.print_exc()
                    continue
            
            # Print epoch statistics
            if batch_count > 0:
                epoch_loss = running_loss / batch_count
                print(f"\nEpoch {epoch+1}/{num_epochs} - Average Loss: {epoch_loss:.6f}")
            else:
                print(f"\nEpoch {epoch+1}/{num_epochs} - No valid batches processed")
            
            # Evaluate on a small validation set
            print("\nRunning evaluation...")
            model.eval()
            transformer_decoder.eval()
            
            with torch.no_grad():
                for i, (inputs, input_lengths, labels_flat, label_lengths) in enumerate(small_loader):
                    try:
                        # Move to device
                        inputs = inputs.to(device)
                        input_lengths = input_lengths.to(device)
                        labels_flat = labels_flat.to(device)
                        label_lengths = label_lengths.to(device)
                        
                        # Forward pass through encoder
                        encoder_features = model(inputs, input_lengths)
                        
                        # Beam search decoding
                        results = transformer_decoder.batch_beam_search(
                            encoder_features,
                            memory_mask=None,  # Let the decoder handle mask creation
                            beam_size=5,
                            maxlen=50,
                            sos=1,
                            eos=2
                        )
                        
                        # Process and print results
                        for b, (score, hyp) in enumerate(results):
                            # Get target indices for comparison
                            start_idx = sum(label_lengths[:b].cpu().tolist()) if b > 0 else 0
                            end_idx = start_idx + label_lengths[b].item()
                            target_idx = labels_flat[start_idx:end_idx].cpu().numpy()
                            
                            # Convert to text
                            pred_text = indices_to_text(hyp, idx2char)
                            target_text = indices_to_text(target_idx, idx2char)
                            
                            print(f"\nSample {b+1}:")
                            print(f"  Predicted: {pred_text}")
                            print(f"  Target: {target_text}")
                            print(f"  Score: {score:.4f}")
                        
                        # Only process first batch during evaluation
                        break
                        
                    except Exception as e:
                        print(f"Error in evaluation: {str(e)}")
                        import traceback
                        traceback.print_exc()
                        continue
    
    except Exception as e:
        print(f"Error in experiment: {str(e)}")
        import traceback
        traceback.print_exc()

# Update the function call to include the number of samples parameter
reset_seed()
# Uncomment one of the following lines to run the full training or quick experiment
# train_model()
quick_experiment(model, transformer_decoder, train_dataset, num_samples=50)  # Run the quick experiment with 50 samples
