In [1]:
from DCGAN import DCGAN
from fid import calculate_frechet_distance, compute_statistics_of_path, calculate_frechet_distance
from inception import InceptionV3
import os
import numpy as np
import torch

In [2]:
# Number of channels in the training images. For color images this is 3
num_channels = 3

# Size of z latent vector (i.e. size of generator input)
latent_dim = 128

# Size of feature maps in generator
num_generator_features = 128

# Size of feature maps in discriminator
num_discriminator_features = 64

NUM_WORKERS = int(os.cpu_count() - 1)
BATCH_SIZE = 6 

IMAGE_SIZE = 512

device = "cpu"

dataset_path = "../data/celeba/img_align_celeba/img_align_celeba/"

dataset_path = "../data/art_dataset/resized/resized/"
dataset_statistics_path = "/".join(dataset_path.split("/")[:-1]) + "/inception_statistics.npz"

In [3]:
model = DCGAN(
    num_channels=num_channels,
    latent_dim=latent_dim,
    num_generator_features=num_discriminator_features,
    num_discriminator_features=num_discriminator_features,
).to(device)

In [4]:
inception_model = InceptionV3().to(device)

## Compute statistics for dataset

In [5]:
m1, s1 = compute_statistics_of_path(dataset_path, inception_model, BATCH_SIZE, 2048, device, 12)
print(m1, s1)


100%|██████████| 1448/1448 [05:33<00:00,  4.34it/s]


[0.25800682 0.30712518 0.2384513  ... 0.39552785 0.35121119 0.48443139] [[ 0.03824288 -0.00230245  0.00187763 ...  0.00401214 -0.00159032
   0.00057394]
 [-0.00230245  0.0691257  -0.00308235 ... -0.0012949   0.01224022
  -0.00482495]
 [ 0.00187763 -0.00308235  0.04732785 ... -0.00113145  0.00652731
   0.00723025]
 ...
 [ 0.00401214 -0.0012949  -0.00113145 ...  0.14086064  0.00867903
   0.01277959]
 [-0.00159032  0.01224022  0.00652731 ...  0.00867903  0.09965251
   0.00404461]
 [ 0.00057394 -0.00482495  0.00723025 ...  0.01277959  0.00404461
   0.16002851]]


In [6]:

np.savez(dataset_statistics_path, m=m1, s=s1)

In [7]:
dataset_statistics_file = np.load(dataset_statistics_path)
m = dataset_statistics_file["m"]
s = dataset_statistics_file["s"]
print(m.shape)
print(s.shape)

(2048,)
(2048, 2048)


## Compute statistics for batch of predictions

In [8]:
z = torch.randn(BATCH_SIZE, latent_dim, 1, 1).to(device)
generated_images = model(z)
print(z.shape)

torch.Size([6, 128, 1, 1])


In [9]:
pred = inception_model(generated_images)[0]

In [10]:
resized_pred = pred.squeeze().detach().cpu().numpy()
mu2 = np.mean(resized_pred, axis=0)
sigma = np.cov(resized_pred, rowvar=False)

In [11]:
print(mu2.shape)
print(sigma.shape)

(2048,)
(2048, 2048)


In [12]:
calculate_frechet_distance(mu2, sigma, m, s)

545.1638588031259