In [None]:
import os
import sys
import torch
from PIL import Image
import numpy as np
import subprocess
import matplotlib.pyplot as plt
from math import sqrt
from scipy import stats
import pickle
import pathlib
from tqdm.notebook import tqdm
%matplotlib inline

In [None]:
datasets = torch.load('/gan-clouds/datasets.pt')

In [None]:
datasets.keys()

In [None]:
clouds = {'dcgan': torch.load('/gan-clouds/DCGAN-2.pt'),
          'lsgan' : torch.load('/gan-clouds/LSGAN-2.pt'),
          'rel': torch.load('/gan-clouds/Relativistic-2.pt'),
          'wgan': torch.load('/gan-clouds/WGAN.pt'),
          'wgan-gp': torch.load('/gan-clouds/WGAN-GP.pt')}

In [None]:
datasets['cifar10'].shape

In [None]:
A = clouds['dcgan']['cifar10']

In [None]:
A.shape

In [None]:
import mtd

In [None]:
for idx, elem in enumerate(datasets['cifar10']):
    data = elem.reshape((3, 32, 32)).astype('uint8')
    data = np.transpose(data, (1, 2, 0))
    png_data = Image.fromarray(data)
    if idx == 10:
        break

In [None]:
png_data

In [None]:
for idx, elem in enumerate(clouds['wgan']['cifar10']):
    data = elem.reshape((3, 32, 32)).astype('uint8')
    data = np.transpose(data, (1, 2, 0))
    png_data = Image.fromarray(data)
    if idx == 5:
        break

In [None]:
png_data

In [None]:
for idx, elem in enumerate(datasets['mnist']):
    data = elem.reshape((32, 32)).astype('uint8')
    data_new = np.zeros((3, 32, 32), 'uint8')

    for i in range(3):
        data_new[i, :, :] = data
        
    data = data_new
    
    data = np.transpose(data, (1, 2, 0))
    png_data = Image.fromarray(data)
    if idx == 2:
        break

In [None]:
png_data

### Calculate FID

In [None]:
def write_dir(adir, cloud, shape = (3, 32, 32), copy_channels = False):
    os.system('rm -rf %s' % adir)
    os.mkdir(adir)
    
    global gd
    
    for idx, elem in enumerate(cloud):
        data = elem.reshape(shape).astype('uint8')
        
        if copy_channels:
            data_new = np.zeros((3, 32, 32), 'uint8')

            for i in range(3):
                data_new[i, :, :] = data
        
            data = data_new
        
        #print(data.shape)
        
        data = np.transpose(data, (1, 2, 0))
        png_data = Image.fromarray(data)
        
        #png_data
        
        path = '%s/%d.png' % (adir, idx)
        png_data.save(path)

In [None]:
for d in datasets.keys():
    
    print()
    print(d)
    print()
    
    for g in clouds.keys():
        if d in ['mnist', 'fashion_mnist']:
            shape = (32, 32)
            copy_channels = True
        if d in ['cifar10', 'svhn']:
            shape = (3, 32, 32)
            copy_channels = False
            
        if g not in ['wgan', 'wgan-gp']:
            continue
        
        write_dir('tmp1', datasets[d], shape, copy_channels = copy_channels)
        write_dir('tmp2', clouds[g][d], shape, copy_channels = copy_channels)

        cmd = 'pytorch-fid tmp1 tmp2 --device cuda:0'
        res_str = subprocess.run(cmd.split(' '), capture_output=True, text=True).stdout

        print(g, res_str)

### MTopDiv

In [None]:
res1 = {}
trials = 200

for d in datasets.keys():
    base_cloud = datasets[d]
    for g in clouds.keys():
        mod_cloud = clouds[g][d]

        np.random.seed(7)
        barcs = [mtd.calc_cross_barcodes(mod_cloud, base_cloud, batch_size1 = 100, batch_size2 = 1000, cuda = 3) for _ in range(trials)]
        
        res1[(d, g)] = barcs

In [None]:
res2 = {}
trials = 200

for d in datasets.keys():
    base_cloud = datasets[d]
    for g in clouds.keys():
        mod_cloud = clouds[g][d]

        np.random.seed(7)
        barcs = [mtd.calc_cross_barcodes(base_cloud, mod_cloud, batch_size1 = 100, batch_size2 = 1000, cuda = 3) for _ in range(trials)]
        
        res2[(d, g)] = barcs

In [None]:
def get_scores(res, args_dict, trials = 10):

    scores = {}

    for k in sorted(res.keys()):
        asum = []
        
        for exp_id, elem in enumerate(res[k]):
            asum.append(mtd.get_score(elem, **args_dict))

        scores[k] = np.mean(asum), np.std(asum) / sqrt(len(asum))
        
    return scores

In [None]:
for d in datasets.keys():
    print(d)

In [None]:
scores = get_scores(res2, {'h_idx' : 1, 'kind' : 'sum_length'})

In [None]:
for d in ['cifar10']:
    for g in clouds.keys():

        k = (d, g)
        
        sys.stdout.write(str(scores[k][0]).replace('.', ',') + ' ')
        
    sys.stdout.write('\n')

In [None]:
scores

In [None]:
# additional experiments with IMD

from msid import msid_score

res = {}

for d in tqdm(datasets.keys()):
    base_cloud = datasets[d]
    for g in clouds.keys():
        mod_cloud = clouds[g][d]

        indices = list(range(len(base_cloud)))
        np.random.seed(7)
        np.random.shuffle(indices)
        rnd_idx = indices[0:5000]
        
        res[(d, g)] = msid_score(base_cloud[rnd_idx], mod_cloud[rnd_idx])

In [None]:
print('CHECKING CORRECTNESS')
print('--------------------')
print('cifar10', res[('cifar10', 'wgan')] < res[('cifar10', 'wgan-gp')])
print('svhn', res[('svhn', 'wgan')] < res[('svhn', 'wgan-gp')])
print('mnist', res[('mnist', 'wgan')] > res[('mnist', 'wgan-gp')])
print('fashion_mnist', res[('fashion_mnist', 'wgan')] > res[('fashion_mnist', 'wgan-gp')])