In [None]:
# figure S8: function for measuring symmetry plane
# useful packages:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import scipy
import scipy.io as sio
from scipy.stats import pearsonr
import warnings
warnings.filterwarnings("ignore")
import matplotlib.gridspec as gridspec
import torchlens as tl
%matplotlib inline
import matplotlib.patheffects as pe

In [None]:
# settings
num_imgs = 2025
ncatg = 9
nview = 9
nexemplar = 25
catg_names = ['car','boat','face','chair','airplane','tool','animal','fruit','flower']

# line colors
cmap = np.divide([[166,206,227],
        [31,120,180],
        [178,223,138],
        [51,160,44],
        [253,191,111],
        [255,127,0],
        [202,178,214],
        [251,154,153],
        [227,26,2]],255)

figname = 'figureS8'

In [None]:
# our stimulus set
imgs_mat = scipy.io.loadmat('imgs.mat')
imgs = list(imgs_mat.values())
imgs = imgs[3].T
imgs.shape

In [None]:
# Function to compute **horizontal** symmetry score for a given x-coordinate
def compute_symmetry_score(img, y):
    img = np.pad(img,((0, 0), (img.shape[0]//2, img.shape[0]//2)),mode='edge')
    left = img[:, y-img.shape[0]//2:y]
    right = np.fliplr(img[:, y:y+img.shape[0]//2])
    corr, _ = pearsonr(left.flatten(),right.flatten())
    return corr

In [None]:
# measure symmetry reflection plane
best_score = np.empty((ncatg,nexemplar,nview))
best_score_idx = np.empty((ncatg,nexemplar,nview))
for i_catg in range(0,ncatg):
    for i_exemplar in range(0,nexemplar):
        idx = np.arange(i_exemplar, (nview * nexemplar), nexemplar)
        score = []
        for i_view in range(0,len(idx)):
            image = imgs[i_catg][idx[i_view]]
            # Compute symmetry scores for all possible x or y-coordinates
            score = [compute_symmetry_score(image,y) for y in range(len(image) // 2,len(image) + (len(image) // 2))] 
            best_score_idx[i_catg,i_exemplar,i_view] = np.nanargmax(score)
            best_score[i_catg,i_exemplar,i_view]  = np.nanmax(score)
        print(f'category {i_catg+1}, exemplar {i_exemplar}')

In [None]:
# remove inf values
best_score[np.isinf(best_score)] = np.nan
best_mscore = np.nanmean(best_score,1)

In [None]:
catg_mean = np.mean(best_mscore,1)
catg_std = np.std(best_score,1)

In [None]:
# threshold for selecting symmetry plane
i_catg = 8
best_mscore[i_catg] > catg_mean[i_catg] + 0.5*catg_std[i_catg]

In [None]:
# plot the figure
fig = plt.figure(figsize=(4,4))
ncols = 3
nrows = 3
gs = gridspec.GridSpec(nrows, ncols, left=0.02, bottom=0.02, right=0.98, top=0.98, wspace=0.15, hspace=-.4)

angles = np.linspace(0, np.pi, nview)  # This will give values for each degree
for i, ax_ in enumerate(gs):
    ax = fig.add_subplot(ax_,projection='polar',zorder=2)

    ax.plot(angles, best_score[i].T,markerfacecolor=cmap[i],color=cmap[i], marker='o',linewidth=0.75,markersize=3, alpha=0.6, markeredgecolor='w',markeredgewidth=0.1,zorder=1)
    ax.plot(angles, best_mscore[i],linewidth=1.5,color=cmap[i],path_effects=[pe.Stroke(linewidth=2, foreground='k'), pe.Normal()],zorder=3)

    ax.set_thetamin(0)
    ax.set_thetamax(180)
    ax.set_yticks([0.25, 0.5, 0.75, 1])  # Less radial ticks
    ax.set_rmax(1)
    ax.set_rmin(0)
    ax.set_xticks(np.deg2rad([0, 22.5, 45, 67.5, 90, 112.5, 135, 157.5, 180]))
    ax.set_title(catg_names[i],pad = -10,
                      fontdict = {'fontsize': 8,
                         'fontweight': 'bold',
                         'color': [0,0,0],
                         'verticalalignment': 'center',
                         'horizontalalignment': 'center'})
    
    ax.tick_params(labelsize=6)

    if i != 0:
        ax.set_xticklabels([])
        ax.set_yticklabels([])
    else:
        ax.set_xticklabels([90, 67.5, 45, 22.5, 0, -22.5, -45, -67.5, -90])
        # adjust the gap between axes and xticklabels
        for label in ax.get_xticklabels():
            label.set_horizontalalignment('center')
            label.set_verticalalignment('center')
            label.set_position((0,0.19))

    locs = ax.get_yticks()
    rlabels = [f'{loc:g}' for loc in ([0.25, 0.5, 0.75, 1])]
    rlabels[0] = ''
    rlabels[2] = ''
    ax.set_yticklabels(rlabels)

# font and tick params
plt.tick_params(length = 1, width = 0.8)
mpl.rcParams['font.family'] = 'Arial'
plt.rcParams['font.size'] = 6
plt.rcParams['axes.linewidth'] = 0.8

plt.tight_layout()
plt.show()