# Hybrid Transformer Models for Sign Language Recognition

This notebook demonstrates the usage of memory-efficient hybrid transformer models:
1. CNN-Transformer
2. TimeSformer

We'll compare their performance, memory usage, and training characteristics.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
import seaborn as sns
from tqdm.notebook import tqdm
import psutil
import gc

from src.models.hybrid_transformers import create_model
from src.training.hybrid_trainers import create_trainer
from configs.hybrid_transformer_config import get_config, print_memory_recommendations

# Enable notebook-wide memory tracking
%load_ext memory_profiler

## 1. Memory Usage Analysis

First, let's analyze memory usage patterns of both models.

In [None]:
def profile_model_memory(model_name: str, batch_size: int = 8):
    """Profile model memory usage."""
    config = get_config(model_name)
    model = create_model(
        model_name=model_name,
        num_classes=26,
        num_frames=30
    )
    
    if torch.cuda.is_available():
        model = model.cuda()
        torch.cuda.reset_peak_memory_stats()
    
    # Create dummy input
    x = torch.randn(batch_size, 30, 3, 224, 224)
    if torch.cuda.is_available():
        x = x.cuda()
    
    # Profile forward pass
    with torch.no_grad():
        output = model(x)
    
    memory_stats = {
        'ram_used': psutil.Process().memory_info().rss / 1024**3,
        'gpu_allocated': torch.cuda.memory_allocated() / 1024**3 if torch.cuda.is_available() else 0,
        'gpu_reserved': torch.cuda.memory_reserved() / 1024**3 if torch.cuda.is_available() else 0
    }
    
    return memory_stats

# Profile both models
models = ['cnn_transformer', 'timesformer']
batch_sizes = [1, 2, 4, 8, 16]
memory_data = []

for model_name in models:
    for bs in batch_sizes:
        stats = profile_model_memory(model_name, bs)
        stats.update({'model': model_name, 'batch_size': bs})
        memory_data.append(stats)
        gc.collect()
        torch.cuda.empty_cache()

# Plot results
df = pd.DataFrame(memory_data)
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
sns.lineplot(data=df, x='batch_size', y='gpu_allocated', hue='model')
plt.title('GPU Memory Usage')
plt.ylabel('Memory (GB)')

plt.subplot(1, 2, 2)
sns.lineplot(data=df, x='batch_size', y='ram_used', hue='model')
plt.title('RAM Usage')
plt.ylabel('Memory (GB)')

plt.tight_layout()
plt.show()

## 2. Training Performance Comparison

Let's compare training performance of both models.

In [None]:
def train_model(model_name: str, num_epochs: int = 5):
    """Train model and track metrics."""
    config = get_config(model_name)
    
    # Create model
    model = create_model(
        model_name=model_name,
        num_classes=26,
        num_frames=30
    )
    
    # Create dummy dataset
    class DummyDataset(torch.utils.data.Dataset):
        def __len__(self):
            return 100
        def __getitem__(self, idx):
            x = torch.randn(30, 3, 224, 224)
            y = torch.randint(0, 26, (1,))[0]
            return x, y
    
    train_loader = torch.utils.data.DataLoader(
        DummyDataset(),
        batch_size=config['trainer'].batch_size,
        shuffle=True
    )
    
    val_loader = torch.utils.data.DataLoader(
        DummyDataset(),
        batch_size=config['trainer'].batch_size
    )
    
    # Create trainer
    trainer = create_trainer(
        model_name=model_name,
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=torch.nn.CrossEntropyLoss(),
        optimizer=torch.optim.Adam(model.parameters()),
        device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        config=config['trainer'],
        checkpoint_dir=Path('checkpoints') / model_name
    )
    
    # Training loop
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_accuracy': [],
        'memory_usage': []
    }
    
    for epoch in range(num_epochs):
        train_metrics = trainer.train_epoch()
        val_metrics = trainer.validate()
        
        history['train_loss'].append(train_metrics['loss'])
        history['val_loss'].append(val_metrics['val_loss'])
        history['val_accuracy'].append(val_metrics['val_accuracy'])
        
        if torch.cuda.is_available():
            history['memory_usage'].append(
                torch.cuda.memory_allocated() / 1024**3
            )
    
    return history

# Train both models
histories = {}
for model_name in models:
    print(f"\nTraining {model_name}...")
    histories[model_name] = train_model(model_name)
    gc.collect()
    torch.cuda.empty_cache()

# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for model_name in models:
    history = histories[model_name]
    axes[0].plot(history['train_loss'], label=f'{model_name} train')
    axes[0].plot(history['val_loss'], label=f'{model_name} val')
    axes[0].set_title('Loss')
    axes[0].legend()
    
    axes[1].plot(history['val_accuracy'], label=model_name)
    axes[1].set_title('Validation Accuracy')
    axes[1].legend()
    
    if 'memory_usage' in history:
        axes[2].plot(history['memory_usage'], label=model_name)
        axes[2].set_title('GPU Memory Usage (GB)')
        axes[2].legend()

plt.tight_layout()
plt.show()

## 3. Model Features and Attention Analysis

In [None]:
def visualize_attention(model_name: str, sample_video: torch.Tensor):
    """Visualize model attention patterns."""
    model = create_model(
        model_name=model_name,
        num_classes=26,
        num_frames=30
    ).eval()
    
    # Get attention weights
    if model_name == 'cnn_transformer':
        # Get transformer attention from last layer
        attention = model.blocks[-1].attn.get_attention_weights(sample_video)
    else:
        # Get space-time attention
        attention = model.blocks[-1].get_attention_weights(sample_video)
    
    # Plot attention patterns
    plt.figure(figsize=(10, 5))
    sns.heatmap(attention[0].mean(0).cpu().detach())
    plt.title(f'{model_name} Attention Pattern')
    plt.xlabel('Key frames')
    plt.ylabel('Query frames')
    plt.show()

# Create sample video
sample_video = torch.randn(1, 30, 3, 224, 224)

# Visualize attention for both models
for model_name in models:
    visualize_attention(model_name, sample_video)

## 4. Memory Efficiency Tips

Here are some key recommendations for memory-efficient training:

In [None]:
print_memory_recommendations()

# Additional tips
tips = [
    "1. Use appropriate batch size and gradient accumulation:",
    "   - CNN-Transformer: batch_size=8, grad_accum=4",
    "   - TimeSformer: batch_size=4, grad_accum=8",
    "\n2. Enable memory optimizations:",
    "   model_config.use_checkpoint = True",
    "   trainer_config.mixed_precision = True",
    "\n3. Monitor memory usage:",
    "   trainer.memory_tracker.check_memory()"
]

print("\nImplementation Tips:")
print("\n".join(tips))