# evlib Unified Models API

This notebook demonstrates the high-level unified API for event-to-video reconstruction models in evlib.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import evlib
from evlib import models

# Set up plotting
%matplotlib inline
plt.rcParams['figure.dpi'] = 100

## 1. Loading Event Data

First, let's load some event data to work with.

In [None]:
# Load events from file
event_file = Path("../data/slider_depth/events.txt")

if event_file.exists():
    xs, ys, ts, ps = evlib.formats.load_events_py(str(event_file))
    # Use subset for faster processing
    n = 10000
    xs, ys, ts, ps = xs[:n], ys[:n], ts[:n], ps[:n]
    print(f"Loaded {n} events from file")
else:
    # Generate synthetic events
    print("Generating synthetic events")
    n = 10000
    t = np.linspace(0, 1, n)
    xs = (120 + 100 * np.sin(2 * np.pi * t)).astype(np.int64)
    ys = (90 + 80 * np.cos(2 * np.pi * t)).astype(np.int64)
    ps = np.random.choice([-1, 1], size=n).astype(np.int64)
    ts = t

# Display event statistics
print(f"Event statistics:")
print(f"  X range: [{xs.min()}, {xs.max()}]")
print(f"  Y range: [{ys.min()}, {ys.max()}]")
print(f"  Time range: [{ts.min():.3f}, {ts.max():.3f}]")
print(f"  Positive events: {(ps > 0).sum()}")
print(f"  Negative events: {(ps < 0).sum()}")

## 2. Basic Model Usage

The simplest way to use a model is to instantiate it and call `reconstruct()`.

In [None]:
# Create an E2VID model
model = models.E2VID()
print(f"Created model: {model}")

# Reconstruct a frame
frame = model.reconstruct((xs, ys, ts, ps))

# Display the result
plt.figure(figsize=(8, 6))
plt.imshow(frame, cmap='gray')
plt.colorbar()
plt.title("E2VID Reconstruction")
plt.show()

## 3. Comparing Different Models

Let's compare the output of different reconstruction models.

In [None]:
# Define models to compare
models_to_compare = [
    ("E2VID (UNet)", models.E2VID(variant="unet")),
    ("FireNet", models.FireNet()),
    # ("SPADE", models.SPADE(variant="lite")),
    # ("SSL", models.SSL()),
]

# Reconstruct with each model
fig, axes = plt.subplots(1, len(models_to_compare), figsize=(6*len(models_to_compare), 5))
if len(models_to_compare) == 1:
    axes = [axes]

for ax, (name, model) in zip(axes, models_to_compare):
    print(f"Processing with {name}...")
    frame = model.reconstruct((xs, ys, ts, ps))
    
    if frame.ndim == 3:
        frame = frame[-1]  # Take last frame if multiple
        
    im = ax.imshow(frame, cmap='gray')
    ax.set_title(name)
    ax.axis('off')
    plt.colorbar(im, ax=ax, fraction=0.046)

plt.tight_layout()
plt.show()

## 4. Custom Model Configuration

You can customize model behavior using `ModelConfig`.

In [None]:
# Create configurations with different settings
configs = [
    ("Default", models.ModelConfig()),
    ("High Resolution", models.ModelConfig(base_channels=128, num_layers=5)),
    ("Fast", models.ModelConfig(base_channels=32, num_layers=3)),
    ("More Time Bins", models.ModelConfig(num_bins=10)),
]

# Compare results
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for ax, (name, config) in zip(axes, configs):
    print(f"Testing {name} configuration...")
    model = models.E2VID(config=config)
    frame = model.reconstruct((xs, ys, ts, ps))
    
    im = ax.imshow(frame, cmap='gray')
    ax.set_title(f"{name}\nChannels: {config.base_channels}, Bins: {config.num_bins}")
    ax.axis('off')
    plt.colorbar(im, ax=ax, fraction=0.046)

plt.tight_layout()
plt.show()

## 5. Temporal Models (E2VID+ and FireNet+)

Temporal models can reconstruct multiple frames using ConvLSTM for temporal consistency.

In [None]:
# Create temporal models
e2vid_plus = models.E2VIDPlus()
firenet_plus = models.FireNetPlus()

# Reconstruct multiple frames
num_frames = 5
print(f"Reconstructing {num_frames} frames...")

frames_e2vid = e2vid_plus.reconstruct((xs, ys, ts, ps), num_frames=num_frames)
frames_firenet = firenet_plus.reconstruct((xs, ys, ts, ps), num_frames=num_frames)

# Display temporal evolution
fig, axes = plt.subplots(2, num_frames, figsize=(15, 6))

for i in range(num_frames):
    axes[0, i].imshow(frames_e2vid[i], cmap='gray')
    axes[0, i].set_title(f"E2VID+ t={i+1}")
    axes[0, i].axis('off')
    
    axes[1, i].imshow(frames_firenet[i], cmap='gray')
    axes[1, i].set_title(f"FireNet+ t={i+1}")
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

## 6. Model Zoo Utilities

The models module provides utilities for managing pre-trained models.

In [None]:
# List available models
print("Available models:")
for model_name in models.list_models():
    print(f"  - {model_name}")

print("\n" + "="*50 + "\n")

# Get detailed information about models
for model_name in ["e2vid_unet", "firenet", "spade_e2vid"]:
    info = models.utils.get_model_info(model_name)
    print(f"{info['name']}:")
    print(f"  Description: {info['description']}")
    print(f"  Architecture: {info['architecture']}")
    print(f"  Size: {info['size_mb']} MB")
    print()

## 7. Pre-defined Configurations

evlib provides several pre-defined configurations for common use cases.

In [None]:
# Show available configurations
config_names = ['default', 'high_res', 'fast', 'temporal', 'spade', 'ssl']

print("Pre-defined configurations:")
print("-" * 60)
print(f"{'Name':<15} {'Channels':<10} {'Layers':<10} {'Bins':<10} {'Extra':<20}")
print("-" * 60)

for name in config_names:
    config = models.config.get_config(name)
    extra = str(list(config.extra_params.keys()))[:20] if config.extra_params else "-"
    print(f"{name:<15} {config.base_channels:<10} {config.num_layers:<10} {config.num_bins:<10} {extra:<20}")

## 8. Advanced Usage: Event Stream Processing

For processing continuous event streams, you can use the models in a streaming fashion.

In [None]:
# Simulate processing an event stream in chunks
chunk_size = 2000
n_chunks = len(xs) // chunk_size

# Use FireNet for fast processing
model = models.FireNet()

# Process chunks and display
fig, axes = plt.subplots(1, min(5, n_chunks), figsize=(15, 3))
if n_chunks == 1:
    axes = [axes]

for i, ax in enumerate(axes[:min(5, n_chunks)]):
    # Get chunk of events
    start_idx = i * chunk_size
    end_idx = (i + 1) * chunk_size
    
    chunk_xs = xs[start_idx:end_idx]
    chunk_ys = ys[start_idx:end_idx]
    chunk_ts = ts[start_idx:end_idx]
    chunk_ps = ps[start_idx:end_idx]
    
    # Reconstruct
    frame = model.reconstruct((chunk_xs, chunk_ys, chunk_ts, chunk_ps))
    
    # Display
    ax.imshow(frame, cmap='gray')
    ax.set_title(f"Chunk {i+1}\nt=[{chunk_ts[0]:.2f}, {chunk_ts[-1]:.2f}]")
    ax.axis('off')

plt.tight_layout()
plt.show()

print(f"Processed {min(5, n_chunks)} chunks of {chunk_size} events each")

## Summary

The unified models API in evlib provides:

1. **Simple Interface**: Just instantiate a model and call `reconstruct()`
2. **Multiple Models**: E2VID, FireNet, SPADE, SSL, and their temporal variants
3. **Flexible Configuration**: Customize models with `ModelConfig`
4. **Pre-trained Weights**: Automatic downloading of pre-trained models
5. **Temporal Processing**: E2VID+ and FireNet+ for multi-frame reconstruction
6. **Model Zoo**: Utilities for managing and discovering models

For more information, see the [evlib documentation](https://github.com/tallamjr/evlib).