# Unified Model Loading in evlib

This notebook demonstrates the unified model loading system in evlib that supports multiple model formats:
- PyTorch models (.pth files)
- ONNX models (.onnx files) 
- SafeTensors models (.safetensors files)

The system provides automatic format detection, priority-based loading, and seamless integration with evlib's processing pipeline.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import time

try:
    import evlib
    print("✅ evlib imported successfully")
except ImportError as e:
    print(f"❌ Failed to import evlib: {e}")
    raise

## 1. Model Format Detection

The unified loader can automatically detect model formats based on file extensions and content validation.

In [None]:
def demonstrate_format_detection():
    """Demonstrate automatic model format detection"""
    
    # Test different file extensions
    test_paths = [
        "models/e2vid_model.pth",
        "models/e2vid_model.onnx", 
        "models/e2vid_model.safetensors",
        "models/unknown_model.bin"
    ]
    
    print("Model Format Detection:")
    print("=" * 40)
    
    for path in test_paths:
        # Simulate format detection (would use actual evlib functions)
        if path.endswith('.pth'):
            format_type = "PyTorch"
            priority = 2
        elif path.endswith('.onnx'):
            format_type = "ONNX"
            priority = 1  # Highest priority
        elif path.endswith('.safetensors'):
            format_type = "SafeTensors"
            priority = 3
        else:
            format_type = "Unknown"
            priority = 999
            
        print(f"📁 {path}")
        print(f"   Format: {format_type} (Priority: {priority})")
        print()

demonstrate_format_detection()

## 2. Priority-Based Loading

When multiple model formats are available, the loader uses a priority system:
1. **ONNX** (highest priority) - Optimised for inference
2. **PyTorch** - Native training format
3. **SafeTensors** - Secure serialisation format

In [None]:
def demonstrate_priority_loading():
    """Show how priority-based loading works"""
    
    # Simulate available models for the same architecture
    available_models = {
        "e2vid_base": ["e2vid_base.pth", "e2vid_base.onnx"],
        "firenet": ["firenet.pth", "firenet.safetensors"],
        "spade_e2vid": ["spade_e2vid.onnx"]
    }
    
    print("Priority-Based Model Loading:")
    print("=" * 50)
    
    for model_name, formats in available_models.items():
        print(f"🧠 Model: {model_name}")
        print(f"   Available formats: {formats}")
        
        # Determine which format would be loaded
        priorities = []
        for fmt in formats:
            if fmt.endswith('.onnx'):
                priorities.append((fmt, 1))
            elif fmt.endswith('.pth'):
                priorities.append((fmt, 2))
            elif fmt.endswith('.safetensors'):
                priorities.append((fmt, 3))
                
        selected = min(priorities, key=lambda x: x[1])[0]
        print(f"   ✅ Selected: {selected}")
        print()

demonstrate_priority_loading()

## 3. Performance Comparison

Different model formats have different loading and inference characteristics:

In [None]:
def benchmark_model_formats():
    """Simulate performance comparison between model formats"""
    
    # Simulated benchmark results (in real implementation, would use actual models)
    benchmark_data = {
        'ONNX': {
            'load_time': 0.05,
            'inference_time': 0.012,
            'memory_usage': 145,
            'supported_backends': ['CPU', 'CUDA', 'DirectML']
        },
        'PyTorch': {
            'load_time': 0.15,
            'inference_time': 0.018,
            'memory_usage': 180,
            'supported_backends': ['CPU', 'CUDA']
        },
        'SafeTensors': {
            'load_time': 0.08,
            'inference_time': 0.020,
            'memory_usage': 150,
            'supported_backends': ['CPU', 'CUDA']
        }
    }
    
    print("Model Format Performance Comparison:")
    print("=" * 60)
    print(f"{'Format':<12} {'Load (s)':<10} {'Inference (s)':<13} {'Memory (MB)':<12} {'Backends'}")
    print("-" * 60)
    
    for format_name, metrics in benchmark_data.items():
        backends_str = ', '.join(metrics['supported_backends'])
        print(f"{format_name:<12} {metrics['load_time']:<10.3f} {metrics['inference_time']:<13.3f} "
              f"{metrics['memory_usage']:<12} {backends_str}")
    
    # Create visualisation
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
    
    formats = list(benchmark_data.keys())
    
    # Load time comparison
    load_times = [benchmark_data[fmt]['load_time'] for fmt in formats]
    ax1.bar(formats, load_times, color=['#ff7f0e', '#2ca02c', '#d62728'])
    ax1.set_title('Model Load Time')
    ax1.set_ylabel('Time (seconds)')
    
    # Inference time comparison
    inference_times = [benchmark_data[fmt]['inference_time'] for fmt in formats]
    ax2.bar(formats, inference_times, color=['#ff7f0e', '#2ca02c', '#d62728'])
    ax2.set_title('Inference Time')
    ax2.set_ylabel('Time (seconds)')
    
    # Memory usage comparison
    memory_usage = [benchmark_data[fmt]['memory_usage'] for fmt in formats]
    ax3.bar(formats, memory_usage, color=['#ff7f0e', '#2ca02c', '#d62728'])
    ax3.set_title('Memory Usage')
    ax3.set_ylabel('Memory (MB)')
    
    plt.tight_layout()
    plt.show()

benchmark_model_formats()

## 4. Model Configuration System

The unified loader supports flexible configuration for different model architectures and inference settings:

In [None]:
def demonstrate_model_configuration():
    """Show different model configuration options"""
    
    configurations = {
        "E2VID Base": {
            "input_channels": 5,
            "output_channels": 1,
            "base_channels": 32,
            "use_skip_connections": True,
            "device_preference": "cuda"
        },
        "E2VID Lightweight": {
            "input_channels": 5,
            "output_channels": 1,
            "base_channels": 16,
            "use_skip_connections": False,
            "device_preference": "cpu"
        },
        "FireNet": {
            "input_channels": 3,
            "output_channels": 1,
            "use_separable_conv": True,
            "activation": "relu",
            "device_preference": "cuda"
        }
    }
    
    print("Model Configuration Examples:")
    print("=" * 50)
    
    for model_name, config in configurations.items():
        print(f"🔧 {model_name}:")
        for key, value in config.items():
            print(f"   {key}: {value}")
        print()

demonstrate_model_configuration()

## 5. Event Processing Workflow

Demonstrate a complete workflow using the unified model loading system:

In [None]:
def create_synthetic_events(num_events=1000, width=128, height=128):
    """Create synthetic event data for demonstration"""
    xs = np.random.randint(0, width, num_events, dtype=np.int64)
    ys = np.random.randint(0, height, num_events, dtype=np.int64)
    ts = np.sort(np.random.uniform(0, 1.0, num_events))
    ps = np.random.choice([-1, 1], num_events, dtype=np.int64)
    
    return xs, ys, ts, ps

def demonstrate_processing_workflow():
    """Show complete event processing workflow"""
    
    print("Event Processing with Unified Model Loading:")
    print("=" * 55)
    
    # Step 1: Create synthetic events
    print("📊 Generating synthetic events...")
    xs, ys, ts, ps = create_synthetic_events(2000)
    print(f"   Generated {len(xs)} events")
    
    # Step 2: Create voxel grid representation
    print("🔄 Converting to voxel grid...")
    try:
        voxel_data, voxel_shape = evlib.representations.events_to_voxel_grid(
            xs, ys, ts, ps, 5, (128, 128), "count"
        )
        voxel_grid = voxel_data.reshape(voxel_shape)
        print(f"   Voxel grid shape: {voxel_grid.shape}")
        
        # Visualise the voxel grid
        fig, axes = plt.subplots(1, 5, figsize=(15, 3))
        for i in range(5):
            im = axes[i].imshow(voxel_grid[i], cmap='viridis')
            axes[i].set_title(f'Time Bin {i}')
            axes[i].axis('off')
        plt.suptitle('Event Voxel Grid Representation')
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f"   ⚠️  Voxel grid creation failed: {e}")
        
    # Step 3: Simulate model loading and inference
    print("🧠 Loading model with unified loader...")
    model_info = {
        "name": "E2VID Base",
        "format": "ONNX",
        "input_shape": (1, 5, 128, 128),
        "output_shape": (1, 1, 128, 128)
    }
    
    print(f"   Model: {model_info['name']}")
    print(f"   Format: {model_info['format']}")
    print(f"   Input shape: {model_info['input_shape']}")
    
    # Step 4: Simulate inference
    print("⚡ Running inference...")
    start_time = time.time()
    
    # Simulate model inference (would use actual model in practice)
    time.sleep(0.01)  # Simulate processing time
    reconstructed_frame = np.random.rand(128, 128).astype(np.float32)
    
    inference_time = time.time() - start_time
    print(f"   Inference completed in {inference_time:.3f}s")
    
    # Visualise results
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Original events
    ax1.scatter(xs[ps > 0], ys[ps > 0], c='red', s=1, alpha=0.6, label='Positive')
    ax1.scatter(xs[ps < 0], ys[ps < 0], c='blue', s=1, alpha=0.6, label='Negative')
    ax1.set_xlim(0, 128)
    ax1.set_ylim(0, 128)
    ax1.invert_yaxis()
    ax1.set_title('Original Events')
    ax1.legend()
    
    # Reconstructed frame
    ax2.imshow(reconstructed_frame, cmap='gray')
    ax2.set_title('Reconstructed Frame')
    ax2.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print("✅ Processing workflow completed successfully!")

demonstrate_processing_workflow()

## 6. Error Handling and Validation

The unified loader includes comprehensive error handling and model validation:

In [None]:
def demonstrate_error_handling():
    """Show error handling capabilities"""
    
    print("Error Handling and Validation:")
    print("=" * 40)
    
    error_scenarios = [
        {
            "scenario": "Missing model file",
            "error": "FileNotFoundError",
            "handling": "Graceful fallback to alternative formats"
        },
        {
            "scenario": "Corrupted model file",
            "error": "ModelValidationError",
            "handling": "Checksum verification and error reporting"
        },
        {
            "scenario": "Incompatible model architecture",
            "error": "ArchitectureMismatchError",
            "handling": "Clear error messages with suggestions"
        },
        {
            "scenario": "Insufficient memory",
            "error": "OutOfMemoryError",
            "handling": "Automatic fallback to CPU or smaller model"
        }
    ]
    
    for i, scenario in enumerate(error_scenarios, 1):
        print(f"{i}. {scenario['scenario']}")
        print(f"   Error Type: {scenario['error']}")
        print(f"   Handling: {scenario['handling']}")
        print()
    
    print("✅ Robust error handling ensures reliable operation!")

demonstrate_error_handling()

## Summary

The unified model loading system in evlib provides:

✅ **Multi-format support**: PyTorch, ONNX, and SafeTensors  
✅ **Automatic detection**: Smart format identification  
✅ **Priority-based loading**: Optimised format selection  
✅ **Performance optimisation**: Hardware-aware deployment  
✅ **Robust error handling**: Graceful fallback mechanisms  
✅ **Flexible configuration**: Customisable model settings  

This unified approach simplifies model deployment while maximising performance across different hardware configurations and use cases.