# Training Transition Merger with Pairwise HOISDF Outputs

This notebook demonstrates how to train the transition merger model using a dataset of HOISDF outputs.
The model learns to create smooth transitions between different hand-object interactions.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pickle

from transition_merger_model import TransitionMergerModel, HOISDFOutputs
from transition_dataset import HOISDFTransitionDataset, TransitionTrainer

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Prepare HOISDF Output Data

First, you need to extract HOISDF outputs from your trained model and save them.

In [None]:
def save_hoisdf_outputs(hoisdf_model, video_frames, save_path, sequence_name):
    """Extract and save HOISDF outputs for a video sequence."""
    hoisdf_model.eval()
    
    with torch.no_grad():
        # Process video through HOISDF
        outputs = hoisdf_model(video_frames)
        
        # Extract relevant outputs
        hoisdf_data = {
            'mano_params': outputs['mano_params'].cpu(),
            'hand_sdf': outputs['hand_sdf'].cpu(),
            'object_sdf': outputs['object_sdf'].cpu(),
            'contact_points': outputs['contact_points'].cpu(),
            'contact_frames': outputs['contact_frames'].cpu(),
            'hand_vertices': outputs['hand_vertices'].cpu(),
            'object_center': outputs['object_center'].cpu(),
            'sequence_name': sequence_name
        }
        
        # Save to file
        save_file = Path(save_path) / f"{sequence_name}.pkl"
        with open(save_file, 'wb') as f:
            pickle.dump(hoisdf_data, f)
            
    print(f"Saved HOISDF outputs to {save_file}")

# Example: Save outputs for multiple sequences
# data_dir = Path('data/hoisdf_outputs')
# data_dir.mkdir(exist_ok=True)
#
# for video_path in video_paths:
#     video_frames = load_video(video_path)
#     save_hoisdf_outputs(hoisdf_model, video_frames, data_dir, video_path.stem)

## 2. Create Example HOISDF Outputs

For demonstration, let's create synthetic HOISDF outputs.

In [None]:
def create_synthetic_hoisdf_sequence(sequence_type='pick_phone', num_frames=100):
    """Create synthetic HOISDF outputs for different tasks."""
    
    # Base MANO parameters
    mano_params = torch.zeros(num_frames, 51)
    
    if sequence_type == 'pick_phone':
        # Simulate picking up phone motion
        t = torch.linspace(0, 1, num_frames)
        mano_params[:, 0] = 0.2 * torch.sin(t * np.pi)  # X translation
        mano_params[:, 1] = -0.3 + 0.2 * t  # Y translation (moving up)
        mano_params[:, 2] = 0.1 * torch.cos(t * np.pi)  # Z translation
        # Add some finger motion
        mano_params[:, 3:8] = 0.5 * torch.sin(t.unsqueeze(1) * np.pi)
        
    elif sequence_type == 'pick_bottle':
        # Simulate picking up bottle motion
        t = torch.linspace(0, 1, num_frames)
        mano_params[:, 0] = -0.1 + 0.2 * torch.cos(t * np.pi)
        mano_params[:, 1] = -0.4 + 0.3 * t
        mano_params[:, 2] = 0.15 * torch.sin(t * np.pi)
        # Different finger pattern for bottle grasp
        mano_params[:, 3:8] = 0.7 * torch.cos(t.unsqueeze(1) * np.pi * 0.5)
        
    elif sequence_type == 'typing':
        # Simulate typing motion
        t = torch.linspace(0, 4*np.pi, num_frames)
        mano_params[:, 0] = 0.05 * torch.sin(t * 3)
        mano_params[:, 1] = -0.2 + 0.02 * torch.sin(t * 5)
        mano_params[:, 2] = 0.1
        # Finger tapping motion
        for i in range(5):
            mano_params[:, 3+i*3:6+i*3] = 0.3 * torch.sin(t.unsqueeze(1) * (i+2))
    
    # Create full HOISDF outputs
    return {
        'mano_params': mano_params,
        'hand_sdf': torch.randn(num_frames, 64, 64, 64) * 0.1,
        'object_sdf': torch.randn(num_frames, 64, 64, 64) * 0.1,
        'contact_points': torch.randn(num_frames, 10, 3) * 0.05,
        'contact_frames': torch.randint(0, 2, (num_frames, 10)).float(),
        'hand_vertices': torch.randn(num_frames, 778, 3) * 0.1,
        'object_center': torch.randn(num_frames, 3) * 0.2,
        'sequence_type': sequence_type
    }

# Create synthetic dataset
data_dir = Path('data/synthetic_hoisdf')
data_dir.mkdir(parents=True, exist_ok=True)

# Generate different types of sequences
sequence_types = ['pick_phone', 'pick_bottle', 'typing']
num_sequences_per_type = 10

for seq_type in sequence_types:
    for i in range(num_sequences_per_type):
        seq_data = create_synthetic_hoisdf_sequence(seq_type)
        filename = data_dir / f"{seq_type}_{i:03d}.pkl"
        
        with open(filename, 'wb') as f:
            pickle.dump(seq_data, f)
            
print(f"Created {len(sequence_types) * num_sequences_per_type} synthetic sequences")

## 3. Initialize Dataset and Model

In [None]:
# Dataset configuration
dataset_config = {
    'sequence_length': 100,
    'transition_length': 30,
    'context_frames': 20,
    'similarity_threshold': 0.3  # Lower threshold to get more pairs
}

# Create datasets
train_dataset = HOISDFTransitionDataset(
    data_dir=str(data_dir),
    mode='train',
    **dataset_config
)

val_dataset = HOISDFTransitionDataset(
    data_dir=str(data_dir),
    mode='val',
    **dataset_config
)

print(f"Training pairs: {len(train_dataset)}")
print(f"Validation pairs: {len(val_dataset)}")

In [None]:
# Model and training configuration
config = {
    # Model config
    'tokenizer': {
        'mano_dim': 51,
        'sdf_resolution': 64,
        'hidden_dim': 256,
        'num_tokens': 256
    },
    'transformer': {
        'input_dim': 256,
        'hidden_dim': 512,
        'num_heads': 8,
        'num_layers': 6,
        'mano_dim': 51,
        'chunk_size': 50,
        'dropout': 0.1
    },
    'diffuser': {
        'mano_dim': 51,
        'hidden_dim': 256,
        'condition_dim': 512,
        'num_timesteps': 100
    },
    
    # Training config
    'batch_size': 4,
    'num_workers': 0,  # Set to 0 for debugging
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'num_epochs': 50,
    'transition_length': 30,
    'log_interval': 5,
    'device': device,
    
    # Loss weights
    'loss_weights': {
        'mano_recon': 1.0,
        'contact': 0.5,
        'smooth': 0.1,
        'boundary': 0.5,
        'contrastive': 0.2,
        'diffusion': 0.5
    }
}

# Initialize model
model = TransitionMergerModel(config).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")

## 4. Visualize Dataset Pairs

In [None]:
# Visualize some dataset pairs
sample = train_dataset[0]
outputs1 = sample['outputs1']
outputs2 = sample['outputs2']
gt_transition = sample['gt_transition']

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Plot MANO translations for both sequences
ax = axes[0, 0]
ax.plot(outputs1.mano_params[:, 0].numpy(), label='Seq1 X', alpha=0.7)
ax.plot(outputs2.mano_params[:, 0].numpy(), label='Seq2 X', alpha=0.7)
ax.axvline(x=80, color='red', linestyle='--', label='Transition region')
ax.set_title('X Translation')
ax.legend()
ax.grid(True)

ax = axes[0, 1]
ax.plot(outputs1.mano_params[:, 1].numpy(), label='Seq1 Y', alpha=0.7)
ax.plot(outputs2.mano_params[:, 1].numpy(), label='Seq2 Y', alpha=0.7)
ax.axvline(x=80, color='red', linestyle='--')
ax.set_title('Y Translation')
ax.legend()
ax.grid(True)

# Plot ground truth transition
ax = axes[1, 0]
ax.plot(gt_transition.mano_params[:, 0].numpy(), label='GT X', color='green')
ax.plot(gt_transition.mano_params[:, 1].numpy(), label='GT Y', color='blue')
ax.plot(gt_transition.mano_params[:, 2].numpy(), label='GT Z', color='red')
ax.set_title('Ground Truth Transition')
ax.legend()
ax.grid(True)

# Plot compatibility score
ax = axes[1, 1]
ax.text(0.5, 0.5, f"Compatibility Score: {sample['compatibility']:.3f}", 
        ha='center', va='center', fontsize=16)
ax.text(0.5, 0.3, f"Pair: {sample['names'][0]} → {sample['names'][1]}", 
        ha='center', va='center', fontsize=12)
ax.axis('off')

plt.tight_layout()
plt.show()

## 5. Train the Model

In [None]:
# Create trainer
trainer = TransitionTrainer(model, train_dataset, val_dataset, config)

# Train for a few epochs (reduce for testing)
num_epochs = 5
trainer.train(num_epochs)

## 6. Evaluate Trained Model

In [None]:
# Load best checkpoint
checkpoint = torch.load('checkpoint_best.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded checkpoint from epoch {checkpoint['epoch']} with val_loss {checkpoint['val_loss']:.4f}")

# Test on a validation pair
model.eval()
val_sample = val_dataset[0]

with torch.no_grad():
    # Add batch dimension
    outputs1 = HOISDFOutputs(
        **{k: v.unsqueeze(0).to(device) for k, v in val_sample['outputs1'].__dict__.items()}
    )
    outputs2 = HOISDFOutputs(
        **{k: v.unsqueeze(0).to(device) for k, v in val_sample['outputs2'].__dict__.items()}
    )
    
    # Generate transition
    model_outputs = model(outputs1, outputs2, transition_length=30, mode='inference')
    
# Extract predicted transition
pred_transition = model_outputs['transformer']['refined_mano'][0].cpu().numpy()
gt_transition = val_sample['gt_transition'].mano_params.numpy()

# Visualize results
plt.figure(figsize=(12, 6))

# Plot X translation
plt.subplot(1, 3, 1)
plt.plot(gt_transition[:, 0], label='GT', color='green', linewidth=2)
plt.plot(pred_transition[:, 0], label='Predicted', color='red', linestyle='--', linewidth=2)
plt.title('X Translation')
plt.xlabel('Frame')
plt.ylabel('Translation')
plt.legend()
plt.grid(True)

# Plot Y translation
plt.subplot(1, 3, 2)
plt.plot(gt_transition[:, 1], label='GT', color='green', linewidth=2)
plt.plot(pred_transition[:, 1], label='Predicted', color='red', linestyle='--', linewidth=2)
plt.title('Y Translation')
plt.xlabel('Frame')
plt.ylabel('Translation')
plt.legend()
plt.grid(True)

# Plot error
plt.subplot(1, 3, 3)
error = np.linalg.norm(pred_transition - gt_transition, axis=1)
plt.plot(error, color='blue', linewidth=2)
plt.title('Prediction Error')
plt.xlabel('Frame')
plt.ylabel('L2 Error')
plt.grid(True)

plt.tight_layout()
plt.show()

print(f"Mean prediction error: {error.mean():.4f}")

## 7. Generate Novel Transitions

Use the trained model to generate transitions between any two sequences.

In [None]:
def generate_novel_transition(model, seq1_path, seq2_path, device):
    """Generate transition between two arbitrary sequences."""
    
    # Load sequences
    with open(seq1_path, 'rb') as f:
        seq1_data = pickle.load(f)
    with open(seq2_path, 'rb') as f:
        seq2_data = pickle.load(f)
        
    # Convert to HOISDFOutputs
    outputs1 = HOISDFOutputs(
        **{k: torch.tensor(v).unsqueeze(0).to(device) 
           for k, v in seq1_data.items() if k != 'sequence_type'}
    )
    outputs2 = HOISDFOutputs(
        **{k: torch.tensor(v).unsqueeze(0).to(device) 
           for k, v in seq2_data.items() if k != 'sequence_type'}
    )
    
    # Generate transition
    model.eval()
    with torch.no_grad():
        model_outputs = model(outputs1, outputs2, transition_length=30, mode='inference')
        
    transition_mano = model_outputs['transformer']['refined_mano'][0].cpu().numpy()
    
    return transition_mano, seq1_data, seq2_data

# Test with different sequence types
seq1_path = data_dir / 'pick_phone_000.pkl'
seq2_path = data_dir / 'typing_000.pkl'

transition, seq1, seq2 = generate_novel_transition(model, seq1_path, seq2_path, device)

# Visualize
plt.figure(figsize=(10, 6))
plt.plot(seq1['mano_params'][-20:, 0], 'b-', alpha=0.5, label='Phone (end)')
plt.plot(np.arange(20, 50), transition[:, 0], 'r-', linewidth=2, label='Transition')
plt.plot(np.arange(50, 70), seq2['mano_params'][:20, 0], 'g-', alpha=0.5, label='Typing (start)')
plt.xlabel('Frame')
plt.ylabel('X Translation')
plt.title('Novel Transition: Pick Phone → Typing')
plt.legend()
plt.grid(True)
plt.show()