In [3]:
# Fréchet Inception Distance (FID) Evaluation for Video Frames
# This notebook implements FID evaluation for comparing frames from generated videos against frames from real videos.

import os
import numpy as np
import torch
from torch.nn.functional import adaptive_avg_pool2d
from scipy import linalg
from tqdm.notebook import tqdm
import cv2
from PIL import Image
from torchvision import models, transforms
import glob
from pathlib import Path

# Load InceptionV3 Model for Feature Extraction
class InceptionV3FeatureExtractor():
    def __init__(self, device='cuda'):
        # Load pre-trained InceptionV3 model
        self.model = models.inception_v3(pretrained=True, transform_input=False)
        # Remove the final classification layer
        self.model.fc = torch.nn.Identity()
        # Set to evaluation mode
        self.model.eval()
        self.model = self.model.to(device)
        self.device = device
        
        # Define image preprocessing
        self.preprocess = transforms.Compose([
            transforms.Resize(299),
            transforms.CenterCrop(299),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    
    def extract_features(self, images):
        """Extract features from a batch of images"""
        with torch.no_grad():
            features = self.model(images)
        return features

def calculate_fid(real_features, fake_features):
    """Calculate FID score between real and fake feature distributions"""
    # Calculate mean and covariance for real features
    mu1 = np.mean(real_features, axis=0)
    sigma1 = np.cov(real_features, rowvar=False)
    
    # Calculate mean and covariance for fake features
    mu2 = np.mean(fake_features, axis=0)
    sigma2 = np.cov(fake_features, rowvar=False)
    
    # Calculate FID score
    ssdiff = np.sum((mu1 - mu2) ** 2.0)
    covmean = linalg.sqrtm(sigma1.dot(sigma2))
    
    # Check if covmean contains complex numbers
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

def calculate_fid_for_frame(real_video_path, gen_video_path, frame_idx=0, device='cpu'):
    """
    Calculate FID score for a specific frame between two videos
    
    Args:
        real_video_path: Path to the real video
        gen_video_path: Path to the generated video
        frame_idx: Index of the frame to compare (default: 0, first frame)
        device: Device to run the model on ('cuda' or 'cpu')
        
    Returns:
        FID score for the specified frame
    """
    # Initialize feature extractor
    feature_extractor = InceptionV3FeatureExtractor(device=device)
    
    # Extract the specified frame from each video
    def extract_single_frame(video_path, frame_idx):
        cap = cv2.VideoCapture(video_path)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        cap.release()
        
        if not ret:
            raise ValueError(f"Could not read frame {frame_idx} from {video_path}")
        
        # Convert BGR to RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        return frame
    
    # Get frames
    real_frame = extract_single_frame(real_video_path, frame_idx)
    gen_frame = extract_single_frame(gen_video_path, frame_idx)
    
    # Process frames
    real_pil = Image.fromarray(real_frame)
    gen_pil = Image.fromarray(gen_frame)
    
    real_tensor = feature_extractor.preprocess(real_pil).unsqueeze(0).to(device)
    gen_tensor = feature_extractor.preprocess(gen_pil).unsqueeze(0).to(device)
    
    # Extract features
    with torch.no_grad():
        real_features = feature_extractor.extract_features(real_tensor).cpu().numpy()
        gen_features = feature_extractor.extract_features(gen_tensor).cpu().numpy()
    
    # Calculate FID score
    fid_score = calculate_fid(real_features, gen_features)
    
    return fid_score


In [4]:
real_video = "/proj/aicell/users/x_aleho/video-diffusion/data/processed/idr0013/LT0001_02/00001_01.mp4"
gen_video = "/proj/aicell/users/x_aleho/video-diffusion/CogVideo/test_generations/i2v_eval1_night/LT0004_06-00058_01_withLORA_highPROF.mp4"
fid = calculate_fid_for_frame(real_video, gen_video, frame_idx=5)
print(f"FID score for frame 5: {fid}")

ValueError: Non-matrix input to matrix function.