In [None]:
# figure2B: plot likelihood for shared and independent covariance matrix along with filter arrangements
# author: Amir Farzmahdi
# last update: June 6th, 2024

In [None]:
# required packages
import os
import pickle
import math
import random

import numpy as np
import scipy as sp
import cv2

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from matplotlib.colors import TwoSlopeNorm

In [None]:
# set random seed for Python's random module
random.seed(42)

In [None]:
# settings

# filter parameters
n_row_col = 9 

nsurr = 8
ncent = 2

n_loc = 81
n_theta = 9
cs_lev = [1, 1]  # Coarsest level (assuming this represents pyramid levels)
ndim = 10000  # Number of test images

# aperture settings
aperture_size = [50, 256]
aperture_half_size = np.divide(aperture_size, 2).astype(int)
background_value = 255

selected_ori = [0, 4, 8]

# colors for groups of filters
colors = [[0.93, 0.4, 0.64], [0.9, 0.76, 0.58]] 

# font parameters
matplotlib.rcParams['font.family'] = 'Arial'
matplotlib.rcParams['font.size'] = 8

In [None]:
with open('locs_list.csv', "rb") as fp:
    locs_list = pickle.load(fp)  
locs_list = locs_list['locs_list']

with open('p_shared_ind_nat_test_images.csv', "rb") as fp:
    p_diff = pickle.load(fp)  
p_data = p_diff['p_diff']

In [None]:
# read the image and convert to grayscale
img = cv2.imread('284.jpg')
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

# crop the image
img_cropped = img_gray[42:298, 122:378] 

# define window sizes
window_small_radius = aperture_size[0]  # Radius of the small window
window_large_radius = aperture_size[1] - 5  # Radius of the large window

# Create meshgrid for window creation
X, Y = np.meshgrid(np.linspace(0, 2*window_large_radius-1, aperture_size[1]), 
                   np.linspace(0, 2*window_large_radius-1, aperture_size[1]))

# calculate distances from the center
distances_from_center_small = np.sqrt((X - window_large_radius)**2 + (Y - window_large_radius)**2)
distances_from_center_large = np.sqrt((X - window_large_radius)**2 + (Y - window_large_radius)**2)

# create circular masks for small and large windows
small_window_mask = distances_from_center_small < window_small_radius
large_window_mask = distances_from_center_large < window_large_radius

# normalize the image and apply the masks to create small and large windows
img_normalized = (np.float32(img_cropped) - background_value) / 255.0 
small_img = (small_window_mask * img_normalized) + (background_value / 255.0) 
large_img = (large_window_mask * img_normalized) + (background_value / 255.0) 

In [None]:
s = np.linspace(0, nsurr, nsurr + 1)
c = np.linspace(0, ncent - 1, ncent)
y_, x_ = np.meshgrid(s, c)
y_ = np.matrix.flatten(y_)
x_ = np.matrix.flatten(x_)

nrow = 1
ncol = 4
fig  = plt.figure(figsize=(7, 1.5))
gs_main = gridspec.GridSpec(nrow, ncol, figure=fig, wspace=0.25, width_ratios=[1,1,1,1])

cbar_ax = fig.add_axes([0.92, .12, .01, .75])

x = np.linspace(0,35,36)
_xticks = np.linspace(0,n_row_col-1,3) + 0.5 
_yticks = np.linspace(0,n_row_col-1,3) + 0.5 

vmin = np.min(p_data)
vmax = -np.min(p_data)

# Left panel
ax = fig.add_subplot(gs_main[0,0])

num = -1
for i_row in range(n_row_col):
    for i_col in range(n_row_col):
        num = num + 1
        if num == 40:
            i = 0
        else:
            i = 1
        circle = plt.Circle((int(len(large_img)/2)+locs_list[num][9][1]*2**cs_lev[0],
                            int(len(large_img)/2)+locs_list[num][9][2]*2**cs_lev[0]), 
                            aperture_half_size[0], ls='-', lw = 0.4, 
                            color=colors[i], fill=False,alpha=1)
        ax.add_patch(circle)
        
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['right'].set_visible(False)

ax.imshow(large_img, cmap='gray')
        
# Right panel
for i, ori in enumerate(selected_ori):
    ax = fig.add_subplot(gs_main[0,i+1])
    idx = np.linspace((i*36),(i+1)*36-1,36)
    norm = TwoSlopeNorm(vcenter=0, vmin=-0.02, vmax=22)

    sns.heatmap(np.reshape(p_data[:,ori],(n_row_col,n_row_col)),ax= ax,annot=False,
                linewidths=0.25,cmap='PRGn',annot_kws={"fontsize":3},center=0,vmin=vmin,
               vmax=vmax,cbar=i == 0, cbar_ax=None if i else cbar_ax, cbar_kws={'label':r'$\mathregular{log(p_{shared})-log(p_{independent})}$'})
    cbar_ax.axes.tick_params(labelsize=8, width = 0.5, length = 0.4,pad=2)
    cbar_ax.axes.yaxis.label.set_fontsize(8)
    cbar_ax.set_yticks([-0.02, 0, 0.02])
    cbar_ax.set_yticklabels(['-0.02', '0', '>0.02'])
    cbar_ax.yaxis.set_label_coords(5,0.5)

    if i == 0:
        ax.set_xlabel('dx', fontsize=8, labelpad=1) # \u0394d
        ax.set_xticks(_xticks)
        ax.set_xticklabels(['-3RF', '0', '+3RF'])

        ax.set_ylabel('dy',fontsize=8, labelpad=1) # \u0394d
        ax.set_yticks(_yticks)
        ax.set_yticklabels(['+3RF', '0', '-3RF'])
    else:
        ax.set_xticks([])
        ax.set_yticks([])

    ax.tick_params(axis='both', labelsize=8, width = 0.5, length = 0.4, pad=2)
    
plt.savefig('figure2B.pdf', bbox_inches='tight',  dpi = 300)