# Hybrid Model Sampling - FlowBending Framework

This notebook implements hybrid video generation using both Wan 14B and 1.3B models with configurable sampling schedules.

## Approach:
- Load both 14B and 1.3B models with shared VAE/T5
- Implement flexible sampling schedule (e.g., LSSSL pattern)
- Work in latent space to maintain consistency during model switching
- Profile memory usage, latency, and total time per segment
- Compare with 14B-only baseline

## Resolution: 480p (832×480)


In [None]:
# Imports
import os
import sys
import time
import random
import math
from datetime import datetime
from typing import Dict, List, Tuple
import json

import torch
import torch.cuda.amp as amp
import numpy as np
from tqdm import tqdm
from easydict import EasyDict

# Add parent directory to path
sys.path.insert(0, '/workspace/wan2.1/Wan2.1')

from wan.text2video import WanT2V
from wan.configs.wan_t2v_14B import t2v_14B
from wan.configs.wan_t2v_1_3B import t2v_1_3B
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.utils.fm_solvers import (
    FlowDPMSolverMultistepScheduler,
    get_sampling_sigmas,
    retrieve_timesteps,
)
from wan.utils.utils import cache_video
from wan.utils.acceleration import (
    AccelerationConfig,
    CacheDiTConfig,
    XDiTConfig,
    check_cache_dit_available,
    check_xdit_available,
    get_acceleration_summary,
    estimate_speedup,
)

print("Imports successful")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")
print(f"Cache-DiT available: {check_cache_dit_available()}")
print(f"xDiT available: {check_xdit_available()}")

## Configuration


In [None]:
# Basic Configuration
CONFIG = {
    # Paths
    'checkpoint_dir_14B': '/workspace/wan2.1/Wan2.1/Wan2.1-T2V-14B',
    'checkpoint_dir_1_3B': '/workspace/wan2.1/Wan2.1/Wan2.1-T2V-1.3B',
    'output_dir': '/workspace/wan2.1/Wan2.1/outputs',
    
    # Resolution (480p)
    'width': 832,
    'height': 480,
    'frame_num': 81,  # Must be 4n+1
    
    # Sampling parameters
    'total_sampling_steps': 50,
    'sample_solver': 'unipc',  # 'unipc' or 'dpm++'
    'shift': 5.0,
    'guide_scale': 5.0,
    
    # Prompt
    'prompt': 'Two cartoon cats animation in comfy boxing gear sparring playfully in a cozy living room',
    'negative_prompt': '',  # Will use default if empty
    'seed': 42,
    
    # Device
    'device_id': 0,
    
    # Model management
    'offload_models': False,  # Set to True to offload models between segments (saves memory)
    'keep_both_loaded': True,  # Set to False if running into memory issues
    
    # ==========================================
    # ACCELERATION: Cache-DiT Configuration
    # ==========================================
    # Cache-DiT caches intermediate activations for speedup
    # Expected speedup: 1.3-2x depending on settings
    
    'enable_cache_dit': False,  # Set to True to enable
    'cache_dit_type': 'db',  # Options: 'db' (DBCache), 'taylor' (TaylorSeer), 'scm' (SCM)
    'cache_dit_interval': 3,  # Cache every N steps (higher = more speedup, may affect quality)
    'cache_dit_start_step': 5,  # Start caching after this many steps
    'cache_dit_end_step': None,  # Stop caching N steps before end (None = cache until end)
    'cache_dit_fresh_ratio': 0.4,  # For taylor/adaptive caching
    
    # ==========================================
    # ACCELERATION: xDiT Configuration  
    # ==========================================
    # xDiT enables distributed inference across multiple GPUs
    # Expected speedup: Up to Nx with N GPUs
    
    'enable_xdit': False,  # Set to True to enable (requires multi-GPU)
    'xdit_ulysses_degree': 1,  # Ulysses sequence parallelism degree
    'xdit_ring_degree': 1,  # Ring attention parallelism degree
    'xdit_use_cfg_parallel': False,  # Parallel CFG computation (requires 2x GPUs)
}

# Create output directory
os.makedirs(CONFIG['output_dir'], exist_ok=True)

# Create acceleration config from CONFIG dict
accel_config = AccelerationConfig.from_dict(CONFIG)

print("Configuration:")
print(f"  Resolution: {CONFIG['width']}x{CONFIG['height']}, {CONFIG['frame_num']} frames")
print(f"  Sampling steps: {CONFIG['total_sampling_steps']}")
print(f"  Solver: {CONFIG['sample_solver']}")
print()
print(get_acceleration_summary(accel_config))

# Show estimated speedup
if accel_config.cache_dit.enabled or accel_config.xdit.enabled:
    speedups = estimate_speedup(accel_config)
    print(f"\nEstimated speedup: ~{speedups['combined']}x")

## Sampling Schedule Configuration

Configure the hybrid sampling schedule here. Define which steps use which model.

**Patterns:**
- **LSL**: Large → Small → Large (e.g., 3-44-3 for 50 steps)
- **LSSSL**: Large → Small → Small → Small → Large (e.g., 10-10-10-10-10)
- **Custom**: Define your own segment boundaries


In [3]:
# Sampling Schedule Configuration
# Define segments: list of (model_name, num_steps)
# model_name: '14B' or '1.3B'

# Example patterns (uncomment one or create your own):

# Pattern 1: LSL (3-44-3) - Original request
SAMPLING_SCHEDULE = [
    ('14B', 20),
    ('1.3B', 30), 
        # First 3 steps with 14B
    
]

# Pattern 2: LSSSL (10-10-10-10-10)
# SAMPLING_SCHEDULE = [
#     ('14B', 10),
#     ('1.3B', 10),
#     ('1.3B', 10),
#     ('1.3B', 10),
#     ('14B', 10),
# ]

# Pattern 3: Heavy start and end (5-40-5)
# SAMPLING_SCHEDULE = [
#     ('14B', 5),
#     ('1.3B', 40),
#     ('14B', 5),
# ]

# Pattern 4: Multiple switches (5-10-10-10-10-5)
# SAMPLING_SCHEDULE = [
#     ('14B', 5),
#     ('1.3B', 10),
#     ('14B', 10),
#     ('1.3B', 10),
#     ('14B', 10),
#     ('1.3B', 5),
# ]

# Validate schedule
total_steps_scheduled = sum(steps for _, steps in SAMPLING_SCHEDULE)
assert total_steps_scheduled == CONFIG['total_sampling_steps'], \
    f"Schedule steps ({total_steps_scheduled}) must match total_sampling_steps ({CONFIG['total_sampling_steps']})"

print("Sampling Schedule:")
print("-" * 50)
cumulative = 0
for i, (model, steps) in enumerate(SAMPLING_SCHEDULE):
    start_step = cumulative
    end_step = cumulative + steps
    cumulative = end_step
    print(f"Segment {i+1}: Steps {start_step:2d}-{end_step:2d} ({steps:2d} steps) → {model}")
print("-" * 50)
print(f"Total: {total_steps_scheduled} steps")


Sampling Schedule:
--------------------------------------------------
Segment 1: Steps  0-20 (20 steps) → 14B
Segment 2: Steps 20-50 (30 steps) → 1.3B
--------------------------------------------------
Total: 50 steps


## Utility Functions


In [4]:
# Utility Functions for Profiling

def get_gpu_memory():
    """Get current GPU memory usage in GB"""
    if torch.cuda.is_available():
        return {
            'allocated': torch.cuda.memory_allocated() / 1e9,
            'reserved': torch.cuda.memory_reserved() / 1e9,
            'max_allocated': torch.cuda.max_memory_allocated() / 1e9,
        }
    return {'allocated': 0, 'reserved': 0, 'max_allocated': 0}

def reset_peak_memory():
    """Reset peak memory stats"""
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()

class SegmentProfiler:
    """Profile a sampling segment"""
    def __init__(self, segment_name: str):
        self.segment_name = segment_name
        self.start_time = None
        self.end_time = None
        self.start_memory = None
        self.peak_memory = None
        self.step_times = []
        
    def start(self):
        """Start profiling"""
        reset_peak_memory()
        self.start_time = time.time()
        self.start_memory = get_gpu_memory()
        
    def record_step(self, step_time: float):
        """Record time for a single step"""
        self.step_times.append(step_time)
        
    def end(self):
        """End profiling"""
        self.end_time = time.time()
        self.peak_memory = get_gpu_memory()
        
    def get_report(self) -> Dict:
        """Get profiling report"""
        total_time = self.end_time - self.start_time if self.end_time else 0
        avg_step_time = np.mean(self.step_times) if self.step_times else 0
        
        return {
            'segment_name': self.segment_name,
            'num_steps': len(self.step_times),
            'total_time': total_time,
            'avg_step_time': avg_step_time,
            'min_step_time': np.min(self.step_times) if self.step_times else 0,
            'max_step_time': np.max(self.step_times) if self.step_times else 0,
            'memory_start_allocated_gb': self.start_memory['allocated'],
            'memory_peak_allocated_gb': self.peak_memory['max_allocated'],
            'memory_peak_reserved_gb': self.peak_memory['reserved'],
        }
        
    def print_report(self):
        """Print formatted report"""
        report = self.get_report()
        print(f"\n{'='*60}")
        print(f"Segment: {report['segment_name']}")
        print(f"{'='*60}")
        print(f"Steps: {report['num_steps']}")
        print(f"Total time: {report['total_time']:.2f}s")
        print(f"Avg step time: {report['avg_step_time']:.3f}s")
        print(f"Min step time: {report['min_step_time']:.3f}s")
        print(f"Max step time: {report['max_step_time']:.3f}s")
        print(f"Memory (start): {report['memory_start_allocated_gb']:.2f} GB")
        print(f"Memory (peak): {report['memory_peak_allocated_gb']:.2f} GB")
        print(f"Memory (reserved): {report['memory_peak_reserved_gb']:.2f} GB")
        print(f"{'='*60}\n")

print("✓ Utility functions defined")


✓ Utility functions defined


## Load Models

Load both 14B and 1.3B models. They share the same VAE and T5 encoder.


In [5]:
# Load 14B Model
print("Loading 14B model...")
model_14B = WanT2V(
    config=t2v_14B,
    checkpoint_dir=CONFIG['checkpoint_dir_14B'],
    device_id=CONFIG['device_id'],
    rank=0,
    t5_fsdp=False,
    dit_fsdp=False,
    use_usp=False,
    t5_cpu=False,
)
print("✓ 14B model loaded")
print(f"Memory after 14B: {get_gpu_memory()['allocated']:.2f} GB")

# Load 1.3B Model
print("\nLoading 1.3B model...")
model_1_3B = WanT2V(
    config=t2v_1_3B,
    checkpoint_dir=CONFIG['checkpoint_dir_1_3B'],
    device_id=CONFIG['device_id'],
    rank=0,
    t5_fsdp=False,
    dit_fsdp=False,
    use_usp=False,
    t5_cpu=False,
)
print("✓ 1.3B model loaded")
print(f"Memory after 1.3B: {get_gpu_memory()['allocated']:.2f} GB")

# Store models in a dictionary for easy access
models = {
    '14B': model_14B,
    '1.3B': model_1_3B,
}

print(f"\n✓ Both models loaded. Total memory: {get_gpu_memory()['allocated']:.2f} GB")


Loading 14B model...
✓ 14B model loaded
Memory after 14B: 57.68 GB

Loading 1.3B model...
✓ 1.3B model loaded
Memory after 1.3B: 63.89 GB

✓ Both models loaded. Total memory: 63.89 GB


## Hybrid Sampling

FlowBending-inspired hybrid sampling with model switching in latent space.


In [None]:
def hybrid_generate(
    models: Dict,
    sampling_schedule: List[Tuple[str, int]],
    config: Dict,
    accel_config: AccelerationConfig = None,
    profile: bool = True
):
    """
    Generate video using hybrid model sampling with optional acceleration.
    
    Args:
        models: Dictionary of models {'14B': model_14B, '1.3B': model_1_3B}
        sampling_schedule: List of (model_name, num_steps) tuples
        config: Configuration dictionary
        accel_config: Acceleration configuration (Cache-DiT, xDiT)
        profile: Whether to profile each segment
        
    Returns:
        video: Generated video tensor
        profiling_reports: List of profiling reports for each segment
    """
    device = torch.device(f"cuda:{config['device_id']}")
    
    # Setup acceleration
    use_cache_dit = False
    cache_config = None
    if accel_config is not None and accel_config.cache_dit.enabled:
        if check_cache_dit_available():
            use_cache_dit = True
            cache_config = accel_config.cache_dit
            print(f"Cache-DiT enabled: type={cache_config.cache_type}, interval={cache_config.cache_interval}")
        else:
            print("Warning: Cache-DiT requested but not installed. Install with: pip install cache-dit")
    
    # Use first model to get shared components
    first_model = models[sampling_schedule[0][0]]
    
    # Prepare latent shape
    F = config['frame_num']
    size = (config['width'], config['height'])
    vae_stride = first_model.vae_stride
    patch_size = first_model.patch_size
    
    target_shape = (
        first_model.vae.model.z_dim,
        (F - 1) // vae_stride[0] + 1,
        size[1] // vae_stride[1],
        size[0] // vae_stride[2]
    )
    
    seq_len = math.ceil(
        (target_shape[2] * target_shape[3]) / (patch_size[1] * patch_size[2]) * target_shape[1]
    )
    
    # Setup text encoding (shared across all models)
    n_prompt = config['negative_prompt'] if config['negative_prompt'] else first_model.sample_neg_prompt
    seed = config['seed'] if config['seed'] >= 0 else random.randint(0, sys.maxsize)
    
    print(f"Using seed: {seed}")
    print(f"Target latent shape: {target_shape}")
    print(f"Sequence length: {seq_len}")
    
    # Encode text prompt (use T5 from first model)
    first_model.text_encoder.model.to(device)
    context = first_model.text_encoder([config['prompt']], device)
    context_null = first_model.text_encoder([n_prompt], device)
    first_model.text_encoder.model.cpu()
    torch.cuda.empty_cache()
    
    # Initialize noise with seed
    seed_g = torch.Generator(device=device)
    seed_g.manual_seed(seed)
    
    noise = torch.randn(
        target_shape[0], target_shape[1], target_shape[2], target_shape[3],
        dtype=torch.float32,
        device=device,
        generator=seed_g
    )
    
    # Setup scheduler
    total_steps = config['total_sampling_steps']
    num_train_timesteps = first_model.num_train_timesteps
    if config['sample_solver'] == 'unipc':
        sample_scheduler = FlowUniPCMultistepScheduler(
            num_train_timesteps=num_train_timesteps,
            shift=1,
            use_dynamic_shifting=False
        )
        sample_scheduler.set_timesteps(total_steps, device=device, shift=config['shift'])
        timesteps = sample_scheduler.timesteps
    elif config['sample_solver'] == 'dpm++':
        sample_scheduler = FlowDPMSolverMultistepScheduler(
            num_train_timesteps=num_train_timesteps,
            shift=1,
            use_dynamic_shifting=False
        )
        sampling_sigmas = get_sampling_sigmas(total_steps, config['shift'])
        timesteps, _ = retrieve_timesteps(sample_scheduler, device=device, sigmas=sampling_sigmas)
    else:
        raise NotImplementedError(f"Unsupported solver: {config['sample_solver']}")
    
    print(f"Total timesteps: {len(timesteps)}")
    
    # Initialize latents
    latents = noise
    
    # Prepare model arguments
    arg_c = {'context': context, 'seq_len': seq_len}
    arg_null = {'context': context_null, 'seq_len': seq_len}
    
    # Cache-DiT state
    cached_noise_pred = None
    cache_hit_count = 0
    cache_miss_count = 0
    
    def should_use_cache(global_step):
        """Determine if we should use cached result for this step."""
        if not use_cache_dit:
            return False
        if global_step < cache_config.cache_start_step:
            return False
        if cache_config.cache_end_step is not None:
            remaining = total_steps - global_step
            if remaining <= cache_config.cache_end_step:
                return False
        steps_since_start = global_step - cache_config.cache_start_step
        return steps_since_start % cache_config.cache_interval != 0
    
    # Run hybrid sampling
    profiling_reports = []
    step_idx = 0
    global_step_idx = 0
    
    print("\n" + "="*80)
    print("HYBRID SAMPLING")
    if use_cache_dit:
        print(f"With Cache-DiT ({cache_config.cache_type})")
    print("="*80)
    
    for segment_idx, (model_name, num_steps) in enumerate(sampling_schedule):
        model = models[model_name]
        segment_name = f"Segment {segment_idx+1}: {model_name} ({num_steps} steps)"
        
        print(f"\n{segment_name}")
        print("-" * 60)
        
        # Setup profiler
        profiler = SegmentProfiler(segment_name) if profile else None
        if profiler:
            profiler.start()
        
        # Move model to device if needed
        model.model.to(device)
        
        # Sample for this segment
        segment_timesteps = timesteps[step_idx:step_idx + num_steps]
        
        with amp.autocast(dtype=model.param_dtype), torch.no_grad():
            for i, t in enumerate(tqdm(segment_timesteps, desc=model_name)):
                step_start = time.time()
                
                latent_model_input = [latents]
                timestep = torch.stack([t])
                
                # Cache-DiT: Check if we should use cached result
                if should_use_cache(global_step_idx) and cached_noise_pred is not None:
                    noise_pred = cached_noise_pred
                    cache_hit_count += 1
                else:
                    # Conditional prediction
                    noise_pred_cond = model.model(latent_model_input, t=timestep, **arg_c)[0]
                    # Unconditional prediction
                    noise_pred_uncond = model.model(latent_model_input, t=timestep, **arg_null)[0]
                    
                    # Classifier-free guidance
                    noise_pred = noise_pred_uncond + config['guide_scale'] * (noise_pred_cond - noise_pred_uncond)
                    
                    # Cache the result
                    if use_cache_dit:
                        cached_noise_pred = noise_pred.clone()
                        cache_miss_count += 1
                
                # Scheduler step
                temp_x0 = sample_scheduler.step(
                    noise_pred.unsqueeze(0),
                    t,
                    latents.unsqueeze(0),
                    return_dict=False,
                    generator=seed_g
                )[0]
                latents = temp_x0.squeeze(0)
                
                step_time = time.time() - step_start
                if profiler:
                    profiler.record_step(step_time)
                
                global_step_idx += 1
        
        # End profiling
        if profiler:
            profiler.end()
            profiler.print_report()
            profiling_reports.append(profiler.get_report())
        
        # Offload model if requested
        if config['offload_models']:
            model.model.cpu()
            torch.cuda.empty_cache()
            print(f"Offloaded {model_name} model")
        
        step_idx += num_steps
    
    # Report cache statistics
    if use_cache_dit:
        print(f"\nCache-DiT Statistics:")
        print(f"  Cache hits: {cache_hit_count}")
        print(f"  Cache misses: {cache_miss_count}")
        total_cache_ops = cache_hit_count + cache_miss_count
        if total_cache_ops > 0:
            print(f"  Hit rate: {cache_hit_count / total_cache_ops * 100:.1f}%")
    
    print("\n" + "="*80)
    print("DECODING LATENTS")
    print("="*80)
    
    # Decode latents using VAE (from any model, they're the same)
    decode_start = time.time()
    with torch.no_grad():
        videos = first_model.vae.decode([latents])
    decode_time = time.time() - decode_start
    print(f"Decode time: {decode_time:.2f}s")
    
    return videos[0], profiling_reports

print("Hybrid sampling function defined (with Cache-DiT support)")

## Run Hybrid Generation


In [None]:
# Run Hybrid Generation
print("Starting hybrid generation...")
print(f"Prompt: {CONFIG['prompt']}")
print(f"Resolution: {CONFIG['width']}x{CONFIG['height']}")
print(f"Frames: {CONFIG['frame_num']}")

hybrid_video, hybrid_reports = hybrid_generate(
    models=models,
    sampling_schedule=SAMPLING_SCHEDULE,
    config=CONFIG,
    accel_config=accel_config,  # Pass acceleration config
    profile=True
)

# Save hybrid video
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
hybrid_output_path = os.path.join(
    CONFIG['output_dir'],
    f"hybrid_{CONFIG['width']}x{CONFIG['height']}_{timestamp}.mp4"
)
cache_video(hybrid_video[None], save_file=hybrid_output_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
print(f"\nHybrid video saved to: {hybrid_output_path}")

### Display Hybrid Video


In [13]:
from IPython.display import Video, display

# Display hybrid video
print(f"Hybrid Video: {hybrid_output_path}")
display(Video(hybrid_output_path, embed=True, width=832))


Hybrid Video: /workspace/wan2.1/Wan2.1/outputs/hybrid_832x480_20260114_054623.mp4


## Baseline: 14B-Only Generation

Run full 50-step generation with 14B model only for comparison.


In [14]:
# Generate baseline with 14B only
print("Starting 14B baseline generation...")
print(f"Prompt: {CONFIG['prompt']}")

baseline_start = time.time()

baseline_video = model_14B.generate(
    input_prompt=CONFIG['prompt'],
    size=(CONFIG['width'], CONFIG['height']),
    frame_num=CONFIG['frame_num'],
    shift=CONFIG['shift'],
    sample_solver=CONFIG['sample_solver'],
    sampling_steps=CONFIG['total_sampling_steps'],
    guide_scale=CONFIG['guide_scale'],
    n_prompt=CONFIG['negative_prompt'],
    seed=CONFIG['seed'],
    offload_model=CONFIG['offload_models']
)

baseline_time = time.time() - baseline_start
print(f"\n✓ 14B baseline completed in {baseline_time:.2f}s")

# Save baseline video
baseline_output_path = os.path.join(
    CONFIG['output_dir'],
    f"baseline_14B_{CONFIG['width']}x{CONFIG['height']}_{timestamp}.mp4"
)
cache_video(baseline_video[None], save_file=baseline_output_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
print(f"✓ Baseline video saved to: {baseline_output_path}")

# Create baseline report
baseline_report = {
    'model': '14B only',
    'total_steps': CONFIG['total_sampling_steps'],
    'total_time': baseline_time,
    'avg_step_time': baseline_time / CONFIG['total_sampling_steps'],
}


Starting 14B baseline generation...
Prompt: Two cartoon cats animation in comfy boxing gear sparring playfully in a cozy living room


100%|██████████| 50/50 [05:34<00:00,  6.69s/it]



✓ 14B baseline completed in 340.64s
✓ Baseline video saved to: /workspace/wan2.1/Wan2.1/outputs/baseline_14B_832x480_20260114_054623.mp4


### Display 14B Baseline Video


In [16]:
# Display hybrid video
print(f"Baseline Video: {baseline_output_path}")
display(Video(baseline_output_path, embed=True, width=832))


Baseline Video: /workspace/wan2.1/Wan2.1/outputs/baseline_14B_832x480_20260114_054623.mp4


## Baseline: 1.3B-Only Generation

Run full 50-step generation with 1.3B model only for comparison.


In [11]:
# Generate baseline with 1.3B only
print("Starting 1.3B baseline generation...")
print(f"Prompt: {CONFIG['prompt']}")

baseline_1_3B_start = time.time()

baseline_1_3B_video = model_1_3B.generate(
    input_prompt=CONFIG['prompt'],
    size=(CONFIG['width'], CONFIG['height']),
    frame_num=CONFIG['frame_num'],
    shift=CONFIG['shift'],
    sample_solver=CONFIG['sample_solver'],
    sampling_steps=CONFIG['total_sampling_steps'],
    guide_scale=CONFIG['guide_scale'],
    n_prompt=CONFIG['negative_prompt'],
    seed=CONFIG['seed'],
    offload_model=CONFIG['offload_models']
)

baseline_1_3B_time = time.time() - baseline_1_3B_start
print(f"\n✓ 1.3B baseline completed in {baseline_1_3B_time:.2f}s")

# Save baseline video
baseline_1_3B_output_path = os.path.join(
    CONFIG['output_dir'],
    f"baseline_1.3B_{CONFIG['width']}x{CONFIG['height']}_{timestamp}.mp4"
)
cache_video(baseline_1_3B_video[None], save_file=baseline_1_3B_output_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
print(f"✓ 1.3B Baseline video saved to: {baseline_1_3B_output_path}")

# Create baseline report
baseline_1_3B_report = {
    'model': '1.3B only',
    'total_steps': CONFIG['total_sampling_steps'],
    'total_time': baseline_1_3B_time,
    'avg_step_time': baseline_1_3B_time / CONFIG['total_sampling_steps'],
}


Starting 1.3B baseline generation...
Prompt: Two cartoon cats animation in comfy boxing gear sparring playfully in a cozy living room


  with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
  with amp.autocast(dtype=torch.float32):
  with amp.autocast(dtype=torch.float32):
  with amp.autocast(dtype=torch.float32):
  with amp.autocast(dtype=torch.float32):
  with amp.autocast(dtype=torch.float32):
100%|██████████| 50/50 [01:12<00:00,  1.44s/it]
  with amp.autocast(dtype=self.dtype):



✓ 1.3B baseline completed in 81.48s
✓ 1.3B Baseline video saved to: /workspace/wan2.1/Wan2.1/outputs/baseline_1.3B_832x480_20260114_053922.mp4


### Display 1.3B Baseline Video


### Side-by-Side Comparison

Display all three videos for easy comparison.


In [None]:
from IPython.display import HTML

# Create a comparison layout with synchronized video playback
html_content = f"""
<style>
    .video-comparison {{
        display: flex;
        flex-direction: column;
        gap: 20px;
        max-width: 900px;
        margin: 0 auto;
    }}
    .sync-controls {{
        text-align: center;
        margin: 20px 0;
        padding: 20px;
        background: #f5f5f5;
        border-radius: 10px;
    }}
    .sync-button {{
        background: #4CAF50;
        color: white;
        border: none;
        padding: 15px 40px;
        font-size: 18px;
        border-radius: 5px;
        cursor: pointer;
        margin: 0 10px;
    }}
    .sync-button:hover {{
        background: #45a049;
    }}
    .sync-button.pause {{
        background: #f44336;
    }}
    .sync-button.pause:hover {{
        background: #da190b;
    }}
    .video-container {{
        text-align: center;
        padding: 15px;
        background: white;
        border-radius: 8px;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
    }}
    .video-container h3 {{
        margin-top: 0;
        color: #333;
    }}
    .video-container video {{
        border-radius: 5px;
    }}
    .time-info {{
        color: #666;
        font-size: 14px;
        margin-top: 10px;
    }}
</style>

<div class="video-comparison">
    <div class="sync-controls">
        <h3>🎬 Synchronized Playback Controls</h3>
        <button class="sync-button" id="playAllBtn" onclick="playAll()">▶ Play All</button>
        <button class="sync-button pause" id="pauseAllBtn" onclick="pauseAll()">⏸ Pause All</button>
        <button class="sync-button" onclick="restartAll()">⏮ Restart All</button>
    </div>
    
    <div class="video-container">
        <h3>🔀 Hybrid Model</h3>
        <p style="font-size: 12px; color: #666;">20 steps 14B + 30 steps 1.3B</p>
        <video id="video1" width="832" controls>
            <source src="{hybrid_output_path}" type="video/mp4">
        </video>
        <p class="time-info">Generation Time: {hybrid_total_time:.2f}s</p>
    </div>
    
    <div class="video-container">
        <h3>🎯 14B Baseline</h3>
        <p style="font-size: 12px; color: #666;">50 steps (highest quality)</p>
        <video id="video2" width="832" controls>
            <source src="{baseline_output_path}" type="video/mp4">
        </video>
        <p class="time-info">Generation Time: {baseline_report['total_time']:.2f}s</p>
    </div>
    
    <div class="video-container">
        <h3>⚡ 1.3B Baseline</h3>
        <p style="font-size: 12px; color: #666;">50 steps (fastest)</p>
        <video id="video3" width="832" controls>
            <source src="{baseline_1_3B_output_path}" type="video/mp4">
        </video>
        <p class="time-info">Generation Time: {baseline_1_3B_report['total_time']:.2f}s</p>
    </div>
</div>

<script>
    // Get all video elements
    const videos = [
        document.getElementById('video1'),
        document.getElementById('video2'),
        document.getElementById('video3')
    ];
    
    // Play all videos synchronously
    function playAll() {{
        videos.forEach(video => {{
            video.play();
        }});
    }}
    
    // Pause all videos
    function pauseAll() {{
        videos.forEach(video => {{
            video.pause();
        }});
    }}
    
    // Restart all videos
    function restartAll() {{
        videos.forEach(video => {{
            video.currentTime = 0;
            video.pause();
        }});
    }}
    
    // Sync videos when one is played
    videos.forEach((video, index) => {{
        video.addEventListener('play', () => {{
            videos.forEach((v, i) => {{
                if (i !== index && v.paused) {{
                    v.play();
                }}
            }});
        }});
        
        video.addEventListener('pause', () => {{
            videos.forEach((v, i) => {{
                if (i !== index && !v.paused) {{
                    v.pause();
                }}
            }});
        }});
        
        // Sync seeking
        video.addEventListener('seeked', () => {{
            const targetTime = video.currentTime;
            videos.forEach((v, i) => {{
                if (i !== index && Math.abs(v.currentTime - targetTime) > 0.3) {{
                    v.currentTime = targetTime;
                }}
            }});
        }});
    }});
    
    // Keep videos in sync during playback
    let syncInterval = setInterval(() => {{
        if (!videos[0].paused) {{
            const mainTime = videos[0].currentTime;
            videos.forEach((video, index) => {{
                if (index > 0 && Math.abs(video.currentTime - mainTime) > 0.3) {{
                    video.currentTime = mainTime;
                }}
            }});
        }}
    }}, 1000);
</script>
"""

display(HTML(html_content))

print("\n📊 Timing Summary:")
print(f"  Hybrid:       {hybrid_total_time:.2f}s")
print(f"  14B Baseline: {baseline_report['total_time']:.2f}s")
print(f"  1.3B Baseline: {baseline_1_3B_report['total_time']:.2f}s")
print(f"\n  Hybrid is {speedup_vs_14B:.2f}x faster than 14B")
print(f"  Hybrid is {1/speedup_vs_1_3B:.2f}x slower than 1.3B")


## Profiling Report

Generate comprehensive profiling report comparing hybrid vs baseline.


In [None]:
# Generate comprehensive report
print("\n" + "="*80)
print("PROFILING REPORT - HYBRID VS BASELINES")
print("="*80)

# Hybrid summary
print("\n### HYBRID MODEL SAMPLING ###")
print(f"Schedule: {SAMPLING_SCHEDULE}")
print(f"\nSegment Details:")
hybrid_total_time = 0
for report in hybrid_reports:
    print(f"\n  {report['segment_name']}")
    print(f"    Steps: {report['num_steps']}")
    print(f"    Total time: {report['total_time']:.2f}s")
    print(f"    Avg step time: {report['avg_step_time']:.3f}s")
    print(f"    Peak memory: {report['memory_peak_allocated_gb']:.2f} GB")
    hybrid_total_time += report['total_time']

print(f"\n  Hybrid Total Time: {hybrid_total_time:.2f}s")
print(f"  Hybrid Avg Step Time: {hybrid_total_time / CONFIG['total_sampling_steps']:.3f}s")

# Baseline summaries
print("\n### BASELINE: 14B ONLY ###")
print(f"Total time: {baseline_report['total_time']:.2f}s")
print(f"Avg step time: {baseline_report['avg_step_time']:.3f}s")

print("\n### BASELINE: 1.3B ONLY ###")
print(f"Total time: {baseline_1_3B_report['total_time']:.2f}s")
print(f"Avg step time: {baseline_1_3B_report['avg_step_time']:.3f}s")

# Comparisons
speedup_vs_14B = baseline_report['total_time'] / hybrid_total_time
speedup_vs_1_3B = baseline_1_3B_report['total_time'] / hybrid_total_time

print("\n### COMPARISON: HYBRID vs 14B BASELINE ###")
print(f"Speedup: {speedup_vs_14B:.2f}x")
print(f"Time saved: {baseline_report['total_time'] - hybrid_total_time:.2f}s")
print(f"Percentage faster: {(1 - 1/speedup_vs_14B) * 100:.1f}%")

print("\n### COMPARISON: HYBRID vs 1.3B BASELINE ###")
if speedup_vs_1_3B > 1:
    print(f"Speedup: {speedup_vs_1_3B:.2f}x")
    print(f"Time saved: {baseline_1_3B_report['total_time'] - hybrid_total_time:.2f}s")
    print(f"Percentage faster: {(1 - 1/speedup_vs_1_3B) * 100:.1f}%")
else:
    slowdown = 1 / speedup_vs_1_3B
    print(f"Slowdown: {slowdown:.2f}x (hybrid is slower)")
    print(f"Extra time: {hybrid_total_time - baseline_1_3B_report['total_time']:.2f}s")
    print(f"Percentage slower: {(slowdown - 1) * 100:.1f}%")

print("\n### QUALITY vs SPEED TRADE-OFF ###")
print(f"14B baseline: {baseline_report['total_time']:.2f}s (highest quality)")
print(f"Hybrid model: {hybrid_total_time:.2f}s (balanced)")
print(f"1.3B baseline: {baseline_1_3B_report['total_time']:.2f}s (fastest)")
print(f"\nHybrid saves {(1 - hybrid_total_time/baseline_report['total_time']) * 100:.1f}% time vs 14B")
print(f"Hybrid adds {(hybrid_total_time/baseline_1_3B_report['total_time'] - 1) * 100:.1f}% time vs 1.3B")

# Save report to JSON
report_data = {
    'timestamp': timestamp,
    'config': CONFIG,
    'sampling_schedule': SAMPLING_SCHEDULE,
    'hybrid': {
        'segments': hybrid_reports,
        'total_time': hybrid_total_time,
        'avg_step_time': hybrid_total_time / CONFIG['total_sampling_steps'],
    },
    'baseline_14B': baseline_report,
    'baseline_1_3B': baseline_1_3B_report,
    'comparison': {
        'hybrid_vs_14B': {
            'speedup': speedup_vs_14B,
            'time_saved': baseline_report['total_time'] - hybrid_total_time,
            'percentage_faster': (1 - 1/speedup_vs_14B) * 100,
        },
        'hybrid_vs_1_3B': {
            'speedup': speedup_vs_1_3B,
            'time_difference': baseline_1_3B_report['total_time'] - hybrid_total_time,
            'percentage_difference': (1 - 1/speedup_vs_1_3B) * 100 if speedup_vs_1_3B > 1 else -(1/speedup_vs_1_3B - 1) * 100,
        }
    },
    'outputs': {
        'hybrid_video': hybrid_output_path,
        'baseline_14B_video': baseline_output_path,
        'baseline_1_3B_video': baseline_1_3B_output_path,
    }
}

report_path = os.path.join(CONFIG['output_dir'], f'profiling_report_{timestamp}.json')
with open(report_path, 'w') as f:
    json.dump(report_data, f, indent=2)

print(f"\n✓ Report saved to: {report_path}")
print("="*80)


## Memory Cleanup

Note: This cell is NOT run by default. Execute manually to clean up memory.


In [None]:
# Memory Cleanup Cell
# Set this to True and run manually when you want to clean up
RUN_CLEANUP = False

if RUN_CLEANUP:
    import gc
    
    print("Cleaning up memory...")
    print(f"Memory before cleanup: {get_gpu_memory()['allocated']:.2f} GB")
    
    # Delete models
    if 'model_14B' in locals():
        del model_14B
    if 'model_1_3B' in locals():
        del model_1_3B
    if 'models' in locals():
        del models
    
    # Delete videos
    if 'hybrid_video' in locals():
        del hybrid_video
    if 'baseline_video' in locals():
        del baseline_video
    if 'baseline_1_3B_video' in locals():
        del baseline_1_3B_video
    
    # Clear CUDA cache
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    gc.collect()
    
    print(f"Memory after cleanup: {get_gpu_memory()['allocated']:.2f} GB")
    print("✓ Cleanup complete")
else:
    print("Cleanup skipped. Set RUN_CLEANUP = True and run this cell manually to clean up.")


## Summary

This notebook implements hybrid video generation using FlowBending framework:

### Key Features:
1. **Flexible Sampling Schedule**: Configure any pattern (LSL, LSSSL, custom)
2. **Latent Space Consistency**: Models switch seamlessly in latent space
3. **Comprehensive Profiling**: Memory usage, latency, and total time per segment
4. **Dual Baseline Comparison**: Full 14B and 1.3B generations for quality vs speed analysis
5. **Acceleration Support**: Cache-DiT and xDiT for faster inference

### Acceleration Options:

**Cache-DiT (Single GPU speedup ~1.3-2x):**
```python
CONFIG['enable_cache_dit'] = True
CONFIG['cache_dit_type'] = 'db'  # or 'taylor', 'scm'
CONFIG['cache_dit_interval'] = 3
CONFIG['cache_dit_start_step'] = 5
```

**xDiT (Multi-GPU speedup ~Nx with N GPUs):**
```python
CONFIG['enable_xdit'] = True
CONFIG['xdit_ulysses_degree'] = 4  # For 4 GPUs
CONFIG['xdit_ring_degree'] = 1
CONFIG['xdit_use_cfg_parallel'] = False
```

### Outputs:
- Hybrid video (saved to outputs directory)
- Baseline 14B video (highest quality, saved to outputs directory)
- Baseline 1.3B video (fastest, saved to outputs directory)
- JSON profiling report with detailed metrics and comparisons

### Usage:
1. Adjust `CONFIG` settings (resolution, prompt, seed, acceleration, etc.)
2. Modify `SAMPLING_SCHEDULE` to experiment with different patterns
3. Run all cells sequentially
4. Review profiling reports and compare all three videos
5. (Optional) Run cleanup cell manually when done

### Expected Results:
- **14B Baseline**: Highest quality, slowest generation
- **1.3B Baseline**: Lower quality, fastest generation  
- **Hybrid**: Balanced quality/speed trade-off
- **With Cache-DiT**: ~1.3-2x faster with minimal quality loss

### Installation:
```bash
# For Cache-DiT support
pip install cache-dit

# For xDiT/multi-GPU support
pip install xfuser
```