In [None]:
# First cell - Setup and Installation
!pip install gradio torch torchvision matplotlib pandas opencv-python-headless



In [None]:
import gradio as gr
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import cv2
import logging
from torch.utils.data import Dataset, DataLoader

from fid_metrics import (
    build_inception,
    build_inception3d,
    calculate_fid,
    postprocess_i2d_pred
)

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class VideoDataset(Dataset):
    def __init__(self, video_paths, sequence_length=32):
        self.video_paths = video_paths
        self.sequence_length = sequence_length

    def __len__(self):
        return len(self.video_paths)

    def load_video(self, path):
        cap = cv2.VideoCapture(path)
        frames = []
        try:
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(frame)
        finally:
            cap.release()

        if not frames:
            raise ValueError(f"No frames could be read from {path}")

        return np.stack(frames)

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        # Load video frames
        frames = self.load_video(video_path)

        # Sample frames uniformly if needed
        if len(frames) > self.sequence_length:
            indices = np.linspace(0, len(frames)-1, self.sequence_length, dtype=int)
            frames = frames[indices]

        # Convert to tensor and normalize
        frames = torch.FloatTensor(frames) / 255.0
        frames = frames.permute(0, 3, 1, 2)  # THWC -> TCHW
        return frames

def evaluate_videos(real_videos, synthetic_videos):
    """Main evaluation function"""
    try:
        if not real_videos or not synthetic_videos:
            return "Please upload both real and synthetic videos.", None

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        logger.info(f"Using device: {device}")

        # Initialize models
        inception_2d = build_inception(2048)  # For FID-VID
        inception_2d = inception_2d.to(device).eval()

        # Log input information
        logger.info(f"Processing {len(real_videos)} real videos and {len(synthetic_videos)} synthetic videos")

        # Create dataloaders
        real_dataset = VideoDataset([v.name for v in real_videos])
        synth_dataset = VideoDataset([v.name for v in synthetic_videos])

        real_loader = DataLoader(real_dataset, batch_size=1, shuffle=False)
        synth_loader = DataLoader(synth_dataset, batch_size=1, shuffle=False)

        # Extract features
        real_features = []
        synth_features = []

        with torch.no_grad():
            # Process real videos
            for videos in real_loader:
                videos = videos.to(device)
                features = inception_2d(videos.squeeze(0))
                features = postprocess_i2d_pred(features)
                real_features.append(features.cpu().numpy())

            # Process synthetic videos
            for videos in synth_loader:
                videos = videos.to(device)
                features = inception_2d(videos.squeeze(0))
                features = postprocess_i2d_pred(features)
                synth_features.append(features.cpu().numpy())

        # Calculate metrics
        real_features = np.concatenate(real_features, axis=0)
        synth_features = np.concatenate(synth_features, axis=0)

        fid_score = calculate_fid(real_features, synth_features)

        # Create results DataFrame
        results = []
        for i in range(len(real_videos)):
            real_name = real_videos[i].name if hasattr(real_videos[i], 'name') else str(real_videos[i])
            synth_name = synthetic_videos[i].name if hasattr(synthetic_videos[i], 'name') else str(synthetic_videos[i])

            results.append({
                'Pair': f'Pair {i+1}',
                'Real Video': Path(real_name).name,
                'Synthetic Video': Path(synth_name).name,
                'FID Score': float(f"{fid_score:.2f}")
            })

        df = pd.DataFrame(results)

        # Create visualization
        fig = plt.figure(figsize=(10, 6))
        plt.bar(df['Pair'], df['FID Score'])
        plt.title('FID Scores')
        plt.ylabel('Score')
        plt.tight_layout()

        result_text = (
            f"Results Summary:\n\n{df.to_string(index=False)}\n\n"
            f"Average FID Score: {df['FID Score'].mean():.2f}"
        )

        return result_text, fig

    except Exception as e:
        logger.error(f"Error in evaluation: {str(e)}", exc_info=True)
        return f"Error processing videos: {str(e)}", None

# Create interface
demo = gr.Interface(
    fn=evaluate_videos,
    inputs=[
        gr.Files(
            label="Real Videos (Sims4Action)",
            file_count="multiple",
            type="filepath"  # Specify filepath type
        ),
        gr.Files(
            label="Synthetic Videos",
            file_count="multiple",
            type="filepath"  # Specify filepath type
        )
    ],
    outputs=[
        gr.Textbox(label="Results", lines=10),
        gr.Plot(label="Quality Scores")
    ],
    title="Video Quality Evaluation",
    description="Upload pairs of real and synthetic videos to evaluate their quality using FID metrics.",
    flagging_mode="never"
)

# Launch with specific settings for Colab
if __name__ == "__main__":
    demo.queue()
    demo.launch(
        debug=True,
        share=True,
        show_error=True
    )

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://c7f1e16bebb4c95f85.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:01<00:00, 70.0MB/s]


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://c7f1e16bebb4c95f85.gradio.live
