 # 1. Imports

In [None]:
# %%

import torch, os, cv2
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import transforms
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
from PIL import Image
import editdistance
from lipreading.model import Lipreading
from lipreading.optim_utils import CosineScheduler
# Import the TCN decoder instead of transformer decoder
from lipreading.tcn_decoder import TCNDecoder
# from lipreading.transformer_decoder import ArabicTransformerDecoder

# We don't need the mask utility for TCN
# from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask



 # 2. Initialize the seed and the device

In [None]:
# %%
# Setting the seed for reproducibility
seed = 0
def reset_seed():
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

# Setting the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


 # 3. Dataset preparation

 ## 3.1. List of Classes

In [None]:
# %%
def extract_label(file):
    label = []
    diacritics = {
        '\u064B',  # Fathatan
        '\u064C',  # Dammatan
        '\u064D',  # Kasratan
        '\u064E',  # Fatha
        '\u064F',  # Damma
        '\u0650',  # Kasra
        '\u0651',  # Shadda
        '\u0652',  # Sukun
        '\u06E2',  # Small High meem
    }

    sentence = pd.read_csv(file)
    for word in sentence.word:
        for char in word:
            if char not in diacritics:
                label.append(char)
            else:
                label[-1] += char

    return label

classes = set()
for i in os.listdir('Dataset/Csv (with Diacritics)'):
    file = 'Dataset/Csv (with Diacritics)/' + i
    label = extract_label(file)
    classes.update(label)

mapped_classes = {}
for i, c in enumerate(sorted(classes, reverse=True), 1):
    mapped_classes[c] = i

print(mapped_classes)


 ## 3.2. Video Dataset Class

In [None]:
# %%
# Defining the video dataset class
class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, video_paths, label_paths, transform=None):
        self.video_paths = video_paths
        self.label_paths = label_paths
        self.transform = transform
        
    def __len__(self):
        return len(self.video_paths)
    
    def __getitem__(self, index):
        video_path = self.video_paths[index]
        label_path = self.label_paths[index]
        frames = self.load_frames(video_path=video_path)
        label = torch.tensor(list(map(lambda x: mapped_classes[x], extract_label(label_path))))
        input_length = torch.tensor(len(frames), dtype=torch.long)
        label_length = torch.tensor(len(label), dtype=torch.long)
        return frames, input_length, label, label_length
    
    def load_frames(self, video_path):
        frames = []
        video = cv2.VideoCapture(video_path)
        total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        for i in range(total_frames):
            video.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret, frame = video.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                frame_pil = Image.fromarray(frame, 'L')
                frames.append(frame_pil)

        if self.transform is not None:
            frames = [self.transform(frame) for frame in frames] 
        frames = torch.stack(frames).permute(1, 0, 2, 3)
        return frames

# Defining the video transform
transforms = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
    transforms.Normalize(mean=0.421, std=0.165),
])


 ## 3.2. Load the dataset

In [None]:
# %%
videos_dir = "Dataset/Preprocessed_Video"
labels_dir = "Dataset/Csv (with Diacritics)"
videos, labels = [], []
file_names = [file_name[:-4] for file_name in os.listdir(videos_dir)]
for file_name in file_names:
    videos.append(os.path.join(videos_dir, file_name + ".mp4"))
    labels.append(os.path.join(labels_dir, file_name + ".csv"))


 ## 3.3. Split the dataset

In [None]:
# %%
# Split the dataset into training, validation, test sets
X_temp, X_test, y_temp, y_test = train_test_split(videos, labels, test_size=0.1000, random_state=seed)
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.1111, random_state=seed)


 ## 3.4. DataLoaders

In [None]:
# %%
def pad_packed_collate(batch):
    """Pads data and labels with different lengths in the same batch
    """
    data_list, input_lengths, labels_list, label_lengths = zip(*batch)
    c, max_len, h, w = max(data_list, key=lambda x: x.shape[1]).shape

    data = torch.zeros((len(data_list), c, max_len, h, w))
    
    # Only copy up to the actual sequence length
    for idx in range(len(data)):
        data[idx, :, :input_lengths[idx], :, :] = data_list[idx][:, :input_lengths[idx], :, :]
    
    # Flatten labels for CTC loss
    labels_flat = []
    for label_seq in labels_list:
        labels_flat.extend(label_seq)
    labels_flat = torch.LongTensor(labels_flat)
    
    # Convert lengths to tensor
    input_lengths = torch.LongTensor(input_lengths)
    label_lengths = torch.LongTensor(label_lengths)
    return data, input_lengths, labels_flat, label_lengths


# Defining the video dataloaders (train, validation, test)
train_dataset = VideoDataset(X_train, y_train, transform=transforms)
val_dataset = VideoDataset(X_val, y_val, transform=transforms)
test_dataset = VideoDataset(X_test, y_test, transform=transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True, collate_fn=pad_packed_collate)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, pin_memory=True, collate_fn=pad_packed_collate)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, pin_memory=True, collate_fn=pad_packed_collate)


 # 4. Model

In [None]:
# %%
def indices_to_text(indices, idx2char):
    """
    Converts a list of indices to text using the reverse vocabulary mapping.
    """
    try:
        return ''.join([idx2char.get(i, '') for i in indices])
    except UnicodeEncodeError:
        # Handle encoding issues in Windows console
        # Return a safe representation that won't cause encoding errors
        safe_text = []
        for i in indices:
            char = idx2char.get(i, '')
            try:
                # Test if character can be encoded
                char.encode('cp1252')
                safe_text.append(char)
            except UnicodeEncodeError:
                # Replace with a placeholder for characters that can't be displayed
                safe_text.append(f"[{i}]")
        return ''.join(safe_text)

def compute_cer(reference_indices, hypothesis_indices):
    """
    Computes Character Error Rate (CER) directly using token indices.
    Takes raw token indices from our vocabulary (class_mapping.txt) rather than Unicode text.
    
    Returns a tuple of (CER, edit_distance)
    """
    # Use the indices directly - each index is one token in our vocabulary
    ref_tokens = reference_indices
    hyp_tokens = hypothesis_indices
    
    try:
        print(f"Debug - Reference tokens ({len(ref_tokens)} tokens): {ref_tokens}")
        print(f"Debug - Hypothesis tokens ({len(hyp_tokens)} tokens): {hyp_tokens}")
    except UnicodeEncodeError:
        # Handle encoding issues in Windows console
        print(f"Debug - Reference tokens ({len(ref_tokens)} tokens): [Token indices omitted due to encoding issues]")
        print(f"Debug - Hypothesis tokens ({len(hyp_tokens)} tokens): [Token indices omitted due to encoding issues]")
    
    # Calculate edit distance using the editdistance library
    edit_distance = editdistance.eval(ref_tokens, hyp_tokens)
    
    # Calculate CER
    cer = edit_distance / max(len(ref_tokens), 1)  # Avoid division by zero
    
    return cer, edit_distance

# Initializing the hyper-parameters
densetcn_options = {
    'block_config': [3, 3, 3, 3],               # Number of layers in each dense block
    'growth_rate_set': [384, 384, 384, 384],    # Growth rate for each block (must be divisible by len(kernel_size_set))
    'reduced_size': 512,                        # Reduced size between blocks (must be divisible by len(kernel_size_set))
    'kernel_size_set': [3, 5, 7],               # Kernel sizes for multi-scale processing
    'dilation_size_set': [1, 2, 5],             # Dilation rates for increasing receptive field
    'squeeze_excitation': True,                 # Whether to use SE blocks for channel attention
    'dropout': 0.2                              # Dropout rate
}
initial_lr = 3e-4
total_epochs = 80
scheduler = CosineScheduler(initial_lr, total_epochs)

# Build reverse mapping for decoding
idx2char = {v: k for k, v in mapped_classes.items()}
idx2char[0] = ""  # Blank token for CTC

# Initializing the model
model = Lipreading(densetcn_options=densetcn_options, hidden_dim=512, num_classes=len(mapped_classes) + 1, relu_type='prelu').to(device)

# Add a TCN decoder on top of the visual encoder
tcn_decoder = TCNDecoder(
    vocab_size=len(mapped_classes) + 1,  # +1 for blank token
    hidden_dim=512,  # Matching hidden_dim from the model
    num_channels=[384, 384, 384, 384],  # Channels must be divisible by 3 (num_kernels in multibranch mode)
    kernel_size=3,  # Kernel size for TCN convolutions
    dropout=0.2,  # Dropout rate
    emb_dropout=0.2,  # Embedding dropout rate
    mode='multibranch'  # Use multi-branch TCN for better feature extraction
).to(device)

print(model)

# Defining the loss function and optimizer
optimizer = optim.Adam(list(model.parameters()) + list(tcn_decoder.parameters()), lr=initial_lr)


 # 5. Training and Evaluation

In [None]:
# %%
# Replace beam search with TCN decoder inference
def tcn_decode(log_probs, blank_index=0):
    """
    Perform TCN-based decoding on CTC log probabilities.
    
    Args:
        log_probs: Log probabilities of shape (B, T, C)
        blank_index: Index of the blank token
        
    Returns:
        List of hypotheses, each with 'yseq' and 'score' keys
    """
    batch_size = log_probs.size(0)
    max_length = log_probs.size(1)
    
    # Create memory from encoder features (log_probs)
    memory = log_probs
    
    # Use TCN decoder for beam search decoding
    results = tcn_decoder.batch_beam_search(memory, beam_size=5, maxlen=24)  # max_label_length=24 from dataset
    
    return results

# Training the model
def train_one_epoch():
    running_loss = 0.0
    model.train()
    tcn_decoder.train()
    ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
    ce_loss = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding index (0)
    
    for batch_idx, (inputs, input_lengths, labels_flat, label_lengths) in enumerate(train_loader):
        # Print input shape for debugging
        print(f"Batch {batch_idx+1} - Input shape: {inputs.shape}")
        
        # Move data to device
        inputs = inputs.to(device)
        input_lengths = input_lengths.to(device)
        labels_flat = labels_flat.to(device)
        label_lengths = label_lengths.to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass through the visual encoder
        encoder_features = model(inputs, input_lengths)
        output_lengths = torch.full((encoder_features.size(0),), encoder_features.size(1), dtype=torch.long, device=device)

        # Print shape to verify sequence output
        print(f"Batch {batch_idx+1} - Encoder features shape: {encoder_features.shape}")
        
        # Apply log_softmax for CTC
        log_probs = F.log_softmax(encoder_features, dim=2)  # (B, T, C)
        
        # Prepare for CTC loss - requires (T, B, C) format
        outputs_for_ctc = log_probs.transpose(0, 1)  # from (B, T, C) to (T, B, C)
        
        # Compute CTC loss
        ctc_loss_val = ctc_loss(outputs_for_ctc, labels_flat, output_lengths, label_lengths)
        
        # Prepare target sequences for TCN training
        # First, reconstruct the target sequences from the flattened labels
        # Create a list of target sequences for each batch item
        target_seqs = []
        target_masks = []
        
        start_idx = 0
        batch_size = inputs.size(0)
        
        for b in range(batch_size):
            seq_len = label_lengths[b].item()
            # Extract the sequence for this batch item
            target_seq = labels_flat[start_idx:start_idx + seq_len]
            # Add start-of-sequence token (1) at the beginning
            target_seq = torch.cat([torch.tensor([1], device=device), target_seq])
            # Add end-of-sequence token (2) at the end
            target_seq = torch.cat([target_seq, torch.tensor([2], device=device)])
            
            # Prepare masks for TCN
            target_mask = torch.ones((seq_len + 2,), device=device)
            
            # Add to lists
            target_seqs.append(target_seq)
            target_masks.append(target_mask)
            
            # Update start index
            start_idx += seq_len
        
        # Pad sequences to max length
        max_len = max(len(seq) for seq in target_seqs)
        padded_seqs = []
        for seq in target_seqs:
            padded = torch.cat([seq, torch.zeros(max_len - len(seq), device=device, dtype=torch.long)])
            padded_seqs.append(padded)
        
        # Stack sequences and masks
        target_tensor = torch.stack(padded_seqs)
        
        # Create memory mask (indicates valid encoder positions)
        memory_mask = torch.ones((batch_size, encoder_features.size(1)), device=device)
        
        # For TCN training, we use the encoder features as memory
        # and teacher-forcing with target sequences as input
        # The input to the TCN decoder is the target sequence shifted right
        decoder_input = target_tensor[:, :-1]  # Exclude the last token (EOS)
        decoder_output = target_tensor[:, 1:]  # Exclude the first token (SOS)
        
        # Create mask for the target - to enforce causal attention
        tgt_mask = torch.ones((batch_size, decoder_input.size(1)), device=device)
        
        # Forward through TCN decoder
        tcn_out = tcn_decoder(
            encoder_features
        )
        
        # Calculate TCN loss - need to handle shape mismatch
        # Get sequence lengths for proper comparison
        tcn_seq_len = tcn_out.size(1)
        decoder_seq_len = decoder_output.size(1)
        
        # Print shapes for debugging
        print(f"TCN output shape: {tcn_out.shape}, Decoder output shape: {decoder_output.shape}")
        
        # Adjust decoder_output to match tcn_out length using interpolation if needed
        if tcn_seq_len != decoder_seq_len:
            print(f"Sequence length mismatch: TCN={tcn_seq_len}, Decoder={decoder_seq_len}")
            # Use only the common prefix of both sequences
            min_seq_len = min(tcn_seq_len, decoder_seq_len)
            tcn_out = tcn_out[:, :min_seq_len, :]
            decoder_output = decoder_output[:, :min_seq_len]
            print(f"Using common prefix with length {min_seq_len}")
            print(f"New shapes - TCN: {tcn_out.shape}, Decoder: {decoder_output.shape}")
        
        # Flatten for cross entropy loss
        tcn_out_flat = tcn_out.reshape(-1, tcn_out.size(-1))
        decoder_output_flat = decoder_output.reshape(-1)
        
        # Verify shapes are compatible
        print(f"Flattened shapes - TCN: {tcn_out_flat.shape}, Decoder: {decoder_output_flat.shape}")
        
        # Calculate loss
        tcn_loss = ce_loss(tcn_out_flat, decoder_output_flat)
        
        # Combined loss (weighted sum of CTC and TCN losses)
        alpha = 0.7  # Weight for CTC loss
        loss = alpha * ctc_loss_val + (1 - alpha) * tcn_loss
        
        print(f"Batch {batch_idx+1} - CTC Loss: {ctc_loss_val.item():.4f}, TCN Loss: {tcn_loss.item():.4f}, Combined Loss: {loss.item():.4f}")
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    return running_loss / len(train_loader)

def evaluate_model(data_loader):
    model.eval()
    tcn_decoder.eval()
    ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)

    # Track statistics
    total_cer = 0
    total_edit_distance = 0
    total_loss = 0
    
    # Process all batches in the test loader
    with torch.no_grad():
        for i, (inputs, input_lengths, labels_flat, label_lengths) in enumerate(data_loader):
            # Move to device
            inputs = inputs.to(device)
            input_lengths = input_lengths.to(device)
            labels_flat = labels_flat.to(device)
            label_lengths = label_lengths.to(device)
            
            # Forward pass through visual encoder
            batch_size = inputs.size(0)
            encoder_features = model(inputs, input_lengths)  # (B, T, hidden_dim)
            
            # Calculate CTC loss as before (for monitoring only)
            output_lengths = torch.full((encoder_features.size(0),), encoder_features.size(1), dtype=torch.long, device=device)
            log_probs = F.log_softmax(encoder_features, dim=2)  # (B, T, C)
            log_probs_ctc = log_probs.transpose(0, 1)  # (T, B, C)
            loss = ctc_loss(log_probs_ctc, labels_flat, output_lengths, label_lengths)
            
            print(f"\nRunning TCN decoding for batch {i+1}...")
            
            try:
                # Batch beam search returns a list of lists - each batch item has its own list of beams
                print(f"Encoder features shape: {encoder_features.shape}")
                print(f"Beam size: 5, Max length: 24")
                
                all_nbest_hyps = tcn_decoder.batch_beam_search(
                    encoder_features, beam_size=5, maxlen=24
                )
                
                print(f"TCN decoding completed for batch {i+1}")
                print(f"Received {len(all_nbest_hyps)} hypotheses sets")
                
                # Process each batch item
                for b in range(batch_size):
                    print(f"\nProcessing batch item {b+1}/{batch_size}")
                    
                    # Get best hypothesis for this batch item
                    if b < len(all_nbest_hyps) and len(all_nbest_hyps[b]) > 0:
                        nbest_hyps = all_nbest_hyps[b]  # List of beams for this batch item
                        print(f"Found {len(nbest_hyps)} beam hypotheses for item {b+1}")
                        
                        best_hyp = nbest_hyps[0]  # Best beam (highest score)
                        print(f"Best hypothesis raw sequence: {best_hyp['yseq']}")
                        
                        # Get predicted sequence (remove SOS token if present)
                        pred_indices = best_hyp["yseq"][1:] if best_hyp["yseq"][0] == 1 else best_hyp["yseq"]
                        print(f"After SOS removal: {pred_indices}")
                        
                        # Clean up sequence (remove padding, EOS tokens)
                        # Assuming 2 is EOS token
                        if 2 in pred_indices:
                            eos_idx = pred_indices.index(2)
                            pred_indices = pred_indices[:eos_idx]
                            print(f"After EOS removal: {pred_indices}")
                        
                        # Print warning if pred_indices is empty
                        if len(pred_indices) == 0:
                            print("WARNING: Prediction sequence is empty after token filtering!")
                        
                        # Convert list to numpy array
                        pred_indices = np.array(pred_indices)
                        
                        # Print top beam search results
                        print("\nTop beam search results:")
                        for j, hyp in enumerate(nbest_hyps[:3]):  # Show top 3 results
                            # Clean up the sequence - remove SOS/EOS tokens
                            hyp_indices = hyp["yseq"][1:] if hyp["yseq"][0] == 1 else hyp["yseq"]
                            if 2 in hyp_indices:
                                eos_idx = hyp_indices.index(2)
                                hyp_indices = hyp_indices[:eos_idx]
                            
                            hyp_text = indices_to_text(hyp_indices, idx2char)
                            try:
                                print(f"  Hyp {j+1}: {hyp_text} (Score: {hyp['score']:.4f})")
                            except UnicodeEncodeError:
                                print(f"  Hyp {j+1}: [Text contains non-displayable characters] (Score: {hyp['score']:.4f})")
                                print(f"  Token indices: {hyp_indices}")
                    else:
                        # No hypotheses for this batch item - use empty prediction
                        print(f"No hypotheses for batch item {b+1}")
                        pred_indices = np.array([])
                        
                        # Add more debug info to understand why no hypotheses were returned
                        if b >= len(all_nbest_hyps):
                            print(f"  Issue: Batch index {b} is out of range for all_nbest_hyps (len={len(all_nbest_hyps)})")
                        elif len(all_nbest_hyps[b]) == 0:
                            print(f"  Issue: Empty hypothesis list for batch item {b+1}")
                    
                    # Get target indices
                    start_idx = sum(label_lengths[:b].cpu().tolist()) if b > 0 else 0
                    end_idx = start_idx + label_lengths[b].item()
                    target_idx = labels_flat[start_idx:end_idx].cpu().numpy()
                    
                    # Convert indices to text
                    pred_text = indices_to_text(pred_indices, idx2char)
                    target_text = indices_to_text(target_idx, idx2char)
                    
                    # Calculate CER using custom function
                    cer, edit_distance = compute_cer(target_idx, pred_indices)
                    
                    # Update statistics
                    total_cer += cer
                    total_edit_distance += edit_distance
                    total_loss += loss.item() / batch_size
                    
                    # Print info
                    print("-" * 50)
                    print(f"Sample {i * batch_size + b + 1}:")
                    try:
                        print(f"Predicted text: {pred_text}")
                        print(f"Target text: {target_text}")
                    except UnicodeEncodeError:
                        print("Predicted text: [Contains characters that can't be displayed in console]")
                        print("Target text: [Contains characters that can't be displayed in console]")
                        print(f"Predicted indices: {pred_indices}")
                        print(f"Target indices: {target_idx}")
                        
                    print(f"Edit distance: {edit_distance}")
                    print(f"CER: {cer:.4f}")
                    print("-" * 50)
            
            except Exception as e:
                print(f"Error during TCN decoding: {str(e)}")
                import traceback
                traceback.print_exc()
                # Fall back to CTC greedy decoding
                print("Falling back to CTC greedy decoding")
                
                # Process batch items with greedy decoding
                for b in range(batch_size):
                    batch_logits = log_probs[b].cpu().numpy()
                    pred_indices = np.argmax(batch_logits, axis=1)
                    # Remove duplicates and blanks
                    filtered_indices = []
                    prev_idx = -1
                    for idx in pred_indices:
                        if idx != 0 and idx != prev_idx:  # Skip blanks and duplicates
                            filtered_indices.append(idx)
                        prev_idx = idx
                    pred_indices = np.array(filtered_indices)
                    
                    # Get target indices
                    start_idx = sum(label_lengths[:b].cpu().tolist()) if b > 0 else 0
                    end_idx = start_idx + label_lengths[b].item()
                    target_idx = labels_flat[start_idx:end_idx].cpu().numpy()
                    
                    # Convert indices to text
                    pred_text = indices_to_text(pred_indices, idx2char)
                    target_text = indices_to_text(target_idx, idx2char)
                    
                    # Calculate CER
                    cer, edit_distance = compute_cer(target_idx, pred_indices)
                    
                    # Update statistics
                    total_cer += cer
                    total_edit_distance += edit_distance
                    total_loss += loss.item() / batch_size
                    
                    # Print info
                    print("-" * 50)
                    print(f"Sample {i * batch_size + b + 1} (Greedy CTC):")
                    print(f"Predicted text: {pred_text}")
                    print(f"Target text: {target_text}")
                    print(f"Edit distance: {edit_distance}")
                    print(f"CER: {cer:.4f}")
                    print("-" * 50)
        
        # Write summary statistics
        n_samples = len(data_loader.dataset)
        avg_cer = total_cer / n_samples
        avg_edit_distance = total_edit_distance / n_samples
        avg_loss = total_loss / n_samples
        
        print("=== Summary Statistics ===")
        print(f"Total samples: {n_samples}")
        print(f"Average CER: {avg_cer:.4f}")
        print(f"Average Edit Distance: {avg_edit_distance:.2f}")
        print(f"Average Loss: {avg_loss:.4f}")

    return total_loss / len(data_loader)


In [None]:
# %%
def train_model():
    # Train and validate
    for epoch in range(total_epochs):
        train_one_epoch()
        scheduler.adjust_lr(optimizer, epoch)
        val_loss = evaluate_model(val_loader)
        print(f"Epoch {epoch + 1}/{total_epochs}, Val Loss: {val_loss:.4f}")


In [None]:
# %%
def quick_experiment():
    """
    Run a quick experiment with a small subset of the data to test the TCN decoder.
    Uses 30 training samples and 5 test samples.
    """
    print("Running quick experiment with a small dataset...")
    
    try:
        # Open a file to save results (avoids console encoding issues)
        with open('tcn_results.txt', 'w', encoding='utf-8') as results_file:
            # Write function to ensure output is flushed to disk
            def write_line(line):
                results_file.write(line + "\n")
                results_file.flush()  # Flush after each write to ensure data is saved
            
            write_line("TCN Decoder Experiment Results")
            write_line("============================")
            write_line("")
            
            try:
                # Create small datasets for quick testing
                small_train_dataset = torch.utils.data.Subset(train_dataset, list(range(30)))
                small_val_dataset = torch.utils.data.Subset(val_dataset, list(range(5)))
                
                # Create dataloaders with the small datasets
                small_train_loader = DataLoader(small_train_dataset, batch_size=5, shuffle=True, 
                                            pin_memory=True, collate_fn=pad_packed_collate)
                small_val_loader = DataLoader(small_val_dataset, batch_size=2, shuffle=False, 
                                         pin_memory=True, collate_fn=pad_packed_collate)
                
                # Train for only 5 epochs
                write_line("Training on 30 samples...")
                for epoch in range(5):
                    try:
                        # Training
                        model.train()
                        tcn_decoder.train()
                        running_loss = 0.0
                        ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
                        ce_loss = nn.CrossEntropyLoss(ignore_index=0)
                        
                        for batch_idx, (inputs, input_lengths, labels_flat, label_lengths) in enumerate(small_train_loader):
                            try:
                                # Move data to device
                                inputs = inputs.to(device)
                                input_lengths = input_lengths.to(device)
                                labels_flat = labels_flat.to(device)
                                label_lengths = label_lengths.to(device)
                                
                                # Zero the gradients
                                optimizer.zero_grad()
                                
                                # Forward pass through encoder
                                encoder_features = model(inputs, input_lengths)
                                output_lengths = torch.full((encoder_features.size(0),), encoder_features.size(1), 
                                                       dtype=torch.long, device=device)
                                
                                # Apply log_softmax for CTC
                                log_probs = F.log_softmax(encoder_features, dim=2)
                                outputs_for_ctc = log_probs.transpose(0, 1)
                                
                                # Compute CTC loss
                                ctc_loss_val = ctc_loss(outputs_for_ctc, labels_flat, output_lengths, label_lengths)
                                
                                # Prepare target sequences for TCN
                                target_seqs = []
                                start_idx = 0
                                batch_size = inputs.size(0)
                                
                                for b in range(batch_size):
                                    seq_len = label_lengths[b].item()
                                    target_seq = labels_flat[start_idx:start_idx + seq_len]
                                    # Add SOS and EOS tokens
                                    target_seq = torch.cat([torch.tensor([1], device=device), 
                                                       target_seq, 
                                                       torch.tensor([2], device=device)])
                                    target_seqs.append(target_seq)
                                    start_idx += seq_len
                                
                                # Pad sequences to same length
                                max_len = max(len(seq) for seq in target_seqs)
                                padded_seqs = []
                                for seq in target_seqs:
                                    padded = torch.cat([seq, torch.zeros(max_len - len(seq), 
                                                                     device=device, dtype=torch.long)])
                                    padded_seqs.append(padded)
                                
                                # Stack sequences
                                target_tensor = torch.stack(padded_seqs)
                                
                                # Create masks
                                memory_mask = torch.ones((batch_size, encoder_features.size(1)), device=device)
                                
                                # Prepare decoder input/output
                                decoder_input = target_tensor[:, :-1]  # Remove last token (shift right)
                                decoder_output = target_tensor[:, 1:]  # Remove first token
                                
                                # Create target mask for TCN
                                tgt_mask = torch.ones((batch_size, decoder_input.size(1)), device=device)
                                tgt_mask = tgt_mask.expand(batch_size, -1)
                                
                                # Forward through TCN
                                tcn_out = tcn_decoder(encoder_features)
                                
                                # Calculate TCN loss - need to handle shape mismatch
                                # Get sequence lengths for proper comparison
                                tcn_seq_len = tcn_out.size(1)
                                decoder_seq_len = decoder_output.size(1)
                                
                                # Log shapes to file
                                write_line(f"Batch {batch_idx+1} - TCN output shape: {tcn_out.shape}, " + 
                                       f"Decoder output shape: {decoder_output.shape}")
                                
                                # Adjust decoder_output to match tcn_out length using interpolation if needed
                                if tcn_seq_len != decoder_seq_len:
                                    write_line(f"Sequence length mismatch: TCN={tcn_seq_len}, Decoder={decoder_seq_len}")
                                    # Use only the common prefix of both sequences
                                    min_seq_len = min(tcn_seq_len, decoder_seq_len)
                                    tcn_out = tcn_out[:, :min_seq_len, :]
                                    decoder_output = decoder_output[:, :min_seq_len]
                                    write_line(f"Using common prefix with length {min_seq_len}")
                                    write_line(f"New shapes - TCN: {tcn_out.shape}, Decoder: {decoder_output.shape}")
                                
                                # Flatten for cross entropy loss
                                tcn_out_flat = tcn_out.reshape(-1, tcn_out.size(-1))
                                decoder_output_flat = decoder_output.reshape(-1)
                                
                                # Calculate loss
                                tcn_loss = ce_loss(tcn_out_flat, decoder_output_flat)
                                
                                # Combined loss
                                alpha = 0.7  # Weight for CTC loss
                                loss = alpha * ctc_loss_val + (1 - alpha) * tcn_loss
                                
                                # Backward pass and optimize
                                loss.backward()
                                optimizer.step()
                                
                                # Print and log progress
                                running_loss += loss.item()
                                log_msg = f"Batch {batch_idx+1} - CTC: {ctc_loss_val.item():.4f}, " + \
                                      f"TCN: {tcn_loss.item():.4f}, Loss: {loss.item():.4f}"
                                print(log_msg)
                                write_line(log_msg)
                            except Exception as e:
                                write_line(f"Error in batch {batch_idx+1}: {str(e)}")
                                import traceback
                                write_line(traceback.format_exc())
                        
                        # Evaluate
                        epoch_msg = f"\nEpoch {epoch+1}/5 - Loss: {running_loss/len(small_train_loader):.4f}"
                        print(epoch_msg)
                        write_line(epoch_msg)
                        
                        # Test with beam search to see actual outputs
                        write_line("\nTesting model with beam search decoding...")
                        test_batch(small_val_loader, results_file, write_line)
                        
                    except Exception as e:
                        write_line(f"Error in epoch {epoch+1}: {str(e)}")
                        import traceback
                        write_line(traceback.format_exc())
            
            except Exception as e:
                write_line(f"Error setting up experiment: {str(e)}")
                import traceback
                write_line(traceback.format_exc())
    
    except Exception as e:
        print(f"Critical error: {str(e)}")
        import traceback
        traceback.print_exc()
    
    print("\nExperiment complete! Check tcn_results.txt for details")

def test_batch(data_loader, results_file, write_fn):
    """
    Test the model with the TCN decoder using beam search.
    Args:
        data_loader: DataLoader with test data
        results_file: File handle for saving results
        write_fn: Function for writing lines to the file
    """
    model.eval()
    tcn_decoder.eval()
    
    with torch.no_grad():
        # Process validation data
        for i, (inputs, input_lengths, labels_flat, label_lengths) in enumerate(data_loader):
            try:
                # Move to device
                inputs = inputs.to(device)
                input_lengths = input_lengths.to(device)
                labels_flat = labels_flat.to(device)
                label_lengths = label_lengths.to(device)
                
                # Forward pass through visual encoder
                batch_size = inputs.size(0)
                encoder_features = model(inputs, input_lengths)  # (B, T, hidden_dim)
                
                write_fn(f"\nRunning TCN decoding for batch {i+1}...")
                write_fn(f"Encoder features shape: {encoder_features.shape}")
                write_fn(f"Using beam size: 5")
                
                # Batch beam search
                all_nbest_hyps = tcn_decoder.batch_beam_search(
                    encoder_features, beam_size=5, maxlen=24
                )
                
                write_fn(f"TCN decoding completed for batch {i+1}")
                write_fn(f"Received {len(all_nbest_hyps)} hypotheses sets")
                
                # Process each batch item
                for b in range(batch_size):
                    try:
                        write_fn(f"\nProcessing batch item {b+1}/{batch_size}")
                        
                        # Get target indices
                        start_idx = sum(label_lengths[:b].cpu().tolist()) if b > 0 else 0
                        end_idx = start_idx + label_lengths[b].item()
                        target_idx = labels_flat[start_idx:end_idx].cpu().numpy()
                        
                        # Get target text representation
                        target_text = indices_to_text(target_idx, idx2char)
                        write_fn(f"Target text: {target_text}")
                        write_fn(f"Target indices: {target_idx.tolist()}")
                        
                        # Get best hypothesis for this batch item
                        if b < len(all_nbest_hyps) and len(all_nbest_hyps[b]) > 0:
                            nbest_hyps = all_nbest_hyps[b]  # List of beams for this batch item
                            write_fn(f"Found {len(nbest_hyps)} beam hypotheses for item {b+1}")
                            
                            best_hyp = nbest_hyps[0]  # Best beam (highest score)
                            write_fn(f"Best hypothesis raw sequence: {best_hyp['yseq']}")
                            
                            # Get predicted sequence (remove SOS token if present)
                            pred_indices = best_hyp["yseq"][1:] if best_hyp["yseq"][0] == 1 else best_hyp["yseq"]
                            write_fn(f"After SOS removal: {pred_indices}")
                            
                            # Clean up sequence (remove padding, EOS tokens)
                            # Assuming 2 is EOS token
                            if 2 in pred_indices:
                                eos_idx = pred_indices.index(2)
                                pred_indices = pred_indices[:eos_idx]
                                write_fn(f"After EOS removal: {pred_indices}")
                            
                            # Print warning if pred_indices is empty
                            if len(pred_indices) == 0:
                                write_fn("WARNING: Prediction sequence is empty after token filtering!")
                            
                            # Convert list to numpy array
                            pred_indices = np.array(pred_indices)
                            
                            # Print top beam search results
                            write_fn("\nTop beam search results:")
                            for j, hyp in enumerate(nbest_hyps[:3]):  # Show top 3 results
                                # Clean up the sequence - remove SOS/EOS tokens
                                hyp_indices = hyp["yseq"][1:] if hyp["yseq"][0] == 1 else hyp["yseq"]
                                if 2 in hyp_indices:
                                    eos_idx = hyp_indices.index(2)
                                    hyp_indices = hyp_indices[:eos_idx]
                                
                                hyp_text = indices_to_text(hyp_indices, idx2char)
                                write_fn(f"  Hyp {j+1}: {hyp_text} (Score: {hyp['score']:.4f})")
                                write_fn(f"  Token indices: {hyp_indices}")
                        else:
                            # No hypotheses for this batch item - use empty prediction
                            write_fn(f"No hypotheses for batch item {b+1}")
                            
                            # Add more debug info to understand why no hypotheses were returned
                            if b >= len(all_nbest_hyps):
                                write_fn(f"  Issue: Batch index {b} is out of range for all_nbest_hyps (len={len(all_nbest_hyps)})")
                            elif len(all_nbest_hyps[b]) == 0:
                                write_fn(f"  Issue: Empty hypothesis list for batch item {b+1}")
                            
                            pred_indices = np.array([])
                        
                        # Convert indices to text
                        pred_text = indices_to_text(pred_indices, idx2char)
                        
                        # Calculate CER using custom function
                        cer, edit_distance = compute_cer(target_idx, pred_indices)
                        
                        # Print info to file
                        write_fn("-" * 50)
                        write_fn(f"Sample {i * batch_size + b + 1}:")
                        write_fn(f"Predicted text: {pred_text}")
                        write_fn(f"Predicted indices: {pred_indices.tolist()}")
                        write_fn(f"Target text: {target_text}")
                        write_fn(f"Target indices: {target_idx.tolist()}")
                        write_fn(f"Edit distance: {edit_distance}")
                        write_fn(f"CER: {cer:.4f}")
                        write_fn("-" * 50)
                    except Exception as e:
                        write_fn(f"Error processing batch item {b+1}: {str(e)}")
                        import traceback
                        write_fn(traceback.format_exc())
            except Exception as e:
                write_fn(f"Error processing batch {i+1}: {str(e)}")
                import traceback
                write_fn(traceback.format_exc())


In [None]:
# %%
reset_seed()
# Uncomment one of the following lines to run the full training or quick experiment
# train_model()
quick_experiment()  # Run the quick experiment with TCN decoder instead of beam search
