In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
import os
from itertools import product
import math
import numpy
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import scipy
from scipy.cluster import hierarchy
from scipy.cluster.hierarchy import dendrogram, linkage, leaves_list
from scipy.spatial.distance import pdist
from sklearn.manifold import TSNE, MDS
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.axes_grid1 import make_axes_locatable
import time
import numpy as np

sys.path.append(os.path.abspath(".."))
from distance_functions_final import *

def symmetrize(A):
    n = A.shape[0]
    B = A.copy()
    B[np.tril_indices(n)] = B.T[np.tril_indices(n)]
    return B
subset = "train"
n = 5000
mode = "eval"
model_names = []
reps_folder = f"/content/drive/MyDrive/UKP/mnist_experiments/reps/{subset}/{n}_{mode}"
filenames = os.listdir(reps_folder)
for filename in filenames:
  if "saved" not in filename and "seed2" in filename and "depth3" not in filename and "depth6" not in filename and "depth8":
    model_names.append(filename[:-4])

model_names = np.sort(model_names)

def extract_width(s):
    parts = s.lower().split("width")
    if len(parts) > 1:
        return int(''.join(filter(str.isdigit, parts[1])))
    return float('inf')
all_classes = {}

width_list =[2,4,7,8,9]

for i in width_list:
    classstring = f"depth{i}"
    all_classes[classstring] = [str(s).replace("_seed2_rep","") for s in model_names if classstring in s]
    all_classes[classstring] = sorted(all_classes[classstring],key=extract_width)


classes = {}
for c in all_classes:
    if len(all_classes[c]) > 1:
        classes[c] = all_classes[c]
num_classes = len(classes)
class_names = np.array(list(classes.keys()))


label_names_dict = {'lin_cka_dist': 'lin CKA', 'mean_sq_cca_e2e' : 'CCA', 'pwcca_dist_e2e': 'PWCCA', 'procrustes': 'Procrustes'}

label_names_dict['GULP_dist_0.000000e+00'] = 'GULP, $\lambda = 0$'
label_names_dict['GULP_dist_1.000000e-07'] = 'GULP, $\lambda = 10^{-7}$'
label_names_dict['GULP_dist_1.000000e-06'] = 'GULP, $\lambda = 10^{-6}$'
label_names_dict['GULP_dist_1.000000e-05'] = 'GULP, $\lambda = 10^{-5}$'
label_names_dict['GULP_dist_1.000000e-04'] = 'GULP, $\lambda = 10^{-4}$'
label_names_dict['GULP_dist_1.000000e-03'] = 'GULP, $\lambda = 10^{-3}$'
label_names_dict['GULP_dist_1.000000e-02'] = 'GULP, $\lambda = 10^{-2}$'
label_names_dict['GULP_dist_1.000000e-01'] = 'GULP, $\lambda = 10^{-1}$'
label_names_dict['GULP_dist_1.000000e+00'] = 'GULP, $\lambda = 1$'
label_names_dict['GULP_dist_1.000000e+01'] = 'GULP, $\lambda = 10$'
label_names_dict['CKA_dist_RBF_1.000000e-03'] = 'CKA_RBF, $\sigma= 10^{-3}$'
label_names_dict['CKA_dist_RBF_1.000000e-02'] = 'CKA_RBF, $\sigma= 10^{-2}$'
label_names_dict['CKA_dist_RBF_1.000000e-01'] = 'CKA_RBF, $\sigma= 10^{-1}$'
label_names_dict['CKA_dist_RBF_1.000000e+00'] = 'CKA_RBF, $\sigma = 1$'
label_names_dict['CKA_dist_RBF_1.000000e+01'] = 'CKA_RBF, $\sigma = 10$'
label_names_dict['CKA_dist_RBF_1.000000e+02'] = 'CKA_RBF, $\sigma = 100$'

label_names_dict['CKA_dist_Laplace_0.001'] = 'CKA_Lap, $\sigma= 10^{-3}$'
label_names_dict['CKA_dist_Laplace_0.01'] = 'CKA_Lap, $\sigma= 10^{-2}$'
label_names_dict['CKA_dist_Laplace_0.09999999999999999'] = 'CKA_Lap, $\sigma= 10^{-1}$'
label_names_dict['CKA_dist_Laplace_1.0'] = 'CKA_Lap, $\sigma = 1$'
label_names_dict['CKA_dist_Laplace_10.0'] = 'CKA_Lap, $\sigma = 10$'
label_names_dict['CKA_dist_Laplace_100.0'] = 'CKA_Lap, $\sigma = 100$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e-07'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10^{-7}$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e-07'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10^{-7}$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e-07'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10^{-7}$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e-07'] = 'UKP_RBF, $\lambda = 1,\sigma= 10^{-7}$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e-07'] = 'UKP_RBF, $\lambda = 10,\sigma= 10^{-7}$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e-07'] = 'UKP_RBF, $\lambda = 100,\sigma= 10^{-7}$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e-06'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10^{-6}$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e-06'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10^{-6}$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e-06'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10^{-6}$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e-06'] = 'UKP_RBF, $\lambda = 1,\sigma= 10^{-6}$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e-06'] = 'UKP_RBF, $\lambda = 10,\sigma= 10^{-6}$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e-06'] = 'UKP_RBF, $\lambda = 100,\sigma= 10^{-6}$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e-05'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10^{-5}$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e-05'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10^{-5}$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e-05'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10^{-5}$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e-05'] = 'UKP_RBF, $\lambda = 1,\sigma= 10^{-5}$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e-05'] = 'UKP_RBF, $\lambda = 10,\sigma= 10^{-5}$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e-05'] = 'UKP_RBF, $\lambda = 100,\sigma= 10^{-5}$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e-04'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10^{-4}$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e-04'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10^{-4}$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e-04'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10^{-4}$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e-04'] = 'UKP_RBF, $\lambda = 1,\sigma= 10^{-4}$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e-04'] = 'UKP_RBF, $\lambda = 10,\sigma= 10^{-4}$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e-04'] = 'UKP_RBF, $\lambda = 100,\sigma= 10^{-4}$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e-03'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10^{-3}$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e-03'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10^{-3}$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e-03'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10^{-3}$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e-03'] = 'UKP_RBF, $\lambda = 1,\sigma= 10^{-3}$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e-03'] = 'UKP_RBF, $\lambda = 10,\sigma= 10^{-3}$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e-03'] = 'UKP_RBF, $\lambda = 100,\sigma= 10^{-3}$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e-02'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10^{-2}$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e-02'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10^{-2}$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e-02'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10^{-2}$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e-02'] = 'UKP_RBF, $\lambda = 1,\sigma= 10^{-2}$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e-02'] = 'UKP_RBF, $\lambda = 10,\sigma= 10^{-2}$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e-02'] = 'UKP_RBF, $\lambda = 100,\sigma= 10^{-2}$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e-01'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10^{-1}$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e-01'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10^{-1}$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e-01'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10^{-1}$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e-01'] = 'UKP_RBF, $\lambda = 1,\sigma= 10^{-1}$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e-01'] = 'UKP_RBF, $\lambda = 10,\sigma= 10^{-1}$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e-01'] = 'UKP_RBF, $\lambda = 100,\sigma= 10^{-1}$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e+00'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 1$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e+00'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 1$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e+00'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 1$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e+00'] = 'UKP_RBF, $\lambda = 1,\sigma= 1$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e+00'] = 'UKP_RBF, $\lambda = 10,\sigma= 1$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e+00'] = 'UKP_RBF, $\lambda = 100,\sigma= 1$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e+01'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e+01'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e+01'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e+01'] = 'UKP_RBF, $\lambda = 1,\sigma= 10$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e+01'] = 'UKP_RBF, $\lambda = 10,\sigma= 10$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e+01'] = 'UKP_RBF, $\lambda = 100,\sigma= 10$'

label_names_dict['UKP_dist_RBF_1.000000e-03_0.000000e+00'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 100$'
label_names_dict['UKP_dist_RBF_1.000000e-02_0.000000e+00'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 100$'
label_names_dict['UKP_dist_RBF_1.000000e-01_0.000000e+00'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 100$'
label_names_dict['UKP_dist_RBF_1.000000e+00_0.000000e+00'] = 'UKP_RBF, $\lambda = 1,\sigma= 100$'
label_names_dict['UKP_dist_RBF_1.000000e+01_0.000000e+00'] = 'UKP_RBF, $\lambda = 10,\sigma= 100$'
label_names_dict['UKP_dist_RBF_1.000000e+02_0.000000e+00'] = 'UKP_RBF, $\lambda = 100,\sigma= 100$'

label_names_dict['UKP_dist_Laplace_0.001_1e-07'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10^{-7}$'
label_names_dict['UKP_dist_Laplace_0.01_1e-07'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10^{-7}$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1e-07'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10^{-7}$'
label_names_dict['UKP_dist_Laplace_1.0_1e-07'] = 'UKP_Lap, $\lambda = 1,\sigma= 10^{-7}$'
label_names_dict['UKP_dist_Laplace_10.0_1e-07'] = 'UKP_Lap, $\lambda = 10,\sigma= 10^{-7}$'
label_names_dict['UKP_dist_Laplace_100.0_1e-07'] = 'UKP_Lap, $\lambda = 100,\sigma= 10^{-7}$'

label_names_dict['UKP_dist_Laplace_0.001_1e-06'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10^{-6}$'
label_names_dict['UKP_dist_Laplace_0.01_1e-06'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10^{-6}$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1e-06'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10^{-6}$'
label_names_dict['UKP_dist_Laplace_1.0_1e-06'] = 'UKP_Lap, $\lambda = 1,\sigma= 10^{-6}$'
label_names_dict['UKP_dist_Laplace_10.0_1e-06'] = 'UKP_Lap, $\lambda = 10,\sigma= 10^{-6}$'
label_names_dict['UKP_dist_Laplace_100.0_1e-06'] = 'UKP_Lap, $\lambda = 100,\sigma= 10^{-6}$'

label_names_dict['UKP_dist_Laplace_0.001_1e-05'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10^{-5}$'
label_names_dict['UKP_dist_Laplace_0.01_1e-05'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10^{-5}$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1e-05'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10^{-5}$'
label_names_dict['UKP_dist_Laplace_1.0_1e-05'] = 'UKP_Lap, $\lambda = 1,\sigma= 10^{-5}$'
label_names_dict['UKP_dist_Laplace_10.0_1e-05'] = 'UKP_Lap, $\lambda = 10,\sigma= 10^{-5}$'
label_names_dict['UKP_dist_Laplace_100.0_1e-05'] = 'UKP_Lap, $\lambda = 100,\sigma= 10^{-5}$'

label_names_dict['UKP_dist_Laplace_0.001_1e-04'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10^{-4}$'
label_names_dict['UKP_dist_Laplace_0.01_1e-04'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10^{-4}$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1e-04'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10^{-4}$'
label_names_dict['UKP_dist_Laplace_1.0_1e-04'] = 'UKP_Lap, $\lambda = 1,\sigma= 10^{-4}$'
label_names_dict['UKP_dist_Laplace_10.0_1e-04'] = 'UKP_Lap, $\lambda = 10,\sigma= 10^{-4}$'
label_names_dict['UKP_dist_Laplace_100.0_1e-04'] = 'UKP_Lap, $\lambda = 100,\sigma= 10^{-4}$'

label_names_dict['UKP_dist_Laplace_0.001_1e-03'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10^{-3}$'
label_names_dict['UKP_dist_Laplace_0.01_1e-03'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10^{-3}$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1e-03'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10^{-3}$'
label_names_dict['UKP_dist_Laplace_1.0_1e-03'] = 'UKP_Lap, $\lambda = 1,\sigma= 10^{-3}$'
label_names_dict['UKP_dist_Laplace_10.0_1e-03'] = 'UKP_Lap, $\lambda = 10,\sigma= 10^{-3}$'
label_names_dict['UKP_dist_Laplace_100.0_1e-03'] = 'UKP_Lap, $\lambda = 100,\sigma= 10^{-3}$'

label_names_dict['UKP_dist_Laplace_0.001_1e-02'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10^{-2}$'
label_names_dict['UKP_dist_Laplace_0.01_1e-02'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10^{-2}$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1e-02'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10^{-2}$'
label_names_dict['UKP_dist_Laplace_1.0_1e-02'] = 'UKP_Lap, $\lambda = 1,\sigma= 10^{-2}$'
label_names_dict['UKP_dist_Laplace_10.0_1e-02'] = 'UKP_Lap, $\lambda = 10,\sigma= 10^{-2}$'
label_names_dict['UKP_dist_Laplace_100.0_1e-02'] = 'UKP_Lap, $\lambda = 100,\sigma= 10^{-2}$'

label_names_dict['UKP_dist_Laplace_0.001_1e-01'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10^{-1}$'
label_names_dict['UKP_dist_Laplace_0.01_1e-01'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10^{-1}$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1e-01'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10^{-1}$'
label_names_dict['UKP_dist_Laplace_1.0_1e-01'] = 'UKP_Lap, $\lambda = 1,\sigma= 10^{-1}$'
label_names_dict['UKP_dist_Laplace_10.0_1e-01'] = 'UKP_Lap, $\lambda = 10,\sigma= 10^{-1}$'
label_names_dict['UKP_dist_Laplace_100.0_1e-01'] = 'UKP_Lap, $\lambda = 100,\sigma= 10^{-1}$'

label_names_dict['UKP_dist_Laplace_0.001_1.0'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 1$'
label_names_dict['UKP_dist_Laplace_0.01_1.0'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 1$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1.0'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 1$'
label_names_dict['UKP_dist_Laplace_1.0_1.0'] = 'UKP_Lap, $\lambda = 1,\sigma= 1$'
label_names_dict['UKP_dist_Laplace_10.0_1.0'] = 'UKP_Lap, $\lambda = 10,\sigma= 1$'
label_names_dict['UKP_dist_Laplace_100.0_1.0'] = 'UKP_Lap, $\lambda = 100,\sigma= 1$'

label_names_dict['UKP_dist_Laplace_0.001_10.0'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10$'
label_names_dict['UKP_dist_Laplace_0.01_10.0'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_10.0'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10$'
label_names_dict['UKP_dist_Laplace_1.0_10.0'] = 'UKP_Lap, $\lambda = 1,\sigma= 10$'
label_names_dict['UKP_dist_Laplace_10.0_10.0'] = 'UKP_Lap, $\lambda = 10,\sigma= 10$'
label_names_dict['UKP_dist_Laplace_100.0_10.0'] = 'UKP_Lap, $\lambda = 100,\sigma= 10$'

label_names_dict['UKP_dist_Laplace_0.001_0.0'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 100$'
label_names_dict['UKP_dist_Laplace_0.01_0.0'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 100$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_0.0'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 100$'
label_names_dict['UKP_dist_Laplace_1.0_0.0'] = 'UKP_Lap, $\lambda = 1,\sigma= 100$'
label_names_dict['UKP_dist_Laplace_10.0_0.0'] = 'UKP_Lap, $\lambda = 10,\sigma= 100$'
label_names_dict['UKP_dist_Laplace_100.0_0.0'] = 'UKP_Lap, $\lambda = 100,\sigma= 100$'


# loading test distance file and seeing its entries
datafilepath = "/content/drive/MyDrive/UKP/mnist_experiments/distances/widthdepth/5000_train/width100_depth2_seed2_rep_width100_depth4_seed2_rep.npz"
data = np.load(f"{datafilepath}")
print(data.files)

distnames = data.files

####################################################################################################################3

# 0.1. Show distance matrices for different distance functions
subset = "val"
distances_folder = "/content/drive/MyDrive/UKP/mnist_experiments/distances/widthdepth/5000_train/"
distances_categorized_folder = "/content/drive/MyDrive/UKP/mnist_experiments/distances/widthdepth/5000_train_categorized/"
num_models = len(model_names)

figures_folder = "/content/drive/MyDrive/UKP/mnist_experiments/Figures"
# # Code to extract each different distance into its own array

dist_array = np.load(f'{distances_categorized_folder}all_distances_categorized_correctedfinal.npz')

model_names_categorized = ['width100_depth2', 'width200_depth2', 'width300_depth2', 'width400_depth2', 'width500_depth2', 'width600_depth2', 'width700_depth2', 'width800_depth2', 'width900_depth2', 'width1000_depth2',
                           'width100_depth4', 'width200_depth4', 'width300_depth4', 'width400_depth4', 'width500_depth4', 'width600_depth4', 'width700_depth4', 'width800_depth4', 'width900_depth4', 'width1000_depth4',
                           'width100_depth7', 'width200_depth7', 'width300_depth7', 'width400_depth7', 'width500_depth7', 'width600_depth7', 'width700_depth7', 'width800_depth7', 'width900_depth7', 'width1000_depth7',
                           'width100_depth8', 'width200_depth8', 'width300_depth8', 'width400_depth8', 'width500_depth8', 'width600_depth8', 'width700_depth8', 'width800_depth8', 'width900_depth8', 'width1000_depth8',
                           'width100_depth9', 'width200_depth9', 'width300_depth9', 'width400_depth9', 'width500_depth9', 'width600_depth9', 'width700_depth9', 'width800_depth9', 'width900_depth9', 'width1000_depth9']
sorted_indices1 = sorted(range(len(model_names)),key=lambda i: model_names[i])
sorted_indices2 = sorted(range(len(model_names_categorized)),key=lambda i: model_names_categorized[i])

matching_indices = sorted_indices2


# ### Testing correcteness of distances
short_model_names = model_names_categorized

sorted_inds = np.array([], dtype=int)
spacings = []
for name in class_names:
    class_inds = np.where(np.in1d(short_model_names, classes[name]))[0]
    spacings.append(len(sorted_inds) + len(class_inds) / 2)
    sorted_inds = np.append(sorted_inds, class_inds)
spacings = np.array(spacings)

dist_inds = [i for i in range(len(distnames))]

labels = [f"Dist{i}" for i in range(len(distnames))]

for distname in distnames:
    try:
       D_temp = dist_array[distname]
       D = np.zeros((num_models, num_models))
       row_indices, col_indices = np.triu_indices(num_models, k=1)
       D[row_indices, col_indices] = D_temp
       D = symmetrize(D)

       if "GULP" in distname or "UKP" in distname:
           D[D < 0] = np.nan

       D = D[np.ix_(matching_indices, matching_indices)]
       fig, ax = plt.subplots(figsize=(8, 8))
       im = ax.pcolormesh(D)  # [sorted_inds][:, sorted_inds])
       divider = make_axes_locatable(ax)
       cax = divider.append_axes("right", size='5%', pad=0.05)
       fig.colorbar(im, cax=cax, orientation='vertical')
       ax.set_title(f"Distance matrix for {label_names_dict[distname]}", fontsize = 20)
       ax.set_yticks(spacings + 0.5)
       ax.set_yticklabels(class_names, fontsize = 15)
       ax.set_ylabel("Networks")
       ax.set_xticks(spacings + 0.5)
       ax.set_xticklabels(class_names, rotation=90, fontsize = 13)
       ax.set_xlabel("Networks")

       plt.tight_layout()
       plt.savefig(f"{figures_folder}/Heatmaps/Heatmap for {distname}.png", format='png', dpi=300,
                    bbox_inches='tight')
       plt.savefig(f"{figures_folder}/Heatmaps/Heatmap for {distname}.pdf", format='pdf', dpi=300,
                    bbox_inches='tight')
       plt.close()
    except Exception as e:
       print(f"Skipped due to error {e}")

# 0.2 Dendogram and TSNE

subset_inds = []
for j in range(num_classes):
    class_inds = np.where(np.in1d(short_model_names, classes[class_names[j]]))[0]
    subset_inds.append(class_inds)
subset_inds = np.sort(np.hstack(subset_inds))
num_subset = len(subset_inds)

subset = "val"
num_dists = len(distnames)

embed_type = "TSNE"

embeddings = np.zeros((num_dists, num_subset, 2))
for i in range(num_dists):
    print(f"Pretrained Networks: {distnames[i]}")
    D_temp = dist_array[distnames[i]]
    D = np.zeros((num_models, num_models))
    row_indices, col_indices = np.triu_indices(num_models, k=1)
    D[row_indices, col_indices] = D_temp
    D = symmetrize(D)
    D_subset = D[subset_inds, :][:, subset_inds]

    try:
        if embed_type == "TSNE":
            X_embedded = TSNE(n_components=2, perplexity=20.0, init="random", metric="precomputed").fit_transform(
                D_subset)
        elif embed_type == "UMAP":
            X_embedded = umap.UMAP(n_components=2, n_neighbors=20, min_dist=0.1).fit_transform(np.sqrt(D_subset))
        elif embed_type == "MDS":
            X_embedded = MDS(n_components=2, dissimilarity="precomputed").fit_transform(np.sqrt(D_subset))
    except:
        X_embedded = np.full(embeddings[i, :, :].shape,np.nan)
        print("Nan generated")

    embeddings[i, :, :] = X_embedded


hierarchy.set_link_color_palette(['violet', 'brown', 'indigo', 'cyan'])

dist_inds = [i for i in range(len(distnames))]

labels = [distnames[i] for i in range(len(distnames))]

cmap = plt.cm.tab10
colors = [cmap(i) for i in range(num_classes)]

subset_inds = []
for j in range(num_classes):
    class_inds = np.where(np.in1d(short_model_names, classes[class_names[j]]))[0]
    subset_inds.append(class_inds)
subset_inds = np.sort(np.hstack(subset_inds))
num_subset = len(subset_inds)

for i in range(len(distnames)):
    try:
      fig, axs = plt.subplots(2, 1, figsize=(30, 20))
      ax = axs[0]
      X_embedded = embeddings[dist_inds[i], :, :]
      for j in range(num_classes):
          class_inds = np.array([i for i, s in enumerate(short_model_names) if s in classes[class_names[j]]])
          ax.scatter(X_embedded[class_inds, 0], X_embedded[class_inds, 1], s=200, color=colors[j])
      ax.set_title(f"{embed_type} embedding", fontdict={'fontsize': 22})
      ax.legend(class_names, fontsize=15)

      # ax = axs[2, i]
      ax = axs[1]
      D_temp = dist_array[distnames[i]]
      D = np.zeros((num_models, num_models))
      row_indices, col_indices = np.triu_indices(num_models, k=1)
      D[row_indices, col_indices] = D_temp
      D = symmetrize(D)
      D_subset = D[subset_inds, :][:, subset_inds]

      Z = linkage(pdist(D_subset), 'ward')
      count_sort = 'ascending'
      dendrogram(Z, labels=short_model_names, leaf_rotation=90, leaf_font_size=20,
                  count_sort=count_sort, ax=ax)
      ax.set_title("Dendrogram", fontdict={'fontsize': 22})

      plt.subplots_adjust(hspace=0.3)

      plt.savefig(f"{figures_folder}/DendogramandTSNE/DendogramandTSNE for {distnames[i]}.pdf", format='pdf',
                  dpi=300,
                  bbox_inches='tight')
      plt.savefig(f"{figures_folder}/DendogramandTSNE/DendogramandTSNE for {distnames[i]}.png", format='png',
                  dpi=300,
                  bbox_inches='tight')
      plt.close()
    except:
      print("Skipped due to Error")

# Only Dendrograms
for i in range(len(distnames)):
    try:
        fig, ax = plt.subplots(1, 1, figsize=(25, 15))

        D_temp = dist_array[distnames[i]]
        D = np.zeros((num_models, num_models))
        row_indices, col_indices = np.triu_indices(num_models, k=1)
        D[row_indices, col_indices] = D_temp
        D = symmetrize(D)
        D_subset = D[subset_inds, :][:, subset_inds]

        Z = linkage(pdist(D_subset), 'ward')
        count_sort = 'ascending'

        dendrogram(Z, labels=short_model_names, leaf_rotation=90, leaf_font_size=20,
                    count_sort=count_sort, ax=ax)
        ax.set_title(f"Dendrogram for {label_names_dict[distnames[i]]}", fontsize = 40)

        plt.subplots_adjust(hspace=0.3)

        plt.savefig(f"{figures_folder}/Dendogram/Dendogram for {distnames[i]}.pdf", format='pdf',
                    dpi=300,
                    bbox_inches='tight')
        plt.savefig(f"{figures_folder}/Dendogram/Dendogram for {distnames[i]}.png", format='png',
                    dpi=300,
                    bbox_inches='tight')
        plt.close()
    except:
        print("Skipped due to Error")


# Correlation plots

GULP_dists = [s for s in distnames if "GULP" in s]
UKP_dists = [s for s in distnames if "UKP" in s and "Laplace" not in s]
CKA_dists = [s for s in distnames if "CKA" in s and "Laplace" not in s]

subset_inds = range(50)

def vectorize_upper_triangle(A, subset_inds):
    assert (A.shape[0] == A.shape[1])
    if subset_inds is None:
        subset_inds = range(len(A.shape[0]))
    n = len(subset_inds)
    v = []
    for i in range(n):
        for j in range(i + 1, n):
            c_val = A[subset_inds[i], subset_inds[j]]
            v.append(c_val)
    return np.asarray(v)


def compare_d1_to_d2(d1, d2, subset_inds=None, d1_name=None, d2_name=None, fsz=17):
    D_temp = dist_array[d1]
    num_models = 50
    D = np.zeros((num_models, num_models))
    row_indices, col_indices = np.triu_indices(num_models, k=1)
    D[row_indices, col_indices] = D_temp
    D = symmetrize(D)

    d1_mat = D

    if subset_inds is None:
        subset_inds = range(d1_mat.shape[0])

    d1_vec = vectorize_upper_triangle(d1_mat, subset_inds)

    plt.subplots(1, 1, figsize=(4, 3))

    D_temp = dist_array[d2]
    D = np.zeros((num_models, num_models))
    row_indices, col_indices = np.triu_indices(num_models, k=1)
    D[row_indices, col_indices] = D_temp
    D = symmetrize(D)

    d2_mat = D
    d2_vec = vectorize_upper_triangle(d2_mat, subset_inds)

    good_inds = np.asarray([i for i in range(len(d2_vec)) if d1_vec[i] > 0 and d2_vec[i] > 0])

    if len(good_inds) < len(d2_vec):
        print('WARNING: some distances that are < 0 were removed')

    d1_filtered = d1_vec[good_inds]
    d2_filtered = d2_vec[good_inds]

    plt.figure(figsize=(8,8))
    plt.scatter(d2_filtered, d1_filtered, marker='.', s=7)
    plt.xlabel(label_names_dict[d2], fontsize=fsz)
    plt.ylabel(label_names_dict[d1], fontsize=fsz)


    plt.yticks(fontsize=fsz - 2)
    plt.xticks(fontsize=fsz - 2)

    pearsonr = scipy.stats.pearsonr(d2_filtered, d1_filtered)
    plt.title(f"Correlation: {pearsonr[0]:.3f}", fontsize=fsz - 2)

    plt.tight_layout(pad=3.0)
    plt.savefig(f"{figures_folder}/Correlation/Correlation plot for {d1} and {d2}.pdf", format='pdf', dpi=300,
                bbox_inches='tight')
    plt.savefig(f"{figures_folder}/Correlation/Correlation plot for {d1} and {d2}.png", format='png', dpi=300,
                bbox_inches='tight')
    plt.close()

for i in range(0, len(distnames)):
    for j in range(i + 1, len(distnames)):
        d1 = distnames[i]
        d2 = distnames[j]
        if ((d1 in GULP_dists) or (d1 in UKP_dists) or (d1 in CKA_dists)) and ((d2 in GULP_dists) or (d2 in UKP_dists) or (d2 in CKA_dists)):
            try:
                print(f'({i} {d1},{j} {d2})')
                compare_d1_to_d2(d1, d2, None)
            except:
                print("Skipped due Error")




