In [1]:
import os, sys
sys.path.append("..")

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline 

import numpy as np
import torch
import torch.nn as nn
import torchvision
import gc

from src.tools import freeze, load_dataset, get_Z_pushed_loader_stats
from src.fid_score import calculate_frechet_distance
from src.cunet import CUNet

import json

from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output
from collections import OrderedDict

# This needed to use dataloaders for some datasets
from PIL import PngImagePlugin
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

## FID scores

In [2]:
DEVICE_ID = 3

# DATASET2, DATASET2_PATH = 'handbag', '../../data/handbag_128.hdf5'
DATASET1, DATASET1_PATH = 'dtd', '../../data/dtd/images'
DATASET2, DATASET2_PATH = 'shoes', '../../data/shoes_128.hdf5'

# DATASET1, DATASET1_PATH = 'outdoor', '../../data/outdoor_128.hdf5'
# DATASET2, DATASET2_PATH = 'church', '../../data/church_128.hdf5'

# DATASET1, DATASET1_PATH = 'celeba_female', '../../data/img_align_celeba'
# DATASET2, DATASET2_PATH = 'aligned_anime_faces', '../../data/aligned_anime_faces'

IMG_SIZE = 128
COST = 'energy'

ZC, Z_STD = 128, 1.
    
assert torch.cuda.is_available()
torch.cuda.set_device(f'cuda:{DEVICE_ID}')

In [3]:
assert torch.cuda.is_available()
torch.cuda.set_device(f'cuda:{DEVICE_ID}')

AUGMENTED_DATASETS = ['dtd']
FID_EPOCHS = 50 if DATASET1 in AUGMENTED_DATASETS else 1

In [6]:
filename = 'stats/{}_{}_test.json'.format(DATASET2, IMG_SIZE)
with open(filename, 'r') as fp:
    data_stats = json.load(fp)
    mu_data, sigma_data = data_stats['mu'], data_stats['sigma']
del data_stats

In [7]:
_, X_test_sampler = load_dataset(DATASET1, DATASET1_PATH, img_size=IMG_SIZE, batch_size=256)
# _, Y_test_sampler = load_dataset(DATASET2, DATASET2_PATH, img_size=IMG_SIZE, batch_size=256)
    
T = CUNet(3, 3, ZC, base_factor=48)
pass

In [8]:
folder = os.path.join('../checkpoints', COST, '{}_{}_{}'.format(DATASET1, DATASET2, IMG_SIZE))
model = 'T.pt'
path = os.path.join(folder, model)

T.load_state_dict(torch.load(path))
T.cuda(); freeze(T)
torch.cuda.empty_cache()

In [10]:
torch.manual_seed(0xBADBEEF)
np.random.seed(0xBADBEEF)

num_calculation_fid = 10

fid_values = []

for _ in range(num_calculation_fid):
    mu, sigma = get_Z_pushed_loader_stats(
        T, X_test_sampler.loader, ZC=ZC, Z_STD=Z_STD,
        n_epochs=FID_EPOCHS, use_downloaded_weights=True
    )
    fid = calculate_frechet_distance(mu_data, sigma_data, mu, sigma)
    print(f"FID = {fid}")
    fid_values.append(fid)
fid_values = np.array(fid_values)
fid_mean = np.mean(fid_values)
fid_std = np.std(fid_values)
print("--------")
print(f"Mean FID = {fid_mean}")
print(f"Std FID = {fid_std}")

  x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
  diffY // 2, diffY - diffY // 2])


FID = 24.79571868987776
FID = 24.959808788050623
FID = 24.823207289467234
FID = 24.778168681212236
FID = 24.854378966612302
FID = 24.69220724157134
FID = 24.82760368139742
FID = 24.768911862425966
FID = 24.854055437890963
FID = 25.030413809289456
--------
Mean FID = 24.83844744477953
Std FID = 0.09171019746958688


## Variance-similarity trade-off

In [None]:
DEVICE_ID = 0

DATASET2, DATASET2_PATH = 'shoes', '../../data/shoes_128.hdf5'
DATASET1, DATASET1_PATH = 'dtd', '../../data/dtd/images'

IMG_SIZE = 64
COST = 'energy'

ZC, Z_STD = 128, 1.
    
assert torch.cuda.is_available()
torch.cuda.set_device(f'cuda:{DEVICE_ID}')

In [None]:
_, X_test_sampler = load_dataset(DATASET1, DATASET1_PATH, img_size=IMG_SIZE, batch_size=256)
_, Y_test_sampler = load_dataset(DATASET2, DATASET2_PATH, img_size=IMG_SIZE, batch_size=256)
    
T = CondUNetV2(3, 3, ZC, base_factor=48)
pass

In [None]:
folder = os.path.join('../checkpoints', COST, '{}_{}_{}'.format(DATASET1, DATASET2, IMG_SIZE))
Z_SIZE = 8
X_SIZE = 128
X = X_test_sampler.sample(X_SIZE)[:,None].repeat(1,Z_SIZE,1,1,1)
Z = torch.randn(X_SIZE, Z_SIZE, ZC, 1, 1, device='cuda') * Z_STD

for model in ['T_0.pt', 'T_0.33.pt', 'T_0.66.pt', 'T_1.pt', 'T_1.33.pt']:
    path = os.path.join(folder, model)
    T.load_state_dict(torch.load(path))
    T.cuda(); freeze(T)
    torch.cuda.empty_cache()

    with torch.no_grad():
        T_XZ = T(
            X.flatten(start_dim=0, end_dim=1), Z.flatten(start_dim=0, end_dim=1)
        ).permute(1,2,3,0).reshape(3, IMG_SIZE, IMG_SIZE, -1, Z_SIZE).permute(3,4,0,1,2)
        var = .5 * torch.cdist(T_XZ.flatten(start_dim=2), T_XZ.flatten(start_dim=2)).mean() * Z_SIZE / (Z_SIZE -1)
        dist2 = (X-T_XZ).flatten(start_dim=2).norm(dim=2).mean()
        gamma = float(model.split('_')[1][:-3])
        cost = (.5 * dist2 - .5 * gamma * var) * 0.5
    print(model, gamma)
    # Division by 2 since 1/2 in kernel
    print('Var:', round(var.item() / 2, 2), 'Dist:', round(dist2.item() / 2, 2))
    print('Cost:', round(cost.item(), 2), '\n')