In [None]:
from DCGAN import DCGAN
from fid import calculate_frechet_distance, compute_statistics_of_path, calculate_frechet_distance
from inception import InceptionV3
import os
import numpy as np
import torch

%load_ext autoreload
%autoreload 2

In [None]:
# Number of channels in the training images. For color images this is 3
num_channels = 3

# Size of z latent vector (i.e. size of generator input)
latent_dim = 128

# Size of feature maps in generator
num_generator_features = 128

# Size of feature maps in discriminator
num_discriminator_features = 64

NUM_WORKERS = int(os.cpu_count() - 1)
BATCH_SIZE = 6 

IMAGE_SIZE = 512

DEVICE = "cuda"

dataset_path = "../data/celeba/images"

# dataset_path = "../data/art_dataset/resized/resized/"
dataset_statistics_path = "/".join(dataset_path.split("/")[:-1]) + "/inception_statistics.npz"

In [None]:
model = DCGAN(
    num_channels=num_channels,
    latent_dim=latent_dim,
    num_generator_features=num_discriminator_features,
    num_discriminator_features=num_discriminator_features,
).to(DEVICE)

In [None]:
inception_model = InceptionV3().to(DEVICE)

## Compute statistics for dataset

In [None]:
m1, s1 = compute_statistics_of_path(dataset_path, inception_model, BATCH_SIZE, 2048, DEVICE, 12)
print(m1, s1)


In [None]:

np.savez(dataset_statistics_path, m=m1, s=s1)

In [None]:
dataset_statistics_file = np.load(dataset_statistics_path)
mu_dataset = dataset_statistics_file["m"]
sigma_dataset = dataset_statistics_file["s"]
print(mu_dataset.shape)
print(sigma_dataset.shape)

## Compute statistics for batch of predictions

In [None]:
z = torch.randn(BATCH_SIZE, latent_dim, 1, 1).to(DEVICE)
generated_images = model(z)
print(z.shape)

In [None]:
pred = inception_model(generated_images)[0]

In [None]:
resized_pred = pred.squeeze().detach().cpu().numpy()
pred_mu = np.mean(resized_pred, axis=0)
pred_sigma = np.cov(resized_pred, rowvar=False)

In [None]:
print(pred_mu.shape)
print(pred_sigma.shape)

In [None]:
calculate_frechet_distance(
    mu1=pred_mu, sigma1=pred_sigma, mu2=mu_dataset, sigma2=sigma_dataset
)


In [None]:
from fid import FID
ds_fid = FID("celeba", image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, device=DEVICE)

ds_fid.compute_statistics_of_path()