In [151]:
%load_ext autoreload
%autoreload 2
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [262]:
import numpy as np
import matplotlib.pyplot as plt
import yaml
import glob
import copy
import pandas as pd
import scipy.linalg

# Get the comparison yaml file

In [270]:
split_comparisons = glob.glob('FID/*tune.yaml')
results = {}
for path in split_comparisons:
    with open(path) as f:
        values = yaml.safe_load(f)
        
    for k, v in values.items():
        results[k] = pd.DataFrame(v)

# Find the minimum values of s to use based on the validation set

In [272]:
minimum_values = {}
for name, df in results.items():
    min_val = df.score.min()
    min_s = df[df.score == min_val]['s'].to_numpy()[0]
    
    minimum_values[name] = min_s

# Load the actual FID results

In [273]:
meta_files = glob.glob('FID/*/*.yaml')

meta_datas = {}
for path in meta_files:
    with open(path) as f:
        meta = yaml.safe_load(f)
    name = path[4:-10]
    meta_datas[name] = meta
    
def get_fids(dataset_name):
    files = dict([(key, val) for key, val in meta_datas.items() if dataset_name in key])
    ans = {}
    for key, val in files.items():
        
        df = pd.DataFrame(val['settings'])
        df['iteration_number'] = val['iteration_number']
        df['index'] = df['path'].apply(lambda x: int(x.split('_')[-1]))
        
        if('glow' in key):
            if(df.shape[0] > 1):
                df = df[df['index'] > 0]
        else:
            # Keep the last 15 folders
            max_index = df['index'].max()
            df = df[df['index'] > max_index - 15]
        
        ans[key] = df.reset_index()
    return ans

In [274]:
celeba_scores = get_fids('celeba')
cifar_scores = get_fids('cifar')
mnist_scores = get_fids('mnist')

In [275]:
full_min_scores = {}

for score_set in [celeba_scores, cifar_scores, mnist_scores]:

    dataset_name = list(score_set.keys())[0].split('_')[0]
    try:

        min_scores = {}
        for i, (name, df) in enumerate(score_set.items()):
            if('glow' in name):
                continue

            corresponding_min = minimum_values[name]
            min_index = np.abs(df['s'] - corresponding_min).argmin()
            min_score = df.iloc[min_index].score

            min_scores[name] = min_score

        # Turn into list
        sorted_min_scores = sorted(min_scores.items(), key=lambda x: int(x[0].split('_')[-1]))

        full_min_scores[dataset_name] = [val[1] for val in sorted_min_scores]
    
    except:
        
        full_min_scores[dataset_name] = [-1, -1, -1, -1]
        
    glow_key = [key for key in score_set.keys() if 'glow' in key][0]
    glow_score = score_set[glow_key].iloc[-1].score
    
    full_min_scores[dataset_name].append(glow_score)

In [276]:
df = pd.DataFrame(full_min_scores)
df.index = [64, 128, 256, 512, 'GLOW']
df.columns = ['CelebA', 'CIFAR-10', 'Fashion MNIST']

In [277]:
print(df.to_latex())

\begin{tabular}{lrrr}
\toprule
{} &     CelebA &   CIFAR-10 &  Fashion MNIST \\
\midrule
64   &  30.960336 &  80.150178 &      23.978618 \\
128  &  34.465172 &  79.386961 &      23.233066 \\
256  &  33.950003 &  78.440733 &      24.842087 \\
512  &  35.960198 &  77.479778 &      25.342172 \\
GLOW &  63.071708 &  78.581376 &      42.775308 \\
\bottomrule
\end{tabular}



In [290]:
path_64 = glob.glob('Results/*/cifar_64_test_embeddings.npz')[0]
with np.load(path_64) as data:
    z_64, y_64, u_64 = data['z'], data['y'], data['u']

In [291]:
path_glow = glob.glob('Results/*/cifar_glow_test_embeddings.npz')[0]
with np.load(path_glow) as data:
    z_glow, y_glow, u_glow = data['z'], data['y'], data['u']

In [301]:
glow_data = np.hstack([u_glow, y_glow[:,None], np.zeros_like(y_glow)[:,None]])

In [302]:
nif_data = np.hstack([u_64, y_64[:,None], np.ones_like(y_64)[:,None]])

In [306]:
df = pd.DataFrame(np.vstack([nif_data, glow_data]), columns=['x', 'y', 'category', 'algorithm'])

In [310]:
df = df.astype({'category':int})

In [311]:
df

Unnamed: 0,x,y,category,algorithm
0,-2.097250,0.171886,3,1.0
1,4.408762,1.097425,8,1.0
2,1.581044,-1.343657,5,1.0
3,-1.425326,-3.071453,4,1.0
4,0.488655,5.978664,0,1.0
...,...,...,...,...
47995,2.124771,12.533604,8,0.0
47996,1.876477,13.597491,3,0.0
47997,1.073510,10.969111,5,0.0
47998,2.032319,13.578838,1,0.0


In [316]:
df['algorithm'] = df['algorithm'].map({1.0: 'nif', 0.0: 'nf'})

In [318]:
df.to_csv('embeddings_data')