In [None]:
# plot mirror-symmetric viewpoint tuning
# useful packages:
import torch
import torchvision.transforms as transforms
from torchvision import models
from scipy import stats
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import scipy
import scipy.io as sio
from scipy.stats import pearsonr
import warnings
warnings.filterwarnings("ignore")
import pickle
from PIL import Image
import random
import matplotlib.gridspec as gridspec

In [None]:
# settings
ncatg = 9
nview = 9
nexemplar = 25
num_imgs = nexemplar * nview * ncatg # number of images
img_size = 227
nchannel = 3
# line colors
cmap = np.divide([[166,206,227],
        [31,120,180],
        [178,223,138],
        [51,160,44],
        [253,191,111],
        [255,127,0],
        [202,178,214],
        [251,154,153],
        [227,26,2]],255)

# font
mpl.rcParams['font.family'] = 'Arial'
plt.rcParams['font.size'] = 6
plt.rcParams['axes.linewidth'] = 0.8

# Initialize the grid with 1 rows and 1 columns
ncols = 1
nrows = 1
grid = gridspec.GridSpec(nrows, ncols,
            left=0.05, bottom=0.1, right=0.95, top=0.91, wspace=0.23, hspace=0.3)

# figure 
msz = 3
lw = 1
y_pos = [-0.5, 0, 0.5, 1]
model_dir = 'msvt/'

In [None]:
# load models' msvt
model_names = ['VGG16','VGGFace','AlexNet','EIG','HMAX','ResNet50','ConvNeXt','ViT']
mean_catg_index = [[] for i in range(len(model_names))]
std_catg_index =  [[] for i in range(len(model_names))]
layer_name =  [[] for i in range(len(model_names))]

for i_model, name in enumerate(model_names):
    if name is 'VGG16' or name is 'VGGFace' or name is 'AlexNet':
        # matlab
        data = scipy.io.loadmat(f'{model_dir}msvt_index_{name}')
        nlayer = data['msvt'].T.shape[0]
        data['msvt'] = data['msvt'].T
        for i_layer in range(0,nlayer):
            if name == 'AlexNet':
                layer_name[i_model].append(data['name'][i_layer][0][0])
            else:
                layer_name[i_model].append(data['name'][0][i_layer][0])
    else:
        # python
        with open(f'{model_dir}msvt_index_{name}.pkl', 'rb') as f:
            data = pickle.load(f)
            if name == 'ResNet50' or name == 'ViT' or name == 'ConvNeXt':
                with open(f'{model_dir}msvt_index_HMAX.pkl', 'rb') as f:
                    tmp =  pickle.load(f)
                    data['msvt'] = [tmp['msvt'][0]] + data['msvt'] 
                    data['name'] = [tmp['name'][0]] + data['name']
                    
        nlayer = len(data['msvt'])
        layer_name[i_model] = data['name']
        
    mean_catg_index[i_model] = np.empty((ncatg,nlayer))
    std_catg_index[i_model] = np.empty((ncatg,nlayer))
    # catg msvt
    for i_catg in range(0,ncatg):
        for i_layer in range(0,nlayer):
            msvt_catg = data['msvt'][i_layer][i_catg * nexemplar : (i_catg + 1) * nexemplar]
            mean_catg_index[i_model][i_catg,i_layer] = np.mean(msvt_catg)
            std_catg_index[i_model][i_catg,i_layer] = np.std(msvt_catg)/np.sqrt(nexemplar)


In [None]:
# plot the figure
figname = 'figureS1'

fig = plt.figure(figsize=(7,8))
ncols = 2
nrows = 4
gs = gridspec.GridSpec(nrows, ncols, left=0.02, bottom=0.02, right=0.98, top=0.98, wspace=0.2, hspace=0.4)
catg_names = ['car','boat','face','chair','airplane','tool','animal','fruit','flower']
labels = ['A','B','C','D','E','F','G','H']
for i, ax_ in enumerate(gs):
    x_pos = np.arange(0,mean_catg_index[i].shape[1])

    ax = fig.add_subplot(ax_,zorder=2)

    for i_catg in range(0,ncatg):
        plt.errorbar(x_pos, mean_catg_index[i][i_catg,:],yerr = std_catg_index[i][i_catg,:],color = cmap[i_catg],
                 ecolor = cmap[i_catg], marker='o',markersize = msz,linewidth = lw,label=catg_names[i_catg],clip_on=False)

    # set tick params
    plt.tick_params(length = 2, width = 0.8)

    # set x axis
    ax.set_xticks(x_pos)
    ax.set_xticklabels(layer_name[i],
                                rotation = 45)
    if i == 0:
        ax.set_ylabel("Mirror-symmetric viewpoint tuning",labelpad=0.8)
        ax.legend(bbox_to_anchor=(0.01, 0.03, 1,1.1),frameon=False,prop={'size': 4}) # (0.01, 0.03, 0.15,1.1)
        
    ax.set_xlim([0,len(x_pos)-1])

    ax.set_yticks(y_pos)
    ax.set_ylim((-1, 1))
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    ax.set_title(model_names[i],pad = 0,
                      fontdict = {'fontsize': 8,
                         'fontweight': 'bold',
                         'color': [0,0,0],
                         'verticalalignment': 'center',
                         'horizontalalignment': 'center'})
    
    ax.text(-0.1,1.2, labels[i],transform=ax.transAxes, fontsize=16, fontweight='bold', va='top')

    # font and tick params
    plt.tick_params(length = 1, width = 0.8)
    mpl.rcParams['font.family'] = 'Arial'
    plt.rcParams['font.size'] = 6
    plt.rcParams['axes.linewidth'] = 0.8

plt.tight_layout()
plt.savefig(figname+'.pdf',dpi=300,bbox_inches='tight',facecolor='w',pad_inches=0)
plt.show()