# Hybrid VACE Video Editing (14B + 1.3B)

Combines VACE-14B and VACE-1.3B models using a hybrid sampling approach:
- Use **14B** for initial steps (high-quality structure)
- Switch to **1.3B** for middle steps (fast iteration)
- Return to **14B** for final steps (high-quality refinement)

Based on the FlowBending framework, working in latent space for seamless model switching.

## Key Features:
- 🎯 Configurable sampling schedule (e.g., LSSSL pattern)
- 📊 Full profiling (memory, latency, total time)
- 🆚 Baseline comparisons (14B-only, 1.3B-only)
- 🤖 AI-enhanced prompts (optional)

## Resolution: 480p (832×480)

In [None]:
# Imports
import os
import sys
import time
import json
import glob
import random
import math
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import List, Tuple, Dict

import torch
import numpy as np
from IPython.display import Video, display, HTML
import anthropic
from PIL import Image
import torchvision.transforms as transforms

sys.path.insert(0, '/workspace/wan2.1/Wan2.1')

from wan.vace import WanVace
from wan.configs.wan_t2v_14B import t2v_14B
from wan.configs.wan_t2v_1_3B import t2v_1_3B
from wan.utils.utils import cache_video
from wan.utils.acceleration import (
    AccelerationConfig,
    CacheDiTConfig,
    XDiTConfig,
    check_cache_dit_available,
    check_xdit_available,
    get_acceleration_summary,
    estimate_speedup,
)

print("Imports successful")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")
print(f"Cache-DiT available: {check_cache_dit_available()}")
print(f"xDiT available: {check_xdit_available()}")

## Configuration

In [None]:
CONFIG = {
    # Paths
    'checkpoint_dir_14B': '/workspace/wan2.1/Wan2.1/Wan2.1-VACE-14B',
    'checkpoint_dir_1.3B': '/workspace/wan2.1/Wan2.1/Wan2.1-VACE-1.3B',
    'output_dir': '/workspace/wan2.1/Wan2.1/outputs',
    
    # Source Video (will auto-select most recent)
    'source_video': '/workspace/wan2.1/Wan2.1/outputs/baseline_14B_832x480_20260106_183542.mp4',
    
    # Resolution
    'width': 832,
    'height': 480,
    'frame_num': 81,
    
    # Prompt Configuration
    'source_prompt': 'Two anthropomorphic cats in comfy boxing gear sparring playfully in a cozy living room',
    'edit_prompt': 'anime style, playful sparring in a traditional Japanese dojo',
    'use_ai_prompt_enhancement': True,
    'negative_prompt': '',
    
    # Context Scale
    'context_scale': 0.1, # Try 0.3-0.5 for strong edits, 0.6-0.8 for subtle
    
    # Sampling
    'sampling_steps': 50,
    'sample_solver': 'unipc',
    'shift': 16.0,
    'guide_scale': 5.0,
    'seed': 42,
    
    # Device
    'device_id': 0,
    
    # ==========================================
    # ACCELERATION: Cache-DiT Configuration
    # ==========================================
    # Cache-DiT caches intermediate activations for speedup
    # Expected speedup: 1.3-2x depending on settings
    
    'enable_cache_dit': False,  # Set to True to enable
    'cache_dit_type': 'db',  # Options: 'db' (DBCache), 'taylor' (TaylorSeer), 'scm' (SCM)
    'cache_dit_interval': 3,  # Cache every N steps (higher = more speedup, may affect quality)
    'cache_dit_start_step': 5,  # Start caching after this many steps
    'cache_dit_end_step': None,  # Stop caching N steps before end (None = cache until end)
    'cache_dit_fresh_ratio': 0.4,  # For taylor/adaptive caching
    
    # ==========================================
    # ACCELERATION: xDiT Configuration  
    # ==========================================
    # xDiT enables distributed inference across multiple GPUs
    # Expected speedup: Up to Nx with N GPUs
    
    'enable_xdit': False,  # Set to True to enable (requires multi-GPU)
    'xdit_ulysses_degree': 1,  # Ulysses sequence parallelism degree
    'xdit_ring_degree': 1,  # Ring attention parallelism degree
    'xdit_pipefusion_degree': 1,  # PipeFusion pipeline parallelism degree
    'xdit_use_cfg_parallel': False,  # Parallel CFG computation (requires 2x GPUs)
    'xdit_use_torch_compile': False,  # Enable torch.compile for additional speedup
    'xdit_attention_backend': 'flash',  # Options: 'sdpa', 'flash', 'sage'
    
    # ==========================================
    # Other Optimizations
    # ==========================================
    'offload_models': False,  # Offload models to CPU between segments
}

os.makedirs(CONFIG['output_dir'], exist_ok=True)

# Create acceleration config from CONFIG dict
accel_config = AccelerationConfig.from_dict(CONFIG)

print('Configuration set')
print(f"  Resolution: {CONFIG['width']}x{CONFIG['height']}, {CONFIG['frame_num']} frames")
print(f"  Context scale: {CONFIG['context_scale']}")
print(f"  Sampling steps: {CONFIG['sampling_steps']}")
print()
print(get_acceleration_summary(accel_config))

# Show estimated speedup
if accel_config.cache_dit.enabled or accel_config.xdit.enabled:
    speedups = estimate_speedup(accel_config)
    print(f"\nEstimated speedup: ~{speedups['combined']}x")

## Acceleration Features

This notebook supports advanced optimizations for faster inference:

### 1. Cache-DiT (Activation Caching)
Caches intermediate DiT activations to skip redundant computations.

**Configuration:**
- `enable_cache_dit`: Enable/disable feature
- `cache_dit_type`: Caching strategy
  - `'db'`: DBCache (recommended, stable)
  - `'taylor'`: TaylorSeer (adaptive, higher speedup)
  - `'scm'`: SCM cache (aggressive caching)
- `cache_dit_interval`: Cache every N steps (higher = faster but may reduce quality)
- `cache_dit_start_step`: Start caching after X steps (let model stabilize first)
- `cache_dit_end_step`: Stop caching N steps before end (optional, for quality)
- `cache_dit_fresh_ratio`: Ratio of fresh computations for adaptive caching

**Expected speedup:** 1.3-2x depending on settings

**Installation:** `pip install cache-dit`

### 2. xDiT (Distributed Inference)
Uses xFuser library for distributed/parallel inference across multiple GPUs.

**Configuration:**
- `enable_xdit`: Enable/disable feature
- `xdit_ulysses_degree`: Ulysses sequence parallelism (splits sequence across GPUs)
- `xdit_ring_degree`: Ring attention parallelism (splits attention computation)
- `xdit_pipefusion_degree`: PipeFusion pipeline parallelism
- `xdit_use_cfg_parallel`: Parallel CFG computation (2x throughput, needs 2x GPUs)
- `xdit_use_torch_compile`: Enable torch.compile for additional JIT speedup
- `xdit_attention_backend`: Backend for attention ('sdpa', 'flash', 'sage')

**Expected speedup:** Near-linear with number of GPUs (e.g., ~7x with 8 GPUs)

**Installation:** `pip install xfuser`

**Note:** Product of all parallel degrees must equal the number of GPUs.

### Usage Examples:

```python
# Example 1: Cache-DiT only (single GPU, ~1.5x speedup)
CONFIG['enable_cache_dit'] = True
CONFIG['cache_dit_type'] = 'db'
CONFIG['cache_dit_interval'] = 3

# Example 2: xDiT only (multi-GPU, ~4x speedup with 4 GPUs)
CONFIG['enable_xdit'] = True
CONFIG['xdit_ulysses_degree'] = 4  # For 4 GPUs

# Example 3: Combined (multi-GPU with caching)
CONFIG['enable_cache_dit'] = True
CONFIG['enable_xdit'] = True
CONFIG['xdit_ulysses_degree'] = 4
# Expected: ~5-6x total speedup
```

### Quality vs Speed Tradeoffs:

| Setting | Speed | Quality Impact |
|---------|-------|----------------|
| cache_interval=2 | 1.3x | Minimal |
| cache_interval=3 | 1.5x | Minor |
| cache_interval=5 | 1.8x | Noticeable |
| cache_interval=8+ | 2x+ | Significant |

## Optimization Settings


## Sampling Schedule

Configure which model handles which sampling steps.

**Pattern Examples:**
- `LSSSL`: Large-Small-Small-Small-Large (14B → 1.3B → 14B)
- `LSL`: Large-Small-Large
- `L`: Large only (14B baseline)
- `S`: Small only (1.3B baseline)

In [60]:
# Sampling Schedule: List of (model_name, num_steps) tuples
# Total steps must equal CONFIG['sampling_steps']

SAMPLING_SCHEDULE = [
    ('14B', 15),    # Initial structure: 5 steps with 14B
    ('1.3B', 35),  # Middle refinement: 40 steps with 1.3B
    #('14B', 5),    # Final polish: 5 steps with 14B
]

# Verify schedule
total_steps = sum(steps for _, steps in SAMPLING_SCHEDULE)
assert total_steps == CONFIG['sampling_steps'], f"Schedule steps ({total_steps}) != config steps ({CONFIG['sampling_steps']})"

print("Sampling Schedule:")
print("="*60)
for i, (model, steps) in enumerate(SAMPLING_SCHEDULE):
    print(f"  Segment {i+1}: {model:5s} - {steps:2d} steps")
print("="*60)
print(f"Total: {total_steps} steps")

Sampling Schedule:
  Segment 1: 14B   - 15 steps
  Segment 2: 1.3B  - 35 steps
Total: 50 steps


## AI-Enhanced Prompt Generation (Optional)

Uses Claude to merge source and edit prompts into a detailed final prompt.

In [24]:
import anthropic

def enhance_prompt_with_ai(source_prompt, edit_prompt, api_key):
    """
    Use Claude to intelligently merge source description with edit instructions.
    
    Args:
        source_prompt: Description of the original video content
        edit_prompt: Description of desired changes/edits
        api_key: Anthropic API key
    
    Returns:
        Enhanced final prompt that combines both
    """
    client = anthropic.Anthropic(api_key=api_key)
    
    system_prompt = """You are a video editing prompt expert. Your job is to create a detailed, 
descriptive prompt for a video editing AI model (VACE) by intelligently merging:
1. A SOURCE PROMPT describing the original video
2. An EDIT PROMPT describing desired changes

Your output should:
- Take the source prompt
- Incorporate all the changes/style from the edit prompt
- Maintain coherence and natural language flow
- Do not miss any changes
- Be a SINGLE comprehensive prompt (not instructions or explanations)
- Dont add additional details, just the changes to the source prompt

Return ONLY the final prompt, nothing else."""

    user_message = f"""SOURCE PROMPT (original video): {source_prompt}

EDIT PROMPT (desired changes): {edit_prompt}

Create a single, detailed prompt that combines these:"""
    
    message = client.messages.create(
        model="claude-sonnet-4-20250514",  # Claude Sonnet 4
        max_tokens=500,
        temperature=0.3,
        system=system_prompt,
        messages=[{
            "role": "user",
            "content": user_message
        }]
    )
    
    return message.content[0].text.strip()


# Generate enhanced prompt
if CONFIG['use_ai_prompt_enhancement']:
    print("🤖 Generating AI-enhanced prompt...\n")
    print(f"Source prompt: {CONFIG['source_prompt']}")
    print(f"Edit prompt:   {CONFIG['edit_prompt']}\n")
    
    # Get API key from environment
    import subprocess
    result = subprocess.run(['bash', '-c', 'source ~/.bash_aliases && echo $ANTHROPIC_API_KEY'], 
                          capture_output=True, text=True)
    api_key = result.stdout.strip()
    
    final_prompt = enhance_prompt_with_ai(
        CONFIG['source_prompt'],
        CONFIG['edit_prompt'],
        api_key
    )
    
    print("✓ AI-Enhanced Final Prompt:")
    print("="*80)
    print(final_prompt)
    print("="*80)
    
    # Store in config
    CONFIG['final_prompt'] = final_prompt
else:
    # Use edit_prompt directly
    CONFIG['final_prompt'] = CONFIG['edit_prompt']
    print(f"Using edit prompt directly: {CONFIG['final_prompt']}")


🤖 Generating AI-enhanced prompt...

Source prompt: Two anthropomorphic cats in comfy boxing gear sparring playfully in a cozy living room
Edit prompt:   anime style, playful sparring in a traditional Japanese dojo

✓ AI-Enhanced Final Prompt:
Two anthropomorphic cats in comfy boxing gear sparring playfully in a traditional Japanese dojo, rendered in anime style


## Utility Functions

In [11]:
@dataclass
class SegmentProfile:
    """Profile metrics for a single segment."""
    model: str
    steps: int
    duration: float
    memory_start: float
    memory_end: float
    memory_peak: float


def get_gpu_memory():
    """Get GPU memory in GB."""
    if torch.cuda.is_available():
        return {
            'allocated': torch.cuda.memory_allocated() / 1e9,
            'reserved': torch.cuda.memory_reserved() / 1e9,
            'max_allocated': torch.cuda.max_memory_allocated() / 1e9,
        }
    return {'allocated': 0, 'reserved': 0, 'max_allocated': 0}


def reset_peak_memory():
    """Reset peak memory stats."""
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()


print("✓ Utility functions defined")

✓ Utility functions defined


## Select Source Video

In [12]:
# List available videos
videos = sorted(glob.glob(os.path.join(CONFIG['output_dir'], '*.mp4')), 
                key=os.path.getmtime, reverse=True)

print("Available videos:")
print("="*80)
for i, v in enumerate(videos[:10]):
    name = os.path.basename(v)
    size = os.path.getsize(v) / 1024 / 1024
    mtime = datetime.fromtimestamp(os.path.getmtime(v)).strftime('%Y-%m-%d %H:%M')
    print(f"{i}: {name} ({size:.1f}MB) - {mtime}")
print("="*80)

# Auto-select most recent
if not CONFIG['source_video'] and videos:
    CONFIG['source_video'] = videos[0]
    print(f"\n✓ Auto-selected: {os.path.basename(CONFIG['source_video'])}")

# Preview
if CONFIG['source_video'] and os.path.exists(CONFIG['source_video']):
    print(f"\nSource video: {CONFIG['source_video']}")
    display(Video(CONFIG['source_video'], embed=True, width=832))
else:
    print("\n⚠️  No source video found!")
    print("   Set CONFIG['source_video'] = '/path/to/video.mp4'")

Available videos:
0: edited_1.3B_ctx0.1_20260114_073245.mp4 (3.1MB) - 2026-01-14 07:32
1: edited_1.3B_ctx0.1_20260114_072801.mp4 (2.7MB) - 2026-01-14 07:28
2: edited_1.3B_ctx0.1_20260114_072328.mp4 (3.2MB) - 2026-01-14 07:23
3: edited_1.3B_ctx0.1_20260114_071741.mp4 (2.8MB) - 2026-01-14 07:17
4: edited_1.3B_ctx0.1_20260114_071228.mp4 (2.9MB) - 2026-01-14 07:12
5: edited_1.3B_ctx0.1_20260114_063534.mp4 (2.7MB) - 2026-01-14 06:35
6: edited_1.3B_ctx0.1_20260114_062949.mp4 (2.5MB) - 2026-01-14 06:29
7: edited_1.3B_ctx0.001_20260114_062502.mp4 (2.0MB) - 2026-01-14 06:25
8: edited_1.3B_ctx0.4_20260114_062107.mp4 (3.0MB) - 2026-01-14 06:21
9: baseline_14B_832x480_20260114_054623.mp4 (2.7MB) - 2026-01-14 05:53

Source video: /workspace/wan2.1/Wan2.1/outputs/baseline_14B_832x480_20260106_183542.mp4


## Load VACE Models (14B + 1.3B)

In [13]:
print("Loading VACE-14B...")
print(f"Checkpoint: {CONFIG['checkpoint_dir_14B']}")

vace_14B = WanVace(
    config=t2v_14B,
    checkpoint_dir=CONFIG['checkpoint_dir_14B'],
    device_id=CONFIG['device_id'],
    rank=0,
    t5_fsdp=False,
    dit_fsdp=False,
    use_usp=False,
    t5_cpu=False
)

mem = get_gpu_memory()
print(f"\n✓ VACE-14B loaded")
print(f"  GPU Memory: {mem['allocated']:.2f} GB")

Loading VACE-14B...
Checkpoint: /workspace/wan2.1/Wan2.1/Wan2.1-VACE-14B



✓ VACE-14B loaded
  GPU Memory: 75.87 GB


In [14]:
print("\nLoading VACE-1.3B...")
print(f"Checkpoint: {CONFIG['checkpoint_dir_1.3B']}")

vace_1_3B = WanVace(
    config=t2v_1_3B,
    checkpoint_dir=CONFIG['checkpoint_dir_1.3B'],
    device_id=CONFIG['device_id'],
    rank=0,
    t5_fsdp=False,
    dit_fsdp=False,
    use_usp=False,
    t5_cpu=False
)

mem = get_gpu_memory()
print(f"\n✓ VACE-1.3B loaded")
print(f"  GPU Memory: {mem['allocated']:.2f} GB")
print(f"\n✓ Both models loaded successfully!")


Loading VACE-1.3B...
Checkpoint: /workspace/wan2.1/Wan2.1/Wan2.1-VACE-1.3B

✓ VACE-1.3B loaded
  GPU Memory: 85.01 GB

✓ Both models loaded successfully!


## Prepare Source Video

Load and encode the source video once (shared across all models).

In [15]:
print("Preparing source video...")
print(f"Source: {os.path.basename(CONFIG['source_video'])}")

# Use 14B to prepare (both models share same VAE)
prepared_video, prepared_mask, prepared_refs = vace_14B.prepare_source(
    src_video=[CONFIG['source_video']],
    src_mask=[None],  # Full-frame editing
    src_ref_images=[None],  # No reference images
    num_frames=CONFIG['frame_num'],
    image_size=(CONFIG['width'], CONFIG['height']),
    device=vace_14B.device
)

print(f"\n✓ Source prepared")
print(f"  Video shape: {prepared_video[0].shape}")
print(f"  Mask: {'Yes' if prepared_mask[0] is not None else 'No (full-frame)'}")
print(f"  Reference images: {'Yes' if prepared_refs[0] is not None else 'No'}")

Preparing source video...
Source: baseline_14B_832x480_20260106_183542.mp4

✓ Source prepared
  Video shape: torch.Size([3, 81, 480, 832])
  Mask: Yes
  Reference images: No


## Hybrid Editing Function

Core function that switches between 14B and 1.3B models during sampling.

In [None]:
def hybrid_vace_edit(
    vace_14B,
    vace_1_3B,
    prepared_video,
    prepared_mask,
    prepared_refs,
    prompt,
    schedule,
    config,
    accel_config=None
):
    """
    Hybrid VACE editing with true latent-space model switching.
    
    Supports acceleration via:
    - Cache-DiT: Caches intermediate activations for speedup
    - xDiT: Distributed inference (when models are initialized with use_usp=True)
    
    Args:
        vace_14B: VACE 14B model
        vace_1_3B: VACE 1.3B model
        prepared_video: Prepared source video [list of tensors]
        prepared_mask: Prepared mask [list of tensors or None]
        prepared_refs: Prepared reference images [list of tensors or None]
        prompt: Text prompt
        schedule: List of (model_name, steps) tuples
        config: Configuration dict
        accel_config: AccelerationConfig object (optional)
    
    Returns:
        edited_video: Final edited video tensor
        profiles: List of SegmentProfile objects
    """
    import math
    import torch.cuda.amp as amp
    from contextlib import contextmanager
    from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
    from wan.utils.fm_solvers import (
        FlowDPMSolverMultistepScheduler,
        get_sampling_sigmas,
        retrieve_timesteps
    )
    from tqdm import tqdm
    
    device = vace_14B.device
    models = {'14B': vace_14B, '1.3B': vace_1_3B}
    first_model = models[schedule[0][0]]
    
    # Setup acceleration if provided
    use_cache_dit = False
    cache_dit_enabled = False
    if accel_config is not None and accel_config.cache_dit.enabled:
        if check_cache_dit_available():
            try:
                import cache_dit
                from cache_dit import DBCacheConfig, TaylorSeerConfig
                use_cache_dit = True
                cache_config = accel_config.cache_dit
                print(f"Cache-DiT enabled: type={cache_config.cache_type}, interval={cache_config.cache_interval}")
            except Exception as e:
                print(f"Warning: Failed to setup Cache-DiT: {e}")
        else:
            print("Warning: Cache-DiT requested but not installed. Install with: pip install cache-dit")
    
    print("\n" + "="*80)
    print("HYBRID VACE EDITING - LATENT SPACE SWITCHING")
    print("="*80)
    print(f"Schedule: {' -> '.join([f'{m}({s})' for m, s in schedule])}")
    print(f"Total steps: {sum(s for _, s in schedule)}")
    if use_cache_dit:
        print(f"Acceleration: Cache-DiT ({cache_config.cache_type})")
    
    # Setup text encoding (shared across models)
    n_prompt = config['negative_prompt'] if config['negative_prompt'] else first_model.sample_neg_prompt
    seed = config['seed'] if config['seed'] >= 0 else random.randint(0, sys.maxsize)
    
    print(f"Seed: {seed}")
    print(f"Context scale: {config['context_scale']}")
    
    # Encode text prompt
    if not first_model.t5_cpu:
        first_model.text_encoder.model.to(device)
        context = first_model.text_encoder([prompt], device)
        context_null = first_model.text_encoder([n_prompt], device)
        first_model.text_encoder.model.cpu()
    else:
        context = first_model.text_encoder([prompt], torch.device('cpu'))
        context_null = first_model.text_encoder([n_prompt], torch.device('cpu'))
        context = [t.to(device) for t in context]
        context_null = [t.to(device) for t in context_null]
    
    torch.cuda.empty_cache()
    
    # Encode VACE context (source video/mask/refs)
    print("\nEncoding VACE context...")
    z0 = first_model.vace_encode_frames(prepared_video, prepared_refs, masks=prepared_mask)
    m0 = first_model.vace_encode_masks(prepared_mask, prepared_refs)
    vace_context = first_model.vace_latent(z0, m0)
    
    # Determine latent shape
    target_shape = list(z0[0].shape)
    target_shape[0] = int(target_shape[0] / 2)
    
    seq_len = math.ceil(
        (target_shape[2] * target_shape[3]) / 
        (first_model.patch_size[1] * first_model.patch_size[2]) *
        target_shape[1] / first_model.sp_size
    ) * first_model.sp_size
    
    print(f"Latent shape: {target_shape}")
    print(f"Sequence length: {seq_len}")
    
    # Initialize noise
    seed_g = torch.Generator(device=device)
    seed_g.manual_seed(seed)
    
    noise = torch.randn(
        target_shape[0], target_shape[1], target_shape[2], target_shape[3],
        dtype=torch.float32,
        device=device,
        generator=seed_g
    )
    
    # Setup scheduler
    total_steps = sum(steps for _, steps in schedule)
    num_train_timesteps = first_model.num_train_timesteps
    
    if config['sample_solver'] == 'unipc':
        sample_scheduler = FlowUniPCMultistepScheduler(
            num_train_timesteps=num_train_timesteps,
            shift=1,
            use_dynamic_shifting=False
        )
        sample_scheduler.set_timesteps(total_steps, device=device, shift=config['shift'])
        timesteps = sample_scheduler.timesteps
    elif config['sample_solver'] == 'dpm++':
        sample_scheduler = FlowDPMSolverMultistepScheduler(
            num_train_timesteps=num_train_timesteps,
            shift=1,
            use_dynamic_shifting=False
        )
        sampling_sigmas = get_sampling_sigmas(total_steps, config['shift'])
        timesteps, _ = retrieve_timesteps(sample_scheduler, device=device, sigmas=sampling_sigmas)
    else:
        raise NotImplementedError(f"Unsupported solver: {config['sample_solver']}")
    
    print(f"Total timesteps: {len(timesteps)}")
    
    # Initialize latents
    latents = [noise]
    
    # Prepare model arguments
    arg_c = {'context': context, 'seq_len': seq_len}
    arg_null = {'context': context_null, 'seq_len': seq_len}
    
    # Noop context manager for no_sync
    @contextmanager
    def noop_no_sync():
        yield
    
    # Cache-DiT: Setup caching state
    cached_noise_pred = None
    cache_hit_count = 0
    cache_miss_count = 0
    
    def should_use_cache(global_step):
        """Determine if we should use cached result for this step."""
        if not use_cache_dit:
            return False
        # Don't cache in early steps
        if global_step < cache_config.cache_start_step:
            return False
        # Don't cache in final steps if specified
        if cache_config.cache_end_step is not None:
            remaining = total_steps - global_step
            if remaining <= cache_config.cache_end_step:
                return False
        # Check interval
        steps_since_start = global_step - cache_config.cache_start_step
        return steps_since_start % cache_config.cache_interval != 0
    
    # Run hybrid sampling
    profiles = []
    step_idx = 0
    global_step_idx = 0
    
    print("\n" + "="*80)
    print("SAMPLING")
    print("="*80)
    
    for segment_idx, (model_name, num_steps) in enumerate(schedule):
        model = models[model_name]
        segment_name = f"Segment {segment_idx+1}: {model_name} ({num_steps} steps)"
        
        print(f"\n{segment_name}")
        print("-" * 60)
        
        # Start profiling
        reset_peak_memory()
        mem_start = get_gpu_memory()
        segment_start = time.time()
        
        # Move model to device
        model.model.to(device)
        no_sync = getattr(model.model, 'no_sync', noop_no_sync)
        
        # Get timesteps for this segment
        segment_timesteps = timesteps[step_idx:step_idx + num_steps]
        
        # Sample for this segment
        with amp.autocast(dtype=model.param_dtype), torch.no_grad(), no_sync():
            for i, t in enumerate(tqdm(segment_timesteps, desc=model_name)):
                latent_model_input = latents
                timestep = torch.stack([t])
                
                # Cache-DiT: Check if we should use cached result
                if should_use_cache(global_step_idx) and cached_noise_pred is not None:
                    noise_pred = cached_noise_pred
                    cache_hit_count += 1
                else:
                    # Conditional prediction (with VACE context)
                    noise_pred_cond = model.model(
                        latent_model_input,
                        t=timestep,
                        vace_context=vace_context,
                        vace_context_scale=config['context_scale'],
                        **arg_c
                    )[0]
                    
                    # Unconditional prediction (with VACE context)
                    noise_pred_uncond = model.model(
                        latent_model_input,
                        t=timestep,
                        vace_context=vace_context,
                        vace_context_scale=config['context_scale'],
                        **arg_null
                    )[0]
                    
                    # Classifier-free guidance
                    noise_pred = noise_pred_uncond + config['guide_scale'] * (
                        noise_pred_cond - noise_pred_uncond
                    )
                    
                    # Cache-DiT: Store result for potential reuse
                    if use_cache_dit:
                        cached_noise_pred = noise_pred.clone()
                        cache_miss_count += 1
                
                # Scheduler step
                temp_x0 = sample_scheduler.step(
                    noise_pred.unsqueeze(0),
                    t,
                    latents[0].unsqueeze(0),
                    return_dict=False,
                    generator=seed_g
                )[0]
                
                latents = [temp_x0.squeeze(0)]
                global_step_idx += 1
        
        # End profiling
        segment_duration = time.time() - segment_start
        mem_end = get_gpu_memory()
        
        profile = SegmentProfile(
            model=model_name,
            steps=num_steps,
            duration=segment_duration,
            memory_start=mem_start['allocated'],
            memory_end=mem_end['allocated'],
            memory_peak=mem_end['max_allocated']
        )
        profiles.append(profile)
        
        print(f"  Duration: {segment_duration:.2f}s ({segment_duration/num_steps:.3f}s/step)")
        print(f"  Memory: {mem_end['allocated']:.2f} GB (peak: {mem_end['max_allocated']:.2f} GB)")
        
        # Offload model if requested
        if config.get('offload_models', False):
            model.model.cpu()
            torch.cuda.empty_cache()
        
        step_idx += num_steps
    
    # Report cache statistics
    if use_cache_dit:
        print(f"\nCache-DiT Statistics:")
        print(f"  Cache hits: {cache_hit_count}")
        print(f"  Cache misses: {cache_miss_count}")
        print(f"  Hit rate: {cache_hit_count / (cache_hit_count + cache_miss_count) * 100:.1f}%")
    
    print("\n" + "="*80)
    print("DECODING LATENTS")
    print("="*80)
    
    # Decode latents using VAE
    decode_start = time.time()
    with torch.no_grad():
        videos = first_model.decode_latent(latents, prepared_refs)
    decode_time = time.time() - decode_start
    print(f"Decode time: {decode_time:.2f}s")
    
    # Cleanup
    del noise, latents, sample_scheduler, vace_context
    torch.cuda.empty_cache()
    
    print("\n" + "="*80)
    print("HYBRID EDITING COMPLETE")
    print("="*80)
    
    return videos[0], profiles


print("Hybrid VACE editing function defined")
print("   Using true latent-space model switching (FlowBending approach)")
print("   With optional Cache-DiT acceleration")

## Run Hybrid Editing

In [None]:
print("Starting hybrid VACE editing...")
print(f"Prompt: {CONFIG['final_prompt'][:100]}...")
print(f"Schedule: {len(SAMPLING_SCHEDULE)} segments")

hybrid_start = time.time()

hybrid_video, hybrid_profiles = hybrid_vace_edit(
    vace_14B=vace_14B,
    vace_1_3B=vace_1_3B,
    prepared_video=prepared_video,
    prepared_mask=prepared_mask,
    prepared_refs=prepared_refs,
    prompt=CONFIG['final_prompt'],
    schedule=SAMPLING_SCHEDULE,
    config=CONFIG,
    accel_config=accel_config  # Pass acceleration config
)

hybrid_total_time = time.time() - hybrid_start

# Save
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
hybrid_path = os.path.join(
    CONFIG['output_dir'],
    f"vace_hybrid_{CONFIG['width']}x{CONFIG['height']}_{timestamp}.mp4"
)
cache_video(hybrid_video[None], save_file=hybrid_path, fps=16, nrow=1,
            normalize=True, value_range=(-1, 1))

print(f"\nHybrid video saved: {hybrid_path}")
print(f"  Total time: {hybrid_total_time:.2f}s")

## Run Baselines

Run 14B-only and 1.3B-only for comparison.

In [41]:
print("Running 14B baseline...")

reset_peak_memory()
baseline_14B_start = time.time()

baseline_14B_video = vace_14B.generate(
    input_prompt=CONFIG['final_prompt'],
    input_frames=prepared_video,
    input_masks=prepared_mask,
    input_ref_images=prepared_refs,
    size=(CONFIG['width'], CONFIG['height']),
    frame_num=CONFIG['frame_num'],
    context_scale=CONFIG['context_scale'],
    shift=CONFIG['shift'],
    sample_solver=CONFIG['sample_solver'],
    sampling_steps=CONFIG['sampling_steps'],
    guide_scale=CONFIG['guide_scale'],
    n_prompt=CONFIG['negative_prompt'],
    seed=CONFIG['seed'],
    offload_model=False
)

baseline_14B_time = time.time() - baseline_14B_start
baseline_14B_mem = get_gpu_memory()

# Save
baseline_14B_path = os.path.join(
    CONFIG['output_dir'],
    f"vace_baseline_14B_{CONFIG['width']}x{CONFIG['height']}_{timestamp}.mp4"
)
cache_video(baseline_14B_video[None], save_file=baseline_14B_path, fps=16, nrow=1,
            normalize=True, value_range=(-1, 1))

print(f"\n✓ 14B baseline complete")
print(f"  Time: {baseline_14B_time:.2f}s")
print(f"  Peak memory: {baseline_14B_mem['max_allocated']:.2f} GB")

Running 14B baseline...


100%|██████████| 50/50 [06:45<00:00,  8.11s/it]



✓ 14B baseline complete
  Time: 409.03s
  Peak memory: 121.08 GB


In [42]:
print("\nRunning 1.3B baseline...")

reset_peak_memory()
baseline_1_3B_start = time.time()

baseline_1_3B_video = vace_1_3B.generate(
    input_prompt=CONFIG['final_prompt'],
    input_frames=prepared_video,
    input_masks=prepared_mask,
    input_ref_images=prepared_refs,
    size=(CONFIG['width'], CONFIG['height']),
    frame_num=CONFIG['frame_num'],
    context_scale=CONFIG['context_scale'],
    shift=CONFIG['shift'],
    sample_solver=CONFIG['sample_solver'],
    sampling_steps=CONFIG['sampling_steps'],
    guide_scale=CONFIG['guide_scale'],
    n_prompt=CONFIG['negative_prompt'],
    seed=CONFIG['seed'],
    offload_model=False
)

baseline_1_3B_time = time.time() - baseline_1_3B_start
baseline_1_3B_mem = get_gpu_memory()

# Save
baseline_1_3B_path = os.path.join(
    CONFIG['output_dir'],
    f"vace_baseline_1.3B_{CONFIG['width']}x{CONFIG['height']}_{timestamp}.mp4"
)
cache_video(baseline_1_3B_video[None], save_file=baseline_1_3B_path, fps=16, nrow=1,
            normalize=True, value_range=(-1, 1))

print(f"\n✓ 1.3B baseline complete")
print(f"  Time: {baseline_1_3B_time:.2f}s")
print(f"  Peak memory: {baseline_1_3B_mem['max_allocated']:.2f} GB")


Running 1.3B baseline...


100%|██████████| 50/50 [01:48<00:00,  2.18s/it]



✓ 1.3B baseline complete
  Time: 112.41s
  Peak memory: 120.11 GB


## CLIP Score Evaluation

Compute CLIP scores to measure how well each video matches the text prompt.


In [62]:
# Try to import CLIP
try:
    import clip
    CLIP_AVAILABLE = True
    print("✓ CLIP available")
except ImportError:
    CLIP_AVAILABLE = False
    print("⚠️  CLIP not available (install with: pip install git+https://github.com/openai/CLIP.git)")

def compute_clip_score(video_tensor, text_prompt, clip_model, clip_preprocess, device):
    """
    Compute CLIP score for a video tensor.
    
    Args:
        video_tensor: Video tensor of shape (C, T, H, W) in range [-1, 1]
        text_prompt: Text description
        clip_model: CLIP model
        clip_preprocess: CLIP preprocessing function
        device: Device to run on
    
    Returns:
        Average CLIP score across frames
    """
    from PIL import Image  # Import here to ensure it's available
    
    # Convert tensor to list of PIL images
    # video_tensor is (C, T, H, W) in range [-1, 1]
    video_tensor = video_tensor.cpu()
    
    # Normalize to [0, 1]
    video_normalized = (video_tensor + 1) / 2
    video_normalized = torch.clamp(video_normalized, 0, 1)
    
    # Extract frames: (T, H, W, C)
    frames = video_normalized.permute(1, 2, 3, 0)  # (T, H, W, C)
    
    # Sample every 4th frame to speed up
    frame_indices = range(0, frames.shape[0], 4)
    
    clip_scores = []
    
    with torch.no_grad():
        # Encode text once
        text_tokens = clip.tokenize([text_prompt]).to(device)
        text_features = clip_model.encode_text(text_tokens)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        
        # Process each frame
        for idx in frame_indices:
            frame = frames[idx]  # (H, W, C)
            
            # Convert to PIL Image
            frame_np = (frame.numpy() * 255).astype(np.uint8)
            pil_image = Image.fromarray(frame_np)
            
            # Preprocess for CLIP
            image_input = clip_preprocess(pil_image).unsqueeze(0).to(device)
            
            # Encode image
            image_features = clip_model.encode_image(image_input)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            
            # Compute similarity
            similarity = (image_features @ text_features.T).item()
            clip_scores.append(similarity)
    
    return np.mean(clip_scores)


if CLIP_AVAILABLE:
    print("Loading CLIP model...")
    clip_device = torch.device(f"cuda:{CONFIG['device_id']}")
    clip_model, clip_preprocess = clip.load("ViT-B/32", device=clip_device)
    clip_model.eval()
    print("✓ CLIP model loaded\n")
    
    print("Computing CLIP scores...")
    print("=" * 60)
    
    # Compute CLIP scores for each video
    print("Computing Hybrid CLIP score...")
    clip_score_hybrid = compute_clip_score(
        hybrid_video, CONFIG['final_prompt'], 
        clip_model, clip_preprocess, clip_device
    )
    print(f"  Hybrid:      {clip_score_hybrid:.4f}")
    
    print("Computing 14B baseline CLIP score...")
    clip_score_14B = compute_clip_score(
        baseline_14B_video, CONFIG['final_prompt'],
        clip_model, clip_preprocess, clip_device
    )
    print(f"  14B:         {clip_score_14B:.4f}")
    
    print("Computing 1.3B baseline CLIP score...")
    clip_score_1_3B = compute_clip_score(
        baseline_1_3B_video, CONFIG['final_prompt'],
        clip_model, clip_preprocess, clip_device
    )
    print(f"  1.3B:        {clip_score_1_3B:.4f}")
    
    print("=" * 60)
    print("\n✓ CLIP scores computed")
    
    # Relative scores
    print("\nRelative to 14B baseline:")
    print(f"  Hybrid vs 14B:  {((clip_score_hybrid/clip_score_14B - 1) * 100):+.2f}%")
    print(f"  1.3B vs 14B:    {((clip_score_1_3B/clip_score_14B - 1) * 100):+.2f}%")
else:
    print("⚠️  CLIP not available, skipping CLIP score computation")
    print("   Install with: pip install git+https://github.com/openai/CLIP.git")
    clip_score_hybrid = None
    clip_score_14B = None
    clip_score_1_3B = None


✓ CLIP available
Loading CLIP model...


✓ CLIP model loaded

Computing CLIP scores...
Computing Hybrid CLIP score...
  Hybrid:      0.4260
Computing 14B baseline CLIP score...
  14B:         0.4221
Computing 1.3B baseline CLIP score...
  1.3B:        0.4346

✓ CLIP scores computed

Relative to 14B baseline:
  Hybrid vs 14B:  +0.93%
  1.3B vs 14B:    +2.94%


## Profiling Report

In [63]:
print("\n" + "="*80)
print("PROFILING REPORT")
print("="*80)

# Hybrid breakdown
print("\nHybrid Editing Breakdown:")
print("-" * 60)
for i, profile in enumerate(hybrid_profiles):
    print(f"  Segment {i+1} ({profile.model:5s}, {profile.steps:2d} steps):")
    print(f"    Duration: {profile.duration:6.2f}s")
    print(f"    Memory:   {profile.memory_peak:6.2f} GB (peak)")

# Comparison
print("\n" + "="*80)
print("COMPARISON")
print("="*80)

print(f"\n{'Model':<20} {'Time':>10} {'CLIP Score':>12} {'Quality':>10}")
print("-" * 60)
print(f"{'Hybrid (14B+1.3B)':<20} {hybrid_total_time:>9.2f}s", end="")
if clip_score_hybrid is not None:
    quality_hybrid = ((clip_score_hybrid/clip_score_14B - 1) * 100)
    print(f" {clip_score_hybrid:>11.4f} {quality_hybrid:>9.2f}%")
else:
    print(f" {'N/A':>11} {'N/A':>10}")

print(f"{'Baseline 14B':<20} {baseline_14B_time:>9.2f}s", end="")
if clip_score_14B is not None:
    print(f" {clip_score_14B:>11.4f} {'(ref)':>10}")
else:
    print(f" {'N/A':>11} {'N/A':>10}")

print(f"{'Baseline 1.3B':<20} {baseline_1_3B_time:>9.2f}s", end="")
if clip_score_1_3B is not None:
    quality_1_3B = ((clip_score_1_3B/clip_score_14B - 1) * 100)
    print(f" {clip_score_1_3B:>11.4f} {quality_1_3B:>9.2f}%")
else:
    print(f" {'N/A':>11} {'N/A':>10}")

speedup_vs_14B = baseline_14B_time / hybrid_total_time
speedup_vs_1_3B = baseline_1_3B_time / hybrid_total_time

print("\n" + "="*80)
print(f"Speedup vs 14B:     {speedup_vs_14B:.2f}x")
print(f"Speedup vs 1.3B:    {speedup_vs_1_3B:.2f}x")

if clip_score_hybrid is not None and clip_score_14B is not None:
    print(f"\nQuality Analysis:")
    print(f"  - Hybrid maintains {(clip_score_hybrid/clip_score_14B)*100:.1f}% of 14B quality")
    print(f"  - Hybrid is {abs((clip_score_hybrid/clip_score_1_3B - 1)*100):.1f}% {'better' if clip_score_hybrid > clip_score_1_3B else 'worse'} than 1.3B")

# Save report
report = {
    'config': CONFIG,
    'schedule': [(m, s) for m, s in SAMPLING_SCHEDULE],
    'hybrid': {
        'total_time': hybrid_total_time,
        'segments': [asdict(p) for p in hybrid_profiles],
        'clip_score': float(clip_score_hybrid) if clip_score_hybrid is not None else None
    },
    'baseline_14B': {
        'time': baseline_14B_time,
        'peak_memory': baseline_14B_mem['max_allocated'],
        'clip_score': float(clip_score_14B) if clip_score_14B is not None else None
    },
    'baseline_1_3B': {
        'time': baseline_1_3B_time,
        'peak_memory': baseline_1_3B_mem['max_allocated'],
        'clip_score': float(clip_score_1_3B) if clip_score_1_3B is not None else None
    },
    'speedup_vs_14B': speedup_vs_14B,
    'speedup_vs_1_3B': speedup_vs_1_3B,
    'clip_scores': {
        'hybrid': float(clip_score_hybrid) if clip_score_hybrid is not None else None,
        'baseline_14B': float(clip_score_14B) if clip_score_14B is not None else None,
        'baseline_1_3B': float(clip_score_1_3B) if clip_score_1_3B is not None else None,
    }
}

report_path = os.path.join(
    CONFIG['output_dir'],
    f"vace_hybrid_profile_{timestamp}.json"
)
with open(report_path, 'w') as f:
    json.dump(report, f, indent=2)

print(f"\n✓ Profile report saved: {report_path}")


PROFILING REPORT

Hybrid Editing Breakdown:
------------------------------------------------------------
  Segment 1 (14B  , 15 steps):
    Duration: 121.43s
    Memory:   116.10 GB (peak)
  Segment 2 (1.3B , 35 steps):
    Duration:  76.05s
    Memory:   109.99 GB (peak)

COMPARISON

Model                      Time   CLIP Score    Quality
------------------------------------------------------------
Hybrid (14B+1.3B)       213.67s      0.4260      0.93%
Baseline 14B            409.03s      0.4221      (ref)
Baseline 1.3B           112.41s      0.4346      2.94%

Speedup vs 14B:     1.91x
Speedup vs 1.3B:    0.53x

Quality Analysis:
  - Hybrid maintains 100.9% of 14B quality
  - Hybrid is 2.0% worse than 1.3B

✓ Profile report saved: /workspace/wan2.1/Wan2.1/outputs/vace_hybrid_profile_20260114_105010.json


## Video Comparison

In [64]:
html_content = f'''
<style>
    .comparison {{
        display: grid;
        grid-template-columns: repeat(2, 1fr);
        gap: 20px;
        max-width: 1800px;
        margin: 0 auto;
    }}
    .video-box {{
        text-align: center;
        padding: 15px;
        border: 2px solid #ddd;
        border-radius: 10px;
        background: #f9f9f9;
    }}
    .video-box h3 {{
        margin-top: 0;
        color: #333;
    }}
    .controls {{
        grid-column: 1 / -1;
        text-align: center;
        padding: 20px;
        background: #e8f4f8;
        border-radius: 10px;
        margin-bottom: 20px;
    }}
    .btn {{
        padding: 12px 30px;
        margin: 5px;
        font-size: 16px;
        background: #4CAF50;
        color: white;
        border: none;
        border-radius: 5px;
        cursor: pointer;
    }}
    .btn:hover {{ background: #45a049; }}
    .info {{
        font-size: 12px;
        color: #666;
        margin-top: 10px;
    }}
    .highlight {{
        border-color: #4CAF50;
        background: #e8f5e9;
    }}
</style>

<div class="comparison">
    <div class="controls">
        <h3>🎬 Synchronized Playback</h3>
        <button class="btn" onclick="playAll()">▶️ Play All</button>
        <button class="btn" onclick="pauseAll()">⏸️ Pause All</button>
        <button class="btn" onclick="restartAll()">⏮️ Restart</button>
    </div>
    
    <div class="video-box">
        <h3>📹 Original</h3>
        <video class="sync-video" width="750" controls>
            <source src="{CONFIG['source_video']}" type="video/mp4">
        </video>
        <div class="info">Source video</div>
    </div>
    
    <div class="video-box highlight">
        <h3>🔀 Hybrid (14B + 1.3B)</h3>
        <video class="sync-video" width="750" controls>
            <source src="{hybrid_path}" type="video/mp4">
        </video>
        <div class="info">
            Time: {hybrid_total_time:.1f}s<br>
            Schedule: {' → '.join([f"{m}({s})" for m, s in SAMPLING_SCHEDULE])}<br>
            {'CLIP: ' + f'{clip_score_hybrid:.4f}' if clip_score_hybrid is not None else ''}
        </div>
    </div>
    
    <div class="video-box">
        <h3>🎯 Baseline 14B</h3>
        <video class="sync-video" width="750" controls>
            <source src="{baseline_14B_path}" type="video/mp4">
        </video>
        <div class="info">
            Time: {baseline_14B_time:.1f}s<br>
            Speedup: {baseline_14B_time/hybrid_total_time:.2f}x slower than hybrid<br>
            {'CLIP: ' + f'{clip_score_14B:.4f} (reference)' if clip_score_14B is not None else ''}
        </div>
    </div>
    
    <div class="video-box">
        <h3>⚡ Baseline 1.3B</h3>
        <video class="sync-video" width="750" controls>
            <source src="{baseline_1_3B_path}" type="video/mp4">
        </video>
        <div class="info">
            Time: {baseline_1_3B_time:.1f}s<br>
            Speedup: {abs(baseline_1_3B_time/hybrid_total_time):.2f}x {'faster' if baseline_1_3B_time < hybrid_total_time else 'slower'} than hybrid<br>
            {'CLIP: ' + f'{clip_score_1_3B:.4f}' if clip_score_1_3B is not None else ''}
        </div>
    </div>
</div>

<script>
function playAll() {{
    document.querySelectorAll('.sync-video').forEach(v => v.play());
}}
function pauseAll() {{
    document.querySelectorAll('.sync-video').forEach(v => v.pause());
}}
function restartAll() {{
    document.querySelectorAll('.sync-video').forEach(v => {{
        v.currentTime = 0;
        v.pause();
    }});
}}
</script>
'''

display(HTML(html_content))

## Tips

### Sampling Schedule Strategies:

1. **LSSSL Pattern (Recommended):**
   - `[('14B', 5), ('1.3B', 40), ('14B', 5)]`
   - Best balance of quality and speed
   - 14B sets structure, 1.3B refines, 14B polishes

2. **LSL Pattern (Faster):**
   - `[('14B', 10), ('1.3B', 30), ('14B', 10)]`
   - More 14B influence
   - Better quality but slower

3. **Heavy 1.3B (Fastest):**
   - `[('14B', 3), ('1.3B', 44), ('14B', 3)]`
   - Maximum speed
   - Slight quality trade-off

### Context Scale:
- **0.2-0.3**: Dramatic changes
- **0.4-0.5**: Strong edits (recommended)
- **0.6-0.7**: Subtle refinements
- **0.8-0.9**: Minimal changes

### Acceleration Tips:

**Cache-DiT (Single GPU):**
```python
# Conservative (minimal quality loss)
CONFIG['enable_cache_dit'] = True
CONFIG['cache_dit_interval'] = 2
CONFIG['cache_dit_start_step'] = 5

# Aggressive (faster, some quality loss)
CONFIG['enable_cache_dit'] = True
CONFIG['cache_dit_interval'] = 4
CONFIG['cache_dit_start_step'] = 3
```

**xDiT (Multi-GPU):**
```python
# 2 GPUs
CONFIG['enable_xdit'] = True
CONFIG['xdit_ulysses_degree'] = 2

# 4 GPUs
CONFIG['enable_xdit'] = True
CONFIG['xdit_ulysses_degree'] = 4

# 8 GPUs with CFG parallel
CONFIG['enable_xdit'] = True
CONFIG['xdit_ulysses_degree'] = 4
CONFIG['xdit_use_cfg_parallel'] = True  # Requires 8 GPUs total
```

**Combined (Best Performance):**
```python
CONFIG['enable_cache_dit'] = True
CONFIG['cache_dit_interval'] = 3
CONFIG['enable_xdit'] = True
CONFIG['xdit_ulysses_degree'] = 4
# Expected: ~6x speedup with 4 GPUs
```

### Important Notes:
- This notebook uses **true latent-space model switching** for seamless hybrid sampling
- Cache-DiT works best with `cache_dit_start_step >= 3` to allow model to establish structure
- xDiT requires models to be initialized with `use_usp=True` for full parallelism
- Monitor CLIP scores to ensure quality is maintained with aggressive acceleration

## Clean Up (Optional)

In [22]:
# Optional: Clean up GPU memory
cleanup = False
if cleanup:
    del vace_14B, vace_1_3B
    torch.cuda.empty_cache()
    print("✓ GPU memory cleared")
else:
    print("Cleanup disabled (set cleanup=True to enable)")

Cleanup disabled (set cleanup=True to enable)
