In [1]:
import torch
import lpips
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL, AutoencoderKLOutput
from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
from typing import Union, Tuple
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
from torchvision.transforms import Resize, ToTensor, Compose
from diffusers.utils import export_to_gif, load_image
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from LightCache.LightCache import LightCacher
from fme import FMEWrapper
from time import time
from PIL import Image, ImageSequence
import numpy as np
from skimage.metrics import structural_similarity as ssim

In [2]:
def preprocess_frames(frames, size=(256, 256)):
    transform = Compose([
        Resize(size),
        ToTensor(),                     # (C, H, W), range [0,1]
        lambda x: x * 2 - 1             # normalize to [-1, 1]
    ])
    return [transform(f) for f in frames]

def compute_lpips(frames1, frames2):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    lpips_model = lpips.LPIPS(net='alex').to(device)
    scores = []
    frames1_tensor = preprocess_frames(frames1, size=(224, 224))
    frames2_tensor = preprocess_frames(frames2, size=(224, 224))
    for i in range(len(frames1_tensor)):
        f1 = frames1_tensor[i].unsqueeze(0).to(device)
        f2 = frames2_tensor[i].unsqueeze(0).to(device)
        with torch.no_grad():
            score = lpips_model(f1, f2).item()
        scores.append(score)
    print(scores)
    return np.mean(scores)

def gif_to_frames(gif_path):
    with Image.open(gif_path) as im:
        frames = [frame.convert("RGB").copy() for frame in ImageSequence.Iterator(im)]
    return frames

def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')  
    PIXEL_MAX = 255.0
    return 20 * np.log10(PIXEL_MAX / np.sqrt(mse))

def compute_ssim(frames1, frames2):
    scores = []
    for img1, img2 in zip(frames1, frames2):
        img1_np = np.array(img1.resize((256, 256))).astype(np.float32)
        img2_np = np.array(img2.resize((256, 256))).astype(np.float32)

        if img1_np.ndim == 3:
            # For RGB images, compute mean SSIM over channels
            ssim_val = 0
            for c in range(3):
                ssim_val += ssim(img1_np[:, :, c], img2_np[:, :, c], data_range=255)
            ssim_val /= 3
        else:
            ssim_val = ssim(img1_np, img2_np, data_range=255)

        scores.append(ssim_val)
    print(scores)
    return np.mean(scores)

In [None]:
device = "cuda"
dtype = torch.float16

step = 8  # Options: [1,2,4,8]
repo = "ByteDance/AnimateDiff-Lightning"
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
base = "emilianJR/epiCRealism"  # Choose to your favorite base model.

adapter = MotionAdapter().to(device, dtype)
adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device))
pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")

In [4]:
cacher = LightCacher(pipe, 25)
cacher.set_params(cache_interval=2, cache_branch_id=0)
cacher.enable(Swap=False, Slice=False, Chunk=False)

# cacher = FMEWrapper(num_temporal_chunk=9, num_spatial_chunk=2, num_frames=25)
# cacher.wrap(pipe)
torch.cuda.reset_peak_memory_stats(device)

In [None]:
seed = 42
generator = torch.Generator(device=device).manual_seed(seed)

start_time = time()
output = pipe(prompt="A girl smiling", guidance_scale=1.0, 
              num_inference_steps=step, num_frames=25, generator=generator)

print(time() - start_time)
# export_to_gif(output.frames[0], "AnimateDiff_Baseline.gif")
# peak_mem_alloc = torch.cuda.max_memory_allocated(device) / 1024 ** 2  # MB
peak_mem_reserved = torch.cuda.max_memory_reserved(device) / 1024 ** 2  # MB

# print(f"Peak memory allocated: {peak_mem_alloc:.2f} MB")
print(f"Peak memory reserved: {peak_mem_reserved:.2f} MB")