# Video Transition Merger using HOISDF Outputs

This notebook demonstrates how to merge transitions between two video sequences using:
- HOISDF outputs (MANO parameters, SDFs, contact points)
- Transformer for learning transition patterns
- Diffusion model for smooth, contact-aware transitions

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transition_merger_model import *

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Model Configuration

In [None]:
# Model configuration
config = {
    'tokenizer': {
        'mano_dim': 51,  # 3 trans + 45 pose + 3 shape
        'sdf_resolution': 64,
        'hidden_dim': 256,
        'num_tokens': 256
    },
    'transformer': {
        'input_dim': 256,
        'hidden_dim': 512,
        'num_heads': 8,
        'num_layers': 6,
        'mano_dim': 51,
        'chunk_size': 50,
        'dropout': 0.1
    },
    'diffuser': {
        'mano_dim': 51,
        'hidden_dim': 256,
        'condition_dim': 512,
        'num_timesteps': 100
    }
}

# Loss weights
loss_weights = {
    'mano_recon': 1.0,      # MANO parameter reconstruction
    'contact': 0.5,         # Contact consistency
    'smooth': 0.1,          # Movement smoothness
    'boundary': 0.5,        # Task boundary detection
    'contrastive': 0.2,     # Task embedding consistency
    'diffusion': 0.5        # Diffusion model training
}

# Initialize model
model = TransitionMergerModel(config).to(device)
criterion = TransitionLoss(loss_weights)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

print(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")

## 2. Understanding HOISDF Outputs

The model takes HOISDF outputs from two videos and learns to merge them smoothly.

In [None]:
# Example: Create HOISDF outputs structure
def create_example_hoisdf_outputs(num_frames, device):
    """Create example HOISDF outputs for demonstration"""
    return HOISDFOutputs(
        # MANO parameters: 3 translation + 45 pose + 3 shape
        mano_params=torch.randn(num_frames, 51).to(device),
        
        # Signed Distance Fields
        hand_sdf=torch.randn(num_frames, 64, 64, 64).to(device),
        object_sdf=torch.randn(num_frames, 64, 64, 64).to(device),
        
        # Contact information
        contact_points=torch.randn(num_frames, 10, 3).to(device),
        contact_frames=torch.randint(0, 2, (num_frames, 10)).float().to(device),
        
        # Additional outputs
        hand_vertices=torch.randn(num_frames, 778, 3).to(device),
        object_center=torch.randn(num_frames, 3).to(device)
    )

# Create example outputs for two videos
video1_outputs = create_example_hoisdf_outputs(100, device)
video2_outputs = create_example_hoisdf_outputs(100, device)

print("HOISDF Output Components:")
print(f"  MANO params shape: {video1_outputs.mano_params.shape}")
print(f"  Hand SDF shape: {video1_outputs.hand_sdf.shape}")
print(f"  Contact points shape: {video1_outputs.contact_points.shape}")

## 3. Generate Transition

In [None]:
# Generate transition between two videos
transition_length = 30  # 30 frames for transition

# Forward pass
with torch.no_grad():
    outputs = model(video1_outputs, video2_outputs, 
                   transition_length=transition_length, 
                   mode='inference')

# Extract transition MANO parameters
transition_mano = outputs['transformer']['refined_mano']
print(f"Generated transition shape: {transition_mano.shape}")
print(f"Transition covers {transition_mano.shape[1]} frames")

## 4. Visualize Transition

In [None]:
# Visualize MANO parameter trajectories
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Plot translation components
ax = axes[0, 0]
trans_params = transition_mano[0, :, :3].cpu().numpy()
ax.plot(trans_params[:, 0], label='X', color='red')
ax.plot(trans_params[:, 1], label='Y', color='green')
ax.plot(trans_params[:, 2], label='Z', color='blue')
ax.set_title('Hand Translation During Transition')
ax.set_xlabel('Frame')
ax.set_ylabel('Translation (m)')
ax.legend()
ax.grid(True)

# Plot rotation magnitude
ax = axes[0, 1]
rot_params = transition_mano[0, :, 3:48].cpu().numpy()
rot_magnitude = np.linalg.norm(rot_params, axis=1)
ax.plot(rot_magnitude, color='purple')
ax.set_title('Hand Rotation Magnitude')
ax.set_xlabel('Frame')
ax.set_ylabel('Rotation Magnitude')
ax.grid(True)

# Plot boundary predictions
ax = axes[1, 0]
boundaries = outputs['transformer']['boundaries'][0, :, 0].cpu().numpy()
ax.plot(boundaries, color='orange')
ax.set_title('Task Boundary Predictions')
ax.set_xlabel('Frame')
ax.set_ylabel('Boundary Probability')
ax.grid(True)

# Plot transition quality scores
ax = axes[1, 1]
quality = outputs['transformer']['transition_quality'][0, :, 0].cpu().numpy()
ax.plot(quality, color='green')
ax.set_title('Transition Quality Score')
ax.set_xlabel('Frame')
ax.set_ylabel('Quality Score')
ax.grid(True)

plt.tight_layout()
plt.show()

## 5. Training Loop

In [None]:
def train_step(model, video1_outputs, video2_outputs, gt_transition, 
               criterion, optimizer):
    """Single training step"""
    model.train()
    
    # Forward pass
    outputs = model(video1_outputs, video2_outputs, 
                   transition_length=30, mode='train')
    
    # Compute losses
    losses = criterion(outputs, gt_transition, model)
    
    # Backward pass
    optimizer.zero_grad()
    losses['total'].backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    
    return losses

# Example training step with dummy ground truth
gt_transition = create_example_hoisdf_outputs(30, device)

# Run one training step
losses = train_step(model, video1_outputs, video2_outputs, 
                   gt_transition, criterion, optimizer)

print("Loss breakdown:")
for k, v in losses.items():
    if k != 'total' and not k.startswith('diffusion_'):
        print(f"  {k}: {v.item():.4f}")
print(f"\nTotal loss: {losses['total'].item():.4f}")

## 6. Loading Real HOISDF Outputs

To use this with real data, you need to load HOISDF outputs from your trained model.

In [None]:
def load_hoisdf_outputs(hoisdf_model, video_frames, device):
    """Load HOISDF outputs from a trained model"""
    # This is a placeholder - implement based on your HOISDF model
    # hoisdf_model.eval()
    # with torch.no_grad():
    #     outputs = hoisdf_model(video_frames)
    
    # For now, return dummy outputs
    T = video_frames.shape[1] if video_frames.dim() > 3 else 100
    return create_example_hoisdf_outputs(T, device)

# Example: Load outputs for your videos
# video1_frames = load_video('path/to/video1.mp4')
# video2_frames = load_video('path/to/video2.mp4')
# 
# hoisdf_outputs1 = load_hoisdf_outputs(hoisdf_model, video1_frames, device)
# hoisdf_outputs2 = load_hoisdf_outputs(hoisdf_model, video2_frames, device)

## 7. Export Merged Sequence

Combine the original sequences with the generated transition.

In [None]:
def create_merged_sequence(video1_outputs, video2_outputs, transition_mano):
    """Create complete merged sequence"""
    # Take last part of video1, transition, first part of video2
    context_frames = 20
    
    # Extract relevant parts
    video1_end = video1_outputs.mano_params[-context_frames:]
    video2_start = video2_outputs.mano_params[:context_frames]
    
    # Combine into full sequence
    merged_mano = torch.cat([
        video1_outputs.mano_params[:-context_frames],
        transition_mano[0],  # Remove batch dimension
        video2_outputs.mano_params[context_frames:]
    ], dim=0)
    
    return merged_mano

# Create merged sequence
merged_sequence = create_merged_sequence(video1_outputs, video2_outputs, transition_mano)
print(f"Merged sequence shape: {merged_sequence.shape}")
print(f"Total frames: {merged_sequence.shape[0]}")

# Visualize merged sequence
plt.figure(figsize=(12, 4))
plt.plot(merged_sequence[:, 0].cpu().numpy(), label='X translation')
plt.axvline(x=80, color='red', linestyle='--', label='Transition start')
plt.axvline(x=110, color='red', linestyle='--', label='Transition end')
plt.xlabel('Frame')
plt.ylabel('Translation')
plt.title('Merged Sequence with Smooth Transition')
plt.legend()
plt.grid(True)
plt.show()