In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
import os
from itertools import product
import math
import numpy
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from scipy.cluster.hierarchy import dendrogram, leaves_list
from scipy.spatial.distance import pdist
from sklearn.manifold import TSNE, MDS
import scipy
# import umap
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.axes_grid1 import make_axes_locatable
import time

from sklearn.kernel_ridge import KernelRidge

sys.path.append(os.path.abspath("../"))

from distance_functions_torch import *
import scipy.stats

folder = "/content/drive/MyDrive/UKP/mnist_experiments/distances/widthdepth/5000_train/"
folder_categorized = "/content/drive/MyDrive/UKP/mnist_experiments/distances/widthdepth/5000_train_categorized"
# loading test distance file and seeing its entries
data = np.load(f"{folder}/width100_depth2_seed2_rep_width100_depth4_seed2_rep.npz")
print(data.files)

distnames = data.files

# Code to extract each different distance into its own array
## done on Google colab using notebook separating distance into separate files.ipynb under mnist_experiments

model_names = []
reps_folder = f"/content/drive/MyDrive/UKP/mnist_experiments/reps/train/5000_eval"
filenames = os.listdir(reps_folder)
for filename in filenames:
  if "saved" not in filename and "seed2" in filename and "depth3" not in filename and "depth6" not in filename and "depth8":
    model_names.append(filename[:-4])


model_names = np.sort(model_names)
num_models = len(model_names)

dist_array = np.load(f'{folder_categorized}/all_distances_categorized_correctedfinal.npz')

# Loading representations (both train and test)

rep_folder_prefix = '/content/drive/MyDrive/UKP/mnist_experiments/'
train_reps_folder = rep_folder_prefix + 'reps/train/5000_eval/'
val_reps_folder = rep_folder_prefix + 'reps/test/5000_eval/'

dist_n = 5000
# Load ImageNet representations
reps_train = {}  # Train dataset
reps_test = {}  # Validation dataset
try:
    for model_name in model_names:
        print(model_name)
        rep1 = np.load(train_reps_folder + model_name + ".npy")

        # use only dist_n samples
        rep1 = rep1[:, :dist_n]

        # center and normalize

        rep1 = rep1 - rep1.mean(axis=1, keepdims=True)
        rep1 = rep1 / np.linalg.norm(rep1)
        rep1 = rep1 * np.sqrt(rep1.shape[1])
        reps_train[model_name] = rep1

        rep2 = np.load(val_reps_folder + model_name + ".npy")

        # Use only dist_n samples
        rep2 = rep2[:, :dist_n]

        # center and normalize
        rep2 = rep2 - rep2.mean(axis=1, keepdims=True)
        rep2 = rep2 / np.linalg.norm(rep2)
        rep2 = rep2 * np.sqrt(rep2.shape[1])
        reps_test[model_name] = rep2

except FileNotFoundError as e:
    print('WARNING: IN ORDER TO RUN THIS CODE, THE IMAGENET REPRESENTATIONS MUST COMPUTED. SEE README.')
    raise e

##########################################################################################################3

#### Predict generalization, random y's
def find_best_pred(y, lmbda, sigma, reps, kernelstr):
    # ridge regression
    # assume reps is dimension x number datapoints
    if kernelstr == 'rbf':
        gamma = 1/ ((2*sigma)**2)
        krr = KernelRidge(alpha=lmbda, kernel='rbf', gamma = gamma)
    krr.fit(reps.T, y)
    return krr


def find_best_lin_pred(y, lmbda, reps):
    # ridge regression
    # assume reps is dimension x number datapoints
    rep_dim = reps.shape[0]
    numpts = reps.shape[1]

    return np.linalg.solve((lmbda * np.eye(rep_dim) + (reps @ reps.T) / numpts), reps @ y)


def symmetrize(A):
    n = A.shape[0]
    B = A.copy()
    B[np.tril_indices(n)] = B.T[np.tril_indices(n)]
    return B

def dist_from_upper_tri_vec(vec, num_models):
    D = np.zeros((num_models, num_models))
    row_indices, col_indices = np.triu_indices(num_models, k=1)
    D[row_indices, col_indices] = vec
    D = symmetrize(D)

    return D

def flatten_upper_right_triangle(curr_mat):
    cv = []
    assert (curr_mat.shape[0] == curr_mat.shape[1])
    assert (curr_mat.shape[0] == len(model_names))
    for i in range(len(model_names) - 1):
        for j in range(i + 1, len(model_names)):
            cv.append(curr_mat[i, j])
    cv = np.asarray(cv)
    return cv

distances = {}
for i in range(len(distnames)):
    distname = distnames[i]
    print(distname)
    distances[distname] = dist_from_upper_tri_vec(dist_array[distname],num_models)
    print(distances[distname])

err_folder = f"/home/soumya/Documents/mnist_experiments/err_folder/"

import gc


label_names_dict = {'lin_cka_dist': 'lin CKA', 'mean_sq_cca_e2e' : 'CCA', 'pwcca_dist_e2e': 'PWCCA', 'procrustes': 'Procrustes'}

label_names_dict['GULP_dist_0.000000e+00'] = 'GULP, $\lambda = 0$'
label_names_dict['GULP_dist_1.000000e-07'] = 'GULP, $\lambda = 10^{-7}$'
label_names_dict['GULP_dist_1.000000e-06'] = 'GULP, $\lambda = 10^{-6}$'
label_names_dict['GULP_dist_1.000000e-05'] = 'GULP, $\lambda = 10^{-5}$'
label_names_dict['GULP_dist_1.000000e-04'] = 'GULP, $\lambda = 10^{-4}$'
label_names_dict['GULP_dist_1.000000e-03'] = 'GULP, $\lambda = 10^{-3}$'
label_names_dict['GULP_dist_1.000000e-02'] = 'GULP, $\lambda = 10^{-2}$'
label_names_dict['GULP_dist_1.000000e-01'] = 'GULP, $\lambda = 10^{-1}$'
label_names_dict['GULP_dist_1.000000e+00'] = 'GULP, $\lambda = 1$'
label_names_dict['GULP_dist_1.000000e+01'] = 'GULP, $\lambda = 10$'

label_names_dict['CKA_dist_RBF_1.000000e-03'] = 'CKA_RBF, $\sigma= 10^{-3}$'
label_names_dict['CKA_dist_RBF_1.000000e-02'] = 'CKA_RBF, $\sigma= 10^{-2}$'
label_names_dict['CKA_dist_RBF_1.000000e-01'] = 'CKA_RBF, $\sigma= 10^{-1}$'
label_names_dict['CKA_dist_RBF_1.000000e+00'] = 'CKA_RBF, $\sigma = 1$'
label_names_dict['CKA_dist_RBF_1.000000e+01'] = 'CKA_RBF, $\sigma = 10$'
label_names_dict['CKA_dist_RBF_1.000000e+02'] = 'CKA_RBF, $\sigma = 100$'

label_names_dict['CKA_dist_Laplace_0.001'] = 'CKA_Lap, $\sigma= 10^{-3}$'
label_names_dict['CKA_dist_Laplace_0.01'] = 'CKA_Lap, $\sigma= 10^{-2}$'
label_names_dict['CKA_dist_Laplace_0.09999999999999999'] = 'CKA_Lap, $\sigma= 10^{-1}$'
label_names_dict['CKA_dist_Laplace_1.0'] = 'CKA_Lap, $\sigma = 1$'
label_names_dict['CKA_dist_Laplace_10.0'] = 'CKA_Lap, $\sigma = 10$'
label_names_dict['CKA_dist_Laplace_100.0'] = 'CKA_Lap, $\sigma = 100$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e-07'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10^{-7}$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e-07'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10^{-7}$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e-07'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10^{-7}$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e-07'] = 'UKP_RBF, $\lambda = 1,\sigma= 10^{-7}$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e-07'] = 'UKP_RBF, $\lambda = 10,\sigma= 10^{-7}$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e-07'] = 'UKP_RBF, $\lambda = 100,\sigma= 10^{-7}$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e-06'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10^{-6}$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e-06'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10^{-6}$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e-06'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10^{-6}$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e-06'] = 'UKP_RBF, $\lambda = 1,\sigma= 10^{-6}$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e-06'] = 'UKP_RBF, $\lambda = 10,\sigma= 10^{-6}$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e-06'] = 'UKP_RBF, $\lambda = 100,\sigma= 10^{-6}$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e-05'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10^{-5}$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e-05'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10^{-5}$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e-05'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10^{-5}$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e-05'] = 'UKP_RBF, $\lambda = 1,\sigma= 10^{-5}$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e-05'] = 'UKP_RBF, $\lambda = 10,\sigma= 10^{-5}$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e-05'] = 'UKP_RBF, $\lambda = 100,\sigma= 10^{-5}$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e-04'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10^{-4}$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e-04'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10^{-4}$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e-04'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10^{-4}$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e-04'] = 'UKP_RBF, $\lambda = 1,\sigma= 10^{-4}$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e-04'] = 'UKP_RBF, $\lambda = 10,\sigma= 10^{-4}$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e-04'] = 'UKP_RBF, $\lambda = 100,\sigma= 10^{-4}$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e-03'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10^{-3}$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e-03'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10^{-3}$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e-03'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10^{-3}$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e-03'] = 'UKP_RBF, $\lambda = 1,\sigma= 10^{-3}$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e-03'] = 'UKP_RBF, $\lambda = 10,\sigma= 10^{-3}$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e-03'] = 'UKP_RBF, $\lambda = 100,\sigma= 10^{-3}$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e-02'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10^{-2}$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e-02'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10^{-2}$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e-02'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10^{-2}$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e-02'] = 'UKP_RBF, $\lambda = 1,\sigma= 10^{-2}$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e-02'] = 'UKP_RBF, $\lambda = 10,\sigma= 10^{-2}$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e-02'] = 'UKP_RBF, $\lambda = 100,\sigma= 10^{-2}$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e-01'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10^{-1}$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e-01'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10^{-1}$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e-01'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10^{-1}$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e-01'] = 'UKP_RBF, $\lambda = 1,\sigma= 10^{-1}$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e-01'] = 'UKP_RBF, $\lambda = 10,\sigma= 10^{-1}$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e-01'] = 'UKP_RBF, $\lambda = 100,\sigma= 10^{-1}$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e+00'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 1$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e+00'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 1$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e+00'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 1$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e+00'] = 'UKP_RBF, $\lambda = 1,\sigma= 1$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e+00'] = 'UKP_RBF, $\lambda = 10,\sigma= 1$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e+00'] = 'UKP_RBF, $\lambda = 100,\sigma= 1$'

label_names_dict['UKP_dist_RBF_1.000000e-03_1.000000e+01'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 10$'
label_names_dict['UKP_dist_RBF_1.000000e-02_1.000000e+01'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 10$'
label_names_dict['UKP_dist_RBF_1.000000e-01_1.000000e+01'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 10$'
label_names_dict['UKP_dist_RBF_1.000000e+00_1.000000e+01'] = 'UKP_RBF, $\lambda = 1,\sigma= 10$'
label_names_dict['UKP_dist_RBF_1.000000e+01_1.000000e+01'] = 'UKP_RBF, $\lambda = 10,\sigma= 10$'
label_names_dict['UKP_dist_RBF_1.000000e+02_1.000000e+01'] = 'UKP_RBF, $\lambda = 100,\sigma= 10$'

label_names_dict['UKP_dist_RBF_1.000000e-03_0.000000e+00'] = 'UKP_RBF, $\lambda = 10^{-3},\sigma= 100$'
label_names_dict['UKP_dist_RBF_1.000000e-02_0.000000e+00'] = 'UKP_RBF, $\lambda = 10^{-2},\sigma= 100$'
label_names_dict['UKP_dist_RBF_1.000000e-01_0.000000e+00'] = 'UKP_RBF, $\lambda = 10^{-1},\sigma= 100$'
label_names_dict['UKP_dist_RBF_1.000000e+00_0.000000e+00'] = 'UKP_RBF, $\lambda = 1,\sigma= 100$'
label_names_dict['UKP_dist_RBF_1.000000e+01_0.000000e+00'] = 'UKP_RBF, $\lambda = 10,\sigma= 100$'
label_names_dict['UKP_dist_RBF_1.000000e+02_0.000000e+00'] = 'UKP_RBF, $\lambda = 100,\sigma= 100$'

label_names_dict['UKP_dist_Laplace_0.001_1e-07'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10^{-7}$'
label_names_dict['UKP_dist_Laplace_0.01_1e-07'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10^{-7}$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1e-07'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10^{-7}$'
label_names_dict['UKP_dist_Laplace_1.0_1e-07'] = 'UKP_Lap, $\lambda = 1,\sigma= 10^{-7}$'
label_names_dict['UKP_dist_Laplace_10.0_1e-07'] = 'UKP_Lap, $\lambda = 10,\sigma= 10^{-7}$'
label_names_dict['UKP_dist_Laplace_100.0_1e-07'] = 'UKP_Lap, $\lambda = 100,\sigma= 10^{-7}$'

label_names_dict['UKP_dist_Laplace_0.001_1e-06'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10^{-6}$'
label_names_dict['UKP_dist_Laplace_0.01_1e-06'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10^{-6}$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1e-06'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10^{-6}$'
label_names_dict['UKP_dist_Laplace_1.0_1e-06'] = 'UKP_Lap, $\lambda = 1,\sigma= 10^{-6}$'
label_names_dict['UKP_dist_Laplace_10.0_1e-06'] = 'UKP_Lap, $\lambda = 10,\sigma= 10^{-6}$'
label_names_dict['UKP_dist_Laplace_100.0_1e-06'] = 'UKP_Lap, $\lambda = 100,\sigma= 10^{-6}$'

label_names_dict['UKP_dist_Laplace_0.001_1e-05'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10^{-5}$'
label_names_dict['UKP_dist_Laplace_0.01_1e-05'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10^{-5}$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1e-05'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10^{-5}$'
label_names_dict['UKP_dist_Laplace_1.0_1e-05'] = 'UKP_Lap, $\lambda = 1,\sigma= 10^{-5}$'
label_names_dict['UKP_dist_Laplace_10.0_1e-05'] = 'UKP_Lap, $\lambda = 10,\sigma= 10^{-5}$'
label_names_dict['UKP_dist_Laplace_100.0_1e-05'] = 'UKP_Lap, $\lambda = 100,\sigma= 10^{-5}$'

label_names_dict['UKP_dist_Laplace_0.001_1e-04'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10^{-4}$'
label_names_dict['UKP_dist_Laplace_0.01_1e-04'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10^{-4}$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1e-04'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10^{-4}$'
label_names_dict['UKP_dist_Laplace_1.0_1e-04'] = 'UKP_Lap, $\lambda = 1,\sigma= 10^{-4}$'
label_names_dict['UKP_dist_Laplace_10.0_1e-04'] = 'UKP_Lap, $\lambda = 10,\sigma= 10^{-4}$'
label_names_dict['UKP_dist_Laplace_100.0_1e-04'] = 'UKP_Lap, $\lambda = 100,\sigma= 10^{-4}$'

label_names_dict['UKP_dist_Laplace_0.001_1e-03'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10^{-3}$'
label_names_dict['UKP_dist_Laplace_0.01_1e-03'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10^{-3}$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1e-03'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10^{-3}$'
label_names_dict['UKP_dist_Laplace_1.0_1e-03'] = 'UKP_Lap, $\lambda = 1,\sigma= 10^{-3}$'
label_names_dict['UKP_dist_Laplace_10.0_1e-03'] = 'UKP_Lap, $\lambda = 10,\sigma= 10^{-3}$'
label_names_dict['UKP_dist_Laplace_100.0_1e-03'] = 'UKP_Lap, $\lambda = 100,\sigma= 10^{-3}$'

label_names_dict['UKP_dist_Laplace_0.001_1e-02'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10^{-2}$'
label_names_dict['UKP_dist_Laplace_0.01_1e-02'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10^{-2}$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1e-02'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10^{-2}$'
label_names_dict['UKP_dist_Laplace_1.0_1e-02'] = 'UKP_Lap, $\lambda = 1,\sigma= 10^{-2}$'
label_names_dict['UKP_dist_Laplace_10.0_1e-02'] = 'UKP_Lap, $\lambda = 10,\sigma= 10^{-2}$'
label_names_dict['UKP_dist_Laplace_100.0_1e-02'] = 'UKP_Lap, $\lambda = 100,\sigma= 10^{-2}$'

label_names_dict['UKP_dist_Laplace_0.001_1e-01'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10^{-1}$'
label_names_dict['UKP_dist_Laplace_0.01_1e-01'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10^{-1}$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1e-01'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10^{-1}$'
label_names_dict['UKP_dist_Laplace_1.0_1e-01'] = 'UKP_Lap, $\lambda = 1,\sigma= 10^{-1}$'
label_names_dict['UKP_dist_Laplace_10.0_1e-01'] = 'UKP_Lap, $\lambda = 10,\sigma= 10^{-1}$'
label_names_dict['UKP_dist_Laplace_100.0_1e-01'] = 'UKP_Lap, $\lambda = 100,\sigma= 10^{-1}$'

label_names_dict['UKP_dist_Laplace_0.001_1.0'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 1$'
label_names_dict['UKP_dist_Laplace_0.01_1.0'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 1$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_1.0'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 1$'
label_names_dict['UKP_dist_Laplace_1.0_1.0'] = 'UKP_Lap, $\lambda = 1,\sigma= 1$'
label_names_dict['UKP_dist_Laplace_10.0_1.0'] = 'UKP_Lap, $\lambda = 10,\sigma= 1$'
label_names_dict['UKP_dist_Laplace_100.0_1.0'] = 'UKP_Lap, $\lambda = 100,\sigma= 1$'

label_names_dict['UKP_dist_Laplace_0.001_10.0'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 10$'
label_names_dict['UKP_dist_Laplace_0.01_10.0'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 10$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_10.0'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 10$'
label_names_dict['UKP_dist_Laplace_1.0_10.0'] = 'UKP_Lap, $\lambda = 1,\sigma= 10$'
label_names_dict['UKP_dist_Laplace_10.0_10.0'] = 'UKP_Lap, $\lambda = 10,\sigma= 10$'
label_names_dict['UKP_dist_Laplace_100.0_10.0'] = 'UKP_Lap, $\lambda = 100,\sigma= 10$'

label_names_dict['UKP_dist_Laplace_0.001_0.0'] = 'UKP_Lap, $\lambda = 10^{-3},\sigma= 100$'
label_names_dict['UKP_dist_Laplace_0.01_0.0'] = 'UKP_Lap, $\lambda = 10^{-2},\sigma= 100$'
label_names_dict['UKP_dist_Laplace_0.09999999999999999_0.0'] = 'UKP_Lap, $\lambda = 10^{-1},\sigma= 100$'
label_names_dict['UKP_dist_Laplace_1.0_0.0'] = 'UKP_Lap, $\lambda = 1,\sigma= 100$'
label_names_dict['UKP_dist_Laplace_10.0_0.0'] = 'UKP_Lap, $\lambda = 10,\sigma= 100$'
label_names_dict['UKP_dist_Laplace_100.0_0.0'] = 'UKP_Lap, $\lambda = 100,\sigma= 100$'

def get_collected_correlations_kertasks(lmbda, sigma, lmbda_name_err,sigma_name_err, numtrials=50, numtrainsamples=5000):
    collected_correlations = []

    labels = []

    for ky in distances.keys():
        if ky != 'predictor_dist_range':
            labels.append(ky)

    for tri in range(numtrials):
        print(f'Trial {tri}')

        err_vec = np.load(f"{err_folder}{lmbda_name_err}_{sigma_name_err}_tri{tri}_numtrial{numtrials}_numtrainsamples{numtrainsamples}_err_RBF_vec.npy")

        correlations = []

        for distname in labels:
            val = scipy.stats.spearmanr(err_vec, dist_array[distname]).correlation
            correlations.append(val)

        collected_correlations.append(correlations)

    return labels, collected_correlations

lmbda_vals = [0.01, 1]
lmbda_names = ["10^{-2}", "1"]
lmbda_names_err = ["1.000000e-02","1.000000e+00"]
sigma_vals = [0.1, 1]
sigma_names = ["10^{-1}", "1"]
sigma_names_err = ["1.000000e-01","1.000000e+00"]
figures_folder = "/content/drive/MyDrive/mnist_experiments/Figures/"


for lmbda_ind, lmbda in enumerate(lmbda_vals):
    for sigma_ind, sigma in enumerate(sigma_vals):
        lmbda_name_err = lmbda_names_err[lmbda_ind]
        sigma_name_err = sigma_names_err[sigma_ind]
        labels, collected_correlations = get_collected_correlations_kertasks(lmbda, sigma, lmbda_name_err,
                                                                             sigma_name_err, numtrials=30,
                                                                             numtrainsamples=5000)

        std_devs = []
        means = []
        for i in range(len(collected_correlations[0])):
            cvs = [collected_correlations[j][i] for j in range(len(collected_correlations))]
            std_devs.append(scipy.stats.sem(cvs))
            means.append(np.mean(cvs))
        std_devs = np.array(std_devs)
        means = np.array(means)

        subset_labels = ['lin_cka_dist', 'mean_sq_cca_e2e', 'GULP_dist_1.000000e-07', 'GULP_dist_1.000000e-05',
                         'GULP_dist_1.000000e-03', 'GULP_dist_1.000000e-01', 'GULP_dist_1.000000e+00',
                         'GULP_dist_1.000000e+01',
                         'UKP_dist_RBF_1.000000e-02_1.000000e-01', 'UKP_dist_RBF_1.000000e+00_1.000000e-01',
                         'UKP_dist_RBF_1.000000e-02_1.000000e+00', 'UKP_dist_RBF_1.000000e+00_1.000000e+00',
                         'CKA_dist_RBF_1.000000e-03', 'CKA_dist_RBF_1.000000e-02', 'CKA_dist_RBF_1.000000e-01',
                         'CKA_dist_RBF_1.000000e+00', 'CKA_dist_RBF_1.000000e+01', 'CKA_dist_RBF_1.000000e+02']
        label_groups = ['lin_cka', 'CCA', 'GULP', 'GULP', 'GULP', 'GULP', 'GULP', 'GULP', 'UKP', 'UKP', 'UKP', 'UKP','CKA_RBF','CKA_RBF','CKA_RBF','CKA_RBF','CKA_RBF','CKA_RBF']
        group_colors = {'lin_cka': 'violet', 'CCA': 'orange', 'GULP': 'green', 'UKP': 'blue', 'CKA_RBF': 'red'}

        subset_label_names = [label_names_dict[x] for x in subset_labels]
        subset_indices = []
        for i in range(len(subset_labels)):
            subset_indices.append(labels.index(subset_labels[i]))

        lbels = labels
        import matplotlib.pyplot as plt

        fig = plt.figure()
        ax = fig.add_axes([0, 0, 1, 1])
        for i, (category,value,group) in enumerate(zip(range(len(subset_labels)),means[subset_indices],label_groups)):
            ax.bar(category, value, color=group_colors[group], yerr = std_devs[subset_indices][i])

        plt.xticks(range(len(subset_labels)), labels=subset_label_names, rotation='vertical', fontsize=15)
        plt.yticks(fontsize=15)
        plt.ylim(top=1)
        plt.title(
            f'Spearman $\\rho$ with prediction distance for $\\lambda = {lmbda_names[lmbda_ind]}$,$\\sigma = {sigma_names[sigma_ind]}$',
            fontsize=16)

        plt.savefig(f'{figures_folder}krrgen/pdf/generalization(includingCKARBF)_lambda' + str(lmbda) + '_sigma' + str(
            sigma) + '.pdf',
                    format='pdf', dpi=300, bbox_inches='tight')
        plt.savefig(f'{figures_folder}krrgen/png/generalization(includingCKARBF)_lambda' + str(lmbda) + '_sigma' + str(
            sigma) + '.png',
                    format='png', dpi=300, bbox_inches='tight')
        plt.close()

