In [14]:
# %%
import torch, os, cv2
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import transforms
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
from PIL import Image
from lipreading.model import Lipreading
from lipreading.optim_utils import CosineScheduler

# Add the greedy CTC decoder functions
def greedy_ctc_decoder(logits, blank_index=0):
    """
    Greedy decoding for CTC.
    Assumes logits shape is (T, C) (log probabilities).
    Returns a list of predicted indices (for one sample).
    """
    # Convert to numpy if it's a tensor
    if isinstance(logits, torch.Tensor):
        logits = logits.detach().cpu().numpy()
    
    # Get the highest probability index at each timestep
    indices = np.argmax(logits, axis=1)  # (T,)
    
    # Remove duplicates and blanks
    filtered_indices = []
    prev_idx = -1
    for idx in indices:
        if idx != blank_index and idx != prev_idx:  # Skip blanks and duplicates
            filtered_indices.append(idx)
        prev_idx = idx
    
    return filtered_indices

def indices_to_text(indices, idx2char):
    """
    Converts a list of indices to text using the reverse vocabulary mapping.
    """
    return ''.join([idx2char.get(i, '') for i in indices])

def normalize_arabic_text(text):
    """
    Normalizes Arabic text by combining base characters with their diacritics.
    Returns a list of complete characters (base + diacritics).
    """
    chars = []
    current_char = ''
    
    diacritics = {
        '\u064B', '\u064C', '\u064D', '\u064E', '\u064F',
        '\u0650', '\u0651', '\u0652', '\u0670', '\u06E2',
        '\u0640'  # tatweel
    }
    
    for c in text:
        if c in diacritics:
            current_char += c
        else:
            if current_char:
                chars.append(current_char)
            current_char = c
    
    # Don't forget the last character
    if current_char:
        chars.append(current_char)
    
    return chars

def compute_cer(reference_indices, hypothesis_indices):
    """
    Computes Character Error Rate (CER) directly using token indices.
    Takes raw token indices from our vocabulary (class_mapping.txt) rather than Unicode text.
    
    Returns a tuple of (CER, reference_len, hypothesis_len, edit_distance)
    """
    # Use the indices directly - each index is one token in our vocabulary
    ref_tokens = reference_indices
    hyp_tokens = hypothesis_indices
    
    print(f"Debug - Reference tokens ({len(ref_tokens)} tokens): {ref_tokens}")
    print(f"Debug - Hypothesis tokens ({len(hyp_tokens)} tokens): {hyp_tokens}")
    
    m, n = len(ref_tokens), len(hyp_tokens)
    
    # Initialize the distance matrix
    dp = [[0 for _ in range(n+1)] for _ in range(m+1)]
    
    # Base cases: empty hypothesis or reference
    for i in range(m+1):
        dp[i][0] = i
    for j in range(n+1):
        dp[0][j] = j
    
    # Fill the distance matrix
    for i in range(1, m+1):
        for j in range(1, n+1):
            # If tokens match, no operation needed
            if ref_tokens[i-1] == hyp_tokens[j-1]:
                dp[i][j] = dp[i-1][j-1]
            else:
                # Minimum of:
                # 1. Substitution: dp[i-1][j-1] + 1
                # 2. Insertion: dp[i][j-1] + 1
                # 3. Deletion: dp[i-1][j] + 1
                dp[i][j] = min(dp[i-1][j-1] + 1,  # substitution
                              dp[i][j-1] + 1,      # insertion
                              dp[i-1][j] + 1)      # deletion
    
    edit_distance = dp[m][n]
    cer = edit_distance / max(m, 1)  # Avoid division by zero
    
    return cer, m, n, edit_distance


 # 2. Initialize the seed and the device

In [15]:
# %%
# Setting the seed for reproducibility
seed = 0
def reset_seed():
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

# Setting the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


 # 3. Dataset preparation

 ## 3.1. List of Classes

In [24]:
# %%
def extract_label(file):
    label = []
    diacritics = {
        '\u064B',  # Fathatan
        '\u064C',  # Dammatan
        '\u064D',  # Kasratan
        '\u064E',  # Fatha
        '\u064F',  # Damma
        '\u0650',  # Kasra
        '\u0651',  # Shadda
        '\u0652',  # Sukun
        '\u06E2',  # Small High meem
    }

    sentence = pd.read_csv(file)
    for word in sentence.word:
        for char in word:
            if char not in diacritics:
                label.append(char)
            else:
                label[-1] += char

    return label

classes = set()
for i in os.listdir('Arabic-Lib-Reading/Dataset/Csv (with Diacritics)'):
    file = 'Arabic-Lib-Reading/Dataset/Csv (with Diacritics)' + i
    label = extract_label(file)
    classes.update(label)

# Create mapping while safely handling Arabic characters
mapped_classes = {}
for i, c in enumerate(sorted(classes), 1):  
    mapped_classes[c] = i

# Print in a way that handles encoding properly
with open('class_mapping.txt', 'w', encoding='utf-8') as f:
    for char, idx in mapped_classes.items():
        f.write(f"{char}: {idx}\n")
    
# Just print count rather than the actual characters to avoid console encoding issues
print(f"Total characters in vocabulary: {len(mapped_classes)}")


FileNotFoundError: [WinError 3] The system cannot find the path specified: 'Arabic-Lib-Reading/Dataset/Csv (with Diacritics)'

 ## 3.2. Video Dataset Class

In [None]:
# %%
# Defining the video dataset class
class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, video_paths, label_paths, transform=None):
        self.video_paths = video_paths
        self.label_paths = label_paths
        self.transform = transform
        
    def __len__(self):
        return len(self.video_paths)
    
    def __getitem__(self, index):
        video_path = self.video_paths[index]
        label_path = self.label_paths[index]
        frames = self.load_frames(video_path=video_path)
        label = list(map(lambda x: mapped_classes[x], extract_label(label_path)))
        
        # Get the number of frames for sequence length
        input_length = torch.tensor(len(frames), dtype=torch.long)
        label_length = torch.tensor(len(label), dtype=torch.long)
        
        # Stack frames into a tensor of shape [C, T, H, W]
        if len(frames) > 0:
            # Stack the list of tensors into a single tensor
            stacked_frames = torch.stack(frames)  # Shape: [T, C, H, W]
            stacked_frames = stacked_frames.permute(1, 0, 2, 3)  # Shape: [C, T, H, W]
        else:
            # Handle empty frame list (shouldn't happen but just in case)
            stacked_frames = torch.zeros((1, 1, 112, 112))  # Single channel
        
        return stacked_frames, input_length, torch.tensor(label, dtype=torch.long), label_length
    
    def load_frames(self, video_path):
        frames = []
        video = cv2.VideoCapture(video_path)
        total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        for i in range(total_frames):
            video.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret, frame = video.read()
            if ret:
                # Convert to grayscale as the Lipreading model expects single-channel input
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                # Create PIL Image
                frame_pil = Image.fromarray(frame, 'L')  # 'L' is for grayscale
                frames.append(frame_pil)

        if self.transform is not None:
            frames = [self.transform(frame) for frame in frames] 
        
        return frames  # Return a list of frame tensors

# Defining the video transform
transforms = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.421], std=[0.165])  # For grayscale, single channel
])


 ## 3.2. Load & Split the dataset

In [None]:
# %%
# Limit to 30 samples total for testing
videos_dir = "Dataset/Video"
labels_dir = "Dataset/Csv (with Diacritics)"
videos, labels = [], []
file_names = [file_name[:-4] for file_name in os.listdir(videos_dir)]
for file_name in file_names:
    videos.append(os.path.join(videos_dir, file_name + ".mp4"))
    labels.append(os.path.join(labels_dir, file_name + ".csv"))

if len(videos) > 30:
    videos = videos[:30]
    labels = labels[:30]
    
# Split the dataset into training, validation, test sets
X_temp, X_test, y_temp, y_test = train_test_split(videos, labels, test_size=0.2, random_state=seed)  # 20% test
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.25, random_state=seed)  # 25% of remaining for validation

print(f"Dataset sizes: Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")


 ## 3.3. DataLoaders

In [None]:
# %%
def collate_fn(batch):
    """Pads data and labels with different lengths in the same batch
    """
    # Unpack the batch - each item is (frames, input_length, label, label_length)
    frames_list, input_lengths, labels_list, label_lengths = zip(*batch)
    
    # Get the max sequence length in this batch
    max_len = max(seq_len.item() for seq_len in input_lengths)
    
    # Get dimensions from the first item
    c, t, h, w = frames_list[0].shape  # c, t, h, w = channels, frames, height, width
    batch_size = len(frames_list)
    
    # Create a padded tensor for all sequences
    padded_frames = torch.zeros((batch_size, c, max_len, h, w))
    
    # Copy each sequence to the padded tensor
    for i, frames in enumerate(frames_list):
        seq_len = input_lengths[i].item()
        padded_frames[i, :, :seq_len, :, :] = frames[:, :seq_len, :, :]
    
    # Flatten labels for CTC loss
    labels_flat = []
    for label in labels_list:
        labels_flat.extend(label)
    labels_flat = torch.LongTensor(labels_flat)
    
    # Convert lengths to tensor
    input_lengths = torch.LongTensor(input_lengths)
    label_lengths = torch.LongTensor(label_lengths)
    
    return padded_frames, input_lengths, labels_flat, label_lengths


# Defining the video dataloaders (train, validation, test)
train_dataset = VideoDataset(X_train, y_train, transform=transforms)
val_dataset = VideoDataset(X_val, y_val, transform=transforms)
test_dataset = VideoDataset(X_test, y_test, transform=transforms)
batch_size = 2  # Changed from 4 to 2 for testing
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, collate_fn=collate_fn)


 # 4. Model

In [None]:
# %%
# Define a custom version of Lipreading that fixes the dimension issue
class FixedLipreading(Lipreading):
    def forward(self, x, lengths):
        B, C, T, H, W = x.size()
        # Process through frontend
        x = self.frontend(x)  # Shape: [B, frontend_nout=64, T, H/4, W/4]
        
        # Get new time dimension after frontend
        Tnew = x.shape[2]
        frontend_channels = x.shape[1]  # Should be 64
        
        # Reshape and permute for ResNet processing
        x = x.permute(0, 2, 1, 3, 4).contiguous()  # Shape: [B, T, frontend_nout, H/4, W/4]
        x = x.view(-1, frontend_channels, x.size(3), x.size(4))  # Shape: [B*T, frontend_nout, H/4, W/4]
        
        # Process through ResNet trunk
        x = self.trunk(x)  # Shape: [B*T, backend_out=512]
        
        # Reshape back to sequence form
        x = x.view(B, Tnew, -1)  # Shape: [B, T, backend_out]
        
        # Return features or process through TCN
        if not self.extract_feats:
            # DenseTCN expects input of shape (B, T, C)
            # The transpose is handled inside the DenseTCN forward method
            return self.tcn(x, lengths, B)
        
        return x

def initialize_model():
    # Define DenseTCN options tailored for character-level sequence recognition
    densetcn_options = {
        'block_config': [3, 3, 3],  # Configuration for DenseTCN blocks
        'growth_rate_set': [9, 9, 9],  # Ensure it's divisible by number of kernels (3)
        'reduced_size': 120,  # Ensure it's divisible by number of kernels (3)
        'kernel_size_set': [3, 5, 7],  # Multiple kernel sizes for different features
        'dilation_size_set': [1, 2, 4],  # Increasing dilation for longer dependencies
        'dropout': 0.2,  # Regularization
        'squeeze_excitation': True,  # Use SE for feature refinement
    }
    
    # Create the model with the appropriate number of classes (characters + blank)
    model = FixedLipreading(
        modality='video',
        hidden_dim=512,  # Match ResNet output size
        backbone_type='resnet',
        num_classes=len(mapped_classes) + 1,  # Number of characters in vocabulary + blank
        relu_type='prelu',
        densetcn_options=densetcn_options,
        extract_feats=False,  # We want predictions, not features
    )
    
    return model.to(device)

model = initialize_model()

# Build reverse mapping for decoding
idx2char = {v: k for k, v in mapped_classes.items()}
idx2char[0] = ""  # Blank token for CTC

# Defining the loss function and optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


 # 5. Training and Evaluation

In [None]:
# %%
# Training the model
def train_model():
    model.train()
    running_loss = 0.0
    ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
    
    # Process all batches (limited since we restricted the dataset size)
    
    for batch_idx, (inputs, input_lengths, labels_flat, label_lengths) in enumerate(train_loader):
        # Print input shape for debugging
        print(f"Batch {batch_idx+1} - Input shape: {inputs.shape}")
        
        # Move data to device
        inputs = inputs.to(device)
        input_lengths = input_lengths.to(device)
        labels_flat = labels_flat.to(device)
        label_lengths = label_lengths.to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass - get sequence logits
        logits = model(inputs, input_lengths)
        
        # Print shape to verify sequence output
        print(f"Batch {batch_idx+1} - Logits shape: {logits.shape}")
        
        # Apply log_softmax for CTC
        log_probs = F.log_softmax(logits, dim=2)  # (B, T, C)
        
        # Prepare for CTC loss - requires (T, B, C) format
        outputs_for_ctc = log_probs.transpose(0, 1)  # from (B, T, C) to (T, B, C)
        
        # Compute CTC loss
        loss = ctc_loss(outputs_for_ctc, labels_flat, input_lengths, label_lengths)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # Show detailed diagnostic for first sample in batch
        with torch.no_grad():  # Don't track gradients for decoding
            sample_idx = 0
            sample_logits = log_probs[sample_idx]  # Shape: [T, C]
            seq_len = input_lengths[sample_idx].item()
            
            # Only use valid timesteps (up to seq_len)
            valid_logits = sample_logits[:seq_len]
            
            # Decode using greedy CTC
            pred_indices = greedy_ctc_decoder(valid_logits, blank_index=0)
            pred_text = indices_to_text(pred_indices, idx2char)
            
            # Get target text
            start_idx = 0  # Start of first sample in batch
            end_idx = label_lengths[sample_idx].item()
            target_indices = labels_flat[start_idx:end_idx].cpu().tolist()
            target_text = indices_to_text(target_indices, idx2char)
            
            print(f"Training Sample - Batch {batch_idx+1}:")
            print(f"  Pred indices: {pred_indices}")
            print(f"  Target indices: {target_indices}")
            print(f"  Pred text: {pred_text}")
            print(f"  Target text: {target_text}")
            print(f"  CTC Loss: {loss.item():.4f}")
            print(f"  Sequence length: {seq_len}, Logits shape: {valid_logits.shape}")
        
    return running_loss / (batch_idx + 1)

# Define a separate testing function that uses our implementation
def test_model():
    """
    Tests the model on the test set. 
    Uses our token-based CER calculation.
    """
    model.eval()  # Set the model to evaluation mode
    ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True, reduction='mean')
    
    # Open the predictions file for writing with proper encoding for Arabic
    with open("predictions.txt", "w", encoding="utf-8-sig") as f:
        # Write header with Arabic support
        f.write("=== New Evaluation Run ===\n\n")
        f.write("Format: UTF-8 with Arabic support\n")
        f.write("Note: Lengths shown are token counts from class_mapping.txt\n\n")
        
        # Track statistics
        total_cer = 0
        total_loss = 0
        total_edit_distance = 0
        
        pred_lengths = []
        target_lengths = []
        
        # Process all batches in the test loader
        with torch.no_grad():
            for i, (frames, input_lengths, labels_flat, label_lengths) in enumerate(test_loader):
                # Move to device
                frames = frames.to(device)
                input_lengths = input_lengths.to(device)
                labels_flat = labels_flat.to(device)
                label_lengths = label_lengths.to(device)
                
                # Forward pass through the entire model
                batch_size = frames.size(0)
                logits = model(frames, input_lengths)
                
                # Apply log_softmax for CTC
                log_probs = F.log_softmax(logits, dim=2)  # (B, T, C)
                
                # For CTC loss we need (T, N, C) format
                log_probs_ctc = log_probs.transpose(0, 1)  # (T, B, C)
                
                # Make sure output_lengths are not greater than the input lengths
                output_lengths = torch.full((logits.size(0),), logits.size(1), dtype=torch.long, device=device)
                if output_lengths.max() > input_lengths.min():
                    scale_factor = input_lengths.min().float() / output_lengths.max().float()
                    output_lengths = (output_lengths.float() * scale_factor).long()
                
                # Calculate CTC loss
                loss = ctc_loss(log_probs_ctc, labels_flat, output_lengths, label_lengths)
                
                # Decode predictions - we convert back to numpy for greedy decoding
                logits_np = log_probs.cpu().detach().numpy()  # (B, T, C)
                
                # Process each batch item
                for b in range(batch_size):
                    # Get batch item logits
                    batch_logits = logits_np[b]  # (T, C)
                    
                    # Decode using CTC
                    pred_indices = greedy_ctc_decoder(batch_logits)
                    
                    # Get target indices
                    start_idx = sum(label_lengths[:b].cpu().tolist()) if b > 0 else 0
                    end_idx = start_idx + label_lengths[b].item()
                    target_idx = labels_flat[start_idx:end_idx].cpu().numpy()
                    
                    # Convert indices to text
                    pred_text = indices_to_text(pred_indices, idx2char)
                    target_text = indices_to_text(target_idx, idx2char)
                    
                    # Compute CER directly using token indices
                    cer, ref_len, hyp_len, edit_distance = compute_cer(target_idx, pred_indices)
                    
                    # Update statistics
                    total_cer += cer
                    total_loss += loss.item()
                    total_edit_distance += edit_distance
                    
                    pred_lengths.append(hyp_len)
                    target_lengths.append(ref_len)
                    
                    # Print info
                    print("-" * 50)
                    print(f"Sample {i * batch_size + b + 1}:")
                    print(f"Predicted indices: {pred_indices}")
                    print(f"Target indices: {list(target_idx)}")
                    print(f"Predicted text: {pred_text}")
                    print(f"Target text: {target_text}")
                    print(f"Lengths: pred={hyp_len} tokens, target={ref_len} tokens")
                    print(f"Edit Distance: {edit_distance}")
                    print(f"CER: {cer:.4f}")
                    print(f"CTC Loss: {loss.item():.4f}")
                    print("-" * 50)
                    
                    # Write to file with detailed information
                    f.write(f"Sample {i * batch_size + b + 1}:\n")
                    f.write("Prediction:\n")
                    f.write(f"  Text ({hyp_len} tokens): {pred_text}\n")
                    f.write(f"  Indices: {pred_indices}\n")
                    f.write("Target:\n")
                    f.write(f"  Text ({ref_len} tokens): {target_text}\n")
                    f.write(f"  Indices: {list(target_idx)}\n")
                    f.write("Metrics:\n")
                    f.write(f"  Edit Distance: {edit_distance}\n")
                    f.write(f"  CER: {cer:.4f}\n")
                    f.write(f"  CTC Loss: {loss.item():.4f}\n")
                    f.write("--------------------------------------------------\n\n")
            
            # Write summary statistics
            n_samples = len(test_loader.dataset)
            avg_cer = total_cer / n_samples
            avg_loss = total_loss / n_samples
            avg_edit_distance = total_edit_distance / n_samples
            
            f.write("=== Summary Statistics ===\n")
            f.write(f"Total samples: {n_samples}\n")
            f.write(f"Average CER: {avg_cer:.4f}\n")
            f.write(f"Average Edit Distance: {avg_edit_distance:.2f}\n")
            f.write(f"Average Loss: {avg_loss:.4f}\n\n")
            
            # Length statistics
            avg_pred_len = sum(pred_lengths) / len(pred_lengths)
            min_pred_len = min(pred_lengths)
            max_pred_len = max(pred_lengths)
            
            avg_target_len = sum(target_lengths) / len(target_lengths)
            min_target_len = min(target_lengths)
            max_target_len = max(target_lengths)
            
            f.write("Length Statistics:\n")
            f.write(f"Predictions:\n")
            f.write(f"  Average: {avg_pred_len:.1f} tokens\n")
            f.write(f"  Range: {min_pred_len} to {max_pred_len} tokens\n")
            f.write(f"Targets:\n")
            f.write(f"  Average: {avg_target_len:.1f} tokens\n")
            f.write(f"  Range: {min_target_len} to {max_target_len} tokens\n")
        
    return avg_cer, avg_loss

# Fix the main function to call our test_model function
if __name__ == "__main__":
    # Initialize the model
    model = initialize_model()
    model = model.to(device)
    
    # Test with the updated CER calculation
    avg_cer, avg_loss = test_model()
    print(f"Test CER: {avg_cer:.4f}, Test Loss: {avg_loss:.4f}")
