In [1]:
from pathlib import Path
import re
import pandas as pd

In [2]:
#determine mice manually
mice = ['bt0410', 'bt1238']
mouse_dirs = [Path(f'/adata/electro/{mouse}/') for mouse in mice]
paths = [path for mouse_dir in mouse_dirs for path in mouse_dir.iterdir() if (path.is_dir() and re.search('0105$', str(path)))]
clustered_paths = [path for path in paths if Path(f"{path}/cluster_group.tsv").exists()]
sessions = [str(path).split("/")[-1] for path in paths]; clustered_sessions = [str(path).split("/")[-1] for path in clustered_paths]
print(f"Sessions:{sessions},\n Number of sessions:{len(sessions)}\nClustered sessions:{len(clustered_sessions)}")

Sessions:['bt0410-22092022-0105', 'bt0410-11102022-0105', 'bt0410-29092022-0105', 'bt0410-23102022-0105', 'bt0410-19102022-0105', 'bt0410-24092022-0105', 'bt0410-08102022-0105', 'bt0410-18102022-0105', 'bt0410-27092022-0105', 'bt0410-06102022-0105', 'bt0410-16102022-0105', 'bt0410-10102022-0105', 'bt0410-14102022-0105', 'bt0410-21102022-0105', 'bt0410-04102022-0105', 'bt0410-12102022-0105', 'bt0410-25102022-0105', 'bt1238-25102022-0105', 'bt1238-23102022-0105'],
 Number of sessions:19
Clustered sessions:19


In [1]:
def get_firing_properties(path, projects_dir='/home/rowena/results', overwrite=False, overwrite_all=False):
    %load_ext autoreload
    %autoreload
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from spikeA.Session import Kilosort_session
    from spikeA.Spike_train_loader import Spike_train_loader
    from spikeA.Cell_group import Cell_group
    from spikeA.Animal_pose import Animal_pose
    from spikeA.Session import Session
    from spikeA.Dat_file_reader import Dat_file_reader
    from spikeA.Intervals import Intervals
    from spikeA.Spike_train import Spike_train
    #from spikeA.Spike_waveform import Spike_waveform
    from tqdm import tqdm
    import os.path
    import matplotlib.gridspec as gridspec
    from matplotlib.backends.backend_pdf import PdfPages
    import spikeA.spatial_properties
    import itertools
    from scipy.stats import pearsonr
    from datetime import datetime
    
    # load session
    name = str(path).split("/")[-1]
    ses = Kilosort_session(name=name,path=path)
    ses.load_parameters_from_files()
    mouse_name=name.split('-')[0]
    date_name=name.split('-')[1]
    if not os.path.exists(f"{projects_dir}/summary_scores/"):
        os.mkdir(f"{projects_dir}/summary_scores/")

    if "sqr100" in ses.desen:
        return name
    else:
        score_file = f'{projects_dir}/summary_scores/{ses.subject}/{name}.csv'

    print(score_file)
    if not os.path.exists(score_file) or overwrite_all==True or overwrite==True:
        print('entering analysis')
        stl = Spike_train_loader()
        stl.load_spike_train_kilosort(ses)
        cg = Cell_group(stl)
        ap = Animal_pose(ses)
        if os.path.exists(f"{str(path)}/{mouse_name}-{date_name}_01.positrack"):
            pose_extension="positrack"
        elif os.path.exists(f"{path}/{mouse_name}-{date_name}_01.trk"):
            pose_extension="trk"
        elif os.path.exists(f"{path}/{mouse_name}-{date_name}_01.positrack2"):
            pose_extension="positrack2"
        else:
            print("Position tracking extension unknown")
        if not os.path.exists(f"{ses.fileBase}.pose.npy"):
            print(f"Pose file for {ses.name} not yet created")
        else:
            ap.load_pose_from_file()
#
            # get time intervals for single trials
            sqr70_indices = [i for i, j in enumerate(ses.desen) if j == 'sqr70']
            circ80_indices = [i for i, j in enumerate(ses.desen) if j == 'circ80']
#
            if ses.desen[0] == 'sqr70':
                first_trial_inter = ses.trial_intervals.inter[sqr70_indices[0:1],:] # first sqr70
                last_trial_inter = ses.trial_intervals.inter[sqr70_indices[-1:],:] # last sqr70
                middle_trial_inter = ses.trial_intervals.inter[circ80_indices[0:1],:] # first circ80
            else:
                first_trial_inter = ses.trial_intervals.inter[circ80_indices[0:1],:] # first circ
                last_trial_inter = ses.trial_intervals.inter[circ80_indices[-1:],:] # last circ
                middle_trial_inter = ses.trial_intervals.inter[sqr70_indices[0:1],:] # first sqr
            trials = [first_trial_inter,middle_trial_inter,last_trial_inter]
                    
            if not os.path.exists(f'/{projects_dir}/summary_scores/{ses.subject}/'):
                os.mkdir(f'/{projects_dir}/summary_scores/{ses.subject}/')
                  
            # session properties
            session = np.repeat(name, len(cg.neuron_list))
            mouse = np.repeat(ses.subject, len(cg.neuron_list))
            date = np.repeat(ses.session_dat_time, len(cg.neuron_list))
            environment = np.tile([str(max(ses.desen, key=len))]*len(cg.neuron_list), (len(ses.desen),1))
            environment_name = ['']*len(ses.desen)
            for (i,j) in enumerate(ses.desen):
                environment[i][:] = np.repeat([ses.desen[i]],len(cg.neuron_list))
                environment_name[i] = f"environment{i+1}"
            arena=[None]*len(ses.desen)
            for n,en in enumerate(ses.desen):
                if re.search('sqr',str(ses.desen[n])):
                     arena[n] = 'square'
                elif re.search('circ',str(ses.desen[n])):
                    arena[n] = 'circle'
                else:
                    arena[n] = ses.desen[n]
#
            df_path = Path(f"/ext_drives/d64/data/projects/alzheimer_in_vivo/mice_table.csv")
            df=pd.read_csv(df_path)
            geno=df.loc[df.loc[:,'Mouse']==mouse[0], 'Genotype']
            genotype = np.repeat(geno.to_string(index=False), len(cg.neuron_list))
            # calculate age of mouse
            session_date=datetime.strptime(date_name, "%d%m%Y")
            birthdate=df.loc[df.loc[:,'Mouse']==mouse_name, 'Birthdate']
            birthdate=birthdate.to_string(index=False) 
            birthdate = datetime.strptime(str(birthdate), "%Y-%m-%d")
            delta = session_date - birthdate
            age_days = np.repeat(delta.days, len(cg.neuron_list))
#
#
            # loop over cells to get firing properties
            clu = [n.name for i,n in enumerate(cg.neuron_list)]
            cell_number = [i for i in range(0,len(cg.neuron_list))]
            refractory_period_ratio = [n.spike_train.refractory_period_ratio() for i,n in enumerate(cg.neuron_list)]
#
#
            mean_firing_rate = np.tile([np.zeros(len(cg.neuron_list))], (len(trials),1))
            mean_firing_rate_name = [''] * len(trials)
            max_firing_rate = np.tile([np.zeros(len(cg.neuron_list))], (len(trials),1))
            max_firing_rate_name = [''] * len(trials)
            
            grid_score = np.tile([np.zeros(len(cg.neuron_list))], (len(trials),1))
            grid_score_name = [''] * len(trials)
#
            info_score = np.tile([np.zeros(len(cg.neuron_list))], (len(trials),1))
            info_score_name = [''] * len(trials)
      
            mean_field_size = np.tile([np.zeros(len(cg.neuron_list))], (len(trials),1))
            mean_field_size_name = [''] * len(trials)
            
            grid_spacing = np.tile([np.zeros(len(cg.neuron_list))], (len(trials),1))
            grid_spacing_name = [''] * len(trials)
            
            grid_orientation = np.tile([np.zeros(len(cg.neuron_list))], (len(trials),1))
            grid_orientation_name = [''] * len(trials)
            
            grid_hexagon_error = np.tile([np.zeros(len(cg.neuron_list))], (len(trials),1))
            grid_hexagon_error_name = [''] * len(trials)
            
            #these properties take a long time to calculate, so read them from an existing score file if possible and if not overwrite_all==True
            if os.path.exists(score_file) and not overwrite_all==True:
                scores = pd.read_csv(score_file)
                info_score_threshold = [scores.loc[:,f"info_score_threshold{2*j+1}"] for j,trial in enumerate(trials)]
                info_score_threshold_name = [f"info_score_threshold{2*j+1}" for j,trial in enumerate(trials)]
                grid_score_threshold = [scores.loc[:,f"grid_score_threshold{2*j+1}"] for j,trial in enumerate(trials)]
                grid_score_threshold_name = [f"grid_score_threshold{2*j+1}" for j,trial in enumerate(trials)] 
            else:
                info_score_threshold = np.tile([np.zeros(len(cg.neuron_list))], (len(trials),1))
                info_score_threshold_name = [''] * len(trials)
                grid_score_threshold = np.tile([np.zeros(len(cg.neuron_list))], (len(trials),1))
                grid_score_threshold_name = [''] * len(trials)
#
#
            for i, n in enumerate(cg.neuron_list):
                for j, interv in enumerate(trials):
                    n.spike_train.unset_intervals()
                    ap.unset_intervals()
                    n.spike_train.set_intervals(interv)
                    ap.set_intervals(interv)
                    if os.path.exists(f'/adata/electro/{ses.subject}/{ses.name}/bottomleft{2*j+1}.npy'):
                        bl=np.load(f'/adata/electro/{ses.subject}/{ses.name}/bottomleft{2*j+1}.npy')
                        if arena[2*j]=='square':
                            center=(bl[0]+35,bl[1]+35)
                        else:
                            center=(bl[0]+40,bl[1]+40)
                        ap.invalid_outside_spatial_area(shape=arena[2*j],length=70, radius=40, center=center)
                    else:
                        ap.invalid_outside_spatial_area(shape=arena[2*j],length=70, radius=40, center=((np.nanmin(ap.pose[:,1])+x_range_pose/2),np.nanmin(ap.pose[:,2])+y_range_pose/2))
#
                    n.set_spatial_properties(ap)
                    mean_firing_rate[j][i] = n.spike_train.mean_firing_rate()
                    mean_firing_rate_name[j] = f"mean_firing_rate{2*j+1}"
                    n.spike_train.instantaneous_firing_rate(bin_size_sec=20)
                    max_firing_rate[j][i] = np.nanmax(n.spike_train.ifr[0])
                    max_firing_rate_name[j] = f"max_firing_rate{2*j+1}"
#
                    n.spatial_properties.firing_rate_map_2d(cm_per_bin =2, smoothing_sigma_cm = 2, smoothing=True)
                    n.spatial_properties.spatial_autocorrelation_map_2d()
                    n.spatial_properties.calculate_doughnut()
                    grid_score[j][i] = n.spatial_properties.grid_score()
                    grid_score_name[j] = f"grid_score{2*j+1}"
                    grid_info=n.spatial_properties.grid_info()
                    if not grid_info:
                        grid_info=np.repeat(np.nan,3)
                    grid_spacing[j][i] = grid_info[0]*n.spatial_properties.map_cm_per_bin #from bins to cm
                    grid_spacing_name[j] = f"grid_spacing{2*j+1}"
                    grid_orientation[j][i] = grid_info[1]
                    grid_orientation_name[j] = f"grid_orientation{2*j+1}"
                    grid_hexagon_error[j][i] = grid_info[2]
                    grid_hexagon_error_name[j] = f"grid_hexagon_error{2*j+1}"            
                    
#
                    n.spatial_properties.firing_rate_map_field_detection(min_pixel_number_per_field=25, max_fraction_pixel_per_field=0.33, min_peak_rate=4, min_fraction_of_peak_rate=0.45, max_min_peak_rate=10)
                    if n.spatial_properties.firing_rate_map_field_size:
                        mean_field_size[j][i] = np.nanmean(n.spatial_properties.firing_rate_map_field_size)
                    else:
                        mean_field_size[j][i] = np.nan
                    mean_field_size_name[j] = f"mean_field_size{2*j+1}"
                    
                    n.spatial_properties.firing_rate_map_2d(cm_per_bin =2, smoothing_sigma_cm = 0, smoothing=False) #need to recalculate without smoothing
                    info_score[j][i] = n.spatial_properties.information_score()
                    info_score_name[j] = f"info_score{2*j+1}"
                    
                    
                    if not os.path.exists(score_file) or overwrite_all==True:
                        shuGS,threshold = n.spatial_properties.shuffle_grid_score(iterations=200, cm_per_bin=2, smoothing=True, percentile=95)
                        grid_score_threshold[j][i] = threshold
                        grid_score_threshold_name[j] = f"grid_score_threshold{2*j+1}"
#
                        shuIS,threshold = n.spatial_properties.shuffle_info_score(iterations=200, cm_per_bin=2, percentile=95)
                        info_score_threshold[j][i] = threshold
                        info_score_threshold_name[j] = f"info_score_threshold{2*j+1}"
#
#
            session_data = [session, mouse, date, genotype, age_days, clu, cell_number, refractory_period_ratio]
            cell_data = [data for sublist in [[e for e in environment],[m for m in mean_firing_rate],[m for m in max_firing_rate],[m for m in grid_score],[m for m in grid_score_threshold],
                                              [m for m in info_score],[m for m in info_score_threshold],[m for m in mean_field_size],
                                              [m for m in grid_spacing],[m for m in grid_orientation],[m for m in grid_hexagon_error],
                                             ] for data in sublist]
            all_data = [data for sublist in [session_data, cell_data] for data in sublist]
            session_data_names = ['session','mouse','date', 'genotype', 'age_days', 'clu', 'cell_number', 'refractory_period_ratio']
            cell_data_names = [name for sublist in [environment_name, mean_firing_rate_name, max_firing_rate_name, grid_score_name, grid_score_threshold_name, 
                                                     info_score_name, info_score_threshold_name, mean_field_size_name, grid_spacing_name, grid_orientation_name, grid_hexagon_error_name
                                                   ] for name in sublist]
            all_data_names = [session_data_names+cell_data_names]
            #print("all_data_names",all_data_names)
            summary_table = pd.DataFrame(np.column_stack(all_data), columns=all_data_names)
            summary_table.to_csv(score_file)
    return name

In [2]:
for path in clustered_paths:
    get_firing_properties(path=path, overwrite=False, overwrite_all=False)

NameError: name 'clustered_paths' is not defined

In [None]:
import multiprocessing as mp
pool = mp.Pool(3)
results = pool.map(get_firing_properties, [path for path in clustered_paths])

pool.close()