In [1]:
%load_ext autoreload
%autoreload 2

In [25]:
import os
import shutil
import io
import math
import random
from pathlib import Path
import cv2
import ffmpeg
import lmdb
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.transforms import Resize
from torch_fidelity import calculate_metrics

from src.vqgan import ViTVQ


In [38]:
def get_total_frames(video_path):
    probe = ffmpeg.probe(video_path)
    video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
    return int(video_stream['nb_frames'])


fids = []
folder_dir = "/mnt/e/kinetics-dataset/k400"
for idx, path in enumerate(Path(folder_dir, "train").glob("*.mp4")):
	if idx == 100: break
	total_frames = get_total_frames(path)
	idx = random.randint(0, total_frames-1)
	for idx_, frame_index in enumerate([idx, idx+1]):
		_dir = "./output/input_images" if idx_ == 0 else "./output/output_images"
		# print(f'{_dir}/frame_{frame_index}.png')
		output_image = f'{_dir}/frame_{frame_index}.png'  # Output image file name
		(
			ffmpeg
			.input(path, ss=frame_index / 30)  # Assuming 30 FPS, adjust as needed
			.output(output_image, vframes=1)  # Extract a single frame
			.run(overwrite_output=True, quiet=True)
		)
	
	fid_result = calculate_metrics(
		input1=f"./output/input_images",
		input2=f"./output/output_images",
		cuda=torch.cuda.is_available(),
		isc=False,
		fid=True,
		verbose=False,
	)

	fid_score = fid_result["frechet_inception_distance"]
	print(f"FID Score: {fid_score}")

	# Clean up
	shutil.rmtree("./output/input_images")
	shutil.rmtree("./output/output_images")
	os.mkdir("./output/input_images")
	os.mkdir("./output/output_images")
	
	fids.append(fid_score)

np.mean(fids)

FID Score: 1.6920183618293607
FID Score: 111.68620848857971
FID Score: 24.92244784762694
FID Score: 1.129539171322589
FID Score: 0.4978480515451713
FID Score: 40.24957821209687
FID Score: 75.22133409061364
FID Score: 27.4493283899948
FID Score: 5.606919316043317
FID Score: 72.0803183746257
FID Score: 43.46672184743729
FID Score: 6.300716210478402
FID Score: 93.90158080626392
FID Score: 102.68643434968688
FID Score: 64.93260229877615
FID Score: 32.0216591556872
FID Score: 101.99337378213842
FID Score: 48.537427671247784
FID Score: 0.10090043351340777
FID Score: 56.4116368507035
FID Score: 13.059054317047407
FID Score: 14.837862343628998
FID Score: 11.544577547609233
FID Score: 13.977263631425718
FID Score: 50.001555400376944
FID Score: 28.325839084731793
FID Score: 27.56756362867404
FID Score: 0.7764108812768193
FID Score: 36.14546080787416
FID Score: 13.75409468283855
FID Score: 2.2025702148518187
FID Score: 15.43604259742405
FID Score: 13.315330947099886
FID Score: 0.0
FID Score: 6.73

np.float64(29.422054633119068)

In [39]:
np.median(fids), np.min(fids), np.max(fids)

(np.float64(19.54708203799235),
 np.float64(0.0),
 np.float64(118.43841000994973))

In [40]:
fids = np.array(fids)
z_scores = (fids - np.mean(fids)) / np.std(fids)
filtered_data = fids[np.abs(z_scores) < 1]  # Keep data points with Z-score < 3
mean_without_outliers = np.mean(filtered_data)

print(mean_without_outliers)

17.030669157368827


In [None]:

save_images([image_resized], "./output/input_images")
save_images([output_image], "./output/output_images")

fid_result = calculate_metrics(
	input1="./output/input_images",
	input2="./output/output_images",
	cuda=torch.cuda.is_available(),
	isc=False,
	fid=True,
)

fid_score = fid_result["frechet_inception_distance"]
print(f"FID Score: {fid_score}")

# Plot the input and output images with FID score
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(image_resized)
axes[0].set_title("Input Image")
axes[0].axis("off")

axes[1].imshow(output_image)
axes[1].set_title(f"Output Image\nFID Score: {fid_score:.2f}")
axes[1].axis("off")
plt.savefig("imgs/base.png")
plt.show()

# Clean up
shutil.rmtree("./output/input_images")
shutil.rmtree("./output/output_images")