# Enhanced WLASL Framework Demo

This notebook demonstrates the complete pipeline of the enhanced WLASL framework, including:
1. Data Loading and Preprocessing
2. Data Analysis
3. Model Training
4. Cross-Validation
5. Evaluation and Visualization

In [None]:
import os
import sys
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import Image, display

# Add project root to path
PROJECT_ROOT = Path().absolute()
sys.path.append(str(PROJECT_ROOT))

from configs.base_config import *
from src.preprocessing.video_processor import VideoProcessor, BatchVideoProcessor
from src.data.data_loader import SignLanguageDataset, create_data_loaders
from src.training.trainer import Trainer
from src.training.cross_validate import CrossValidator

# Enable interactive plots
%matplotlib inline
plt.style.use('seaborn')

## 1. Data Loading and Preprocessing

First, let's load and preprocess some sample videos using our memory-efficient processing pipeline.

In [None]:
# Initialize video processor
video_processor = VideoProcessor(
    frame_size=DATA_CONFIG['frame_size'],
    num_frames=DATA_CONFIG['num_frames'],
    fps=DATA_CONFIG['fps']
)

# Process a batch of videos
sample_video_dir = DATA_DIR / 'raw_videos'
video_paths = list(sample_video_dir.glob('*.mp4'))

print(f"Found {len(video_paths)} videos to process")

batch_processor = BatchVideoProcessor(video_processor)
batch_processor.process_batch(video_paths[:5])  # Process first 5 videos as example

### Visualize Processed Frames

Let's look at some processed frames to verify our preprocessing pipeline.

In [None]:
def plot_video_frames(video_path, num_frames=5):
    """Plot sample frames from a processed video."""
    frames = video_processor.process_video(video_path)
    
    fig, axes = plt.subplots(1, num_frames, figsize=(15, 3))
    step = len(frames) // num_frames
    
    for i, ax in enumerate(axes):
        ax.imshow(frames[i * step])
        ax.axis('off')
        ax.set_title(f'Frame {i * step}')
    
    plt.tight_layout()
    plt.show()

# Visualize frames from first video
plot_video_frames(video_paths[0])

## 2. Data Analysis

Now let's analyze our dataset to understand its characteristics.

In [None]:
# Load dataset info
import json
with open(DATA_DIR / 'data_info.json', 'r') as f:
    data_info = json.load(f)

# Analyze class distribution
class_counts = {}
signer_counts = {}

for item in data_info:
    class_counts[item['label']] = class_counts.get(item['label'], 0) + 1
    signer_counts[item['signer_id']] = signer_counts.get(item['signer_id'], 0) + 1

# Plot class distribution
plt.figure(figsize=(12, 6))
plt.bar(range(len(class_counts)), sorted(class_counts.values(), reverse=True))
plt.title('Class Distribution')
plt.xlabel('Class Index')
plt.ylabel('Number of Samples')
plt.show()

# Plot signer distribution
plt.figure(figsize=(12, 6))
plt.bar(range(len(signer_counts)), sorted(signer_counts.values(), reverse=True))
plt.title('Signer Distribution')
plt.xlabel('Signer ID')
plt.ylabel('Number of Videos')
plt.show()

## 3. Model Training

Let's train both I3D and TGCN models on our processed dataset.

In [None]:
# Create data loaders
dataloaders = create_data_loaders(data_info)

# Function to train and evaluate a model
def train_and_evaluate_model(model_name):
    # Setup model
    if model_name == 'i3d':
        from code.I3D.pytorch_i3d import InceptionI3d
        model = InceptionI3d(**I3D_CONFIG)
    else:
        from code.TGCN.tgcn_model import TGCN
        model = TGCN(**TGCN_CONFIG)
    
    # Setup training
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=TRAIN_CONFIG['learning_rate'],
        weight_decay=TRAIN_CONFIG['weight_decay']
    )
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=TRAIN_CONFIG['reduce_lr_factor'],
        patience=TRAIN_CONFIG['reduce_lr_patience']
    )
    
    trainer = Trainer(
        model=model,
        train_loader=dataloaders['train'],
        val_loader=dataloaders['val'],
        criterion=nn.CrossEntropyLoss(),
        optimizer=optimizer,
        scheduler=scheduler
    )
    
    # Train model
    history = trainer.train()
    return history

# Train I3D model
print("Training I3D model...")
i3d_history = train_and_evaluate_model('i3d')

# Train TGCN model
print("\nTraining TGCN model...")
tgcn_history = train_and_evaluate_model('tgcn')

## 4. Cross-Validation

Now let's perform cross-validation to get a better estimate of model performance.

In [None]:
def run_cross_validation(model_name):
    if model_name == 'i3d':
        from code.I3D.pytorch_i3d import InceptionI3d as ModelClass
        model_params = I3D_CONFIG
    else:
        from code.TGCN.tgcn_model import TGCN as ModelClass
        model_params = TGCN_CONFIG
    
    validator = CrossValidator(
        model_class=ModelClass,
        model_params=model_params,
        data_info=data_info,
        num_folds=TRAIN_CONFIG['num_folds']
    )
    
    results = validator.run()
    return results

# Run cross-validation for both models
print("Running cross-validation for I3D...")
i3d_cv_results = run_cross_validation('i3d')

print("\nRunning cross-validation for TGCN...")
tgcn_cv_results = run_cross_validation('tgcn')

## 5. Evaluation and Visualization

Let's visualize the results of our training and cross-validation.

In [None]:
def plot_training_curves(history, title):
    """Plot training and validation curves."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss curves
    ax1.plot(history['train_loss'], label='Train')
    ax1.plot(history['val_loss'], label='Validation')
    ax1.set_title(f'{title} - Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    
    # Accuracy curves
    ax2.plot(history['train_acc'], label='Train')
    ax2.plot(history['val_acc'], label='Validation')
    ax2.set_title(f'{title} - Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    
    plt.tight_layout()
    plt.show()

# Plot training curves
plot_training_curves(i3d_history, 'I3D Model')
plot_training_curves(tgcn_history, 'TGCN Model')

# Print cross-validation results
def print_cv_results(results, model_name):
    print(f"\n{model_name} Cross-Validation Results:")
    print("-" * 40)
    for metric, values in results['aggregate_metrics'].items():
        print(f"{metric}: {values['mean']:.2f} ± {values['std']:.2f}")

print_cv_results(i3d_cv_results, 'I3D')
print_cv_results(tgcn_cv_results, 'TGCN')

## 6. Model Comparison and Analysis

Finally, let's compare the performance of both models and analyze their strengths and weaknesses.

In [None]:
def compare_models(i3d_results, tgcn_results):
    """Compare performance metrics between models."""
    metrics = ['accuracy', 'precision', 'recall', 'f1']
    
    # Prepare data for plotting
    model_names = ['I3D', 'TGCN']
    metric_data = {
        metric: [i3d_results['aggregate_metrics'][metric]['mean'],
                tgcn_results['aggregate_metrics'][metric]['mean']]
        for metric in metrics
    }
    
    # Create comparison plot
    fig, ax = plt.subplots(figsize=(10, 6))
    x = np.arange(len(model_names))
    width = 0.15
    multiplier = 0
    
    for attribute, measurement in metric_data.items():
        offset = width * multiplier
        rects = ax.bar(x + offset, measurement, width, label=attribute)
        multiplier += 1
    
    ax.set_ylabel('Score')
    ax.set_title('Model Performance Comparison')
    ax.set_xticks(x + width * 2)
    ax.set_xticklabels(model_names)
    ax.legend(loc='lower right')
    ax.set_ylim(0, 100)
    
    plt.show()

# Compare model performances
compare_models(i3d_cv_results, tgcn_cv_results)