In [3]:
import torch
import os
import math
import torch.nn.functional as F
from typing import Tuple
from scipy.linalg import sqrtm
import numpy as np

# https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0/blob/b1ca112023d762ebab42c48a0d70254ec95b2e4d/opensora/eval/fvd/styleganv/fvd.py

device=torch.device('cpu')

def load_i3d_pretrained(device=torch.device('cpu')):
    i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt"
    filepath = 'i3d_torchscript.pt'
    print(filepath)
    if not os.path.exists(filepath):
        print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
        os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
    i3d = torch.jit.load(filepath).eval().to(device)
    i3d = torch.nn.DataParallel(i3d)
    return i3d

def get_feats(videos, detector, device, bs=10):
    # videos : torch.tensor BCTHW [0, 1]
    detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer.
    feats = np.empty((0, 400))
    with torch.no_grad():
        for i in range((len(videos)-1)//bs + 1):
            feats = np.vstack([feats, detector(torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device), **detector_kwargs).detach().cpu().numpy()])
    return feats


def get_fvd_feats(videos, i3d, device, bs=10):
    # videos in [0, 1] as torch tensor BCTHW
    # videos = [preprocess_single(video) for video in videos]
    embeddings = get_feats(videos, i3d, device, bs)
    return embeddings

def preprocess_single(video, resolution=224, sequence_length=None):
    # video: CTHW, [0, 1]
    c, t, h, w = video.shape

    # temporal crop
    if sequence_length is not None:
        assert sequence_length <= t
        video = video[:, :sequence_length]

    # scale shorter side to resolution
    scale = resolution / min(h, w)
    if h < w:
        target_size = (resolution, math.ceil(w * scale))
    else:
        target_size = (math.ceil(h * scale), resolution)
    video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False)

    # center crop
    c, t, h, w = video.shape
    w_start = (w - resolution) // 2
    h_start = (h - resolution) // 2
    video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]

    # [0, 1] -> [-1, 1]
    video = (video - 0.5) * 2

    return video.contiguous()


"""
Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py
"""


def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    mu = feats.mean(axis=0) # [d]
    sigma = np.cov(feats, rowvar=False) # [d, d]
    return mu, sigma

def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
    mu_gen, sigma_gen = compute_stats(feats_fake)
    mu_real, sigma_real = compute_stats(feats_real)
    m = np.square(mu_gen - mu_real).sum()
    if feats_fake.shape[0] > 1:
        s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
        fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
    else:
        fid = np.real(m)
    return float(fid)


@torch.no_grad()
def compute_our_fvd(videos_fake: torch.tensor, videos_real: torch.tensor, detector, device: str='cuda') -> float:
    feats_fake = get_fvd_feats(videos_fake, detector, device, bs=2)
    feats_real = get_fvd_feats(videos_real, detector, device, bs=2)

    return frechet_distance(feats_fake, feats_real)

In [50]:
import cv2
import torch
import torchvision.transforms as transforms

def video_to_tensor(filepath, size=(128, 128), len_frame=600, start_frame =0):
    """
    Converts a video file into a PyTorch tensor.

    Args:
    filepath (str): Path to the MP4 video file.
    size (tuple): The desired (height, width) to resize each frame.
    len_frame (int): The number of frames to include in the tensor (video length).

    Returns:
    torch.Tensor: Tensor of shape (1, channels, len_frame, height, width)
    """
    # Initialize a VideoCapture object to read video data from a file
    cap = cv2.VideoCapture(filepath)
    frames = []
    
    # Frame transformation pipeline
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(size),  # Resize each frame
        transforms.ToTensor(),    # Convert image to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Standard normalization
                             std=[0.229, 0.224, 0.225])
    ])
    
    # Read frames until you have enough or the video ends
    count = 0
    while len(frames) < len_frame:
        ret, frame = cap.read()
        if count >= start_frame:
            if not ret:
                print(count)
                break  # Break the loop if there are no frames to read
            # Convert color from BGR to RGB
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # Apply transformations
            frame = transform(frame)
            frames.append(frame)
        count += 1
    cap.release()

    # Stack frames along a new dimension (time dimension)
    if len(frames) > 0:
        video_tensor = torch.stack(frames)
        # Add a batch dimension
        video_tensor = video_tensor.unsqueeze(0)  # Shape: (1, channels, len_frame, height, width)
        return video_tensor
    else:
        return torch.empty(0)


In [4]:
i3d = load_i3d_pretrained()

i3d_torchscript.pt


In [20]:
video = torch.rand(1,3,600, 128,128)
video1 = torch.rand(100,3,600, 128,128)
compute_our_fvd(video,video1,i3d)

1.7644517032287588

Do it on our model inference output

In [5]:
dataset_hard = torch.load('data/mmnist-hard/batch_0_frames.pt', weights_only=True).permute(0,2,1,3,4).repeat(1, 3, 1, 1, 1)
dataset_medium = torch.load('data/mmnist-medium/batch_0_frames.pt', weights_only=True).permute(0,2,1,3,4).repeat(1, 3, 1, 1, 1)

In [None]:
compute_our_fvd(torch.rand(1,3,60, 72,128), dataset_hard, i3d)

In [51]:
filepath = 'video_hard_autoregressive.mp4'
video_tensor = video_to_tensor(filepath, size=(72, 128), len_frame=100, start_frame=50).permute(0,2,1,3,4)
print("Tensor shape:", video_tensor.shape)

100
Tensor shape: torch.Size([1, 3, 50, 72, 128])


In [52]:
compute_our_fvd(video_tensor, dataset_hard, i3d)

14174.336900809723

In [43]:
filepath = 'medium_key.mp4'
video_tensor = video_to_tensor(filepath, size=(72, 128), len_frame=100).permute(0,2,1,3,4)
print("Tensor shape:", video_tensor.shape)

Tensor shape: torch.Size([1, 3, 100, 72, 128])


In [44]:
compute_our_fvd(video_tensor, dataset_medium, i3d)

41003.24953633075

In [41]:
filepath = 'medium_pure_autoregressive.mp4'
video_tensor = video_to_tensor(filepath, size=(72, 128), len_frame=100, start_frame=50).permute(0,2,1,3,4)
print("Tensor shape:", video_tensor.shape)

Tensor shape: torch.Size([1, 3, 100, 72, 128])


In [42]:
compute_our_fvd(video_tensor, dataset_medium, i3d)

21309.833642988368

In [56]:
filepath = 'video.mp4'
video_tensor = video_to_tensor(filepath, size=(72, 128), len_frame=100).permute(0,2,1,3,4)
print("Tensor shape:", video_tensor.shape)

87
Tensor shape: torch.Size([1, 3, 87, 72, 128])


In [57]:
compute_our_fvd(video_tensor, dataset_hard, i3d)

18096.241959897998