# Comprehensive Model Comparison for Sign Language Recognition

This notebook compares all implemented models:
1. Original I3D Model (WLASL baseline)
2. Memory-Efficient Models
3. Hybrid Transformer Models

We'll evaluate:
- Model performance
- Memory efficiency
- Training speed
- Cross-validation results

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import time
import logging
from tqdm.notebook import tqdm
import psutil
import gc

# Import original I3D model
from WLASL.code.I3D.models.pytorch_i3d import InceptionI3d

# Import memory-efficient models
from wlasl_modified.src.models.efficient_sign_net import EfficientSignNet

# Import hybrid models
from wlasl_modified.src.models.hybrid_transformers import CNNTransformer, TimeSformer

# Import data processing
from wlasl_modified.src.data.preprocessing import MemoryEfficientPreprocessor
from wlasl_modified.src.data.loader import create_data_loaders

# Import training utilities
from wlasl_modified.src.training.hybrid_trainers import create_trainer
from wlasl_modified.src.training.cross_validate import CrossValidationTrainer

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

%matplotlib inline
plt.style.use('seaborn')

## 1. Data Loading and Preprocessing

In [None]:
# Configure data paths
DATA_DIR = Path('data')
PROCESSED_DIR = Path('processed')

def prepare_dataset(data_dir: Path, processed_dir: Path):
    """Prepare dataset with memory-efficient preprocessing."""
    preprocessor = MemoryEfficientPreprocessor(
        output_dir=processed_dir / 'frames',
        frame_size=(224, 224),
        target_fps=25,
        chunk_size=32
    )
    
    # Process videos
    video_paths = list(data_dir.glob('**/*.mp4'))
    logger.info(f"Found {len(video_paths)} videos")
    
    results = []
    for video_path in tqdm(video_paths, desc="Processing videos"):
        try:
            result = preprocessor.preprocess_video(video_path)
            results.append(result)
        except Exception as e:
            logger.error(f"Error processing {video_path}: {str(e)}")
    
    return results

# Prepare dataset if not already processed
if not (PROCESSED_DIR / 'frames').exists():
    preprocessing_results = prepare_dataset(DATA_DIR, PROCESSED_DIR)
else:
    logger.info("Using previously processed data")

## 2. Data Analysis

In [None]:
def analyze_dataset(processed_dir: Path):
    """Analyze processed dataset statistics."""
    frame_dirs = list((processed_dir / 'frames').glob('*'))
    
    stats = {
        'num_samples': len(frame_dirs),
        'frames_per_video': [],
        'video_sizes': []
    }
    
    for frame_dir in frame_dirs:
        frames = list(frame_dir.glob('*.jpg'))
        stats['frames_per_video'].append(len(frames))
        
        # Calculate video size
        size = sum(f.stat().st_size for f in frames) / 1024**2  # MB
        stats['video_sizes'].append(size)
    
    return stats

# Analyze dataset
dataset_stats = analyze_dataset(PROCESSED_DIR)

# Plot statistics
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Frames per video distribution
sns.histplot(dataset_stats['frames_per_video'], ax=ax1)
ax1.set_title('Frames per Video Distribution')
ax1.set_xlabel('Number of Frames')

# Video sizes distribution
sns.histplot(dataset_stats['video_sizes'], ax=ax2)
ax2.set_title('Video Sizes Distribution')
ax2.set_xlabel('Size (MB)')

plt.tight_layout()
plt.show()

## 3. Model Creation and Configuration

In [None]:
def create_models(num_classes: int):
    """Create all model variants."""
    models = {
        'i3d': InceptionI3d(
            num_classes=num_classes,
            in_channels=3
        ),
        'efficient': EfficientSignNet(
            num_classes=num_classes,
            in_channels=3
        ),
        'cnn_transformer': CNNTransformer(
            num_classes=num_classes,
            num_frames=30
        ),
        'timesformer': TimeSformer(
            num_classes=num_classes,
            num_frames=30
        )
    }
    
    return models

# Create models
NUM_CLASSES = 26  # Adjust based on your dataset
models = create_models(NUM_CLASSES)

# Print model summaries
for name, model in models.items():
    print(f"\n{name.upper()} Model Summary:")
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Cross-Validation Training

In [None]:
def train_and_evaluate(models: dict, data_loaders: dict, num_epochs: int = 20):
    """Train and evaluate all models with cross-validation."""
    results = {}
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    for name, model in models.items():
        logger.info(f"\nTraining {name}...")
        model = model.to(device)
        
        # Create trainer
        trainer = CrossValidationTrainer(
            model=model,
            train_loader=data_loaders['train'],
            val_loader=data_loaders['val'],
            device=device,
            num_folds=7
        )
        
        # Train with cross-validation
        cv_results = trainer.train(num_epochs)
        results[name] = cv_results
        
        # Clear memory
        del model, trainer
        gc.collect()
        torch.cuda.empty_cache()
    
    return results

# Create data loaders
data_loaders = create_data_loaders(
    processed_dir=PROCESSED_DIR,
    batch_size=8,
    num_workers=4
)

# Train all models
cv_results = train_and_evaluate(models, data_loaders)

# Plot cross-validation results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

for name, result in cv_results.items():
    ax1.plot(result['val_accuracy'], label=name)
    ax2.plot(result['val_loss'], label=name)

ax1.set_title('Validation Accuracy')
ax1.set_xlabel('Epoch')
ax1.legend()

ax2.set_title('Validation Loss')
ax2.set_xlabel('Epoch')
ax2.legend()

plt.tight_layout()
plt.show()

## 5. Memory Usage Analysis

In [None]:
def profile_memory_usage(models: dict, sample_batch: torch.Tensor):
    """Profile memory usage for all models."""
    memory_stats = {}
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    for name, model in models.items():
        model = model.to(device)
        torch.cuda.reset_peak_memory_stats()
        
        # Forward pass
        with torch.no_grad():
            _ = model(sample_batch.to(device))
        
        memory_stats[name] = {
            'gpu_allocated': torch.cuda.memory_allocated() / 1024**2,
            'gpu_reserved': torch.cuda.memory_reserved() / 1024**2
        }
        
        del model
        gc.collect()
        torch.cuda.empty_cache()
    
    return memory_stats

# Create sample batch
sample_batch = torch.randn(8, 30, 3, 224, 224)

# Profile memory usage
memory_stats = profile_memory_usage(models, sample_batch)

# Plot memory usage
fig, ax = plt.subplots(figsize=(10, 5))

x = np.arange(len(memory_stats))
width = 0.35

allocated = [stats['gpu_allocated'] for stats in memory_stats.values()]
reserved = [stats['gpu_reserved'] for stats in memory_stats.values()]

ax.bar(x - width/2, allocated, width, label='Allocated')
ax.bar(x + width/2, reserved, width, label='Reserved')

ax.set_ylabel('Memory (MB)')
ax.set_title('GPU Memory Usage by Model')
ax.set_xticks(x)
ax.set_xticklabels(memory_stats.keys())
ax.legend()

plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

## 6. Performance Metrics Comparison

In [None]:
def compile_metrics(cv_results: dict):
    """Compile performance metrics for all models."""
    metrics = []
    
    for name, result in cv_results.items():
        metrics.append({
            'model': name,
            'accuracy': np.mean(result['val_accuracy']),
            'accuracy_std': np.std(result['val_accuracy']),
            'training_time': result.get('training_time', 0),
            'memory_usage': memory_stats[name]['gpu_allocated']
        })
    
    return pd.DataFrame(metrics)

# Compile and display metrics
metrics_df = compile_metrics(cv_results)
display(metrics_df)

# Plot comparison
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Accuracy comparison
sns.barplot(data=metrics_df, x='model', y='accuracy', ax=axes[0, 0])
axes[0, 0].set_title('Model Accuracy')
axes[0, 0].tick_params(axis='x', rotation=45)

# Training time comparison
sns.barplot(data=metrics_df, x='model', y='training_time', ax=axes[0, 1])
axes[0, 1].set_title('Training Time (s)')
axes[0, 1].tick_params(axis='x', rotation=45)

# Memory usage comparison
sns.barplot(data=metrics_df, x='model', y='memory_usage', ax=axes[1, 0])
axes[1, 0].set_title('Memory Usage (MB)')
axes[1, 0].tick_params(axis='x', rotation=45)

# Accuracy vs Memory trade-off
sns.scatterplot(data=metrics_df, x='memory_usage', y='accuracy', ax=axes[1, 1])
for i, row in metrics_df.iterrows():
    axes[1, 1].annotate(row['model'], (row['memory_usage'], row['accuracy']))
axes[1, 1].set_title('Accuracy vs Memory Trade-off')

plt.tight_layout()
plt.show()

## 7. Conclusion and Recommendations

Based on the comparison above:

1. Model Performance:
   - Best accuracy: [Model name]
   - Most memory efficient: [Model name]
   - Best accuracy/memory trade-off: [Model name]

2. Recommendations:
   - For high accuracy: Use [Model name]
   - For limited resources: Use [Model name]
   - For balanced performance: Use [Model name]

3. Memory Optimization Tips:
   - Use gradient checkpointing for large models
   - Enable mixed precision training
   - Adjust batch size based on available memory
   - Use appropriate model for your hardware constraints