In [None]:
# Install required dependencies
!pip install torch torchvision diffusers transformers accelerate imageio imageio-ffmpeg ipywidgets -q

In [None]:
import os
import pathlib

# Set custom cache directory for model downloads
# For Lightning AI Studio, use a path in the teamspace
cache_dir = "/teamspace/studios/this_studio/models"
pathlib.Path(cache_dir).mkdir(parents=True, exist_ok=True)

# Set environment variables BEFORE importing huggingface libraries
os.environ["HF_HOME"] = cache_dir
os.environ["HF_HUB_CACHE"] = os.path.join(cache_dir, "hub")
os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(cache_dir, "hub")

# Create hub subdirectory
pathlib.Path(os.path.join(cache_dir, "hub")).mkdir(parents=True, exist_ok=True)

import torch
import ipywidgets as widgets
from IPython.display import display, HTML, Video
from diffusers import WanPipeline
from diffusers.utils import export_to_video
import gc
from datetime import datetime

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Model cache directory: {cache_dir}")

In [None]:
# Configuration Constants
MODEL_ID = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"  # Wan 2.1 Text-to-Video 1.3B model
OUTPUT_DIR = "generated_videos"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Style prompts mapping
STYLE_PROMPTS = {
    "cinematic": "cinematic lighting, film grain, dramatic shadows, professional cinematography",
    "cartoon": "cartoon style, animated, vibrant colors, hand-drawn aesthetic",
    "realistic": "photorealistic, 4K quality, natural lighting, detailed textures",
    "anime": "anime style, Japanese animation, cel shading, expressive",
    "vintage": "vintage film, retro aesthetic, faded colors, nostalgic"
}

# Camera angle prompts mapping
CAMERA_PROMPTS = {
    "front view": "front view, facing camera",
    "side view": "side view, profile shot",
    "top-down": "top-down view, bird's eye perspective, overhead shot",
    "low angle": "low angle shot, looking up, dramatic perspective",
    "close-up": "close-up shot, detailed focus",
    "wide shot": "wide shot, establishing shot, full scene view"
}

print("Configuration loaded successfully!")

In [None]:
# Global pipeline variable
pipe = None

def load_pipeline():
    """Load the Wan 2.1 T2V 1.3B pipeline."""
    global pipe
    if pipe is not None:
        return pipe
    print("Loading Wan 2.1 T2V-1.3B model...")
    
    # Load the full pipeline directly (includes VAE, transformer, scheduler, etc.)
    pipe = WanPipeline.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16
    )
    
    # Enable memory optimizations
    pipe.enable_model_cpu_offload()
    
    print("Model loaded successfully!")
    return pipe

In [None]:
# Download and load the model (run this cell to download ~20GB model files)
# This will download to the HF_HOME directory set above
print("Downloading and loading model... This may take a while on first run.")
pipe = load_pipeline()
print("Model ready for video generation!")

In [None]:
def construct_prompt(subject_action: str, style: str, camera_angle: str) -> str:
    """Construct an enhanced prompt from user inputs."""
    style_desc = STYLE_PROMPTS.get(style, style)
    camera_desc = CAMERA_PROMPTS.get(camera_angle, camera_angle)
    full_prompt = f"{subject_action}, {camera_desc}, {style_desc}, high quality, smooth motion"
    return full_prompt

def duration_to_frames(duration_seconds: float, fps: int = 16) -> int:
    """Convert duration in seconds to number of frames."""
    target_frames = int(duration_seconds * fps)
    # Extended valid frame counts for longer videos (up to ~10s)
    valid_frames = [17, 25, 33, 41, 49, 57, 65, 81, 97, 113, 129, 145, 161]
    return min(valid_frames, key=lambda x: abs(x - target_frames))

def generate_video(subject_action: str, style: str, camera_angle: str, 
                   duration: float, seed: int = -1) -> tuple:
    """Generate video from text prompt."""
    
    pipe = load_pipeline()
    
    # Construct prompt
    prompt = construct_prompt(subject_action, style, camera_angle)
    print(f"Generated prompt: {prompt}")
    
    # Calculate frames
    num_frames = duration_to_frames(duration)
    print(f"Generating {num_frames} frames (~{num_frames/16:.1f}s at 16fps)...")
    
    # Set seed
    if seed == -1:
        seed = torch.randint(0, 2**32, (1,)).item()
    generator = torch.Generator(device="cpu").manual_seed(seed)
    print(f"Using seed: {seed}")
    
    negative_prompt = "blurry, low quality, distorted, deformed, static, no motion"
    
    # Generate video
    output = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=480, width=832,
        num_frames=num_frames,
        num_inference_steps=30,
        guidance_scale=5.0,
        generator=generator
    )
    
    # Save video
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_path = os.path.join(OUTPUT_DIR, f"video_{timestamp}.mp4")
    export_to_video(output.frames[0], output_path, fps=16)
    
    # Cleanup
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print(f"Video saved to: {output_path}")
    return output_path, prompt, seed

In [None]:
# Create Interactive UI Widgets
prompt_input = widgets.Textarea(
    value="A cat playing with a ball in a sunny garden",
    placeholder="Enter subject + action (e.g., 'A dog running in a park')",
    description="Prompt:",
    layout=widgets.Layout(width="100%", height="80px")
)

style_dropdown = widgets.Dropdown(
    options=["cinematic", "cartoon", "realistic", "anime", "vintage"],
    value="cinematic",
    description="Style:"
)

camera_dropdown = widgets.Dropdown(
    options=["front view", "side view", "top-down", "low angle", "close-up", "wide shot"],
    value="front view",
    description="Camera:"
)

duration_slider = widgets.FloatSlider(
    value=2.0, min=1.0, max=10.0, step=0.5,
    description="Duration (s):",
    continuous_update=False
)

seed_input = widgets.IntText(value=-1, description="Seed:", tooltip="-1 for random seed")

generate_btn = widgets.Button(description="Generate Video", button_style="success", icon="play")
output_area = widgets.Output()

def on_generate_click(b):
    with output_area:
        output_area.clear_output()
        print("Starting Text-to-Video generation...")
        try:
            video_path, final_prompt, used_seed = generate_video(
                prompt_input.value,
                style_dropdown.value,
                camera_dropdown.value,
                duration_slider.value,
                seed_input.value
            )
            print("\n" + "="*50)
            print("Generation Complete!")
            print(f"Seed used: {used_seed}")
            display(Video(video_path, embed=True, width=640))
        except Exception as e:
            print(f"Error: {e}")

generate_btn.on_click(on_generate_click)

# Display UI
ui = widgets.VBox([
    widgets.HTML("<h3>ðŸŽ¬ Text-to-Video Generator</h3>"),
    prompt_input,
    widgets.HBox([style_dropdown, camera_dropdown]),
    widgets.HBox([duration_slider, seed_input]),
    generate_btn,
    output_area
])
display(ui)

In [None]:
# Example: Direct function call for quick testing
# Uncomment to run directly without UI

# video_path, prompt, seed = generate_video(
#     subject_action="A bird flying over mountains",
#     style="cinematic",
#     camera_angle="wide shot",
#     duration=2.0,
#     seed=42
# )
# display(Video(video_path, embed=True, width=640))

## Qualitative Evaluation

After generating videos, evaluate them based on:

1. **Prompt Adherence**: Does the video show the described subject and action?
2. **Style Accuracy**: Is the requested style (cinematic, cartoon, etc.) visible?
3. **Camera Angle**: Does the perspective match the requested camera angle?
4. **Motion Quality**: Is the motion smooth and natural?
5. **Visual Quality**: Overall clarity and coherence of the generated video

In [None]:
def evaluate_video(video_path: str, original_inputs: dict) -> dict:
    """Simple evaluation template for generated videos."""
    print("="*50)
    print("QUALITATIVE EVALUATION FORM")
    print("="*50)
    print(f"\nVideo: {video_path}")
    print(f"\nOriginal Inputs:")
    for key, value in original_inputs.items():
        print(f"  - {key}: {value}")
    
    print("\nEvaluation Criteria (Rate 1-5):")
    criteria = [
        "Prompt Adherence (subject/action accuracy)",
        "Style Accuracy (visual style match)", 
        "Camera Angle (perspective correctness)",
        "Motion Quality (smoothness/naturalness)",
        "Visual Quality (clarity/coherence)"
    ]
    
    for i, c in enumerate(criteria, 1):
        print(f"  {i}. {c}: ___/5")
    
    print("\nNotes:")
    print("  _" * 30)
    
# Example evaluation
# evaluate_video("generated_videos/video_xxx.mp4", {
#     "prompt": "A cat playing with a ball",
#     "style": "cartoon",
#     "camera": "front view",
#     "duration": 2.0
# })