In [None]:
import os
import sys

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [None]:
from sklearn.metrics.pairwise import pairwise_distances
import numpy as np
import matplotlib.pyplot as plt
import sklearn
from scipy import stats
import pickle
from tqdm.notebook import tqdm
%matplotlib inline

In [None]:
import barcodes
from cka import cka

In [None]:
from importlib import reload
reload(barcodes)

In [None]:
# test barcodes

import numpy as np
import barcodes

np.random.seed(7)
P = np.random.rand(100, 10)
Q = np.random.rand(100, 10)

In [None]:
barc = barcodes.calc_embed_dist(P, Q, pdist_device = 'cuda:0', verbose = True, fast = True)

In [None]:
data_full = pickle.load(open('/nas-bench-nlp-release/embeds90.pickle', 'rb'))

In [None]:
print(len(data_full))

In [None]:
for suffix, test_loss, embeds in data_full:
    print(suffix, test_loss)

In [None]:
data = data_full

In [None]:
data[0][2].shape

In [None]:
N_tokens = data[0][2].shape[0]
N_trials = 10
batch = 100

In [None]:
res1 = {}
res_cka = {}
barcs = {}

for idx1 in tqdm(range(len(data))):
    for idx2 in range(len(data)):
        
        h1sum_part = 0
        cka_part = 0
        
        a_full = data[idx1][2].detach().numpy()
        b_full = data[idx2][2].detach().numpy()
        
        for trial in range(N_trials):
            cnt += 1
            rnd_slice = np.random.permutation(range(N_tokens))[0:batch]

            a = a_full[rnd_slice]
            b = b_full[rnd_slice]
            
            barc = barcodes.calc_embed_dist(a, b, norm = 'quantile', fast = True, verbose = True)
            #barcodes.plot_barcodes(barcodes.barc2array(barc), title = '')
            #plt.show()
            
            barcs[(idx1, idx2, trial)] = barc

            h1sum_part += barcodes.h1sum(barc)
            cka_part += cka(a, b)
            
        res1[(idx1, idx2)] = h1sum_part / N_trials
        res_cka[(idx1, idx2)] = cka_part / N_trials

### Check relative std

In [None]:
data_std = []

for idx1 in range(len(data)):
    for idx2 in range(len(data)):
        if idx1 != idx2:
            trials_barc = [barcs[(idx1, idx2, t)] for t in range(N_trials)]
            trials_rtd = [barcodes.h1sum(barc) for barc in trials_barc]
            
            data_std.append(np.std(trials_rtd) / np.mean(trials_rtd) / pow(N_trials, 0.5))

In [None]:
np.mean(data_std)

In [None]:
#pickle.dump((res1, res_cka), open('exp_nas-bench-nlp.pickle', 'wb'))
(res1, res_cka) = pickle.load(open('exp_nas-bench-nlp.pickle', 'rb'))

In [None]:
sim = np.zeros((len(data), len(data)))

In [None]:
from math import log

In [None]:
for idx1 in range(len(data)):
    for idx2 in range(len(data)):
        sim[idx1, idx2] = pow(res1[(idx1, idx2)] + res1[(idx2, idx1)], 1.0)
        #sim[idx1, idx2] = 20 - res_cka[(idx1, idx2)] - res_cka[(idx2, idx1)] 

### Check triangle inequality violations

In [None]:
cnt = 0
cnt_err = 0

for idx1 in range(len(data)):
    for idx2 in range(idx1 + 1, len(data)):
        for idx3 in range(idx2 + 1, len(data)):
            if idx1 == idx2 or idx1 == idx3 or idx2 == idx3:
                continue
            
            a = sim[idx1, idx2]
            b = sim[idx1, idx3]
            c = sim[idx2, idx3]       
            
            cnt += 1
            
            if a + b < c or a + c < b or b + c < a:
                cnt_err += 1

In [None]:
1 - cnt_err / cnt

In [None]:
print(__doc__)

from collections import OrderedDict
from functools import partial
from time import time

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import NullFormatter

from sklearn import manifold, datasets

# Next line to silence pyflakes. This import is needed.
Axes3D

n_points = 1000
n_neighbors = 5
n_components = 2

# Set-up manifold methods

methods = OrderedDict()
methods['MDS'] = manifold.MDS(n_components, max_iter=100, n_init=100, dissimilarity="precomputed")
methods['t-SNE'] = manifold.TSNE(n_components, metric="precomputed")
#methods['UMAP'] = umap.UMAP(n_neighbors = n_neighbors, min_dist = 0.1, metric = 'precomputed')

Y_ALL = {}

# Plot results
for i, (label, method) in enumerate(methods.items()):
    t0 = time()
    Y_ALL[i] = method.fit_transform(sim / 100)

In [None]:
#pickle.dump(Y_ALL, open('exp_nas-bench-nlp-mds.pickle', 'wb'))
#Y_ALL = pickle.load(open('exp_nas-bench-nlp-mds.pickle', 'rb'))

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
from math import exp

In [None]:
color = np.array([x[1] for x in data])

# Create figure
fig = plt.figure(figsize=(16, 6))

# Plot results
for i, (label, method) in enumerate(methods.items()):
    t0 = time()
    Y = Y_ALL[i]
    t1 = time()
    print("%s: %.2g sec" % (label, t1 - t0))
    ax = fig.add_subplot(2, 5, 2 + i + (i > 3))
    im = ax.scatter(Y[:, 0], Y[:, 1], c=color, cmap=plt.cm.Spectral)
    #ax.set_title("%s" % label)
    ax.xaxis.set_major_formatter(NullFormatter())
    ax.yaxis.set_major_formatter(NullFormatter())
    ax.axis('tight')
    
    # create an axes on the right side of ax. The width of cax will be 5%
    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.10)
    plt.colorbar(im, cax=cax)
    #fig.tight_layout()
    
plt.show()