In [None]:
# measuring response to test images
# author: Amir Farzmahdi
# last update: June 7th 2024
# pair arangement: 1D or 2D

In [None]:
# library imports
import os
import random
import pickle
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pyrtools as pt
import cv2
import time
import math

In [None]:
# set random seed for NumPy and Python's random module
random.seed(42)
np.random.seed(42)

In [None]:
# setting parameters

# images directory
test_path = '/home/images/test' # imagenet
# test_path = '/home/images/test_ephys_experiment' # bsd500
# test_path = '/home/images/nn2015_images' # NN2015


n_test = 10000
imageset = 'imagenet'
test_lst = os.listdir(test_path)
test_lst.sort()
test_lst = random.sample(test_lst, n_test)

# # bsd500
# n_test = 500
# imageset = 'bsd500'
# test_lst = os.listdir(test_path)
# test_lst.sort()

# # nn2015 images
# n_test = 270
# imageset = 'nn2015'
# test_lst = os.listdir(test_path)
# test_lst.sort()
# # test_lst = random.sample(test_lst, n_test)

# images setting
cs_lev = [1, 1] # 0, 1, 2, 3: level 1, 2, 3, 4
img_sz = 256 
img_half_sz = int(img_sz/2)
test_bkg = 113.0

# size
aperture_size = [25, 256]
aperture_half_size = np.divide(aperture_size,2).astype(int)
filtered_image_size = [int(img_sz/(2**cs_lev[0])), int(img_sz/(2**cs_lev[0]))] # img_size / 2^pyr_level
filter_radius = int(aperture_size[0]/(2**cs_lev[0]))

# filters parameters
# 1D
sloc = 0
dim = 1
n_theta = 15
n_row_col = 17 
max_dist_n_radius = 4 # 2
x1 = 0
y1 = np.linspace(0,0,n_row_col).astype(int)
x2 = 0
y2 = np.linspace(0, max_dist_n_radius * filter_radius, n_row_col).astype(int)
ori_2 = np.linspace(84, 0, n_theta)

# # 2D
# sloc = 40
# dim = 2
# n_theta = 9
# n_row_col = 9
# max_dist_n_radius = 3 
# x1 = np.linspace(0,0,n_row_col).astype(int)
# y1 = np.linspace(0,0,n_row_col).astype(int)
# x2 = np.linspace(-max_dist_n_radius * filter_radius, max_dist_n_radius * filter_radius, n_row_col).astype(int)
# y2 = np.linspace(-max_dist_n_radius * filter_radius, max_dist_n_radius * filter_radius, n_row_col).astype(int)
# ori_2 = np.linspace(80, 0, n_theta)

ncent = 2
nneuron = 2
nsurr = 8
n_cent_surr = 1 + nsurr
nphase = 2
nfilt = 36

# filter setting, steerable pyramids
# orientation
ori_1 = np.linspace(90,90,n_theta) 
cs_ori = np.column_stack((ori_1,ori_2))
angs = np.linspace(0, 2*math.pi, num=n_cent_surr)
angs = angs[:-1]
dist_cent_surr = int(aperture_size[0]/(2*cs_lev[0])) - 1 # Edit: 1 # distance between center and surround filter (pixels)

xv1, yv1 = np.meshgrid(x1, y1, indexing='ij')
xv2, yv2 = np.meshgrid(x2, y2, indexing='ij')

n_loc = len(xv1.flatten())
dy = [xv1.astype('float64').flatten(),xv2.astype('float64').flatten()]
dx = [yv1.astype('float64').flatten(),yv2.astype('float64').flatten()]

# name of condition
cond_name = f'nloc_{n_loc}_ntheta_{n_theta}'

# colors for groups of filters
selected_colors = [[0.9333, 0.4078, 0.6392],[0.8980, 0.7686, 0.5804]]
colors = []
for i in range(2):
    colors.append((selected_colors[i], ) * n_cent_surr)
    
# figure settings
fig_labelsize = 8
fig_ticks_width = 1 
fig_ticks_length = 0.5

In [None]:
# create windows
L1 = aperture_half_size[0]
L2 = aperture_half_size[1]
X, Y = np.meshgrid(np.linspace(0,2*L2-1,aperture_size[1]),np.linspace(0,2*L2-1,aperture_size[1]))
Zsmall = np.sqrt((X-L2)**2 + (Y-L2)**2) < L1
Zlarge = np.sqrt((X-L2)**2 + (Y-L2)**2) < L2 

In [None]:
# location of filters
locs_list =  [[] for i in range(n_loc)]
x_pos = [[] for i in range(n_loc)]
y_pos = [[] for i in range(n_loc)]
for k in range(0,n_loc):
    for j in range(0,ncent):
        locs_list[k].append([cs_lev[j],dx[j][k],dy[j][k]])
        x_pos[k].append(dx[j][k])
        y_pos[k].append(dy[j][k]) 
        for i in range(0,nsurr):
            ival = np.round(math.cos(angs[i]) * dist_cent_surr)
            jval = np.round(math.sin(angs[i]) * dist_cent_surr)
            locs_list[k].append([cs_lev[j], ival+dx[j][k], jval+dy[j][k]])
            x_pos[k].append(ival+dx[j][k])
            y_pos[k].append(jval+dy[j][k])

In [None]:
# function for extracting filter outputs
def mask_conv_image(img_sz, cs_lev, cs_ori, image, locs_list, n_cent_surr, n_loc):
    
    filt_res = [[] for i in range(n_loc)]
    filt_out_imgs = []
    pyr  = pt.pyramids.SteerablePyramidFreq(image,is_complex=True)
    for k in range(0,n_loc):
        for ii in range(0,n_theta):
            steered_coeffs, _ =  pyr.steer_coeffs([i*np.pi/180
                                    for i in cs_ori[ii,:]])
            for n in range(0,ncent):
                tmp = steered_coeffs[cs_lev[n],n]
                bands = [np.float64(np.float16(tmp.real)),
                         np.float64(np.float16(tmp.imag))] 

                filt_out_imgs.append(bands)
                center_ = np.round((bands[0].shape[0] / 2)) 
                ind = np.linspace(n_cent_surr*n,(n_cent_surr*(n+1))-1,n_cent_surr).astype(int)

                for j in range(0,2): # 2 phases
                    band = bands[j]
                    for i in range(0,n_cent_surr):
                        # x and y axis are swapped due to the 
                        # difference between indexing of image and listpoints
                        x_ = int(locs_list[k][ind[i]][2] + center_)
                        y_ = int(locs_list[k][ind[i]][1] + center_)
                
                        filt_res[k].append(band[x_,y_])
    return filt_res,filt_out_imgs

In [None]:
# applying filters to test images
test_filter_res = []
test_images = [[],[]]
test_imgs_dir = []
# compute width and heigth of resized image
for i in range(0,len(aperture_size)):
    start_time = time.time()
    tmp_test_res = [[] for ii in range(n_loc)]
    if i == 0:
        mask = Zsmall
    elif i == 1:
        mask = Zlarge
    for j , image_name in enumerate(test_lst): # set list's name
        # image path
        if i == 0:
            test_imgs_dir.append(f"{test_path}/{image_name}") 
            
        start_time = time.time()
        img = cv2.imread(test_path+'/'+image_name)
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        h,w = img.shape
        if h < img_sz:
            x = 0
            if (img_sz-h) % 2 == 0:
                h1 = int((img_sz - h)/2)
                h2 = h1
            else:
                h1 = int((img_sz - h)/2)
                h2 = h1+1
        else:
            x = 1
            h1 = int((h/2)-img_half_sz)
            h2 = int((h/2)+img_half_sz)

        if w < img_sz:
            y = 0
            if (img_sz-w) % 2 == 0:
                w1 = int((img_sz - w)/2)
                w2 = w1
            else:
                w1 = int((img_sz - w)/2)
                w2 = w1+1
        else:
            y = 1
            w1 = int((w/2)-img_half_sz)
            w2 = int((w/2)+img_half_sz)

        if x == 0 and y == 0:
            img = cv2.copyMakeBorder(img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, None, 
                 value = [test_bkg, test_bkg, test_bkg])
        if x == 0 and y == 1:
            img = cv2.copyMakeBorder(img, h1, h2, 0, 0, cv2.BORDER_CONSTANT, None, 
                 value = [test_bkg, test_bkg, test_bkg])
            img = img[:,w1:w2]
        if x == 1 and y == 0:
            img = cv2.copyMakeBorder(img, 0, 0, w1, w2, cv2.BORDER_CONSTANT, None, 
             value = [test_bkg, test_bkg, test_bkg])
            img = img[h1:h2,:]
        if x == 1 and y == 1:
            img = img[h1:h2,w1:w2]
            
        img = np.float32(img) - test_bkg
        new_img = (mask * img) + test_bkg
        new_img = new_img / 255.0
        
        test_images[i].append(new_img)
        
        tmp, band = mask_conv_image(img_sz, cs_lev, cs_ori, 
                            new_img, locs_list, n_cent_surr, n_loc)
        for k in range(0,n_loc):
            tmp_test_res[k].append(np.array(tmp[k]))
        print(str(j)+"--- %s seconds ---" % (time.time() - start_time))

    test_filter_res.append(tmp_test_res)
    
test_filter_res = np.array(test_filter_res)

In [None]:
# sorting images based on the center 1 filter response (high to low)
tmp = np.argsort(test_filter_res[0,sloc,:,0])[::-1]
sorted_img_idx = tmp[0:n_test]

In [None]:
# save filter output results
with open(f'test_res_{imageset}_dim_{dim}_level_{cs_lev[0]}_{cond_name}_date_{date}.csv', "wb") as fp:  
    pickle.dump({
                'test_filter_res': test_filter_res,
                'test_imgs_dir': test_imgs_dir,
                'filters_ori1': ori_1,
                'filters_ori2': ori_2,
                'filters_dx': dx,
                'filters_dy': dy,
                'locs_list': locs_list,
                'aperture_size': aperture_size,
                'filter_radius': filter_radius,
                'filtered_image_size': filtered_image_size,
                'dist_cent_surr': dist_cent_surr,
                'n_loc_per_row_col': n_row_col,
                'n_loc': n_loc,
                'n_theta': n_theta,
                'max_dist_n_radius': max_dist_n_radius,
                'dim': dim,
                'sorted_img_idx': sorted_img_idx},
                fp)