In [1]:
import torch
from diffusers import StableVideoDiffusionPipeline, AutoencoderKL
from diffusers.utils import load_image, export_to_video
from tqdm import tqdm
from torchvision.transforms import Resize, ToTensor, Compose
from torchvision.utils import save_image
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
import time
import lpips
import numpy as np
from torchvision.models.video import r3d_18, R3D_18_Weights
from scipy import linalg
from PIL import Image
import cv2
from typing import List
from LightCache.LightCache import LightCacher
from fme import FMEWrapper
from skimage.metrics import structural_similarity as ssim

os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [2]:
def preprocess_frames(frames, size=(256, 256)):
    transform = Compose([
        Resize(size),
        ToTensor(),                     # (C, H, W), range [0,1]
        lambda x: x * 2 - 1             # normalize to [-1, 1]
    ])
    return [transform(f) for f in frames]

def measure_block_latency(unet):
    device = 'cuda'
    timings = {}
    start_events = {}
    end_events = {}

    def pre_hook(name):
        def inner(module, input):
            torch.cuda.synchronize()
            evt = torch.cuda.Event(enable_timing=True)
            evt.record()
            start_events[name] = evt
        return inner

    def post_hook(name):
        def inner(module, input, output):
            torch.cuda.synchronize()
            evt = torch.cuda.Event(enable_timing=True)
            evt.record()
            end_events[name] = evt
        return inner

    handles = []
    for name, module in unet.named_modules():
        if any(name == f"{prefix}.{i}" for prefix in ['down_blocks', 'up_blocks'] for i in range(4)) or name == "mid_block":
            handles.append(module.register_forward_pre_hook(pre_hook(name)))
            handles.append(module.register_forward_hook(post_hook(name)))

    def compute_timings():
        for name in start_events:
            if name in end_events:
                elapsed = start_events[name].elapsed_time(end_events[name])  # ms
                if name not in timings:
                    timings[name] = []
                timings[name].append(elapsed)
        return timings, handles

    return compute_timings



def compute_lpips(frames1, frames2):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    lpips_model = lpips.LPIPS(net='alex').to(device)
    scores = []
    frames1_tensor = preprocess_frames(frames1, size=(224, 224))
    frames2_tensor = preprocess_frames(frames2, size=(224, 224))
    for i in range(len(frames1_tensor)):
        f1 = frames1_tensor[i].unsqueeze(0).to(device)
        f2 = frames2_tensor[i].unsqueeze(0).to(device)
        with torch.no_grad():
            score = lpips_model(f1, f2).item()
        scores.append(score)
    return np.mean(scores)

def read_video_as_pil_frames(video_path: str) -> List[Image.Image]:
    cap = cv2.VideoCapture(video_path)
    frames = []

    if not cap.isOpened():
        raise ValueError(f"Failed to open video file: {video_path}")

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        # OpenCV: BGR → RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(frame_rgb)
        frames.append(pil_image)

    cap.release()
    return frames

def register_memory_hooks(model):
    mem_stats = {}

    def pre_hook(module, input):
        torch.cuda.synchronize()
        module._pre_mem = torch.cuda.memory_allocated()

    def post_hook(module, input, output):
        torch.cuda.synchronize()
        post_mem = torch.cuda.memory_allocated()
        delta = (post_mem - getattr(module, '_pre_mem', 0)) / 1024 ** 2
        name = module._get_name()
        if name not in mem_stats:
            mem_stats[name] = []
        mem_stats[name].append(delta)

    handles = []
    for name, module in model.named_modules():
        if any(x in name for x in ['down_blocks', 'up_blocks', 'mid_block']):
            handles.append(module.register_forward_pre_hook(pre_hook))
            handles.append(module.register_forward_hook(post_hook))

    return mem_stats, handles

def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')  
    PIXEL_MAX = 255.0
    return 20 * np.log10(PIXEL_MAX / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=img2.max() - img2.min(), multichannel=True)


In [None]:
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
file_name = "rocket"
image = image.resize((1024, 576))

torch.cuda.empty_cache()
device = 'cuda:3'
# pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", 
#                                                     torch_dtype=torch.float16, variant="fp16").to(device)

vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema", torch_dtype=torch.float16)
pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", vae=vae,
                                                    torch_dtype=torch.float16, variant="fp16").to(device)


generator = torch.Generator(device=device).manual_seed(42) 

In [4]:
# cacher = LightCacher(pipe, 25)
# cacher.set_params(cache_interval=8, cache_branch_id=0)
# cacher.enable(Swap=True, Slice=True, Chunk=True)

# cacher = FMEWrapper(num_temporal_chunk=9, num_spatial_chunk=2, num_frames=25)
# cacher.wrap(pipe)

torch.cuda.reset_peak_memory_stats(device)

In [None]:
# get_timings = measure_block_latency(pipe.unet)

start_time = time.time()
frames = pipe(image, decode_chunk_size=8, generator=generator, num_frames=25).frames[0]
print(time.time() - start_time)
peak_mem_reserved = torch.cuda.max_memory_reserved(device) / 1024 ** 2  # MB

# export_to_video(frames, "generated_videos/SVD_our.mp4", fps=7)

print(f"Peak memory reserved: {peak_mem_reserved:.2f} MB")

origin = read_video_as_pil_frames("generated_videos/rocket_origin.mp4")
print(compute_lpips(frames, origin))