# Dependencies

In [None]:
import bk.load
import bk.plot
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

from sklearn.preprocessing import normalize
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures

import neuroseries as nts

from scipy.stats import zscore
from scipy.stats import spearmanr,pearsonr
from scipy.stats import ttest_ind, linregress
import os
import pl
import re
from itertools import chain

%matplotlib qt

paths = pd.read_csv('Z:/All-Rats/Billel/session_indexing.csv',sep = ';')['Path']

# General Functions

In [None]:
def flatten(t):
    return [item for sublist in t for item in sublist]

In [None]:
def extract(lst, i,e):
    return [item[i:e] for item in lst]

# Sleep

In [None]:
# Importing pre and post RUN sleeps
sleep_pre=np.load('C://Users//Panagiota.Loizidou//Desktop//Amy-Hpc-sleep-dynamics-python-master//REM project//Sleeps_epochs.npy', allow_pickle=True)
sleep_post=np.load('C://Users//Panagiota.Loizidou//Desktop//Amy-Hpc-sleep-dynamics-python-master//REM project//Sleeps_epochs_post.npy', allow_pickle=True)
sleep=np.concatenate((sleep_pre,sleep_post), axis=0)

In [None]:
def parsing_sleep(path, pre_RUN=True):
    """
    input: the path of session you want to analyze
    ->  all states in chronological order in the pre run period 
        all states in chronological order in the post run period 
        sleeps sessions within the desgnated period (pre or post run) that are separated by more than
        30 seconds of wakefullness  
    """
    #loading data
    bk.load.current_session(path)
    states=bk.load.states()
    pre, post = bk.load.sleep()
    
    #turning them into a pd.DataFrame
    wake=np.insert(np.array(states['wake'],dtype='object'),2,'wake', axis=1)
    nrem=np.insert(np.array(states['sws'],dtype='object'), 2, 'nrem', axis=1)
    rem=np.insert(np.array(states['Rem'],dtype='object'), 2, 'rem', axis=1)
    drowsy=np.insert(np.array(states['drowsy'],dtype='object'), 2, 'drowsy', axis=1)
    whole_session=np.concatenate((wake, nrem, rem, drowsy))
    whole_session=pd.DataFrame(whole_session, columns=['start', 'stop', 'state'])
    whole_session_sorted=whole_session.sort_values('start', ignore_index=True)
    
    # separating pre and post RUN
    pre_run_period = whole_session_sorted.loc[whole_session_sorted['start']<int(pre['end'])]
    post_run_period = whole_session_sorted.loc[whole_session_sorted['start']>int(post['start'])]
    
    # reseting indexing for the post run sleep session
    post_run_period=post_run_period.reset_index(drop=True) 
    index=np.arange(0, len(post_run_period), 1)
    post_run_period=post_run_period.reindex(index)
    
    # chosing which period to parse
    if pre_RUN:
        l=pre_run_period
    else:
        l=post_run_period
       
    p=l[(((l['stop']-l['start'])/1e6)>60) & (l['state']=='wake')]
    k=[-1]+p.index.values.tolist()+[-1]    # list containing indexes of epochs of wake>30 s
    
    sleeps = [l.iloc[k[n]+1:k[n+1]] for n in range(len(k)-1)]
    sleeps=[[i] for i in sleeps if len(i)>0] #dropping empty periods
        
    return pre_run_period,post_run_period, sleeps

In [None]:
pre_run_period,post_run_period, sleeps = parsing_sleep(paths[10], pre_RUN=True)

In [None]:
def get_duration(sleep, unit='s'):
    """
    input: only a specific sleep session (i.e. sleeps[0]) can be imported
    """
    session=nts.IntervalSet(start=sleep[0]['start'].iloc[0], end=sleep[0]['stop'].iloc[-1])
    return session.tot_length(unit)

In [None]:
def removing_short_sessions(path, pre_RUN=True, duration=30, path_list=False):
    '''
    duration: select how long the extended sleep should be (in minutes). 
    path_list returns a list containing n times the path where n is the number of long sleeps. 
    Used later for labeling purposes 
    -> weird dataframe (np.ndarray of np.ndarray of pd.DF). Might be easier to access it using [0]
    '''
    _,_, sleeps=parsing_sleep(path, pre_RUN)
    sleeps=np.array(sleeps, dtype=object)
    index_list=[]
    for i in range(len(sleeps)):
        if get_duration(sleeps[i])> (duration*60):  #*60 to make it into minutes 
            index_list.append(i)
            
    long_sleeps=sleeps[index_list]
    if path_list:
        paths=[path]
        paths=len(long_sleeps)*paths
        return long_sleeps, paths
    else:
        return long_sleeps

In [None]:
removing_short_sessions(path, pre_RUN=True)

In [None]:
def get_table(path, pre_RUN=True, duration=30):
    """
    -> table with average firing rates for each state lasting longer than 30 seconds of the specified session
    """
    
    session, path, rat, day,n_channels=bk.load.current_session(path, return_vars=True)
    neurons,metadata = bk.load.loadSpikeData(bk.load.path)
    
    long_sleeps=removing_short_sessions(path, pre_RUN, duration)
    durations=[]
    rem=[]
    starts=[]
    stops=[]
    states=[]
    mean_FRs=np.zeros(len(neurons))
    BLA_Pyr=[]
    BLA_Int=[]
    Hpc_Pyr=[]
    Hpc_Int=[]
    mean_FR=[]
    index=[]
    path_s=[]
    
    for s in range(len(long_sleeps)):
        for i in range(len(long_sleeps[s][0])):
            epoch=nts.IntervalSet(start=long_sleeps[s][0].iloc[i].start, end=long_sleeps[s][0].iloc[i].stop)
            durations.append(epoch.tot_length('ms'))
            states.append(long_sleeps[s][0].iloc[i].state)
            starts.append(long_sleeps[s][0].iloc[i].start)
            stops.append(long_sleeps[s][0].iloc[i].stop)
            index.append(i)
            
            
            for n in range(len(neurons)):
                spk_time = neurons[n].restrict(epoch).as_units('s').index.values
                mean_firing_rate= len(spk_time)/epoch.tot_length('s')
                mean_FRs[n]=mean_firing_rate

            mean_FR.append(np.nanmean(mean_firing_rate))
            BLA_Pyr_n=np.nanmean(mean_FRs[(metadata.Region == 'BLA') & (metadata.Type == 'Pyr')])
            BLA_Int_n=np.nanmean(mean_FRs[(metadata.Region == 'BLA') & (metadata.Type == 'Int')])
            Hpc_Pyr_n=np.nanmean(mean_FRs[(metadata.Region == 'Hpc') & (metadata.Type == 'Pyr')])
            Hpc_Int_n=np.nanmean(mean_FRs[(metadata.Region == 'Hpc') & (metadata.Type == 'Int')])
            BLA_Pyr.append(float(BLA_Pyr_n))
            BLA_Int.append(float(BLA_Int_n))
            Hpc_Pyr.append(Hpc_Pyr_n)
            Hpc_Int.append(Hpc_Int_n)
    table=pd.DataFrame(np.column_stack([index, starts, stops, durations, states, mean_FR, BLA_Pyr,BLA_Int,Hpc_Pyr,Hpc_Int]), columns=['index', 'start', 'stop', 'duration', 'state', 'all_cells', 'BLA_Pyr','BLA_Int','Hpc_Pyr','Hpc_Int'])
    for col in table.columns:
        if col == 'state':
            pass
        else:
            table[col] = table[col].astype(float)
    
    #adding the path column (useful if anything weird is seen)
    table['path']=path
    return table

In [None]:
get_table(path, pre_RUN=True)

In [None]:
def separate_ES(path, pre_RUN=True, duration=30):
    """
    parses the table of a whole sleep (pre or post run) into 'extended sleep' tables
    *Not all sessions are supposed to have BLA and Hpc cells (because of recordings outside the BLA and Hpc)
    """
    
    sleep=get_table(path, pre_RUN, duration)
    individual_sessions=[]
    h=sleep[(sleep['index']==0.0)].index.values.tolist() + [len(sleep)+1]
    j = [sleep.iloc[h[n]:h[n+1]] for n in range(len(h)-1)]
    for u in j:
        individual_sessions.append(u)
    return individual_sessions

In [None]:
def separate_ES_multisession(paths=paths, pre_RUN=True, duration=30):
    """
    parses into 'extended sleep' for multiple sessions but always in the same period pre or post RUN
    """
    life_ruining=[] #add it to the return if you need to have it, but it is generally stable. 
                    #See list of sessions not used and reason why in Summer Internship-Part II
    sleep=[]
    for i in paths:
        try:
            sleep.append(get_table(i, pre_RUN, duration))
        except:
            life_ruining.append(i)
    
    individual_sessions=[]
    for s in range(len(sleep)):    
        h=sleep[s][(sleep[s]['index']==0.0)].index.values.tolist()+[-1]
        j = [sleep[s].iloc[h[n]:h[n+1]] for n in range(len(h)-1)]
        for u in j:
            individual_sessions.append(u)
    return individual_sessions

In [None]:
#checking if indeed we have only 4 states returned from the function separate_ES_multisession
ses=separate_ES_multisession(paths=paths, pre_RUN=False)
lst1=[]
for i in range(len(ses)):
    lst1.append(ses[i]['state'])
lst2=flatten(lst1)
set(lst2) #identifies unique elements

In [None]:
#checking difference between separate_ES and removing_short_sessions
separate_ES(path)
removing_short_sessions(path, pre_RUN=True)

# ==> removing_short_sessions is far faster because it does not call get_table

## Visualizing ES

In [None]:
# getting the ES for all recording sessions
longlong=[]
longpaths=[]
life=[]
for i in paths:
    try:
        long_sleeps_pre, path_preR=removing_short_sessions(i, pre_RUN=True,min_dur=30, path_list=True)
        path_pre = ["PRE " + p for p in path_preR]
        long_sleeps_post, path_postR=removing_short_sessions(i, pre_RUN=False,min_dur=30, path_list=True)
        path_post = ["POST " + p for p in path_postR]
        long_sleeps=np.concatenate((long_sleeps_pre,long_sleeps_post))
        paths_all=path_pre+path_post
        longlong.extend(long_sleeps)
        longpaths.extend(paths_all)
        longpaths = [p.replace("Z:\\Rat","") for p in longpaths]
        
    except:
        life.append(i)  #find out more about excluded sessions and reasons why in Methodology file

In [None]:
len(life)

In [None]:
# checking that all epochs identified are consecutive without any time gaps
for i in range(len(longlong)):
    dur=longlong[i][0]['stop']/60e6-longlong[i][0]['start']/60e6
    start=(longlong[i][0]['start']-longlong[i][0]['start'].iloc[0])/60e6  # subtracting start time so it starts from 0
    start=start.tolist()
    couples=list(zip(start, dur))
    for i in range(1,len(couples)):
        if couples[i][0]-couples[i-1][1]==couples[i-1][0]:
            print('there\'s a prob friend')

In [None]:
colors = {'nrem':'tab:blue', 'drowsy':'tab:purple', 'rem':'tab:orange',  'wake':'tab:pink'}
fig, ax = plt.subplots()

for i in range(len(longlong)):
    dur=longlong[i][0]['stop']/60e6-longlong[i][0]['start']/60e6
    start=(longlong[i][0]['start']-longlong[i][0]['start'].iloc[0])/60e6  # subtracting start time so it starts from 0
    start=start.tolist()
    couples=list(zip(start, dur))
    
    coloring=[]
    for c in range(len(longlong[i][0])):   # the [0] is to get the pd.array instead of the np.ndarray
        g=colors[longlong[i][0]['state'].iloc[c]]
        coloring.append(g)  
    
    ax.broken_barh(couples, (i, 0.85), facecolors=coloring)

y_pos=np.arange(0.5,len(longpaths)+0.5)
ax.set_yticks(y_pos)
ax.set_yticklabels(longpaths)

# adding the legend
pop_a = mpatches.Patch(color='tab:blue', label='NREM')
pop_b = mpatches.Patch(color='tab:orange', label='REM')
pop_c = mpatches.Patch(color='tab:purple', label='Drowsy')
pop_d = mpatches.Patch(color='tab:pink', label='Wake')


plt.legend(handles=[pop_a,pop_b, pop_c, pop_d])

plt.title('States in Extended Sleeps')
plt.ylabel('Extended Sleep Session')
plt.xlabel('Time since extended sleep onset (m)')

# Normalization

## Over the Whole Recording
Firing rate z scores calculated using means and SDs in 1-min bins over the whole recording (no separation of pre/post run).
Includes RUN intervals.

In [None]:
path=paths[5]
bk.load.current_session(path, return_vars=True)
neurons,metadata = bk.load.loadSpikeData(bk.load.path)
pre, post =bk.load.sleep()
states=bk.load.states()

In [None]:
def norm_FR(path, region='Hpc', celltype='Pyr', min=0, max=2040, bin=1, whole=True):
    """
    inputs:
    min, max, bin is in seconds
    if whole is True, the interval will be from 0 to the end of recording. It overwrites the min and max specified.
    -> Firing rate z scores for each cell of specified celltype at structure specified calculated using means and SDs in 1-min bins of the specified interval
    """
    
    bk.load.current_session(path)
    neurons,metadata = bk.load.loadSpikeData(bk.load.path)
    pre, post =bk.load.sleep()
    
    if whole:
        min=pre['start'].values/1e6   # from microseconds to seconds
        max=post['end'].values/1e6
    
    window = nts.IntervalSet(min,max,time_units = 's')
    
    n=[]
    for i in range(len(neurons)):
        n.append(neurons[i].restrict(window).as_units('s').index)
    n=sorted(n,key=len)     # sorted based on highest firing rate overall 
    n=np.array(n,dtype=object)
    
    bins=np.arange(min,max,bin)
    hist=[]
    for i in range(len(neurons)):
        j,e=np.histogram(n[i],bins)
        hist.append(j)
    z=zscore(hist, axis=1)
    
    fr_z=z[(metadata.Type==celltype) &(metadata.Region==region)]
    
    return fr_z, e

# Normalizing over each ES

In [None]:
# removing paths without any ES both for pre and post RUN
not_use_pre=[]
for i in range(len(paths)):
    try:
        ES=removing_short_sessions(paths[i])
        if ES[0]==0:
            not_use_pre.append(paths[i])  
    except:
        not_use_pre.append(paths[i])  
useful_paths_pre=[ele for ele in paths if ele not in not_use_pre]
# useful_paths_pre.remove("Z:\Rat11\Rat11-20150401")

not_use_post=[]
for i in range(len(paths)):
    try:
        ES=removing_short_sessions(paths[i], pre_RUN=False)
        if ES[0]==0:
            not_use_post.append(paths[i]) 
    except:
        not_use_post.append(paths[i])  
useful_paths_post=[ele for ele in paths if ele not in not_use_post]


In [None]:
# Normalizing within an ES and then selecting intervals to compare

all_ES_pre=[]
all_ES_post=[]
first_epochs_pre=[]
first_epochs_post=[]
last_epochs_pre=[]
last_epochs_post=[]

for i in range(len(useful_paths_pre)): # going over all sessions
    ES_pre=removing_short_sessions(useful_paths_pre[i])  #going over all ES in useful sessions
    for n in range(len(ES_pre)):
        start=ES_pre[n][0]['start'].iloc[0]/1e6
        stop=ES_pre[n][0]['stop'].iloc[-1]/1e6
        c, e=norm_FR(useful_paths_pre[i], min=start, max=stop, whole=False)
        mean_c=np.nanmean((c), axis=0)
        all_ES_pre.append(mean_c)
        
        if len(ES_pre)>0:
            # getting the intervals of first and last session. Subtracting the start time:
            start_time=ES_pre[n][0]['start'].iloc[0]
            start_first=(ES_pre[n][0][ES_pre[n][0]['state']=='nrem'].iloc[0]['start']-start_time)/1e6
            stop_first=(ES_pre[n][0][ES_pre[n][0]['state']=='nrem'].iloc[0]['stop']-start_time)/1e6
            start_last=(ES_pre[n][0][ES_pre[n][0]['state']=='nrem'].iloc[-1]['start']-start_time)/1e6
            stop_last=(ES_pre[n][0][ES_pre[n][0]['state']=='nrem'].iloc[-1]['stop']-start_time)/1e6
            first_epoch=mean_c[int(start_first):int(stop_first)]
            last_epoch=mean_c[int(start_last):int(stop_last)]
            first_epochs_pre.append(first_epoch)
            last_epochs_pre.append(last_epoch)

for i in range(len(useful_paths_post)):
    ES_post=removing_short_sessions(useful_paths_post[i], pre_RUN=False)
    for n in range(len(ES_post)):
        start=ES_post[n][0]['start'].iloc[0]/1e6
        stop=ES_post[n][0]['stop'].iloc[-1]/1e6
        c, e=norm_FR(useful_paths_post[i], min=start, max=stop, whole=False)
        mean_c=np.nanmean((c), axis=0)
        all_ES_post.append(mean_c)
        
        if len(ES_post)>0:
            # getting the intervals of first and last sessions:
            start_time=ES_post[n][0]['start'].iloc[0]
            start_first=(ES_post[n][0][ES_post[n][0]['state']=='nrem'].iloc[0]['start']-start_time)/1e6
            stop_first=(ES_post[n][0][ES_post[n][0]['state']=='nrem'].iloc[0]['stop']-start_time)/1e6
            start_last=(ES_post[n][0][ES_post[n][0]['state']=='nrem'].iloc[-1]['start']-start_time)/1e6
            stop_last=(ES_post[n][0][ES_post[n][0]['state']=='nrem'].iloc[-1]['stop']-start_time)/1e6
            first_epoch=mean_c[int(start_first):int(stop_first)]
            last_epoch=mean_c[int(start_last):int(stop_last)]
            first_epochs_post.append(first_epoch)
            last_epochs_post.append(last_epoch)
        
all_ES=all_ES_pre+all_ES_post

In [None]:
# visualizing how z-scored FR looks like
for i in range(len(all_ES)):
    plt.scatter(np.arange(len(all_ES[i])),all_ES[i])

In [None]:
fr_t=[]
for i in range(len(all_ES)):
    fr_t.extend(list(zip(all_ES[i], np.arange(len(all_ES[i])))))
    
a=np.array(fr_t)

In [None]:
# removing nan values
filt = np.isfinite(a[:,0])
y = a[filt,0]
x = a[filt,1]

In [None]:
# linear regression
slope, intercept, r, p, se=linregress(x,y)

In [None]:
# visualizing the linear regression fitted line
plt.scatter(x,y)
plt.plot(x, intercept + slope*x, 'r', label='fitted line')

In [None]:
v_clean

In [None]:
len(a[:,0])

In [None]:
a[:,1]

In [None]:
def regression_Line(path=path, neurons=neurons,metadata=metadata,region='Hpc', celltype='Pyr',state='Rem', min=0, max=2040, bin=1, whole=True):
    g, bins=separate_States_Epochs(path, neurons,metadata,region, celltype,state, min, max, bin, whole)
    slopes=[]
    y_preds=[]
    for i in np.arange(len(bins)):
        x=np.array(zscore(bins[i])).reshape((-1,1))  #normalizing x (y is already z-scored)
        y=np.array(g[i])
        model = LinearRegression()
        model.fit(x, y)
        model = LinearRegression().fit(x, y)
        r_sq = model.score(x, y)
        slopes.append(model.coef_)
        intercept=model.intercept_ #not useful
        y_preds.append(model.predict(x))
    return g, bins, slopes, y_preds

In [None]:
ttest_ind(flatten(last_epochs_pre), flatten(first_epochs_pre))

In [None]:
len(all_ES)

In [None]:
ttest_ind(flatten(last_epochs_post), flatten(first_epochs_post))

In [None]:
plt.plot(last_epochs_post[5])
plt.plot(first_epochs_pre[5])

In [None]:
ES_pre=removing_short_sessions(useful_paths_pre[0])  #going over all ES in useful sessions
start=ES_pre[0][0]['start'].iloc[0]/1e6
stop=ES_pre[0][0]['stop'].iloc[-1]/1e6
c, e=norm_FR(useful_paths_pre[0], min=start, max=stop, whole=False)
mean_c=np.nanmean((c), axis=1)

In [None]:
mean_c

In [None]:
#comparing pre and post FR
ttest_ind(flatten(first_epochs_post), flatten(first_epochs_pre)) 
ttest_ind(flatten(last_epochs_post), flatten(last_epochs_pre)) 

###  Plotting firing rates (z-score) for single sessions

In [None]:
session, path, rat, day,n_channels=bk.load.current_session(paths[3],return_vars=True)
neurons,metadata = bk.load.loadSpikeData(bk.load.path)
pre, post =bk.load.sleep()
states=bk.load.states()
region='Hpc'
celltype='Pyr'
c,e=norm_FR(paths[3],region, celltype, min=0, max=2040, bin=60, whole=True)
# c=np.mean(v, axis=0)

In [None]:
plt.figure()
plt.plot(c[1])
bk.plot.intervals(states['Rem']/60, col = 'blue', alpha=0.3)
# bk.plot.intervals(states['wake']/60, col = 'r', alpha=0.3)
# bk.plot.intervals(states['drowsy']/60, col = 'g', alpha=0.3)
bk.plot.intervals(states['sws']/60, col = 'red', alpha=0.3)

plt.suptitle('Single Session Firing Rate (z-score)')
plt.title('Session: '+str(session) +', Rat: '+ str(rat) +', Day: '+  str(day)+', Region: '+ region+', Celltype: '+ celltype)
plt.ylabel('Normalized Firing Rate (z)')
plt.xlabel('1-min Timebins (m)')

REM_patch = mpatches.Patch(color = 'blue', alpha=0.3, label='REM')
NREM_patch = mpatches.Patch(color = 'red', alpha=0.3, label='NREM')
plt.legend(handles=[REM_patch, NREM_patch])

plt.show()

In [None]:
# h=[]
# life=[]
# for i in range(len(paths)):
#     try:
#         v,e=norm_FR(paths[i],region='Hpc', celltype='Pyr', min=0, max=2040, bin=60, whole=True)
#         c=np.mean(v, axis=0)
#         h.append(c)
#     except:
#         life.append(paths[i])

In [None]:
# s=long_sleeps[0][0]['start'][long_sleeps[0][0]['state']=='nrem'].tolist()
# f=long_sleeps[0][0]['stop'][long_sleeps[0][0]['state']=='nrem'].tolist()
# for start, finish in zip(s,f):
#     plt.axvspan(start/60e6,finish/60e6, facecolor='b', alpha=0.1)
# s=long_sleeps[0][0]['start'][long_sleeps[0][0]['state']=='rem'].tolist()
# f=long_sleeps[0][0]['stop'][long_sleeps[0][0]['state']=='rem'].tolist()
# for start, finish in zip(s,f):
#     plt.axvspan(start/60e6,finish/60e6, facecolor='r', alpha=0.1)
# plt.plot(np.mean(fr[0], axis=1))

In [None]:
def norm_FR_sessionLoaded(path=path, neurons=neurons,metadata=metadata,region='Hpc', celltype='Pyr', min=0, max=2040, bin=60, whole=True):
    """
    inputs:
    min, max, bin is in seconds
    if whole is True, the interval will be from 0 to the end of recording. It overwrites the min and max specified.
    -> Firing rate z scores for the celltypes at structure specified calculated using means and SDs in 1-min bins of the specified interval
    """
    
    if whole:
        min=pre['start'].values/1e6
        max=post['end'].values/1e6
    window = nts.IntervalSet(min,max,time_units = 's')
    
    n=[]
    for i in range(len(neurons)):
        n.append(neurons[i].restrict(window).as_units('s').index)
#     n=sorted(n,key=len)
    n=np.array(n,dtype=object)
    
    bins=np.arange(min,max,bin)
    hist=[]
    for i in range(len(neurons)):
        j,e=np.histogram(n[i],bins)
        hist.append(j)
    z=zscore(hist, axis=1)
    
    c=z[(metadata.Type==celltype) &(metadata.Region==region)]
    
    return c,e

In [None]:
c, e=norm_FR_sessionLoaded()

In [None]:
plt.plot(c[2])

In [None]:
np.mean(c[2])

In [None]:
def norm_FR_noRUN_sessionLoaded(path=path, neurons=neurons,metadata=metadata,region='Hpc', celltype='Pyr', min=0, max=2040, bin=60, whole=True):
    """
    inputs:
    min, max, bin is in seconds
    if whole is True, the interval will be from 0 to the end of recording WITHOUT TAKING INTO ACCOUNT THE RUN SESSION. It overwrites the min and max specified.
    -> Firing rate z scores for the celltypes at structure specified calculated using means and SDs in 1-min bins of the specified interval
    """
    
    if whole:
        min=pre['start'].values/1e6
        min_2=post['start'].values/1e6
        max=pre['end'].values/1e6
        max_2=post['end'].values/1e6
        
    window = nts.IntervalSet([min, min_2],[max, max_2],time_units = 's')
    
    n=[]
    for i in range(len(neurons)):
        n.append(neurons[i].restrict(window).as_units('s').index)
#     n=sorted(n,key=len)
    n=np.array(n,dtype=object)
    
    bins_pre=np.arange(min,max,bin).tolist()
    bins_post=np.arange(min_2,max_2,bin).tolist()
    bins=bins_pre+bins_post
    hist=[]
    for i in range(len(neurons)):
        j,e=np.histogram(n[i],bins)
        hist.append(j)
    z=zscore(hist, axis=1)
    
    c=z[(metadata.Type==celltype) &(metadata.Region==region)]
    
    return c,e

In [None]:
# def norm_FR_noRUN_sessionLoaded_withinState(path=path, neurons=neurons,metadata=metadata,region='Hpc', celltype='Pyr', min=0, max=2040, bin=60, whole=True):
#     """
#     inputs:
#     min, max, bin is in seconds
#     if whole is True, the interval will be from 0 to the end of recording WITHOUT TAKING INTO ACCOUNT THE RUN SESSION. It overwrites the min and max specified.
#     -> Firing rate z scores for the celltypes at structure specified calculated using means and SDs in 1-min bins of the specified interval. 
#     ***normalization performed within each state: i.e. the avg of all nrem states within an ext. sleep will be 0***
#     """
    
#     if whole:
#         min=pre['start'].values/1e6
#         min_2=post['start'].values/1e6
#         max=pre['end'].values/1e6
#         max_2=post['end'].values/1e6
        
#     window = nts.IntervalSet([min, min_2],[max, max_2],time_units = 's')
    
#     n=[]
#     for i in range(len(neurons)):
#         n.append(neurons[i].restrict(window).as_units('s').index)
# #     n=sorted(n,key=len)
#     n=np.array(n,dtype=object)
    
#     bins_pre=np.arange(min,max,bin).tolist()
#     bins_post=np.arange(min_2,max_2,bin).tolist()
#     bins=bins_pre+bins_post
#     hist=[]
#     for i in range(len(neurons)):
#         j,e=np.histogram(n[i],bins)
#         hist.append(j)
#     z=zscore(hist, axis=1)
    
#     c=z[(metadata.Type==celltype) &(metadata.Region==region)]
    
#     return c,e

# under construction

In [None]:
d=states['Rem']['start'][states['Rem']['start']>12692000000]

In [None]:
d

In [None]:
def norm_FR_noRUN_sessionLoaded_withinState(path=path, neurons=neurons,metadata=metadata,region='Hpc', celltype='Pyr',state='Rem', min=0, max=2040, bin=1, whole=True):
    """
    inputs:
    state= choose between 'Rem', 'sws','drowsy', 'wake'
    min, max, bin is in seconds
    if whole is True, the interval will be from 0 to the end of recording WITHOUT TAKING INTO ACCOUNT THE RUN SESSION. It overwrites the min and max specified.
    -> Firing rate z scores for the celltypes at structure specified calculated using means and SDs in 1-min bins of the specified interval. 
    ***normalization performed within each state: i.e. the avg of all nrem states within a sleep session will be 0***
    """
    
    #selecting interesting neurons
    neurons=neurons[(metadata.Type==celltype) &(metadata.Region==region)]
    
    
    if whole:
        min=pre['start'].values/1e6
        min_2=post['start'].values/1e6
        max=pre['end'].values/1e6
        max_2=post['end'].values/1e6

    window = nts.IntervalSet([min, min_2],[max, max_2],time_units = 's')


    #creating an interval window for the state chosen
    state=states[state].values/1e6
    state_start=[]
    state_end=[]
    for i in state:
        state_start.append(i[0])
        state_end.append(i[1])

    window = nts.IntervalSet(state_start,state_end,time_units = 's')
    
    #taking the FR of each neuron within the window of each state
    n=[]
    for i in range(len(neurons)):
        n.append(neurons[i].restrict(window).as_units('s').index)
    n=np.array(n,dtype=object)
    
    #creating bins in order to extract the firing rate of each neuron. binsize specified above 
    bins=[]
    for i in state:
        edges=np.arange(i[0],i[1],bin).tolist()
        bins.append(edges)
    bins_flat=flatten(bins)
    
    #extractin firing rate/bin
    hist=[]
    for i in range(len(neurons)):
        j,_=np.histogram(n[i],bins_flat)
        hist.append(j)
        
    #normalizing (z-scoring)
    fr_z=zscore(hist, axis=1)

    
    #fr_z is the z-scored firing rate for each neuron in the epoch designated 
    #bins_flat is all the seconds in the specified state in this ES
    #bins is all the seconds in the ES binned in separate epochs (for instance the first REM epoch in the ES is bins[0])
    return fr_z,bins_flat,bins

In [None]:
norm_FR_noRUN_sessionLoaded_withinState(path=path, neurons=neurons,metadata=metadata,region='Hpc', celltype='Pyr',state='Rem',min=start, max=stop, bin=1, whole=False):

In [None]:
ES=removing_short_sessions(paths[0])

In [None]:
ES=removing_short_sessions(paths[0])
start=ES[0][0]['start'].iloc[0]
stop=ES[0][0]['stop'].iloc[-1]

In [None]:
fr_z,bins_flat,bins=norm_FR_noRUN_sessionLoaded_withinState()

In [None]:
plt.plot(bins_flat[0:-1], np.mean(fr_z, axis=0))

In [None]:
def separate_States_Epochs(path=path, neurons=neurons,metadata=metadata,region='Hpc', celltype='Pyr',state='Rem', min=0, max=2040, bin=1, whole=True):
    """
    
    """"
    
    fr_z, bins_flat, bins=norm_FR_noRUN_sessionLoaded_withinState(path, neurons,metadata,region, celltype,state, min, max, bin, whole)
    
    g=[] #average fr for each second in each epoch
    counter=0
    for i in np.arange(len(bins)):
        fr_ep=extract(fr_z,counter,counter+len(bins[i]))
        counter+=int(len(bins[i]))
        g.append(np.mean(fr_ep, axis=0))
    
    del bins[-1][-1]
    return g, bins

In [None]:
g, bins=separate_States_Epochs()

In [None]:
l=flatten(g)
k = flatten(bins)
len(k)
plt.plot(k,l)
np.mean(l)

In [None]:
g_0=[]
g_last=[]
for i in paths:
    g, _=separate_States_Epochs(path=i)
    g_0.append(g[0])
    g_last.append(g[-1])

In [None]:
g_0=flatten(g_0)
g_last=flatten(g_last)

In [None]:
ttest_ind(g_0, g_last)

In [None]:
plt.scatter(k,l)

In [None]:
for i in np.arange(len(paths)):
    kill=[]
    avg_avg=[]
    try:
        bk.load.current_session(paths[i], return_vars=True)
        neurons,metadata = bk.load.loadSpikeData(bk.load.path)
        pre, post =bk.load.sleep()
        states=bk.load.states()
        g, bins, slopes, y_preds= regression_Line()
        slope_list=[]
        for i in np.arange(len(g)):
            plt.plot(bins[i],y_preds[i])
            plt.scatter(bins[i],g[i])
            slope_list.append(slopes[i])
        avg=np.mean(slope_list)
        avg_avg.append(avg)
    except:
        kill.append(paths[i])

In [None]:
kill

In [None]:
slope_list=[]
for i in np.arange(len(g)):
    plt.plot(bins[i],y_preds[i])
    plt.scatter(bins[i],g[i])
    slope_list.append(slopes[i])
avg=np.mean(slope_list)
avg

In [None]:
# Polynomial Regression [only to play with, not used]
transformer = PolynomialFeatures(degree=2, include_bias=False)
transformer.fit(x)
x_ = transformer.transform(x)
x_ = PolynomialFeatures(degree=3, include_bias=False).fit_transform(x)
model = LinearRegression()
model.fit(x_, y)
model = LinearRegression().fit(x_, y)

In [None]:
c, bins=norm_FR_noRUN_sessionLoaded_withinState()

In [None]:
c,e,_=norm_FR_noRUN_sessionLoaded_withinState()
c_1,e_1,_=norm_FR_noRUN_sessionLoaded_withinState(region='BLA', celltype='Pyr')
c_2,e_2,_=norm_FR_noRUN_sessionLoaded_withinState(state='sws')
c_3,e_3,_=norm_FR_noRUN_sessionLoaded_withinState(state='sws',region='BLA', celltype='Pyr')
g=np.mean(c, axis=0)
g_1=np.mean(c_1, axis=0)
g_2=np.mean(c_2, axis=0)
g_3=np.mean(c_3, axis=0)
plt.plot(e[:-1], g)
plt.plot(e_1[:-1], g_1)
plt.plot(e_2[:-1], g_2)
plt.plot(e_3[:-1], g_3)
plt.title("Avg firing rate selected neurons in the state selected (z-score obtained based on firing rates during x)")
bk.plot.intervals(states['Rem'], col = 'blue', alpha=0.3)
# bk.plot.intervals(states['wake']/60, col = 'r', alpha=0.3)
# bk.plot.intervals(states['drowsy']/60, col = 'g', alpha=0.3)
bk.plot.intervals(states['sws'], col = 'red', alpha=0.3)
REM_patch = mpatches.Patch(color = 'blue', alpha=0.3, label='REM')
NREM_patch = mpatches.Patch(color = 'red', alpha=0.3, label='NREM')
plt.legend(handles=[REM_patch, NREM_patch])

# Plotting

In [None]:
#plot graphs of spikes and z-score run included
c, e=norm_FR_sessionLoaded(paths[6], neurons)
plt.figure()
plt.eventplot(neurons[0].restrict(post).index/60e6-pre['start'].values/60e6, alpha=0.2, label='spike train')
plt.eventplot(neurons[0].restrict(pre).index/60e6-pre['start'].values/60e6, alpha=0.2)
plt.plot(c[0], c='orange',linewidth=3, label='normalized firing rate')
plt.axvspan(pre['end'].values/60e6-pre['start'].values/60e6, post['start'].values/60e6-pre['start'].values/60e6, facecolor='g', alpha=0.2, label='run')
plt.ylabel('Firing Rate (z)')
plt.xlabel('Time since session startes (mins)')
plt.legend()
plt.title('Example of normalized firing rate as compared to the spike train of the same neuron')

In [None]:
#plot graphs of spikes and z-score run NOT included
c, e=norm_FR_noRUN_sessionLoaded(paths[5], neurons)
plt.figure()
plt.eventplot(neurons[103].restrict(post).index/60e6-(post['start'].values-pre['end'].values+pre['start'].values)/60e6, alpha=0.2, label='spike train')
plt.eventplot(neurons[103].restrict(pre).index/60e6-pre['start'].values/60e6, alpha=0.2)
plt.plot(c[103], c='orange',linewidth=2.5, label='normalized firing rate')
plt.axvspan(pre['start'].values/60e6-pre['start'].values/60e6, pre['end'].values/60e6-pre['start'].values/60e6, facecolor='g', alpha=0.2, label='pre-run')
plt.ylabel('Firing Rate (z)')
plt.xlabel('Time since session startes (mins)')
plt.legend()
plt.title('Example of normalized firing rate as compared to the spike train of the same neuron')

# Normalizing within each epoch, taking the mean of all neurons during each epoch and plotting them against the time since the start of sleep session

problematic because comparing mean z-scores obtained by different periods. 

In [None]:
def fav(path, region='Hpc', celltype='Pyr', pre_RUN=True):
    bk.load.current_session(path)
    neurons,metadata = bk.load.loadSpikeData(bk.load.path)
    pre, post =bk.load.sleep()
    
    if type(path)==str:
        fr=[]
        long_sleeps=removing_short_sessions(path, pre_RUN)
        for s in range(len(long_sleeps)): #for each sleep session
            for e in range(len(long_sleeps[s][0])): #for each state/epoch in sleep 
                c,et=norm_FR_sessionLoaded(path,neurons, metadata, region, celltype, min=long_sleeps[s][0]['start'].iloc[e]/1e6, max=long_sleeps[s][0]['stop'].iloc[e]/1e6, bin=60, whole=False)
                fr.append(c)
#     else:
#         long_sleeps=[]
#         fr=[]
#         for p in path:
#             long=removing_short_sessions(p, pre_RUN)
#             long_sleeps.extend(long)
#             for s in range(len(long_sleeps)):
#                 for e in range(len(long_sleeps[s])):
#                     c=norm_FR(p, region, celltype, min=long_sleeps[s][0]['start'].iloc[e]/1e6, max=long_sleeps[s][0]['stop'].iloc[e]/1e6, bin=60, whole=False)
#                     fr.append(c)
    return long_sleeps, fr

In [None]:
bk.load.current_session(paths[4])
neurons,metadata = bk.load.loadSpikeData(bk.load.path)
pre, post =bk.load.sleep()
states=bk.load.states()

In [None]:
long_sleeps_all, fr=fav(paths[4], region='Hpc', celltype='Pyr', pre_RUN=True)
long_sleeps=long_sleeps_all[0][0].reset_index(drop=True) 

In [None]:
long_sleeps_all[0]

In [None]:
f=[np.nanmean(fr[i], axis=0) for i in range(len(fr))]
f=np.array(f[0:len(long_sleeps)])
g=f[long_sleeps.state=='rem']

g=[np.nanmean(g[i]) for i in range(len(g))]  #taking the mean of each epoch
g

In [None]:
plt.figure()
plt.scatter((long_sleeps[long_sleeps.state=='rem']['start']-long_sleeps['start'][0])/60e6,g)
plt.ylabel('Firing Rate (z)')
plt.xlabel('Timebins (mins)')

# Normalizing for whole session, taking the mean of all neurons during each epoch and plotting them against the time since the start of sleep session

In [None]:
def fav(path, region='Hpc', celltype='Pyr', pre_RUN=True, duration=30):
    bk.load.current_session(path)
    neurons,metadata = bk.load.loadSpikeData(bk.load.path)
    pre, post =bk.load.sleep()
    
    if type(path)==str:
        fr=[]
        long_sleeps=removing_short_sessions(path, pre_RUN, duration)
        for s in range(len(long_sleeps)): #for each sleep session
            for e in range(len(long_sleeps[s][0])): #for each state/epoch in sleep 
                c,et=norm_FR_noRUN_sessionLoaded(path,neurons, metadata, region, celltype, min=long_sleeps[s][0]['start'].iloc[e]/1e6, max=long_sleeps[s][0]['stop'].iloc[e]/1e6, bin=60, whole=False)
                fr.append(c)
#     else:
#         long_sleeps=[]
#         fr=[]
#         for p in path:
#             long=removing_short_sessions(p, pre_RUN)
#             long_sleeps.extend(long)
#             for s in range(len(long_sleeps)):
#                 for e in range(len(long_sleeps[s])):
#                     c=norm_FR(p, region, celltype, min=long_sleeps[s][0]['start'].iloc[e]/1e6, max=long_sleeps[s][0]['stop'].iloc[e]/1e6, bin=60, whole=False)
#                     fr.append(c)
    return long_sleeps, fr

In [None]:
def final(path=paths[5],neurons=neurons, metadata=metadata, region='Hpc', celltype='Pyr',state='nrem', min=0, max=90, bin=60, whole=True, duration=30):
    bk.load.current_session(path)
    neurons,metadata = bk.load.loadSpikeData(bk.load.path)
    pre, post =bk.load.sleep()
    states=bk.load.states()
    
    c,e=norm_FR_noRUN_sessionLoaded(path,neurons, metadata, region, celltype, min, max, bin, whole) # c is normalized fr for each neuron, e is the timebins used for normalization
    r=np.mean(c, axis=0) #r is the average z of all cells per timebin
    timepoints=np.array(list(zip(r,e)))
    
    l=[] #list of sleeps
    long_sleeps_pre=removing_short_sessions(path, pre_RUN=True, duration)
    long_sleeps_post=removing_short_sessions(path, pre_RUN=False, duration)
    if len(long_sleeps_pre)>0 and len(long_sleeps_post)==0:
        long_sleeps=long_sleeps_pre
    elif len(long_sleeps_post)>0 and len(long_sleeps_pre)==0:
        long_sleeps=long_sleeps_post
    else:
        long_sleeps=np.concatenate((long_sleeps_pre,long_sleeps_post))
    
    for i in range(len(long_sleeps)):
        ls=long_sleeps[i][0].reset_index(drop=True) 
        l.append(ls)
    state_fr=[]
    fr=[]
    t_state=[]
    for i in range(len(l)):    
        bin_start=(l[i]['start']/1e6).tolist()       
        bin_start.append(l[i]['stop'].iloc[-1]/1e6) #adding the last end of the extended sleep
        f,_=np.histogram(timepoints[:,1],bin_start)       #how many timepoints are in each epoch
        p=0
        index=[0]
        for c in f:
            p+=c
            index.append(p)
        f=[r[index[h]:index[h+1]] for h in range(len(index)-1)]
        fr.append(f)
        w=np.array([np.mean(f[u]) for u in range(len(f))]) 
        nrem=w[l[i]['state']==state].tolist()
        times=(((l[i]['start'][l[i]['state']==state]+l[i]['stop'][l[i]['state']==state])/2)-l[i]['start'][0])/1e6  #taking the midpoint of an epoch
        t_state.append(times)
        state_fr.append(nrem)
        
    return state_fr,t_state, fr

In [None]:
state_fr,t_state, fr = final(paths[56])

In [None]:
plt.scatter(flatten(t_state), flatten(state_fr))

In [None]:
op_rem=[]
t_rem=[]
fre_rem=[]

cest_la_vie=[]
for i in range(len(paths)):
    try:
        op,t, fre=final(paths[i], region='BLA',state='rem')
        op_rem.extend(op)
        t_rem.extend(t)
        fre_rem.extend(fre)
        
    except:
        cest_la_vie.append(paths[i])

In [None]:
cest_la_vie

In [None]:
dataframe, corr, spear=cuantilization(op_rem, t_rem, cell='pyramidal', structure='BLA')

In [None]:
op_nrem=[]
t_nrem=[]
fre_nrem=[]

cest_la_vie=[]
for i in range(len(paths)):
    try:
        op,t, fre=final(paths[i], state='nrem')
        op_rem.extend(op)
        t_rem.extend(t)
        fre_rem.extend(fre)
        
    except:
        cest_la_vie.append(paths[i])

In [None]:
op_dro=[]
t_dro=[]
fre_dro=[]

cest_la_vie=[]
for i in range(len(paths)):
    try:
        op,t, fre=final(paths[i], state='drowsy')
        op_rem.extend(op)
        t_rem.extend(t)
        fre_rem.extend(fre)
        
    except:
        cest_la_vie.append(paths[i])

In [None]:
def cuantilization(op_a, t_a, bins=100, state='REM', cell='pyramidal', structure='Hpc'):
    fio=list(chain.from_iterable(t_a))
    fiu=list(chain.from_iterable(op_a))
    data_tuples=list(zip(fiu, fio))
    dataframe=pd.DataFrame(data_tuples, columns=['fr', 'timing'])
    bin_times=pd.cut(dataframe['timing'], bins)
    o=[bin_times[i].mid for i in range(len(bin_times))]
    data_tuples=list(zip(fiu, fio, o))
    dataframe=pd.DataFrame(data_tuples, columns=['fr', 'timing', 'mid'])
    quantiles=pd.qcut(dataframe['timing'].values, bins, labels=np.arange(0,bins,1))
    data_tuples=list(zip(fiu, fio,bin_times, o, quantiles))
    dataframe=pd.DataFrame(data_tuples, columns=['fr', 'timing', 'bins', 'mid', 'quantiles'])
    f=dataframe.groupby('quantiles')['fr'].mean()
    t=dataframe.groupby('quantiles')['timing'].mean()
    
    fr_mid=list(zip(dataframe['fr'], dataframe['mid']))
    fr_mid=pd.DataFrame(fr_mid, columns=['fr', 'mid'])
    fr_mid=fr_mid[~dataframe['fr'].isna()]
    m, b = np.polyfit(fr_mid['mid'], fr_mid['fr'], 1)

    
    f_tuples=list(zip(dataframe['mid'], dataframe['fr']))
    f_df=pd.DataFrame(f_tuples, columns=['mid','fr'])
    corr=f_df.corr()
    
    spear=spearmanr(dataframe['mid'], dataframe['fr'], nan_policy='omit')
    
    plt.scatter(t/60,f)
    plt.plot(t/60, m*t + b)
    plt.title('Mean firing rate of '+ cell + ' cells in the '+ structure +' during '+state)
    plt.ylabel('Firing Rate (z)')
    plt.xlabel('Time from onset of ES (min)')
    plt.text(50, -0.3, r'$r_s$: '+str(np.around(spear[0], 3))+'\n'+ 'p: '+str("{:.2e}".format(spear[1]))+ '\n'+ 'Pearsons r: ' +str(np.around(corr['fr']['mid'], 3)))
    plt.show()
    
    return dataframe, corr, spear

In [None]:
dataframe, corr, spear=quantilization(op_, t_rem)

In [None]:
fio=list(chain.from_iterable(t_rem))
fiu=list(chain.from_iterable(op_a))

In [None]:
data_tuples=list(zip(fiu, fio))
dataframe=pd.DataFrame(data_tuples, columns=['fr', 'timing'])

In [None]:
bin_times=pd.cut(dataframe['timing'], 10)

In [None]:
o=[bin_times[i].mid for i in range(len(bin_times))]

In [None]:
data_tuples=list(zip(fiu, fio,o))
dataframe=pd.DataFrame(data_tuples, columns=['fr', 'timing', 'mid'])

In [None]:
quantiles=pd.qcut(dataframe['timing'].values, 10, labels=np.arange(0,10,1))

In [None]:
data_tuples=list(zip(fiu, fio,bin_times, o, quantiles))
dataframe=pd.DataFrame(data_tuples, columns=['fr', 'timing', 'bins', 'mid', 'quantiles'])

In [None]:
f=dataframe.groupby('quantiles')['fr'].mean()

In [None]:
t=dataframe.groupby('quantiles')['timing'].mean()

In [None]:
fr_mid=list(zip(dataframe['fr'], dataframe['mid']))
fr_mid=pd.DataFrame(fr_mid, columns=['fr', 'mid'])
fr_mid=fr_mid[~dataframe['fr'].isna()]

In [None]:
m, b = np.polyfit(fr_mid['mid'], fr_mid['fr'], 1)

In [None]:
plt.scatter(t,f)
plt.plot(t, m*t + b)
plt.title()
plt.ylabel('Firing Rate (z)')
plt.xlabel('Time from onset of ES (s)')

In [None]:
f_tuples=list(zip(dataframe['mid'], dataframe['fr']))
f_df=pd.DataFrame(f_tuples, columns=['mid','fr'])
f_df.corr()

In [None]:
plt.figure()
plt.plot(f_df)

plt.ylabel('Firing Rate (z)')
plt.xlabel('Timebins (mins)')

In [None]:
spearmanr(dataframe['mid'], dataframe['fr'], nan_policy='omit')

In [None]:
dataframe['fr']