# Video Autoencoder Tutorial

This notebook demonstrates how to use the video autoencoder with programmatic latent space control.

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Video, display

from src.models import VideoAutoencoder, load_pretrained_model
from src.latent import LatentManipulator, interpolate_latents
from src.utils import load_video, save_video

## 1. Initialize Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize or load model
model = VideoAutoencoder(latent_dim=512, base_channels=64)
model.to(device)
model.eval()

# Or load pretrained
# model = load_pretrained_model('vae_ucf101', device=device)

## 2. Basic Latent Space Operations

In [None]:
# Generate random samples
z_random = torch.randn(4, 512).to(device)

with torch.no_grad():
    generated_videos = model.decode(z_random)

print(f"Generated shape: {generated_videos.shape}")

## 3. Programmatic Weight Control

In [None]:
manipulator = LatentManipulator(model)
manipulator.save_original_params()

# Example 1: Scale weights
manipulator.scale_weights('decoder_fc', scale_factor=1.5)

# Example 2: Add controlled noise
manipulator.add_noise('encoder_fc', noise_std=0.1)

# Get statistics
stats = manipulator.compute_weight_statistics()
for layer, layer_stats in stats.items():
    print(f"\n{layer}:")
    for param_type, param_stats in layer_stats.items():
        print(f"  {param_type}: mean={param_stats['mean']:.4f}, std={param_stats['std']:.4f}")

## 4. Latent Space Interpolation

In [None]:
# Create two random latent codes
z1 = torch.randn(1, 512).to(device)
z2 = torch.randn(1, 512).to(device)

# Interpolate between them
interpolated = interpolate_latents(z1[0], z2[0], steps=10)

# Decode interpolations
interpolated_videos = []
with torch.no_grad():
    for z in interpolated:
        video = model.decode(z.unsqueeze(0))
        interpolated_videos.append(video)

print(f"Created {len(interpolated_videos)} interpolated videos")

## 5. Custom Latent Manipulation Functions

In [None]:
def apply_custom_transform(z, transform_type='sine'):
    """Apply custom transformations to latent codes"""
    if transform_type == 'sine':
        # Apply sinusoidal modulation
        t = torch.linspace(0, 2*np.pi, z.shape[1]).to(z.device)
        return z * torch.sin(t)
    
    elif transform_type == 'threshold':
        # Threshold activations
        return torch.where(torch.abs(z) > 1.0, z, torch.zeros_like(z))
    
    elif transform_type == 'amplify_dims':
        # Amplify specific dimensions
        z_new = z.clone()
        z_new[:, :100] *= 2.0  # Amplify first 100 dimensions
        z_new[:, -100:] *= 0.5  # Reduce last 100 dimensions
        return z_new
    
    return z

# Test transforms
z_test = torch.randn(1, 512).to(device)

for transform in ['sine', 'threshold', 'amplify_dims']:
    z_transformed = apply_custom_transform(z_test, transform)
    with torch.no_grad():
        video = model.decode(z_transformed)
    print(f"Applied {transform} transform")

## 6. Interactive Latent Space Explorer

In [None]:
from ipywidgets import interact, FloatSlider, IntSlider

def explore_latent_dimension(dim=0, value=0.0):
    """Interactive widget to explore individual latent dimensions"""
    z = torch.zeros(1, 512).to(device)
    z[0, dim] = value
    
    with torch.no_grad():
        video = model.decode(z)
        # Display first frame
        frame = video[0, :, 0].cpu().numpy().transpose(1, 2, 0)
        frame = (frame + 1) / 2  # Denormalize
        
    plt.figure(figsize=(6, 6))
    plt.imshow(frame)
    plt.title(f'Dimension {dim} = {value:.2f}')
    plt.axis('off')
    plt.show()

# Create interactive widget
interact(explore_latent_dimension,
         dim=IntSlider(min=0, max=511, step=1, value=0),
         value=FloatSlider(min=-3.0, max=3.0, step=0.1, value=0.0));

## 7. Batch Processing with Custom Weights

In [None]:
# Create a batch of modifications
modifications = [
    {'type': 'scale', 'factor': 0.5},
    {'type': 'scale', 'factor': 1.0},
    {'type': 'scale', 'factor': 1.5},
    {'type': 'scale', 'factor': 2.0},
]

z_base = torch.randn(1, 512).to(device)
results = []

for mod in modifications:
    manipulator.restore_original_params()
    
    if mod['type'] == 'scale':
        manipulator.scale_weights('decoder_fc', scale_factor=mod['factor'])
    
    with torch.no_grad():
        video = model.decode(z_base)
        results.append(video)
    
    print(f"Applied {mod['type']} with factor {mod['factor']}")

# Reset to original
manipulator.restore_original_params()