In [None]:
# measuring train and noise covariance matrix
# author: Amir Farzmahdi
# last update: June 11th, 2024

In [None]:
# library imports
import os
import random
import pickle
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pyrtools as pt
import cv2
import time
import seaborn as sns
import math
from scipy import ndimage
from datetime import datetime

In [None]:
# set random seed for NumPy and Python's random module
random.seed(42)
np.random.seed(42)

In [None]:
# date info
date = datetime.now().strftime("%Y_%m_%d")
print("Date: " f"{date}")

In [None]:
# setting parameters

# directory
train_path = '/home/images/val'

# images setting
cs_lev = [1, 1] # Edit: [1, 1] # 0, 1, 2, 3: level 1, 2, 3, 4
n_train = 2500 # number of train images * 4 orientations
n_noise = 10000 # number of noise images
rotate = [0,45,-45,-90] # rotate train images
img_sz = 256 
img_half_sz = int(img_sz/2)
train_bkg = 113.0 

# white noise
nmean = 0
nstd = 1
noise_bkg = train_bkg / 255.0

# train set
train_lst = os.listdir(train_path)
train_lst.sort()
train_lst = random.sample(train_lst, n_train)

# size
aperture_size = [25, 256] # Edit: 25 or 50
aperture_half_size = np.divide(aperture_size,2).astype(int)
filtered_image_size = [int(img_sz/(2**cs_lev[0])), int(img_sz/(2**cs_lev[0]))] # img_size / 2^pyr_level
filter_radius = int(aperture_size[0]/(2**cs_lev[0]))

# filters parameters
# # 1D
# sloc = 0
# dim = 1
# n_theta = 15
# n_row_col = 17
# max_dist_n_radius = 4 # Edit: 2
# x1 = 0
# y1 = np.linspace(0,0,n_row_col).astype(int)
# x2 = 0
# y2 = np.linspace(0, max_dist_n_radius * filter_radius, n_row_col).astype(int)
# ori_2 = np.linspace(84, 0, n_theta)

# 2D
sloc = 40
dim = 2
n_theta = 9
n_row_col = 9
max_dist_n_radius = 3 
x1 = np.linspace(0,0,n_row_col).astype(int)
y1 = np.linspace(0,0,n_row_col).astype(int)
x2 = np.linspace(-max_dist_n_radius * filter_radius, max_dist_n_radius * filter_radius, n_row_col).astype(int)
y2 = np.linspace(-max_dist_n_radius * filter_radius, max_dist_n_radius * filter_radius, n_row_col).astype(int)
ori_2 = np.linspace(80, 0, n_theta)

ncent = 2
nneuron = 2
nsurr = 8
n_cent_surr = 1 + nsurr
nphase = 2
nfilt = 36

# filter setting, steerable pyramids
# orientation
ori_1 = np.linspace(90,90,n_theta) 
cs_ori = np.column_stack((ori_1,ori_2))
angs = np.linspace(0, 2*math.pi, num=n_cent_surr)
angs = angs[:-1]
dist_cent_surr = int(aperture_size[0]/(2*cs_lev[0])) - 1 # Edit: 1 # distance between center and surround filter (pixels)

xv1, yv1 = np.meshgrid(x1, y1, indexing='ij')
xv2, yv2 = np.meshgrid(x2, y2, indexing='ij')

n_loc = len(xv1.flatten())
dy = [xv1.astype('float64').flatten(),xv2.astype('float64').flatten()]
dx = [yv1.astype('float64').flatten(),yv2.astype('float64').flatten()]

# name of condition
cond_name = f'nloc_{n_loc}_ntheta_{n_theta}'

# colors for groups of filters
selected_colors = [[0.9333, 0.4078, 0.6392],[0.8980, 0.7686, 0.5804]]
colors = []
for i in range(2):
    colors.append((selected_colors[i], ) * n_cent_surr)
    
# figure settings
fig_labelsize = 8
fig_ticks_width = 1 
fig_ticks_length = 0.5

In [None]:
# create windows
L1 = aperture_half_size[0]
L2 = aperture_half_size[1]
X, Y = np.meshgrid(np.linspace(0,2*L2-1,aperture_size[1]),np.linspace(0,2*L2-1,aperture_size[1]))
Zsmall = np.sqrt((X-L2)**2 + (Y-L2)**2) < L1
Zlarge = np.sqrt((X-L2)**2 + (Y-L2)**2) < L2 

In [None]:
# location of filters
locs_list =  [[] for i in range(n_loc)]
x_pos = [[] for i in range(n_loc)]
y_pos = [[] for i in range(n_loc)]
for k in range(0,n_loc):
    for j in range(0,ncent):
        locs_list[k].append([cs_lev[j],dx[j][k],dy[j][k]])
        x_pos[k].append(dx[j][k])
        y_pos[k].append(dy[j][k]) 
        for i in range(0,nsurr):
            ival = np.round(math.cos(angs[i]) * dist_cent_surr)
            jval = np.round(math.sin(angs[i]) * dist_cent_surr)
            locs_list[k].append([cs_lev[j], ival+dx[j][k], jval+dy[j][k]])
            x_pos[k].append(ival+dx[j][k])
            y_pos[k].append(jval+dy[j][k])

In [None]:
# function for extracting filter outputs
def image_convolve_mask(img_sz, cs_lev, cs_ori, image, locs_list, n_cent_surr, n_loc):
    
    filt_res = [[] for i in range(n_loc)]
    filt_out_imgs = []
    pyr  = pt.pyramids.SteerablePyramidFreq(image,is_complex=True)
    for k in range(0,n_loc):
        for ii in range(0,n_theta):
            steered_coeffs, _ =  pyr.steer_coeffs([i*np.pi/180
                                    for i in cs_ori[ii,:]])
            for n in range(0,ncent):
                tmp = steered_coeffs[cs_lev[n],n]
                bands = [np.float64(np.float16(tmp.real)),
                         np.float64(np.float16(tmp.imag))] 
                filt_out_imgs.append(bands)
                center_ = np.round((bands[0].shape[0] / 2)) 
                ind = np.linspace(n_cent_surr*n,(n_cent_surr*(n+1))-1,n_cent_surr).astype(int)

                for j in range(0,2): # 2 phases
                    band = bands[j]

                    for i in range(0,n_cent_surr):
                        # x and y axis are swapped due to the 
                        # difference between indexing of image and listpoints
                        x_ = int(locs_list[k][ind[i]][2] + center_)
                        y_ = int(locs_list[k][ind[i]][1] + center_)
                
                        filt_res[k].append(band[x_,y_])
    return filt_res,filt_out_imgs

In [None]:
# Apply filters to training images
train_filter_res = [[] for i in range(n_loc)]
train_orig_images = []
train_imgs_dir = []
for idx , image_name in enumerate(train_lst): # set list's name
    # append image dir
    train_imgs_dir.append(f"{train_path}/{image_name}")  
    start_time = time.time()
    img = cv2.imread(train_path+'/'+image_name)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    h,w = img.shape
    if h < img_sz:
        x = 0
        if (img_sz-h) % 2 == 0:
            h1 = int((img_sz - h)/2)
            h2 = h1
        else:
            h1 = int((img_sz - h)/2)
            h2 = h1+1
    else:
        x = 1
        h1 = int((h/2)-img_half_sz)
        h2 = int((h/2)+img_half_sz)
        
    if w < img_sz:
        y = 0
        if (img_sz-w) % 2 == 0:
            w1 = int((img_sz - w)/2)
            w2 = w1
        else:
            w1 = int((img_sz - w)/2)
            w2 = w1+1
    else:
        y = 1
        w1 = int((w/2)-img_half_sz)
        w2 = int((w/2)+img_half_sz)
        
    if x == 0 and y == 0:
        img = cv2.copyMakeBorder(img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, None, 
             value = [train_bkg,train_bkg,train_bkg])
    if x == 0 and y == 1:
        img = cv2.copyMakeBorder(img, h1, h2, 0, 0, cv2.BORDER_CONSTANT, None, 
             value = [train_bkg,train_bkg,train_bkg])
        img = img[:,w1:w2]
    if x == 1 and y == 0:
        img = cv2.copyMakeBorder(img, 0, 0, w1, w2, cv2.BORDER_CONSTANT, None, 
         value = [train_bkg,train_bkg,train_bkg])
        img = img[h1:h2,:]
    if x == 1 and y == 1:
        img = img[h1:h2,w1:w2]

    for rot in range(0,len(rotate)):
        img_ = ndimage.rotate(img, rotate[rot], reshape=False)
        img_ = np.float32(img_) - train_bkg
        new_img = (Zlarge * img_) + train_bkg
        new_img = new_img / 255.0
        
        if idx == 0:
            train_orig_images.append(new_img)
    
        tmp, image = image_convolve_mask(img_sz, cs_lev, cs_ori,new_img,
                                         locs_list, n_cent_surr, n_loc)
        for k in range(0,n_loc):
            train_filter_res[k].append(np.array(tmp[k]))
    print(str(idx)+"--- %s seconds ---" % (time.time() - start_time))
train_filter_res = np.array(train_filter_res)

In [None]:
def create_white_noise_image(img_size, mean=0, std=1):
    # Generate random noise with a normal distribution
    noise = np.random.normal(loc=mean, scale=std, size=(img_size, img_size))
    
    # Normalize the noise to 0-1 range
    normalized_noise = (noise - noise.min()) / (noise.max() - noise.min())
    
    return normalized_noise

In [None]:
# random white noise images
nimg = []
for i_noise in range(n_noise):
    noise = create_white_noise_image(img_sz, mean=nmean, std=nstd)
    nimg.append(np.float32(noise)) 

In [None]:
# initialize lists for storing results
noise_filter_res = [[] for _ in range(n_loc)]
noise_orig_images = []

# process each noise image
for i_noise in range(n_noise):
    start_time = time.time()
    new_nimg = (Zlarge * (nimg[i_noise] - noise_bkg) + noise_bkg)
    noise_orig_images.append(new_nimg)

    tmp, image = image_convolve_mask(img_sz, cs_lev, cs_ori, new_nimg, locs_list, n_cent_surr, n_loc)
    
    for k in range(n_loc):
        noise_filter_res[k].append(np.array(tmp[k]))

    print(f"{i_noise}--- {time.time() - start_time} seconds ---")

noise_filter_res = np.array(noise_filter_res)


In [None]:
# compute *correlation* matrices
train_corr_mat = [[] for i in range(n_loc)]
noise_corr_mat = [[] for i in range(n_loc)]
for k in range(0,n_loc): # n_loc
    for i in range(0,n_theta):
        idx = np.linspace((i*36),(i+1)*36-1,36)
        train_corr_mat[k].append(np.corrcoef(train_filter_res[k,:,idx.astype(int)]))
        noise_corr_mat[k].append(np.corrcoef(noise_filter_res[k,:,idx.astype(int)]))
        
train_corr_mat = np.asarray(train_corr_mat)
noise_corr_mat = np.asarray(noise_corr_mat)

In [None]:
# save filter output results
with open(f'cov_mat_dim_{dim}_level_{cs_lev[0]}_{cond_name}_date_{date}.csv', "wb") as fp:  
    pickle.dump({
                'train_cov_mat': train_cov_mat,
                'noise_cov_mat': noise_cov_mat,
                'train_imgs_dir': train_imgs_dir,
                'filters_ori1': ori_1,
                'filters_ori2': ori_2,
                'filters_dx': dx,
                'filters_dy': dy,
                'locs_list': locs_list,
                'aperture_size': aperture_size,
                'filter_radius': filter_radius,
                'filtered_image_size': filtered_image_size,
                'dist_cent_surr': dist_cent_surr,
                'n_loc_per_row_col': n_row_col,
                'n_loc': n_loc,
                'n_theta': n_theta,
                'max_dist_n_radius': max_dist_n_radius,
                'dim':dim},
                fp)