In [None]:
PARAMS = {'min_contrast': 0.25,  
          't_bin': 0.02, 'pre_stim': 0.5, 'post_stim': 1.0, 'min_time': 0,
            'filter_regions':  ['VISp', 'VISpm', 'VISam', 'VISa', 'VISrl', 'VISal', 'VISli', 'VISl'], 'only_good_clusters': True,
              'probabilityLeft_filter': [0.5], 
                'contrast_stim_filter': [0, 1]}

import numpy as np
import pandas as pd
import re
from functions import firingRate_OnClusters, get_spikes, get_channels, get_behavior

min_contrast = PARAMS.get('min_contrast', 0.25)
t_bin = PARAMS.get('t_bin', 0.02)
pre_stim = PARAMS.get('pre_stim', 0.5)
post_stim = PARAMS.get('post_stim', 1.0)
min_time = PARAMS.get('min_time', 0) * 1000 # Convert to milliseconds
filter_regions = PARAMS.get('filter_regions', None)
only_good_clusters = PARAMS.get('only_good_clusters', True)
contrast_filter = PARAMS.get('contrast_stim_filter', [0, 0.25, 1])
######################################################
# Extract trial information                                
######################################################
# Load passive Gabor data
passiveGabor = one.load_object(eid, 'passiveGabor')
passiveGabor = pd.DataFrame(passiveGabor)

# filter based on contrast  
valid_trials= np.where(np.isin(passiveGabor['contrast'], contrast_filter))[0]

behavior = passiveGabor.loc[valid_trials].reset_index(drop=True)

# trial metadata
trial_indx = valid_trials 
trial_onsets = behavior['start'].values
contrasts = behavior['contrast'].values
phases = behavior['phase'].values
positions = behavior['position'].values
# Assign labels: 1 for right (+35), 0 for left (-35)
labels =  np.where(contrasts == 0, 0, np.where(positions == 35, 1, np.where(positions == -35, -1, 0)))

######################################################
# Load spikes data                        
######################################################

spike_activity = get_spikes(pid, modee='download')
channels = get_channels(eid, pid, modee='download')
clusters = spike_activity['clusters']
spikes = spike_activity['spikes']
channels_clusters = clusters['channels'] # each cluster is assigned to which channel 
spike_times, spike_clusters = spikes['times'], spikes['clusters'] # Get the spike times and clusters

# remove nan clusters
kp_idx = np.where(~np.isnan(spike_clusters))[0]
spike_times, spike_clusters = spike_times[kp_idx], spike_clusters[kp_idx]

# Filter out Bad clusters 
if only_good_clusters:
    metrics = clusters['metrics'].reset_index(drop=True)
    good_clusters = np.where(metrics['ks2_label'] == 'good')[0]

else:
    good_clusters= np.unique(spike_clusters)

# filter cluster based on regions
if filter_regions:
    # Filter channels based on brain regions
    # clusters = np.unique(channels_clusters)
    index_channel = [i for i, acronym in enumerate(channels['acronym']) for region in filter_regions if re.match(rf'^{region}[12456]', acronym) and i in channels_clusters]
    region_clusters = np.where(np.isin(channels_clusters, index_channel))[0]

else:
    index_channel = np.unique(channels_clusters)
    region_clusters = np.unique(spike_clusters)

selected_clusters = np.intersect1d(good_clusters, region_clusters)
keep_indices = np.where(np.isin(spike_clusters, selected_clusters))[0]
spike_clusters = spike_clusters[keep_indices]
spike_times = spike_times[keep_indices]
channels_clusters = channels_clusters[selected_clusters]


# z_score_firing_rate.shape = (n_trials, n_clusters, n_time_bins) | times.shape = (n_time_bins,) in milliseconds | clusters.shape = (n_clusters,)
z_score_firing_rate, times, clusters = firingRate_OnClusters(trial_onsets, spike_times, spike_clusters, t_bin=t_bin, pre_stim=pre_stim, post_stim=post_stim)

# Filter time indices to include only those after time 0
time_indices = np.where(times >= min_time)[0]
times = times[time_indices]
z_score_firing_rate = z_score_firing_rate[:, :, time_indices]

# save Firing rates for each channel                     
FR_channel = {}
nan_channels = []
for i, ch in enumerate(index_channel):
    indx_cluster = np.where(channels_clusters == ch)[0]
    if len(indx_cluster) > 0:
        FR_channel[ch] = z_score_firing_rate[:, indx_cluster, :]


#############################
# Extract channel metadata
##############################

ids, acronyms, depths, ch_indexs, coordinates = [], [], [], [], []
for ch in index_channel:
    channel_info = channels.loc[ch]
    ids.append(channel_info['atlas_id'])
    acronyms.append(channel_info['acronym'])
    ch_indexs.append(ch)
    coordinates.append(channel_info[['x', 'y', 'z']])
    depths.append(channel_info['axial_um'])
ids, acronyms, ch_indexs, coordinates = np.array(ids), np.array(acronyms), np.array(ch_indexs), np.array(coordinates)


# save channel metadata and trial metadata into dataframes
channel_info = pd.DataFrame({'depth': depths, 'ids': ids, 'acronyms': acronyms, 'ch_indexs': ch_indexs, 'x_coordinates': coordinates[:, 0], 'y_coordinates': coordinates[:, 1], 'z_coordinates': coordinates[:, 2]})
trial_info = pd.DataFrame({'trial_index': trial_indx, 'labels': labels, 'assigned_side': np.full(len(trial_indx), np.nan), 'contrasts': contrasts, 'distance_to_change': np.full(len(trial_indx), np.nan), 'prob_left': np.full(len(trial_indx), np.nan), 'probe_id': np.repeat(pid, trial_indx.size), 'experiment_id': np.repeat(eid, trial_indx.size), 'phase': phases})

# final preprocessed data
pre_processed_data = {'firing_rates': FR_channel, 'trial_info': trial_info, 'channel_info': channel_info, 'time_bins': times}

In [None]:
from functions import pre_processed_active_data , pre_processed_passive_data
PARAMS = {'min_contrast': 0.25,  
          't_bin': 0.02, 'pre_stim': 0.5, 'post_stim': 1.0, 'min_time': 0,
            'filter_regions':  ['VISp', 'VISpm', 'VISam', 'VISa', 'VISrl', 'VISal', 'VISli', 'VISl'], 'only_good_clusters': True,
              'probabilityLeft_filter': [0.5], 
                'contrast_stim_filter': [0, 1]}

pid = 'a8a59fc3-a658-4db4-b5e8-09f1e4df03fd'
eid = '5ae68c54-2897-4d3a-8120-426150704385'
pre_processed_data_active = pre_processed_active_data(eid, pid, **PARAMS)
# pre_processed_data_passive = pre_processed_passive_data(eid, pid, **PARAMS)

In [None]:
FR_channel = pre_processed_data_passive['firing_rates']
channel_info = pre_processed_data_passive['channel_info']
trial_info = pre_processed_data_passive['trial_info']
print(len(FR_channel.keys()))
[print(f'{key}{len(channel_info[key].values)}') for key in channel_info.keys()]
print((FR_channel[324].shape))
[print(f'{key}{len(trial_info[key].values)}') for key in trial_info.keys()]
print('right')
print(len([i for i in trial_info['labels'].values if i == 1]))
print('left')
print(len([i for i in trial_info['labels'].values if i == -1]))
print('nostim')
print(len([i for i in trial_info['labels'].values if i == 0]))

In [None]:
FR_channel = pre_processed_data_active['firing_rates']
channel_info = pre_processed_data_active['channel_info']
trial_info = pre_processed_data_active['trial_info']
print(len(FR_channel.keys()))
[print(f'{key}{len(channel_info[key].values)}') for key in channel_info.keys()]
print((FR_channel[324].shape))
[print(f'{key}{len(trial_info[key].values)}') for key in trial_info.keys()]
print('right')
print(len([i for i in trial_info['labels'].values if i == 1]))
print('left')
print(len([i for i in trial_info['labels'].values if i == -1]))
print('nostim')
print(len([i for i in trial_info['labels'].values if i == 0]))