In [None]:
import torch
import lpips
from skimage.metrics import structural_similarity as ssim
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL, AutoencoderKLOutput
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
from typing import Union, Tuple
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, DDIMScheduler
from torchvision.transforms import Resize, ToTensor, Compose
from diffusers.utils import export_to_gif, load_image
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from LightCache.LightCache import LightCacher
from fme import FMEWrapper
from time import time
from PIL import Image, ImageSequence
import numpy as np

In [2]:
def preprocess_frames(frames, size=(256, 256)):
    transform = Compose([
        Resize(size),
        ToTensor(),                     # (C, H, W), range [0,1]
        lambda x: x * 2 - 1             # normalize to [-1, 1]
    ])
    return [transform(f) for f in frames]

def compute_lpips(frames1, frames2):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    lpips_model = lpips.LPIPS(net='alex').to(device)
    scores = []
    frames1_tensor = preprocess_frames(frames1, size=(224, 224))
    frames2_tensor = preprocess_frames(frames2, size=(224, 224))
    for i in range(len(frames1_tensor)):
        f1 = frames1_tensor[i].unsqueeze(0).to(device)
        f2 = frames2_tensor[i].unsqueeze(0).to(device)
        with torch.no_grad():
            score = lpips_model(f1, f2).item()
        scores.append(score)
    return np.mean(scores)

def gif_to_frames(gif_path):
    with Image.open(gif_path) as im:
        frames = [frame.convert("RGB").copy() for frame in ImageSequence.Iterator(im)]
    return frames

def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')  
    PIXEL_MAX = 255.0
    return 20 * np.log10(PIXEL_MAX / np.sqrt(mse))

def compute_ssim(frames1, frames2):
    scores = []
    for img1, img2 in zip(frames1, frames2):
        img1_np = np.array(img1.resize((256, 256))).astype(np.float32)
        img2_np = np.array(img2.resize((256, 256))).astype(np.float32)

        if img1_np.ndim == 3:
            # For RGB images, compute mean SSIM over channels
            ssim_val = 0
            for c in range(3):
                ssim_val += ssim(img1_np[:, :, c], img2_np[:, :, c], data_range=255)
            ssim_val /= 3
        else:
            ssim_val = ssim(img1_np, img2_np, data_range=255)

        scores.append(ssim_val)
    return np.mean(scores)

In [None]:
device = "cuda"
# Load the motion adapter
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
# load SD 1.5 based finetuned model
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", beta_schedule="linear", steps_offset=1)
# pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")


pipe = pipe.to(device)

In [4]:
cacher = LightCacher(pipe, 25)
cacher.set_params(cache_interval=2, cache_branch_id=0)
cacher.enable(Swap=True, Slice=True, Chunk=True)

# cacher = FMEWrapper(num_temporal_chunk=9, num_spatial_chunk=2, num_frames=25)
# cacher.wrap(pipe)
torch.cuda.reset_peak_memory_stats('cuda')

In [None]:
start_time = time()
output = pipe(
    prompt=(
        "masterpiece, bestquality, highlydetailed, ultradetailed, sunset, "
        "orange sky, warm lighting, fishing boats, ocean waves seagulls, "
        "rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, "
        "golden hour, coastal landscape, seaside scenery"
    ),
    negative_prompt="bad quality, worse quality",
    num_frames=25,
    guidance_scale=7.5,
    num_inference_steps=50,
    generator=torch.Generator(device).manual_seed(42),
)
print(time() - start_time)

frames = output.frames[0]
# export_to_gif(frames, "./generated_videos/AnimateDiff_50step_FME.gif")