In [1]:
import bk.load
import bk.compute

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import neuroseries as nts

from scipy.stats import zscore
from tqdm import tqdm

In [2]:
def firing_rates_intervals(neurons,intervals,nbins):
    intervals_activity_matrix = []
    for s,e in tqdm(intervals.as_units('s').iloc(),total = len(intervals)):
        inter = nts.IntervalSet(s,e,time_units='s')
        bin_size = inter.duration(time_units='s')/nbins
        t,binned = bk.compute.binSpikes(neurons,start = s,stop = e,nbins = nbins)
        binned = binned/bin_size

        intervals_activity_matrix.append(binned)
    return np.mean(intervals_activity_matrix,0)

def mean_firing_rates_intervals(neurons,intervals,name = None):
    firing_rates = []
    for n in neurons:
        fr = len(n.restrict(intervals))/intervals.tot_length(time_units = 's')
        firing_rates.append(fr)
    if name is None: name = 'FR'   
    firing_rates = pd.DataFrame(firing_rates,columns=[name])
    return firing_rates

In [3]:
def main(base_folder, local_path, *args, **kwargs):  
    bk.load.current_session_linux(base_folder=base_folder,local_path=local_path)
    neurons,metadata = bk.load.spikes()
    states = bk.load.states()
    sleep = states['Rem'].union(states['sws'])

    for s in ['wake','drowsy']: states.pop(s)

    all_average_firing = {}
    for k,state in states.items():
        all_average_firing.update({k:firing_rates_intervals(neurons,state,kwargs['nbins'][k])})

    fr = mean_firing_rates_intervals(neurons,sleep,'FR_sleep')
    metadata = pd.concat((metadata,fr),1)
    return all_average_firing,metadata

def merging(batch_output):
    states_activity = {'Rem':[],
                        'sws':[],
                        'wake_homecage':[]}
    metadata = pd.DataFrame()
    for k,b in batch_output.items():
        metadata = pd.concat((metadata,b[1]))

        for kk,state_activity in b[0].items():
            states_activity[kk].append(state_activity)
    
    for k,state_activity in states_activity.items():
        states_activity[k] = np.vstack(state_activity)
    return states_activity,metadata

def plot(activity,metadata,norm = True,order = None,quintile = False,ax = None):

    if norm:
        activity = zscore(activity,1)
    if order == 'fr':
        order = np.argsort(metadata.FR_sleep)
        activity = activity[order,:]
        metadata = metadata.iloc[order]


    if ax is None: fig, ax = plt.subplots(1,2)
    ax[0].imshow(activity,aspect = 'auto',interpolation = 'None')
    if quintile:
        labels = ['Very Low','Low','Medium','High','Very High']
        labels_neurons = pd.qcut(metadata.FR_sleep,5,labels = labels)
        for label in labels:
            ax[1].plot(np.nanmean(activity[labels_neurons == label,:],0))
            ax[1].legend(labels)
    else:
        ax[1].plot(np.nanmean(activity,0),'k--')    
    ax[1].set_ylim(-1,1)


In [4]:
kwargs = {'nbins':
         {
             'Rem':60,
             'sws':60,
             'wake_homecage':60
         }}
batch_output_60 = bk.load.batch(main,**kwargs)
states_activity,metadata = merging(batch_output_60)

100%|██████████| 61/61 [04:08<00:00,  4.07s/it]

Batch finished in 248.15912222862244
Some session were not processed correctly
['Rat08-20130720', 'Rat08-20130722', 'Rat09-20140407', 'Rat09-20140408', 'Rat09-20140409', 'Rat11-20150402', 'Rat11-20150403']
11.475409836065573 %





In [9]:
%matplotlib qt
# metadata.Type[metadata.Region == 'CeCM'] = 'Int'
for stru in np.unique(['Hpc','BLA']):
    fig,ax = plt.subplots(3,3,gridspec_kw = {'height_ratios':[4,1,5]},figsize =(16,9),dpi = 300)
    for i,state in enumerate(['sws','Rem','wake_homecage']):
        mask_pyr = (metadata.Region == stru) & (metadata.Type == 'Pyr') 
        mask_int = (metadata.Region == stru) & (metadata.Type == 'Int') 
        if np.any(mask_pyr):
            plot(states_activity[state][mask_pyr],metadata[mask_pyr],
                norm = True,
                order = 'fr',
                quintile = True,
                ax = [ax[0,i],ax[2,i]])
        if np.any(mask_int):
            plot(states_activity[state][mask_int],metadata[mask_int],
                norm = True,
                order = 'fr',
                quintile = False,
                ax = [ax[1,i],ax[2,i]])

        plt.suptitle(f'{stru} ordered - zScored')
        plt.sca(ax[0,i])
        plt.title(f'{state}')
    plt.savefig(f'/home/billel/pCloudDrive/IFM/work/Figures/Figures_Gabrielle/states/dynamic/{stru}.svg')
    plt.savefig(f'/home/billel/pCloudDrive/IFM/work/Figures/Figures_Gabrielle/states/dynamic/{stru}.png')

