In [None]:
# install and upgrade all dependencies
!pip install --upgrade diffusers transformers ftfy accelerate imageio lpips torch torchvision decord git+https://github.com/openai/CLIP.git

In [None]:
# import all required libraries
import cv2
import lpips
from torchvision import transforms
import imageio
import gradio as gr
import torch
import subprocess
import tqdm
import tqdm.auto
import matplotlib.pyplot as plt
import time
import numpy as np
from PIL import Image
import clip

In [None]:
# global variables for accessing video and frames
video_path = "output_video.mp4"
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
output_frames_path = "output_frames.mp4"
denoise_steps = [0]*50

In [None]:
original_tqdm = tqdm.auto.tqdm
class TqdmSpy(original_tqdm):
    def update(self, n=1):
        super().update(n)
        global denoise_steps
        denoise_steps[self.n-1] = time.time()

In [None]:
import diffusers
diffusers.utils.tqdm = TqdmSpy
tqdm.tqdm = TqdmSpy
tqdm.auto.tqdm = TqdmSpy

In [None]:
import decord
decord.bridge.set_bridge("torch")
from decord import VideoReader

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.is_available(), torch.cuda.get_device_name(0)

In [None]:
# load model for CLIP Score
model, preprocess = clip.load("ViT-B/32", device=device)

In [None]:
# LPIPS loss function and transform definition
loss_fn = lpips.LPIPS(net='vgg')

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
# load model while recording time to load the model
start_time = time.time()
pipe = diffusers.DiffusionPipeline.from_pretrained(model_id, device_map="balanced")
end_time = time.time()
load_time=end_time-start_time

In [None]:
# helper function for encoding prompt
def patch_encode_prompt(pipe):
    text_encoder_device = pipe.text_encoder.get_input_embeddings().weight.device
    orig_encode_prompt = pipe.encode_prompt

    def patched_encode_prompt(*args, **kwargs):
        if "device" not in kwargs or kwargs["device"] is None:
            kwargs["device"] = text_encoder_device
        return orig_encode_prompt(*args, **kwargs)

    pipe.encode_prompt = patched_encode_prompt

In [None]:
# helper function for extracting frames from video for CLIP score
def extract_frames(num_frames=8):
    vr = VideoReader(video_path)
    total_frames = len(vr)
    indices = torch.linspace(0, total_frames - 1, steps=num_frames).long()
    batch = vr.get_batch(indices).asnumpy()  # shape: (T, H, W, C)

    return [Image.fromarray(frame) for frame in batch]

In [None]:
# helper function for extracting frames from video for LPIPS score
def extract_frames_imageio(video_path):
    reader = imageio.get_reader(video_path)
    frames = []
    for frame in reader:
        frames.append(Image.fromarray(frame))
    reader.close()
    return frames

In [None]:
#helper function to plot graph of denoising
def denoise_graph():
    global denoise_steps
    dns2 = [denoise_steps[i] - denoise_steps[i-1] for i in range(1, len(denoise_steps))]
    plt.plot(list(range(len(dns2))), dns2)
    plt.xlabel('Denoising step')
    plt.ylabel('Time taken')
    plt.title('Denoising')
    graph = plt.gcf()
    return graph

In [None]:
# helper function for computing CLIP score
def compute_clip_score(frames, text):
    text_token = clip.tokenize([text]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_token).float()
        frame_features = []
        for frame in frames:
            image_input = preprocess(frame).unsqueeze(0).to(device)
            image_feature = model.encode_image(image_input).float()
            frame_features.append(image_feature)
        frame_features = torch.stack(frame_features).squeeze(1)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        frame_features = frame_features / frame_features.norm(dim=-1, keepdim=True)
        similarities = (frame_features @ text_features.T).squeeze()
        return similarities.mean().item()

In [None]:
# helper function for computing LPIPS score
def compute_temporal_lpips(frames):
    if not frames or len(frames) < 2:
        return float("nan")

    scores = []

    for i in range(len(frames) - 1):
        try:
            img1 = transform(frames[i]).unsqueeze(0)
            img2 = transform(frames[i + 1]).unsqueeze(0)

            with torch.no_grad():
                dist = loss_fn(img1, img2)
                score = dist.item()
                scores.append(score)

        except Exception as e:
            continue

    if not scores:
        return float("nan")

    avg_score = np.mean(scores)
    return avg_score

In [None]:
# helper function to get model load time
def data_load_time():
  return f"{load_time:.3f} s"

In [None]:
# helper functions to get GPU stats
def get_gpu_stats():
    try:
        result = subprocess.check_output([
            "nvidia-smi",
            "--query-gpu=memory.used,memory.free,utilization.gpu,temperature.gpu,power.draw",
            "--format=csv,nounits,noheader"
        ], encoding='utf-8')
        memory_used, memory_free, utilization, temp, power = result.strip().split(', ')
        return {
            "memory_used": f"{memory_used} MB",
            "memory_free": f"{memory_free} MB",
            "utilization": f"{utilization} %",
            "temperature": f"{temp} °C",
            "power": f"{power} W"
        }
    except Exception as e:
        return {
            "memory_used": "Error",
            "memory_free": "Error",
            "utilization": "Error",
            "temperature": "Error",
            "power": "Error"
        }
def get_gpu_info_only():
    stats = get_gpu_stats()
    return (
        stats["memory_used"],
        stats["memory_free"],
        stats["utilization"],
        stats["temperature"],
        stats["power"]
    )

In [None]:
def generate_video(prompt, negative_prompt="Blurry, unrealistic, shaky", frames=60, fps=12, resolution=480):


    global denoise_steps
    denoise_steps = [0] * 50

    # Adjust frame count as per WAN's requirement
    frames = 4 * frames + 1

    # Set height and width based on resolution
    height = resolution
    width = 832  # default

    if height == 240:
        width = 416
    elif height == 720:
        width = 1248
    elif height == 1080:
        width = 1872

    # Set seed for reproducibility
    torch.manual_seed(42)
    generator = torch.Generator().manual_seed(42)

    # Measure generation time
    start_time = time.time()

    # Generate video frames
    output = pipe(
        prompt,
        negative_prompt=negative_prompt,
        height=height,
        width=width,
        num_frames=frames,
        guidance_scale=5.0,
        generator=generator
    )

    end_time = time.time()
    total_latency = end_time - start_time
    frame_latency = total_latency / frames
    Throughput = frames / total_latency

    video = output.frames

    if isinstance(video, np.ndarray):
        video = np.squeeze(video)
        video = (video * 255).clip(0, 255).astype("uint8")
    else:
        raise TypeError("Unexpected output format from pipeline")

    # Convert frames to PIL Images
    frame_images = [Image.fromarray(frame) for frame in video]

    # Export videos
    diffusers.utils.export_to_video(frame_images, video_path, fps=fps)
    diffusers.utils.export_to_video(frame_images, output_frames_path, fps=1)

    framess = extract_frames(num_frames=frames)
    score = compute_clip_score(framess, prompt)


    return video_path, f"{total_latency:.3f} s", f"{frame_latency:.3f} s", f"{score:.3f}", f"{Throughput:.3f} fps", denoise_graph()

In [None]:
patch_encode_prompt(pipe)

In [None]:
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

In [None]:
themeeeeee = gr.themes.Base(
    primary_hue="indigo",
    secondary_hue="gray",
    radius_size=gr.themes.Size(
        xxs="6px", xs="6px", sm="8px", md="10px", lg="12px", xl="14px", xxl="16px"
    ),
    spacing_size=gr.themes.Size(
        xxs="2px", xs="4px", sm="6px", md="10px", lg="16px", xl="24px", xxl="32px"
    )
).set(
    body_background_fill="linear-gradient(135deg, #0f2027, #203a43, #2c5364)",
    body_text_color="white",
    block_background_fill="rgba(255, 255, 255, 0.08)",
    block_border_color="rgba(255, 255, 255, 0.2)",
    block_shadow="0 12px 40px rgba(0, 0, 0, 0.4)",
    input_background_fill="rgba(255, 255, 255, 0.1)",
    input_border_color="rgba(255, 255, 255, 0.2)",
    button_primary_background_fill="rgba(99, 102, 241, 0.85)",
    button_primary_text_color="white",
    button_primary_background_fill_hover="rgba(99, 102, 241, 1)"
)

css_reset = """
<style>
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
body, #root, .gradio-container {
    font-family: 'Inter', sans-serif !important;
    background: linear-gradient(135deg, #0f2027, #203a43, #2c5364);
    perspective: 1000px;
    overflow-x: hidden;
    animation: fadeIn 1s ease-in-out;
}
#root { transform-style: preserve-3d; }
.gradio-container > * {
    /*transform: rotateX(1deg) rotateY(-2deg);*/
    transition: transform 0.4s ease, box-shadow 0.4s ease;
    backdrop-filter: blur(14px);
    -webkit-backdrop-filter: blur(14px);
    border-radius: 20px;
    background: rgba(255, 255, 255, 0.05);
    box-shadow: 0 12px 40px rgba(0, 0, 0, 0.3);
    border: 1px solid rgba(255, 255, 255, 0.15);
    margin: 30px auto;
    padding: 25px;
    width: 85% !important;
    box-sizing: border-box;
    color: white !important;
}
.gr-button { font-weight: bold; border-radius: 12px !important; box-shadow: 0 4px 16px rgba(0,0,0,0.3); }
.gr-button:hover { transform: scale(1.05); }
.gr-slider input[type="range"] { accent-color: #6366f1; }
@keyframes fadeIn { from { opacity: 0; transform: translateY(20px); } to { opacity: 1; transform: translateY(0); } }

/* Accordion container */
.gr-accordion {
    background-color: #00ffff !important;   /* Cyan background */
    color: #ffffff !important;             /* White text */
    border-radius: 8px;
    border: 1px solid #444;
    padding: 4px;
}

/* Accordion header */
.gr-accordion .prose {
    color: #1e1e2f !important;             /* Dark heading text */
    font-weight: bold;
}
</style>
"""

with gr.Blocks(theme=themeeeeee) as demo:
    gr.HTML(css_reset)
    gr.Markdown("# Accelerate-WAN")
    gr.Markdown("<hr>")

    with gr.Group():
        with gr.Row():
            gr.Markdown("## <div style='text-align:center; padding:15px;'>Metrics</div>")
        with gr.Row():
            with gr.Accordion("Hardware", open=False):
                with gr.Row():
                    mem_used = gr.Textbox(label="Memory used", interactive=False)
                    mem_free = gr.Textbox(label="Memory free", interactive=False)
                    gpu_util = gr.Textbox(label="GPU utilization", interactive=False)
                    temp = gr.Textbox(label="Temperature", interactive=False)
                    powe = gr.Textbox(label="Power draw", interactive=False)
            with gr.Accordion("Efficiency", open=True):
                with gr.Row():
                    clip_latency_box = gr.Textbox(label="Clip-wise latency", interactive=False)
                    frame_latency_box = gr.Textbox(label="Frame-wise latency", interactive=False)
                    throughput_box = gr.Textbox(label="Throughput", interactive=False)
                    dngraphoutput = gr.Plot(label="Time vs Denoising-Steps Graph")
        with gr.Row():
            with gr.Accordion("Accuracy", open=False):
                with gr.Row():
                    lpips_score_box = gr.Textbox(label="LPIPS score", interactive=False)
                    clip_score_box = gr.Textbox(label="CLIP score", interactive=False)
            with gr.Accordion("Others", open=False):
                with gr.Row():
                    gr.Textbox(label="Compile time", interactive=False)
                    lolu = gr.Textbox(label="Load time", interactive=False)
                    gr.Textbox(label="Batch processing efficiency", interactive=False)

    with gr.Row():
        with gr.Group():
            gr.Markdown("## <div style='text-align:center; padding:15px;'>Parameters</div>")
            res = gr.Radio(choices=[240, 480, 720, 1080], value=480, label="Output resolution", interactive=True)
            with gr.Row():
                prompt = gr.Textbox(placeholder="e.g. A cat walking on moon", label="Prompt")
                nprompt = gr.Textbox(value = negative_prompt, label="Negative prompt")
            with gr.Row():
                fps = gr.Slider(minimum=1, maximum=120, label="FPS", interactive=True, value=12)
                frames = gr.Slider(minimum=1, maximum=480, label="Number of frames", interactive=True, value=61)
            opt = gr.CheckboxGroup(choices = ['Flash Attention', 'Operator Fusion', 'CFG Parallelism', 'LoRA', 'Quantization', 'Best'], value = 'Best', label = 'Optimization techniques')

        with gr.Group():
            gr.Markdown("## <div style='text-align:center; padding:15px;'>Output</div>")
            output = gr.Video(label="Generated video")

    generate = gr.Button("Generate")

    # Timers for GPU and load time monitoring
    timer = gr.Timer()
    timer.tick(fn=get_gpu_info_only, inputs=[], outputs=[mem_used, mem_free, gpu_util, temp, powe])
    timer.tick(fn=data_load_time, inputs=[], outputs=[lolu])

    # Wrapper with LPIPS integration
    def wrapper(prompt, nprompt, frames, fps, res):
        video_path, clip_latency, frame_latency, clip_scoree, throughputt, dngraph = generate_video(prompt, nprompt, frames, fps, res)
        extracted_frames = extract_frames_imageio(video_path)
        lpips_score = compute_temporal_lpips(extracted_frames)
        lpips_display = f"{lpips_score:.3f}" if not np.isnan(lpips_score) else "N/A"
        return video_path, clip_latency, frame_latency, lpips_display, clip_scoree, throughputt, dngraph

    generate.click(
        fn=wrapper,
        inputs=[prompt, nprompt, frames, fps, res],
        outputs=[output, clip_latency_box, frame_latency_box, lpips_score_box, clip_score_box, throughput_box, dngraphoutput]
    )

demo.launch(debug=True)