In [None]:
# figure S11-13: measuring brain alignment
# useful packages:
import torch
import torchvision.transforms as transforms
from torchvision import models
from scipy import stats
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import scipy
import scipy.io as sio
from scipy.stats import pearsonr
import warnings
warnings.filterwarnings("ignore")
import pickle
from PIL import Image
import random
import matplotlib.gridspec as gridspec
import os
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

In [None]:
# settings
nview = 5
nexemplar = 25
img_size = 224
nchannel = 3

# select face patch to plot
roi_idx = 0
face_patch = ['MLMF', 'AL', 'AM'][roi_idx]
face_patch_fig_name = ['11', '12', '13'][roi_idx]
dataset = 'fiv' # bfm or fiv
nsample = 10000

# line colors
cmap = np.divide([178,223,138],255)

# font
mpl.rcParams['font.family'] = 'Arial'
plt.rcParams['font.size'] = 6
plt.rcParams['axes.linewidth'] = 0.8
cdir = os.getcwd()
file_dir = f'{cdir}/comparison_w_AL/'

In [None]:
# load noise ceiling
with open(f'{file_dir}neural_msvt_ns_ceiling.pkl', 'rb') as f:
    neural_data = pickle.load(f)

In [None]:
# load models' msvt
model_names = ['AlexNet','EIG','VGG16','ResNet50','ConvNeXt','ViT']

layer_name = [[] for i in range(len(model_names))]
shapley_vals = [[] for i in range(len(model_names))]
shapley_vals_mean = [[] for i in range(len(model_names))]
shapley_vals_std = [[] for i in range(len(model_names))]
shapley_vals_ci = [[] for i in range(len(model_names))]

shapley_vals_full_mean = [[] for i in range(len(model_names))]
shapley_vals_full_std = [[] for i in range(len(model_names))]
shapley_vals_full_ci = [[] for i in range(len(model_names))]


for i_model, name in enumerate(model_names):
    with open(f'{file_dir}comparison_neural_rdm_shapley_{name}_{dataset}.pkl', 'rb') as f:
        data = pickle.load(f)

    layer_name[i_model] = data['name']
    shapley_vals[i_model] = data['shapley_vals']

    shapley_vals_mean[i_model] = np.squeeze(np.mean(shapley_vals[i_model][roi_idx],axis=2))
    shapley_vals_std[i_model] = np.squeeze(np.std(shapley_vals[i_model][roi_idx],axis=2))

    shapley_vals_full_ci[i_model].append(stats.t.interval(0.95, nsample-1, loc=np.mean(np.squeeze(np.sum(shapley_vals[i_model][roi_idx],axis=3)),axis=1),
                                   scale=np.squeeze(stats.sem(np.squeeze(np.sum(shapley_vals[i_model][roi_idx],axis=3)),axis=1))))
    
    shapley_vals_full_mean[i_model] = np.mean(np.squeeze(np.sum(shapley_vals[i_model][roi_idx],axis=3)),axis=1)
    shapley_vals_full_std[i_model] = np.std(np.squeeze(np.sum(shapley_vals[i_model][roi_idx],axis=3)),axis=1)

    for i_components in range(2):
        shapley_vals_ci[i_model].append(stats.t.interval(0.95, nsample-1, loc=np.squeeze(np.mean(shapley_vals[i_model][roi_idx],axis=2))[:,i_components],
                                   scale=np.squeeze(stats.sem(shapley_vals[i_model][roi_idx],axis=2))[:,i_components]))
    

In [None]:
def model_res_plot(ax, i, model_name, mean_data, yerr, layers, ylabel, y_pos, y_lim, x_axis, title, facecolor, lcolor, label):
    x_pos = np.arange(0,len(mean_data),1)
    

    eplot = plt.errorbar(x_pos,  mean_data ,yerr = yerr, markerfacecolor = facecolor,capsize=None, color = lcolor,
                    ecolor=[0,0,0],markeredgecolor = [0,0,0], marker='o',markersize = 3.5, linewidth = 1.5, ls= '-',clip_on=False,markeredgewidth=0.1,zorder=10, label=label)

    for b in eplot[2]:
        b.set_clip_on(False)


    # set tick params
    plt.tick_params(length = 2, width = 0.8)

    # set x axis
    ax.set_xticks(x_pos)
    if x_axis:
        ax.set_xticklabels(layers,
                                rotation = 90)
    else:
        ax.set_xticklabels('')
    
    if i == 0 or i == 3 or i == 6:
            ax.set_ylabel(ylabel,labelpad=0.8)

    ax.set_xlim([0,len(x_pos)-1])
    ax.set_yticks(y_pos)
    ax.set_ylim((y_lim[0], y_lim[1]))
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    if title == True:
        ax.set_title(model_name,pad = 0,
                        fontdict = {'fontsize': 8,
                            'fontweight': 'bold',
                            'color': [0,0,0],
                            'verticalalignment': 'center',
                            'horizontalalignment': 'center'})

In [None]:
# plot the figure
msz = 3
lw = 1

figname = f'figureS{face_patch_fig_name}'
fig = plt.figure(figsize=(7,6))
ncols = 3
nrows = 5

y_pos = [0, 0.5, 1]
y_lim = [-0.2, 1]

start_point = np.mean(neural_data['ns_ceiling'][roi_idx])
end_point = 1
gs = gridspec.GridSpec(nrows, ncols, left=0.02, bottom=0.02, right=0.98, top=0.98, wspace=0.35, hspace=0.2, height_ratios=[1,1,0.3,1,1])
for i, ax_ in enumerate(gs):

    if i > 5 and i< 9:
        ax_spacer = fig.add_subplot(ax_,zorder=2)
        ax_spacer.axis('off')
    else:
        ax = fig.add_subplot(ax_)

    if np.isin(i,[0, 1, 2]):
        ylabel = f'Similarity to {face_patch} (r)'
        model_name = model_names[i]
        layers = layer_name[i]
        mean_data = shapley_vals_full_mean[i] 
        std_data = shapley_vals_full_std[i]
        model_res_plot(ax, i, model_name, mean_data, std_data, layers, ylabel, y_pos, y_lim, x_axis=False, title=True, facecolor=[0,0,0], lcolor= [0,0,0], label='brain alignment')
        ax.axhspan(start_point, end_point, color=[0.8,0.8,0.8], xmax=1.0, alpha = 0.5, linewidth= 0, clip_on=False)


    if np.isin(i,[3, 4, 5]):
        ylabel = 'Shapley values (r)'
        model_name = model_names[i - 3]
        layers = layer_name[i - 3]
        mean_data = shapley_vals_mean[i - 3]
        std_data = shapley_vals_std[i - 3]
        model_res_plot(ax, i, model_name, mean_data[:,1], std_data[:,1], layers, ylabel, y_pos, y_lim, x_axis=False, title=False, facecolor=[0.4940, 0.1840, 0.5560], lcolor= [0.4940, 0.1840, 0.5560],label='reflection-sensitive')
        model_res_plot(ax, i, model_name, mean_data[:,0], std_data[:,0], layers, ylabel, y_pos, y_lim, x_axis=True, title=False, facecolor=[0.9290, 0.6940, 0.1250], lcolor= [0.9290, 0.6940, 0.1250],label='reflection-invariant')

        if i == 3: 
            ax.legend(loc='upper left',frameon=False) # {'reflection-sensitive', 'reflection-invariant'}
        ax.axhline(0, color=[0.5, 0.5, 0.5], linestyle='--',linewidth = 1)

    # ResNet, ConxNeXt, ViT
    if np.isin(i,[9, 10, 11]):
        ylabel = f'Similarity to {face_patch}'
        model_name = model_names[i - 6]
        layers = layer_name[i - 6]
        mean_data = shapley_vals_full_mean[i - 6]
        std_data = shapley_vals_full_std[i - 6]
        model_res_plot(ax, i, model_name, mean_data, std_data, layers, ylabel, y_pos, y_lim, x_axis=False, title=True, facecolor=[0,0,0], lcolor= [0,0,0],  label='brain alignment')
        ax.axhspan(start_point, end_point, color=[0.8,0.8,0.8], xmax=1.0, alpha = 0.5, linewidth= 0, clip_on=False)

    if np.isin(i,[12, 13, 14]):
        ylabel = 'Shapley values'
        model_name = model_names[i - 9]
        layers = layer_name[i - 9]
        std_data = shapley_vals_std[i - 9]
        mean_data = shapley_vals_mean[i - 9]

        model_res_plot(ax, i, model_name, mean_data[:,1], std_data[:,1], layers, ylabel, y_pos, y_lim, x_axis=False, title=False, facecolor=[0.4940, 0.1840, 0.5560], lcolor= [0.4940, 0.1840, 0.5560],  label='reflection-sensitive')
        model_res_plot(ax, i, model_name, mean_data[:,0], std_data[:,0], layers, ylabel, y_pos, y_lim, x_axis=True, title=False, facecolor=[0.9290, 0.6940, 0.1250], lcolor= [0.9290, 0.6940, 0.1250], label='reflection-invariant')

        ax.axhline(0, color=[0.5, 0.5, 0.5], linestyle='--',linewidth = 1)
        
    # font and tick params
    ax.tick_params(length = 0.8, width = 0.8)
    mpl.rcParams['font.family'] = 'Arial'
    plt.rcParams['font.size'] = 6
    plt.rcParams['axes.linewidth'] = 0.8

plt.tight_layout()
plt.savefig(figname+'.pdf',dpi=300,bbox_inches='tight',facecolor='w',pad_inches=0)
plt.show()