# Setup

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
import seaborn
main_path = '/mnt/raid/ni/agnessa/RSA/'


# Define the function to create filenames

In [2]:
def getFileName(n_samples,name,model_name,layer_name):
    return name \
        + "_{}_".format(n_samples) \
        + "_{}_".format(model_name) \
        + "_{}".format(layer_name)  \
        + ".npy"       

# Average correlation of a subset of layers for two networks (same model, different training tasks) - use the whole matrix (not just upper triangular)

In [3]:
def avg_correlation_cross_task(model_name,trained_on_1,trained_on_2,tested_on,min_layer_idx,max_layer_idx):
    layer_name = 'all'
    
    #define number of samples depending on the testing dataset
    if tested_on == 'ImageNet' or tested_on == '':
        n_samples = 10000
    elif tested_on == 'Places365':
        n_samples = 10220
        
    #load the models
    model_rdm_1_filename = os.path.join(main_path,trained_on_1,'','Model_RDM', \
                                      getFileName(n_samples,'Model_RDM',model_name,layer_name))
    model_name_2 = model_name+'_'+model_name
    model_rdm_2_filename = os.path.join(main_path,trained_on_2,tested_on,'Model_RDM', \
                                      getFileName(n_samples,'Model_RDM_cross_task',model_name_2,layer_name))  
    model_rdm_1 = np.load(model_rdm_1_filename)
    model_rdm_2 = np.load(model_rdm_2_filename)
    
    #select only the desired layers
    selected_model_rdm_1 = model_rdm_1[min_layer_idx:max_layer_idx+1,min_layer_idx:max_layer_idx+1]#plus one makes sure that the last layer is included
    selected_model_rdm_2 = model_rdm_2[min_layer_idx:max_layer_idx+1,min_layer_idx:max_layer_idx+1]
   
    #get the correlation
    avg_correlation_1 = 1-np.mean(selected_model_rdm_1)
    avg_correlation_2 = 1-np.mean(selected_model_rdm_2)
    return avg_correlation_1, avg_correlation_2

# Plot the correlations (bar plot)

In [124]:
def bar_plot_correlations(**kwargs):
    
    ## Model name labels ## 

    #model name labels
    def autolabel(bars, model_name_label, height): 
        """Attach a text label above each quadruplet of bars (in the middle), displaying the model name."""
        for b in bars:
            if bars.index(b)==0:
                ax.annotate('{}'.format(model_name_label),
                            xy=(b.get_x() + b.get_width()*2.5, height),
                            xytext=(0, 3),  # 3 points vertical offset
                            textcoords="offset points",
                            ha='center', va='bottom',
                            fontsize=25,
                            color='mediumblue')
                xy=(b.get_x() + b.get_width()*2.5, height)
                
    #get the max correlation for the y coord of model name labels
    num_bars = 2 #1 for within task, 1 for across tasks
    all_correlations = np.ones((len(kwargs),num_bars*2))
    all_correlations[:] = np.nan
    for index,k in enumerate(kwargs):
        correlations = (kwargs.get(k)).get('correlations')
        all_correlations[index,:] = correlations.flatten()      
    max_height = np.max(all_correlations)              
        
    ## Bar plot ## 
    
    #setup the figure
    fig, ax = plt.subplots(figsize=(15,13))
    ind = np.arange(num_bars) 
    width = 0.35       
    x_all = np.ones((len(kwargs),num_bars))
    x_all[:] = np.nan
    
    #loop over models to make the bar plots
    for index,k in enumerate(kwargs):
        correlations = (kwargs.get(k)).get('correlations')
        x = np.arange(num_bars)+2.5*index
        within_task = plt.bar(x, correlations[0,:], width, label='Within task', color='purple')
        across_tasks = plt.bar(x + width, correlations[1,:], width, label='Across tasks', color='lightseagreen')
        x_all[index,:] = x
        autolabel(within_task,(kwargs.get(k)).get('model name'),max_height)

    #plotting parameters
    font_small = 25
    font_title = 30
    plt.title('Average correlations for early and late layers in the within-task and across-task RDMs',fontsize=font_title)
    plt.xticks(x_all.flatten() + width / 2, ('Early layers', 'Late layers','Early layers', 'Late layers'),fontsize=font_small)
    plt.yticks(fontsize=font_small)
    plt.ylabel('Spearman\'s coefficient',fontsize=font_small)
    plt.legend([within_task,across_tasks],['Within-task', 'Across-task'],loc='best', fontsize=font_small) #can do it for just one of the models
    
    plt.show()
    return fig

In [None]:
#choose the datasets
trained_on_1 = 'Objects'
model_name_1 = 'resnet50'
trained_on_2 = 'Scenes'
model_name_2 = 'alexnet'
tested_on = 'ImageNet'

#get the correlations
mod1_within_early,mod1_across_early = avg_correlation_cross_task(model_name_1,trained_on_1,trained_on_2,tested_on,0,6)
mod1_within_late,mod1_across_late = avg_correlation_cross_task(model_name_1,trained_on_1,trained_on_2,tested_on,7,15)
mod2_within_early,mod2_across_early = avg_correlation_cross_task(model_name_2,trained_on_1,trained_on_2,tested_on,0,8)
mod2_within_late,mod2_across_late = avg_correlation_cross_task(model_name_2,trained_on_1,trained_on_2,tested_on,9,20)

#plot the correlations
fig = bar_plot_correlations(corr_model_1={'model name': 'ResNet-50', 'correlations': np.array([[mod1_within_early,mod1_within_late], 
                                                                                         [mod1_across_early,mod1_across_late]])}, 
                      corr_model_2={'model name': 'AlexNet', 'correlations': np.array([[mod2_within_early,mod2_within_late], 
                                                                                       [mod2_across_early,mod2_across_late]])})

#save the bar plot
path_png = os.path.join(main_path + 'other_plots', "{}_{}_{}_{}_{}_{}.png".format(
    'avg_corr_bar_plot',trained_on_1,trained_on_2,tested_on,model_name_1,model_name_2))
fig.savefig(path_png)
path_svg = os.path.join(main_path + 'other_plots', "{}_{}_{}_{}_{}_{}.svg".format(
    'avg_corr_bar_plot',trained_on_1,trained_on_2,tested_on,model_name_1,model_name_2))
fig.savefig(path_svg)


In [119]:
n = 1
m = 2


'_1_2_'