In [None]:
# plot mirror-symmetric viewpoint tuning
# useful packages:
import torch
import torchvision.transforms as transforms
from torchvision import models
from scipy import stats
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import scipy
import scipy.io as sio
from scipy.stats import pearsonr
import warnings
warnings.filterwarnings("ignore")
import pickle
from PIL import Image
import random
import matplotlib.gridspec as gridspec
import seaborn as sns
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

In [None]:
# settings
nview = 5
nexemplar = 25
img_size = 224
nchannel = 3

# line colors
cmap = np.divide([178,223,138],255)

# font
mpl.rcParams['font.family'] = 'Arial'
plt.rcParams['font.size'] = 6
plt.rcParams['axes.linewidth'] = 0.8

dataset = 'fiv'

file_dir = 'comparison_w_AL/'

In [None]:
with open(f'{file_dir}neural_msvt_ns_ceiling.pkl', 'rb') as f:
    neural_data = pickle.load(f)

In [None]:
# load models' msvt
model_names = ['AlexNet','EIG','VGG16','ResNet50','ConvNeXt','ViT']

layer_name =  [[] for i in range(len(model_names))]
mean_index = [[] for i in range(len(model_names))]
std_index = [[] for i in range(len(model_names))]

for i_model, name in enumerate(model_names):
    # python
    with open(f'{file_dir}comparison_neural_rdm_shapley_{name}_{dataset}.pkl', 'rb') as f:
        data = pickle.load(f)

    layer_name[i_model] = data['name']


    mean_index[i_model] = np.mean(data['msvt'], axis=1)
    std_index[i_model] = np.std(data['msvt'], axis=1)#/np.sqrt(nexemplar)

In [None]:
median_neural_msvt_index = np.median(neural_data['corrected_msvt'],axis=1)
std_neural_msvt_index = np.std(neural_data['msvt'],axis=1)

In [None]:
# mean_neural_msvt_index / np.mean(neural_data['ns_ceiling'])

In [None]:
# plot the figure
msz = 3
lw = 1

figname = 'figureS10'
fig = plt.figure(figsize=(7,3))
ncols = 3
nrows = 2
gs = gridspec.GridSpec(nrows, ncols, left=0.02, bottom=0.02, right=0.98, top=0.98, wspace=0.25, hspace=0.8, height_ratios=[1,1])
for i, ax_ in enumerate(gs):

    ax = fig.add_subplot(ax_,zorder=2)
    
    x_pos = np.arange(0,len(mean_index[i]))
    y_pos = [-1, -0.5, 0, 0.5, 1]

    eplot = plt.errorbar(x_pos,  mean_index[i] , std_index[i], markerfacecolor = [0.5, 0.5, 0.5],capsize=None,
                ecolor = [0,0,0], markeredgecolor = [0,0,0], marker='o',markersize = 4,linewidth = 1, ls= '',clip_on=False,markeredgewidth=0.5)
    
    for b in eplot[2]:
        b.set_clip_on(False)

    ax.axhline(y=median_neural_msvt_index[0], xmin=0, xmax=1.01, color=[0.4980, 0.7882, 0.4980], alpha = 0.35, linewidth= 2, clip_on=False, label='MLMF')
    ax.axhline(y=median_neural_msvt_index[1], xmin=0, xmax=1.01, color=[0.7451, 0.6824, 0.8314], alpha = 0.35, linewidth= 2, clip_on=False, label='AL')
    ax.axhline(y=median_neural_msvt_index[2], xmin=0, xmax=1.01, color=[0.9922, 0.7529, 0.5255], alpha = 0.35, linewidth= 2, clip_on=False, label='AM')

    # set tick params
    plt.tick_params(length = 2, width = 0.8)

    # set x axis
    ax.set_xticks(x_pos)
    ax.set_xticklabels(layer_name[i],
                                rotation = 90)
    if i == 0:
        ax.set_ylabel("Mirror-symmetric\nviewpoint tuning index",labelpad=0.8)
        ax.text(-0.3,median_neural_msvt_index[0]-0.05,neural_data['name'][0],color=[0.4980, 0.7882, 0.4980])
        ax.text(-0.3,median_neural_msvt_index[1]-0.05,neural_data['name'][1],color=[0.7451, 0.6824, 0.8314])
        ax.text(-0.3,median_neural_msvt_index[2]-0.05,neural_data['name'][2],color=[0.9922, 0.7529, 0.5255])

        # ax.legend(frameon=False, bbox_to_anchor=(-0.2, 0.57, 0.5, 0.5))
            
    ax.set_xlim([0-0.5,len(x_pos)-1])
    ax.set_yticks(y_pos)
    ax.set_ylim((-1, 1))
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    ax.set_title(model_names[i],pad = 0,
                    fontdict = {'fontsize': 8,
                        'fontweight': 'bold',
                        'color': [0,0,0],
                        'verticalalignment': 'center',
                        'horizontalalignment': 'center'})
    ax.hlines(y=0, xmin=x_pos[0]-0.5, xmax=x_pos[-1], color=[0.7,0.7,0.7],ls='--',lw=0.5)

    # font and tick params
    plt.tick_params(length = 1, width = 0.8)
    mpl.rcParams['font.family'] = 'Arial'
    plt.rcParams['font.size'] = 6
    plt.rcParams['axes.linewidth'] = 0.8

plt.tight_layout()
plt.savefig(figname+'.pdf',dpi=300,bbox_inches='tight',facecolor='w',pad_inches=0)
plt.show()