In [None]:
import os
import sys
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
import pickle
import pathlib
from tqdm.notebook import tqdm
%matplotlib inline

In [None]:
img = Image.open('/gan-clouds/ffhq/ffhq/image_02220_psi1.png').convert('RGB')

In [None]:
img

In [None]:
img = Image.open('/gan-clouds/ffhq/stylegan2/image_02262_psi07.png').convert('RGB')

In [None]:
img

In [None]:
def read_cloud(path, substr = None):
    path = pathlib.Path(path)
    files = sorted([file for ext in ['png']
                           for file in path.glob('*.{}'.format(ext))])

    X = []

    for afile in tqdm(files):
        if not substr or afile.name.find(substr) > 0:
            
            img = Image.open(afile).convert('RGB')
            t = transforms.ToTensor()(img).flatten()
            X.append(t.numpy())
                
    return np.array(X)

In [None]:
X = read_cloud('/gan-clouds/ffhq/ffhq')

In [None]:
S1_07 = read_cloud('/gan-clouds/ffhq/stylegan1', 'psi07')

In [None]:
S1_1 = read_cloud('/gan-clouds/ffhq/stylegan1', 'psi1')

In [None]:
S2_07 = read_cloud('/gan-clouds/ffhq/stylegan2', 'psi07')

In [None]:
S2_1 = read_cloud('/gan-clouds/ffhq/stylegan2', 'psi1')

In [None]:
clouds = [X, S1_07, S1_1, S2_07, S2_1]

In [None]:
for cloud in clouds:
    print(cloud.shape)

In [None]:
import mtd

In [None]:
%time
res_ffhq_large1 = []
trials = 20
cuda = 0

for i, cloud in enumerate(clouds):
    if i == 0:
        continue
        
    np.random.seed(7)
    barcs = [mtd.calc_cross_barcodes(clouds[i], clouds[0], batch_size1 = 1000, batch_size2 = 10000,\
                                    cuda = cuda) for _ in range(trials)]
    res_ffhq_large1.append(barcs)

In [None]:
%%time
res_ffhq_large2 = []
trials = 20
cuda = 0

for i, cloud in enumerate(clouds):
    if i == 0:
        continue
    
    np.random.seed(7)
    barcs = [mtd.calc_cross_barcodes(clouds[0], clouds[i], batch_size1 = 1000, batch_size2 = 10000,\
                                   cuda = cuda) for _ in range(trials)]
    res_ffhq_large2.append(barcs)

In [None]:
def get_scores(res, args_dict, trials = 10):

    scores = []

    for i in range(len(res)): 
        asum = []
        
        for exp_id, elem in enumerate(res[i]):
            asum.append(mtd.get_score(elem, **args_dict))
            
        print(asum)

        scores.append(sum(asum) / len(res[i]))

    return scores

In [None]:
scores = get_scores(res_ffhq_large2, {'h_idx' : 1, 'kind' : 'sum_length'})

In [None]:
descriptions = ['StyleGan1_psi07', 'StyleGan1_psi1', 'StyleGan2_psi07', 'StyleGan2_psi1']

In [None]:
for s, d in zip(scores, descriptions):
    print(d, s)

In [None]:
# additional experiment with IMD
from msid import msid_score
res_imd = [0] * len(clouds)

for i, cloud in tqdm(list(enumerate(clouds))):
    if i == 0:
        continue
    
    indices = list(range(len(clouds[0])))
    np.random.seed(7)
    np.random.shuffle(indices)
    rnd_idx = indices[0:2000]

    v = msid_score(clouds[0][rnd_idx], clouds[i][rnd_idx])
    res_imd[i] = v