# Set up the environment

In [None]:
import numpy as np
import os
import json
import matplotlib.pyplot as plt
import seaborn
from sklearn.manifold import TSNE, MDS

# Define a function to create a filename


In [None]:
def getFileName(name, n_samples, model_name, layer_name):
    return name \
        + "_{}_".format(n_samples) \
        + "_{}_".format(model_name) \
        + "_{}".format(layer_name)  \
        + ".npy"   

# Plot stitched model

In [None]:
#use the same layers across all networks (1) or all layers (0)
same_layers = 1

#Load the model rdms
path_34 = '/mnt/raid/ni/agnessa/RSA/Model_RDM/Model_RDM_10000__resnet34__all.npy'
path_50 = '/mnt/raid/ni/agnessa/RSA/Model_RDM/Model_RDM_10000__resnet50__all.npy'
path_34_50 = '/mnt/raid/ni/agnessa/RSA/Model_RDM/Model_RDM_10000__resnet34_resnet50__all.npy'
path_34_resnext = '/mnt/raid/ni/agnessa/RSA/Model_RDM/Model_RDM_10000__resnext101_32x8d_resnet34__all.npy'
path_50_resnext = '/mnt/raid/ni/agnessa/RSA/Model_RDM/Model_RDM_10000__resnext101_32x8d_resnet50__all.npy'
path_resnext = '/mnt/raid/ni/agnessa/RSA/Model_RDM/Model_RDM_10000__resnext101_32x8d__all.npy'

model_rdm_34 = np.load(path_34)
model_rdm_50 = np.load(path_50)
model_rdm_34_50 = np.load(path_34_50)
model_rdm_34_resnext = np.load(path_34_resnext)
model_rdm_50_resnext = np.load(path_50_resnext)
model_rdm_resnext = np.load(path_resnext)

#take the same layers of resnext as in the resnets
if same_layers == 1:
    a = np.arange(13)
    b = np.arange(30,33)
    c = np.concatenate((a,b),axis=0)
    model_resized_rows = model_rdm_resnext[c,:]
    model_rdm_resnext_resized = model_resized_rows[:,c]
    stitched_rdm = np.ones([48,48])
    stitched_rdm[:] = np.nan
    stitched_rdm[0:16,0:16] = model_rdm_34
    stitched_rdm[0:16,16:32] = model_rdm_34_50
    stitched_rdm[16:32,0:16] = np.transpose(model_rdm_34_50)
    stitched_rdm[16:32,16:32] = model_rdm_50
    stitched_rdm[0:16,32:48] = np.transpose(model_rdm_34_resnext[c,0:16])
    stitched_rdm[16:32,32:48] = np.transpose(model_rdm_50_resnext[c,0:16])
    stitched_rdm[32:66,32:66] = model_rdm_resnext[c,c]
stitched_rdm[32:66,0:16] = model_rdm_34_resnext[c,0:16]
stitched_rdm[32:66,16:32] = model_rdm_50_resnext[c,0:16]

    stitched_rdm = np.ones([65,65])
stitched_rdm[:] = np.nan
stitched_rdm[0:16,0:16] = model_rdm_34
stitched_rdm[0:16,16:32] = model_rdm_34_50
stitched_rdm[16:32,0:16] = np.transpose(model_rdm_34_50)
stitched_rdm[16:32,16:32] = model_rdm_50
stitched_rdm[0:16,32:66] = np.transpose(model_rdm_34_resnext[c,0:16])
stitched_rdm[16:32,32:66] = np.transpose(model_rdm_50_resnext[c,0:16])
stitched_rdm[32:66,32:66] = model_rdm_resnext[c,c]
stitched_rdm[32:66,0:16] = model_rdm_34_resnext[c,0:16]
stitched_rdm[32:66,16:32] = model_rdm_50_resnext[c,0:16]


layer_names_resnets = ['layer1.0','layer1.1','layer1.2','layer2.0','layer2.1','layer2.2','layer2.3','layer3.0',\
                       'layer3.1','layer3.2','layer3.3','layer3.4','layer3.5','layer4.0','layer4.1','layer4.2']
layer_names_resnext = ['layer1.0','layer1.1','layer1.2','layer2.0','layer2.1','layer2.2','layer2.3','layer3.0',\
                        'layer3.1','layer3.2','layer3.3','layer3.4','layer3.5','layer3.6','layer3.7','layer3.8','layer3.9',\
                        'layer3.10','layer3.11','layer3.12','layer3.13','layer3.14','layer3.15','layer3.16','layer3.17',\
                        'layer3.18','layer3.19','layer3.20','layer3.21','layer3.22','layer4.0','layer4.1','layer4.2']
# layer_names = layer_names_resnets*2 + layer_names_resnext
layer_names = layer_names_resnets*3



In [None]:
fig = plt.figure(figsize=(18,15))
ax = seaborn.heatmap(stitched_rdm,  xticklabels = layer_names, yticklabels = layer_names, cmap='rainbow', vmin=0.0, vmax=1.0)
plt.xticks(rotation=90) 
plt.yticks(rotation=0)
plt.show()

In [None]:
ROOT_PATH = '/mnt/raid/ni/agnessa/RSA/'
NR_OF_SAMPLES = 10000
layer_name_model_rdm = 'all'
model_name = 'resnet34_resnet50_resnext'
path = os.path.join(ROOT_PATH + 'Model_RDM_plots', getFileName("Model_RDM",NR_OF_SAMPLES,model_name,layer_name_model_rdm) + '.png')
fig.savefig(path)

In [None]:
path_stitched = os.path.join(ROOT_PATH + 'Model_RDM', getFileName("Model_RDM_stitched",NR_OF_SAMPLES,model_name,layer_name_model_rdm))
np.save(path_stitched,stitched_rdm)

# Multidimensional scaling

In [None]:
mds = MDS(n_components=2, n_init=200, max_iter=2000, eps=0.0001, dissimilarity='precomputed')
data_embedded = mds.fit_transform(stitched_rdm)

In [None]:
#automate the layers
# int_name = list(np.ones(np.array(layer_names).shape))
# layer_name_int = list(np.ones(np.array(layer_names).shape))
# for layer in layer_names:
#     print(layer)
#     i = layer_names.index(layer)
#     print(i)
#     for char in layer:
#         c = layer.index(char)
#         if int(c) == 5:
#             int_name[0] = char
#         elif int(c) == 7:
#             int_name[1] = char
#     print(int_name)
#     layer_name_int[i] = float("{}.{}".format(int_name[0],int_name[1]))

# layers_resnet = np.repeat([1.0,1.1,1.2,2.0,2.1,2.2,2.3,3.0,3.1,3.2,3.3,3.4,3.5,4.0,4.1,4.2],3)
# layers_resnext = (np.linspace(3.6,3.22,17))
# #[3.6,3.7,3.8,3.9,3.10,3.11,3.12,3.13,3.14,3.15,3.16,3.17,3.18,3.19,3.20,3.21,3.22]
# layers_all = np.concatenate(layers_resnet,layers_resnext)
# sorted = list(layers_all).sort(key=float)
# print(sorted)

sorted_layers = [1.0,1.0,1.0,1.1,1.1,1.1,1.2,1.2,1.2,2.0,2.0,2.0,2.1,2.1,2.1,2.2,2.2,2.2, \
                 2.3,2.3,2.3,3.0,3.0,3.0,3.1,3.1,3.1,3.2,3.2,3.2,3.3,3.3,3.3,3.4,3.4,3.4, \
                 3.5,3.5,3.5,3.6,3.7,3.8,3.9,3.10,3.11,3.12,3.13,3.14,3.15,3.16,3.17,3.18,3.19, \
                 3.20,3.21,3.22,4.0,4.0,4.0,4.1,4.1,4.1,4.2,4.2,4.2]
fig, ax = plt.subplots()
ax.scatter(data_embedded[0:16, 0], data_embedded[0:16, 1], c=sorted_layers[0:16], cmap = 'rainbow', marker = '^')
ax.scatter(data_embedded[16:32, 0], data_embedded[16:32, 1], c=sorted_layers[16:32], cmap = 'rainbow', marker = '*')
ax.scatter(data_embedded[32:65, 0], data_embedded[32:65, 1], c=sorted_layers[32:65], cmap = 'rainbow', marker = 'o')

# sc = ax.scatter(data_embedded[:, 0], data_embedded[:, 1], c=sorted_layers, cmap = 'rainbow', marker = '^')
for i in range(np.array(layer_names).shape[0]):
    ax.annotate(layer_names[i], (data_embedded[i, 0], data_embedded[i, 1]))


plt.show()


In [None]:
path = os.path.join(ROOT_PATH + 'Model_RDM_MDS', getFileName("Model_RDM_MDS_full",NR_OF_SAMPLES,model_name,layer_name_model_rdm) + '.png')
fig.savefig(path)