# 🎥 WAN 2.1 Video Generation Demo

This notebook demonstrates how to use the integrated WAN 2.1 text-to-video generation system with our batch processing framework.

## Features Demonstrated:
- Loading WAN 2.1 model components
- Configuring the pipeline for optimal performance
- Generating videos from text prompts
- Integrating with our project's batch processing system
- Exporting and displaying results

## Requirements:
- CUDA-capable GPU (recommended)
- WAN model dependencies installed
- At least 16GB VRAM for high-quality generation

In [None]:
# Import Required Libraries
import torch
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from diffusers.utils import export_to_video
from matplotlib import pyplot as plt
from tqdm import tqdm
import shutil
import os
import hashlib
from enum import Enum
import cv2
import numpy as np
import imageio
import json
from pathlib import Path
import time

# Import our project modules
from src.config import ConfigManager, GenerationConfig
from src.generators import WAN13BVideoGenerator
from src.utils import FileManager, ProgressTracker

print("✅ All libraries imported successfully")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🎮 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🚀 GPU: {torch.cuda.get_device_name()}")
    print(f"💾 VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Set Model Configuration
model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"

# Configuration settings
config = {
    "model_id": model_id,
    "device": "cuda:0" if torch.cuda.is_available() else "cpu",
    "torch_dtype": torch.bfloat16,
    "vae_dtype": torch.float32,
    
    # Video settings - choose resolution based on your VRAM
    "width": 512,      # Use 1280 for 720P (requires more VRAM)
    "height": 512,     # Use 720 for 720P
    "fps": 24,
    "duration": 4.0,   # seconds
    
    # Generation settings
    "num_inference_steps": 50,
    "guidance_scale": 7.5,
    "seed": 42
}

# Flow shift: 5.0 for 720P, 3.0 for 480P
flow_shift = 5.0 if max(config["width"], config["height"]) >= 720 else 3.0
config["flow_shift"] = flow_shift

print("🎛️  Model Configuration:")
for key, value in config.items():
    print(f"   {key}: {value}")
    
print(f"\n📺 Video will be: {config['width']}x{config['height']} at {config['fps']} FPS")
print(f"⏱️  Duration: {config['duration']} seconds ({int(config['fps'] * config['duration'])} frames)")

In [None]:
# Load Wan Model Components
print("🔄 Loading WAN model components...")

# Load VAE (Video AutoEncoder)
print("   📦 Loading VAE...")
vae = AutoencoderKLWan.from_pretrained(
    config["model_id"], 
    subfolder="vae", 
    torch_dtype=config["vae_dtype"]
)
print("   ✅ VAE loaded successfully")

# Create scheduler with flow prediction
print("   ⚙️  Creating scheduler...")
scheduler = UniPCMultistepScheduler(
    prediction_type='flow_prediction',
    use_flow_sigmas=True,
    num_train_timesteps=1000,
    flow_shift=config["flow_shift"]
)
print(f"   ✅ UniPCMultistepScheduler created (flow_shift={config['flow_shift']})")

print("🎉 Model components loaded successfully!")

In [None]:
# Configure Pipeline Settings
print("🚀 Initializing WAN pipeline...")

# Load the complete pipeline
pipe = WanPipeline.from_pretrained(
    config["model_id"],
    vae=vae,
    torch_dtype=config["torch_dtype"]
)

# Set our custom scheduler
pipe.scheduler = scheduler

# Move to GPU if available
if torch.cuda.is_available():
    print(f"   🎮 Moving pipeline to {config['device']}")
    pipe.to(config["device"])
    
    # Enable memory efficient attention if available
    try:
        pipe.enable_xformers_memory_efficient_attention()
        print("   ⚡ XFormers memory efficient attention enabled")
    except:
        print("   ⚠️  XFormers not available, using default attention")
        
    # Enable CPU offload for large models if needed
    # pipe.enable_sequential_cpu_offload()  # Uncomment if you have limited VRAM
else:
    print("   ⚠️  CUDA not available, using CPU (very slow)")

print("✅ Pipeline configured and ready for generation!")

# Print memory usage if on CUDA
if torch.cuda.is_available():
    memory_allocated = torch.cuda.memory_allocated() / 1e9
    memory_reserved = torch.cuda.memory_reserved() / 1e9
    print(f"💾 GPU Memory - Allocated: {memory_allocated:.1f}GB, Reserved: {memory_reserved:.1f}GB")

In [None]:
# Generate Video from Text Prompt
prompt = "a romantic kiss between two people at sunset"
print(f"🎬 Generating video for prompt: '{prompt}'")

# Create generator for reproducible results
generator = torch.Generator(device=config["device"]).manual_seed(config["seed"])

# Calculate number of frames
num_frames = int(config["fps"] * config["duration"])
print(f"📊 Generation parameters:")
print(f"   Resolution: {config['width']}x{config['height']}")
print(f"   Frames: {num_frames} ({config['duration']}s at {config['fps']} FPS)")
print(f"   Steps: {config['num_inference_steps']}")
print(f"   Guidance: {config['guidance_scale']}")
print(f"   Seed: {config['seed']}")

# Start generation
start_time = time.time()
print(f"\n⏳ Starting generation... (this may take several minutes)")

with torch.no_grad():
    video_frames = pipe(
        prompt=prompt,
        width=config["width"],
        height=config["height"],
        num_frames=num_frames,
        num_inference_steps=config["num_inference_steps"],
        guidance_scale=config["guidance_scale"],
        generator=generator
    ).frames[0]

generation_time = time.time() - start_time
print(f"✅ Video generated successfully!")
print(f"⏱️  Generation time: {generation_time:.1f} seconds")
print(f"📦 Generated {len(video_frames)} frames")

# Clear some GPU memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    memory_allocated = torch.cuda.memory_allocated() / 1e9
    print(f"💾 GPU Memory after generation: {memory_allocated:.1f}GB")

In [None]:
# Save and Export Video Output
print("💾 Saving video to file...")

# Create output directory structure using our project's file manager
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_dir = Path(f"outputs/notebook_generation_{timestamp}")
output_dir.mkdir(parents=True, exist_ok=True)

# Export video
video_path = output_dir / "generated_video.mp4"
export_to_video(video_frames, str(video_path), fps=config["fps"])

print(f"🎥 Video saved to: {video_path}")
print(f"📏 File size: {video_path.stat().st_size / 1e6:.1f} MB")

# Save generation metadata
metadata = {
    "prompt": prompt,
    "timestamp": timestamp,
    "generation_time": generation_time,
    "config": config,
    "model_info": {
        "model_id": config["model_id"],
        "pipeline_type": "WanPipeline",
        "scheduler": "UniPCMultistepScheduler"
    },
    "video_info": {
        "path": str(video_path),
        "frames": len(video_frames),
        "duration": config["duration"],
        "fps": config["fps"],
        "resolution": f"{config['width']}x{config['height']}"
    }
}

metadata_path = output_dir / "metadata.json"
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2, default=str)

print(f"📋 Metadata saved to: {metadata_path}")

# Also save the prompt for easy reference
prompt_path = output_dir / "prompt.txt"
with open(prompt_path, 'w') as f:
    f.write(prompt)

print(f"📝 Prompt saved to: {prompt_path}")
print(f"\n📁 Complete output directory: {output_dir}")

In [None]:
# Display Generated Video
print("🖼️  Displaying video frames...")

# Create a grid to show some frames
num_display_frames = min(8, len(video_frames))
frames_to_show = [video_frames[i] for i in range(0, len(video_frames), len(video_frames) // num_display_frames)][:num_display_frames]

# Setup matplotlib figure
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i, frame in enumerate(frames_to_show):
    if i < len(axes):
        # Convert frame to numpy array if needed
        frame_array = np.array(frame)
        axes[i].imshow(frame_array)
        axes[i].set_title(f"Frame {i * len(video_frames) // num_display_frames}")
        axes[i].axis('off')

# Hide unused subplots
for i in range(len(frames_to_show), len(axes)):
    axes[i].axis('off')

plt.suptitle(f"Generated Video Frames - '{prompt}'", fontsize=14)
plt.tight_layout()
plt.show()

print(f"📺 Showing {len(frames_to_show)} sample frames from the {len(video_frames)}-frame video")
print(f"🎞️  Full video saved as: {video_path.name}")

# Display video info
print(f"\n📊 Video Statistics:")
print(f"   • Resolution: {config['width']} x {config['height']} pixels")
print(f"   • Duration: {config['duration']} seconds")
print(f"   • Frame Rate: {config['fps']} FPS") 
print(f"   • Total Frames: {len(video_frames)}")
print(f"   • Generation Time: {generation_time:.1f} seconds")
print(f"   • Time per Frame: {generation_time/len(video_frames):.2f} seconds")

# Note about playing the video
print(f"\n💡 To play the video:")
print(f"   • Open: {video_path}")
print(f"   • Or use: python -c 'import cv2; cv2.VideoCapture(\"{video_path}\").read()'")
print(f"   • Or any video player that supports MP4")

In [None]:
# Integration with Batch Processing System
print("🔄 Demonstrating integration with our batch processing system...")

# Load our project configuration
config_manager = ConfigManager()

try:
    # Load the default configuration
    project_config = config_manager.load_config("configs/default.yaml")
    print("✅ Project configuration loaded successfully")
    
    # Test the WAN generator from our system
    wan_generator = WAN13BVideoGenerator(project_config)
    
    if wan_generator.is_real_model():
        print("🎉 Real WAN model is being used by our system!")
    else:
        print("⚠️  System is using mock implementation (WAN dependencies not available)")
    
    print(f"🔧 System configuration:")
    print(f"   Model ID: {project_config.model_settings.model_id}")
    print(f"   Sampler: {project_config.model_settings.sampler}")
    print(f"   Resolution: {project_config.video_settings.width}x{project_config.video_settings.height}")
    print(f"   Videos per variation: {project_config.videos_per_variation}")
    
except Exception as e:
    print(f"⚠️  Error loading project config: {e}")

print(f"\n💡 To use our batch processing system:")
print(f'   python main.py --template "a [romantic|passionate] kiss" --videos-per-variation 2')
print(f'   python main.py --config configs/high_quality.yaml --template "your prompt template"')
print(f'   python main.py --preview --template "test [this|that] template"')

In [None]:
# Demonstrate Prompt Variations (Quick Example)
print("🎭 Quick demonstration of prompt variations...")

# Example prompts with variations
prompt_templates = [
    "a [gentle|passionate] kiss between [two people|a couple]",
    "a [cute|playful] [cat|dog] in a [garden|house]",
    "a [dramatic|serene] [sunset|sunrise] over [mountains|ocean]"
]

from src.prompts import PromptTemplate

for i, template in enumerate(prompt_templates):
    print(f"\n📝 Template {i+1}: {template}")
    
    prompt_template = PromptTemplate(template)
    variations = prompt_template.generate_variations()
    
    print(f"   Generates {len(variations)} variations:")
    for j, variation in enumerate(variations[:4]):  # Show first 4
        print(f"     {j+1}. {variation.text}")
    
    if len(variations) > 4:
        print(f"     ... and {len(variations) - 4} more")

print(f"\n🚀 To generate videos for all variations:")
print(f'   python main.py --template "a [gentle|passionate] kiss" --videos-per-variation 2')
print(f"   This would create {len(PromptTemplate('a [gentle|passionate] kiss').generate_variations())} × 2 = {len(PromptTemplate('a [gentle|passionate] kiss').generate_variations()) * 2} videos total!")

## 🎉 Demo Complete!

### What we accomplished:
- ✅ **Loaded WAN 2.1 model** with proper VAE and scheduler configuration
- ✅ **Generated a video** from text prompt using the real model
- ✅ **Saved video and metadata** with proper organization
- ✅ **Integrated with our batch system** for scalable processing
- ✅ **Demonstrated prompt variations** for systematic content generation

### 🚀 Next Steps:

1. **Install Dependencies:**
   ```bash
   pip install -r requirements.txt
   ```

2. **Run Batch Generation:**
   ```bash
   # Preview what will be generated
   python main.py --preview --template "a [romantic|dramatic] scene with [two people|a couple]"
   
   # Generate with default settings
   python main.py --template "your template here"
   
   # High-quality generation
   python main.py --config configs/high_quality.yaml --template "your template"
   ```

3. **Customize Settings:**
   - Edit `configs/default.yaml` for your preferred settings
   - Adjust resolution based on your VRAM (512x512 vs 1280x720)
   - Modify generation parameters (steps, guidance, etc.)

### 💡 Tips for Best Results:

- **Memory Management:** Use 512x512 for lower VRAM, 720P for high-end GPUs
- **Quality vs Speed:** More steps = better quality but slower generation
- **Batch Processing:** Use our system for generating many variations efficiently
- **Organization:** All outputs are automatically organized with metadata

### 🔧 System Integration:
The WAN model is now fully integrated into our batch processing system. You can use either:
- **This notebook** for interactive single-video generation
- **Command line tool** for automated batch processing with prompt variations

Both approaches use the same underlying WAN model implementation!