In [None]:
# figure 6B: measuring drsc across different rsignal and distance thresholds
# author: Amir Farzmahdi
# last update: Jul 1st, 2024

In [None]:
# library imports
import os
import random
import pickle
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pyrtools as pt
import time
import scipy as sp
import pandas as pd
from scipy.stats import pearsonr

In [None]:
# set random seed
np.random.seed(42)
random.seed(42)

In [None]:
# settings

# directory 
file_path = 'path to csv files'

ncov_type = 'overlapped' 
dc_offset = 2 # 0 or 2

# fixed
ntest = 500 
ntest_rsignal = 100 
tr_cov_scale = 1
noise_cov_scale_single = 1
noise_cov_scale_pair_shared = 0.01
noise_cov_scale_pair_ind = 1
alpha = 1

# filters parameters
n_loc = 17
n_theta = 15
nsize = 2

drsc_model = np.zeros((nsize, n_loc, n_theta))

In [None]:
# load log likelihood of natural images under shared versus independent GSM
with open(f'p_diff_nat_images_{n_loc}_locs_{n_theta}_oris.csv', "rb") as fp:   # Unpickling
    res = pickle.load(fp)
    
p_diff = res['p_diff']
p_diff_binary = np.where(p_diff > 0, 1, 0) # convert to binary matrix

In [None]:
# load gsm samples
def load_gsm_samples(file_path, model_type, noise_cov_scale_pair, nsize, n_loc, n_theta, ntest, ncov_type, tr_cov_scale, noise_cov_scale_single, dc_offset, alpha):
    rsignal = np.zeros((nsize, n_loc, n_theta))
    rsc_images = np.zeros((nsize, n_loc, n_theta, ntest))
    
    for i_loc in range(n_loc):

        file_name = f'{file_path}{model_type}_gsm_{ncov_type}/{model_type}_gsm_level_1_bsd500_ncov_{ncov_type}_tr_{tr_cov_scale}_nsdp_{noise_cov_scale_single}_nsindp_{noise_cov_scale_pair}_nloc_{i_loc}_of_{n_loc}.csv'

        with open(file_name, "rb") as fp:   
            gsm_file = pickle.load(fp)
        
        for i_theta in range(n_theta):
            for i_size in range(2): # two small and two large model neuron responses
                avg_spike_count1 = []
                avg_spike_count2 = []
                for i_image in range(ntest):# ntest
                    gc11 = gsm_file['gs'][i_size][i_theta][:,i_image,0] + dc_offset # phase 1
                    gc12 = gsm_file['gs'][i_size][i_theta][:,i_image,9] + dc_offset # phase 2

                    gc21 = gsm_file['gs'][i_size + 2][i_theta][:,i_image,0] + dc_offset  # phase 1
                    gc22 = gsm_file['gs'][i_size + 2][i_theta][:,i_image,9] + dc_offset # phase 2

                    # non-linearity- method 1: rectifier
                    gc11 = gc11 * (gc11 > 0)
                    gc12 = gc12 * (gc12 > 0)
                    gc21 = gc21 * (gc21 > 0)
                    gc22 = gc22 * (gc22 > 0)

                    spike_count1 = alpha * (gc11 + gc12)
                    spike_count2 = alpha * (gc21 + gc22)

                    # measure noise correlation
                    rsc_stat, _ = pearsonr(spike_count1, spike_count2)
                    rsc_images[i_size, i_loc, i_theta, i_image] = rsc_stat

                    avg_spike_count1.append(spike_count1.mean())
                    avg_spike_count2.append(spike_count2.mean())

                rsignal_stat, _ = pearsonr(avg_spike_count1[:ntest_rsignal], avg_spike_count2[:ntest_rsignal])
                rsignal[i_size, i_loc, i_theta] = rsignal_stat
            
            print(f'loc{i_loc}-theta{i_theta}  mean:{np.nanmean(rsc_images[1, i_loc, i_theta, :])}   median:{np.nanmedian(rsc_images[1, i_loc, i_theta, :])}')
        print('')
            
    rsc = np.nanmean(rsc_images, axis=(3))
        
    return rsignal, rsc

In [None]:
# load shared GSM samples
model_type = 'shared'
shared_rsignal, shared_rsc = load_gsm_samples(file_path, model_type, noise_cov_scale_pair_shared, nsize, n_loc, n_theta, ntest, ncov_type, tr_cov_scale, noise_cov_scale_single, dc_offset, alpha)

In [None]:
# load ind GSM samples
model_type = 'ind'
ind_rsignal, ind_rsc = load_gsm_samples(file_path, model_type, noise_cov_scale_pair_ind, nsize, n_loc, n_theta, ntest, ncov_type, tr_cov_scale, noise_cov_scale_single, dc_offset, alpha)

In [None]:
# measure drsc
drsc_model = [shared_rsc[0] - shared_rsc[1], ind_rsc[0] - ind_rsc[1]]

In [None]:
# Create 15 intervals for rsignal from 1 to -1
vector = np.linspace(1, -1, 17)
print(vector)

In [None]:
def find_interval_index(value, boundaries):
    """
    Find the interval index for a given value within sorted descending interval boundaries.
    
    Parameters:
    - boundaries: List of sorted numbers defining the interval boundaries in descending order.
    - value: The float value to classify into an interval.
    
    Returns:
    - The index of the interval in which 'value' falls, or -1 if the value is out of bounds.
    """

    if value > boundaries[0] or value < boundaries[-1]:
        return -1  # Value is out of the defined boundaries

    for i in range(len(boundaries) - 1):
        if boundaries[i] >= value > boundaries[i + 1]:
            return i

    # Check if value is exactly equal to the last boundary
    if value == boundaries[-1]:
        return len(boundaries) - 2
    return -1

In [None]:
# Initialize arrays
i_size = 1  # large images

rsc_sorted_r_d_shared = np.zeros((n_loc, n_theta))
rsc_sorted_r_d_shared_idx = np.zeros((n_loc, n_theta))
rsc_sorted_r_d_ind = np.zeros((n_loc, n_theta))
rsc_sorted_r_d_ind_idx = np.zeros((n_loc, n_theta))
imgs_stats_sorted_r = np.zeros((n_loc, n_theta))
imgs_stats_sorted_r_idx = np.zeros((n_loc, n_theta))

def update_rsc_arrays(rsignal, drsc_model, rsc_array, rsc_idx_array, i_loc, i_theta, vector):
    index = find_interval_index(rsignal, vector)
    print(f'value:{rsignal}, index:{index}')
    rsc_array[i_loc, index] += drsc_model[i_loc, i_theta]
    rsc_idx_array[i_loc, index] += 1

for i_loc in range(n_loc):
    for i_theta in range(n_theta):
        # Update shared arrays
        update_rsc_arrays(shared_rsignal[i_size, i_loc, i_theta], drsc_model[0], rsc_sorted_r_d_shared, rsc_sorted_r_d_shared_idx, i_loc, i_theta, vector)
        
        # Update independent arrays
        update_rsc_arrays(ind_rsignal[i_size, i_loc, i_theta], drsc_model[1], rsc_sorted_r_d_ind, rsc_sorted_r_d_ind_idx, i_loc, i_theta, vector)
        
        # Update image statistics arrays
        if p_diff_binary[i_loc, i_theta] == 1:
            index = find_interval_index(shared_rsignal[i_size, i_loc, i_theta], vector)
        elif p_diff_binary[i_loc, i_theta] == 0:
            index = find_interval_index(ind_rsignal[i_size, i_loc, i_theta], vector)
            
        imgs_stats_sorted_r[i_loc, index] += p_diff[i_loc, i_theta]
        imgs_stats_sorted_r_idx[i_loc, index] += 1

# Safely divide arrays, replacing zero denominators with 1 to avoid division by zero
def safe_divide(numerator, denominator):
    return np.divide(numerator, denominator, where=denominator != 0, out=np.zeros_like(numerator))

rsc_sorted_r_d_shared_val = safe_divide(rsc_sorted_r_d_shared, rsc_sorted_r_d_shared_idx)
rsc_sorted_r_d_ind_val = safe_divide(rsc_sorted_r_d_ind, rsc_sorted_r_d_ind_idx)
imgs_stats_sorted_r_val = safe_divide(imgs_stats_sorted_r, imgs_stats_sorted_r_idx)

In [None]:
# drsc model selected based on image statistics
imgs_stats_binary_shared = np.where(imgs_stats_sorted_r_val > 0, 1, 0)
imgs_stats_binary_ind = np.where(imgs_stats_sorted_r_val < 0, 1, 0)

rsc_model_ind = rsc_sorted_r_d_ind_val * imgs_stats_binary_ind
rsc_model_shared = rsc_sorted_r_d_shared_val * imgs_stats_binary_shared

drsc_model_selected = np.nan_to_num(rsc_model_shared, nan=0) + np.nan_to_num(rsc_model_ind, nan=0)

In [None]:
# save model drsc_model and imgs statistics sorted by rsignal
with open(f'drsc_model_imgs_stats_sorted_{n_loc}_locs_{n_theta}_oris_ncov_{ncov_type}_offset_{dc_offset}.csv', "wb") as fp:  
    pickle.dump(dict(drsc_model = drsc_model_selected,
                     imgs_stats_sorted_by_r = imgs_stats_sorted_r_val,
                     rsc_sorted_r_d_ind_val = rsc_sorted_r_d_ind_val,
                     rsc_sorted_r_d_shared_val = rsc_sorted_r_d_shared_val,
                     imgs_stats_binary_ind = imgs_stats_binary_ind,
                     imgs_stats_binary_shared = imgs_stats_binary_shared
                     ),fp)