-
Notifications
You must be signed in to change notification settings - Fork 4
/
evaluate_FVD.py
93 lines (74 loc) · 3.02 KB
/
evaluate_FVD.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import argparse
import os
import numpy as np
import scipy.linalg
import torch
from tqdm import tqdm
from vidm.dataset import VideoFolder
def compute_fvd(feats_fake: np.ndarray, feats_real: np.ndarray):
mu_gen, sigma_gen = compute_stats(feats_fake)
mu_real, sigma_real = compute_stats(feats_real)
m = np.square(mu_gen - mu_real).sum()
s, _ = scipy.linalg.sqrtm(
np.dot(sigma_gen, sigma_real), disp=False
) # pylint: disable=no-member
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
return float(fid)
def compute_stats(feats: np.ndarray):
mu = feats.mean(axis=0) # [d]
sigma = np.cov(feats, rowvar=False) # [d, d]
return mu, sigma
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-dir1", "--dir1", type=str, default="./gts/")
parser.add_argument("-dir2", "--dir2", type=str, default="./preds/")
parser.add_argument("-b", "--batch", type=int, default=32)
parser.add_argument("-r", "--resolution", type=int, default=256)
parser.add_argument("-n", "--nframes", type=int, default=128)
parser.add_argument("-ns", "--nsamples", type=int, default=2048)
parser.add_argument("-i3d", "--i3d", type=str, default=2048)
opt = parser.parse_args()
device = "cuda:0"
batch_size = opt.batch
resolution = opt.resolution
nframes = opt.nframes
nsamples = opt.nsamples
detector = torch.jit.load(opt.i3d).eval().to(device)
gt_dataset = VideoFolder(path=opt.dir1, nframes=nframes, size=resolution)
gt_dataset_iter = iter(
torch.utils.data.DataLoader(
gt_dataset, num_workers=8, batch_size=batch_size, shuffle=False
)
)
pred_dataset = VideoFolder(path=opt.dir2, nframes=nframes, size=resolution)
pred_dataset_iter = iter(
torch.utils.data.DataLoader(
pred_dataset, num_workers=8, batch_size=batch_size, shuffle=False
)
)
print(f"loading videos with number of {len(gt_dataset)} and {len(pred_dataset)}")
feats_real = []
for i in tqdm(range(nsamples // batch_size)):
video = next(gt_dataset_iter).to(device) # b,n,h,w,c => b,c,n,h,w
video = video.permute(0, 4, 1, 2, 3).contiguous()
with torch.no_grad():
micro_feats_real = (
detector(video, rescale=False, resize=True, return_features=True)
.cpu()
.numpy()
)
feats_real.append(micro_feats_real)
feats_real = np.concatenate(feats_real, axis=0)
feats_fake = []
for i in tqdm(range(nsamples // batch_size)):
video = next(gt_dataset_iter).to(device) # b,n,h,w,c => b,c,n,h,w
video = video.permute(0, 4, 1, 2, 3).contiguous()
with torch.no_grad():
micro_feats_real = (
detector(video, rescale=False, resize=True, return_features=True)
.cpu()
.numpy()
)
feats_fake.append(micro_feats_real)
feats_fake = np.concatenate(feats_fake, axis=0)
print(compute_fvd(feats_fake, feats_real))