# This file contains all functions that are called in the main code and revisions jupyter notebooks

In [6]:
from pynwb import TimeSeries
from datetime import datetime
from dateutil.tz import tzlocal
from pynwb import NWBFile
import numpy as np
from pynwb import NWBHDF5IO
import h5py
#from pynwb import h5py
import pynwb
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd 
import seaborn as sns
import math
from scipy import stats
from scipy.ndimage import gaussian_filter
import os, sys
import scipy
from pydoc import help
from scipy.stats.stats import pearsonr
import importnb



def get_session_data(session_name, dictionary, num_cells):
    #A function that loops over the keys of a dictionary and orders the data in 3darrays.
    #The dictionary keys are session names, the dictionary values are 2d arrays of the activity of a given cell 
    #in that particular enviornment. This function orders the data of sessions in 3d arrays where the z-coordinate
    #is the cell number. Session data that is missing is flagged with a 3darray on ones. 
    
    ordered_data = np.ones(shape=(num_cells,20,20))
    for i in dictionary:
        
        if i[6:16] == session_name: 
            try: 
                ordered_data[int(i[17:19]),:,:] = dictionary[i]
            except ValueError:
                ordered_data[int(i[17]),:,:] = dictionary[i]
                
   
        elif i[6:19] == session_name:
            try: 
                ordered_data[ int(i[20:22]),:, :] = dictionary[i]
            except ValueError:
                ordered_data[int(i[20]),:,:] = dictionary[i] 
            
               
        elif i[6:20] == session_name: 
            try:
                ordered_data[int(i[21:23]),:,:] = dictionary[i]
            
            except ValueError:
                ordered_data[int(i[21]),:,:] = dictionary[i]
    
    return ordered_data

def remove_missing_data(dictionary, num_cell):
    #A function that takes in a dictionary of 3darrays and removes flagged arrays by converting them to empty lists
    #this function is unique to the syntax of the Alme_final code. the get_session_data function flags missing data 
    #by filling in the value of the dictionary with a 3darray of ones. 
    for i in dictionary: 
        if np.all(dictionary[i][0] == 1):
            dictionary[i] =[]
    return dictionary
            


def remove_outer_list(dictionary):
    #A function that takes in a dictionary and removes the outer list of the values, if the values are in a list 
    for key in dictionary: 
        if len(dictionary[key]) == 1:
            dictionary[key]= dictionary[key][0]
    return dictionary 

# Functions to identify and match repeated and different sessions 

In [2]:
def all_different_sessions(a):
    #A function that returns the different session pair of the input session
    #Different session pair means all potential combinations with other sessions that are different
    #It is expected to loop over the keys of the dictionary below. 
    #Repetions are removed! When we input a='F1' we make a comparison with 'N1', so when we input a='N1', 'F1'
    #does not appear anymore to prevent double counting.
    
    diff_sess = {'F1': ['N1','N1!','N2','N2!','N3','N3!','N4','N4!','N5','N5!','N6','N6!', 'N1*', 'N1*!',
                       'N7','N7!', 'N8','N8!','N9','N9!','N10','N10!', 'N6*', 'N6*!'],
                 
                 'N1': ['F1*','F2','F2*','N2','N2!','N3','N3!','N4','N4!','N5','N5!','N6','N6!',
                       'N7','N7!', 'N8','N8!','N9','N9!','N10','N10!', 'N6*', 'N6*!'],
                 
                 'N1!': ['F1*','F2','F2*','N2','N2!','N3','N3!','N4','N4!','N5','N5!','N6','N6!',
                       'N7','N7!', 'N8','N8!','N9','N9!','N10','N10!', 'N6*', 'N6*!'],
                
                 'N2': ['F1*','F2','F2*','N3','N3!','N4','N4!','N5','N5!','N6','N6!','N1*', 'N1*!',
                       'N7','N7!', 'N8','N8!','N9','N9!','N10','N10!', 'N6*', 'N6*!'],
                 
                 'N2!': ['F1*','F2','F2*','N3','N3!','N4','N4!','N5','N5!','N6','N6!','N1*', 'N1*!',
                       'N7','N7!', 'N8','N8!','N9','N9!','N10','N10!', 'N6*', 'N6*!'],
                 
                 'N3': ['F1*','F2','F2*','N4','N4!','N5','N5!','N6','N6!','N1*', 'N1*!',
                       'N7','N7!', 'N8','N8!','N9','N9!','N10','N10!', 'N6*', 'N6*!'],
                 
                 'N3!': ['F1*','F2','F2*','N4','N4!','N5','N5!','N6','N6!','N1*', 'N1*!',
                       'N7','N7!', 'N8','N8!','N9','N9!','N10','N10!', 'N6*', 'N6*!'],
                 
                 'N4': ['F1*','F2','F2*','N5','N5!','N6','N6!','N1*', 'N1*!',
                        'N7','N7!', 'N8','N8!','N9','N9!','N10','N10!','N6*', 'N6*!'],
                 
                 'N4!': ['F1*','F2','F2*','N5','N5!','N6','N6!','N1*', 'N1*!',
                        'N7','N7!', 'N8','N8!','N9','N9!','N10','N10!','N6*', 'N6*!'],
            
                 'N5': ['F1*','F2','F2*','N6','N6!','N1*', 'N1*!',
                        'N7','N7!', 'N8','N8!','N9','N9!','N10','N10!','N6*', 'N6*!'],
                 'N5!': ['F1*','F2','F2*','N6','N6!','N1*', 'N1*!',
                        'N7','N7!', 'N8','N8!','N9','N9!','N10','N10!','N6*', 'N6*!'],
                 
                 'N1*': ['F1*','F2','F2*','N6','N6!','N7','N7!', 'N8','N8!','N9','N9!','N10','N10!','N6*', 'N6*!'],
                 'N1*!': ['F1*','F2','F2*','N6','N6!','N7','N7!', 'N8','N8!','N9','N9!','N10','N10!','N6*', 'N6*!'], 
                 
                 'F1*': ['N6','N6!','N7','N7!', 'N8','N8!','N9','N9!','N10','N10!','N6*', 'N6*!'], 
                 
                 'F2': ['N6','N6!','N7','N7!', 'N8','N8!','N9','N9!','N10','N10!','N6*', 'N6*!'], 
                 
                 'N6': ['F2*','N7','N7!', 'N8','N8!','N9','N9!','N10','N10!'],  
                 'N6!': ['F2*','N7','N7!', 'N8','N8!','N9','N9!','N10','N10!'],  
                 
                 'N7': ['F2*','N8','N8!','N9','N9!','N10','N10!','N6*', 'N6*!'],
                 'N7!':['F2*','N8','N8!','N9','N9!','N10','N10!','N6*', 'N6*!'],
                 
                 'N8': ['F2*','N9','N9!','N10','N10!','N6*', 'N6*!'],
                 'N8!': ['F2*','N9','N9!','N10','N10!', 'N6*', 'N6*!'],
                 
                 'N9': ['F2*','N10','N10!','N6*', 'N6*!'],
                 'N9!': ['F2*','N10','N10!','N6*', 'N6*!'],
                
                 'N10': ['F2*','N6*', 'N6*!'],
                 'N10!': ['F2*','N6*', 'N6*!'],
                   
                 'N6*': ['F2*'], 
                 'N6*!': ['F2*']}
                 
    return diff_sess[a]

def all_repeated_sessions(session): 
    #A function that returns the repeated session pair of the input session 
    rep_sess_combi={'F1':  ['F1*','F2','F2*'],
                    'N1':  ['N1!','N1*','N1*!'],
                    'N1!': ['N1*','N1*!'],
                    'N2':  ['N2!'], 
                    'N3':  ['N3!'], 
                    'N4':  ['N4!'],
                    'N1*': ['N1*!'],
                    'F1*': ['F2', 'F2*'],
                    'F2':  ['F2*'],
                    'N5':  ['N5!'],
                    'N6':  ['N6!','N6*','N6*!'],
                    'N6!': ['N6*','N6*!'],
                    'N7':  ['N7!'],
                    'N8':  ['N8!'],
                    'N9':  ['N9!'],
                    'N10': ['N10!'],
                    'N6*': ['N6*!']}
    return rep_sess_combi[session] 

# Functions to compute Correlations

In [1]:
def specific_rate_map_corr(session1, session2, data):
    ###used to get RMCs of individual cells in figure 3 ONLY. Not used for 
    result = []                             #list of rate_map_corr but only the ones we are interested in 
    
    for cell in range(len(data[session1])):      #for each cell... 
        rm1 = data[session1][cell]               #get the rate map in session 1
        rm2 = data[session2][cell]               #get the rate map in session 2
        th1 = thresholds[session1][str(cell)]    #get the threshold_description in session 1
        th2 = thresholds[session2][str(cell)]    #get the threshold_description in session 1
        num_bins = rm1.shape[0]                  #should be 20 bins
        rm1 = np.reshape(rm1, num_bins**2)       #reshape the rate maps 
        rm2 = np.reshape(rm2, num_bins**2)
        result.append(pearsonr(rm1,rm2)[0])
    return result

def pv_dot(vec1, vec2): 
    #a function that takes in two 1d_arrays and returns the dot product divided by the number of elements that are
    #not nans        
    #vec1 = np.nan_to_num(vec1)
    #vec2 = np.nan_to_num(vec1)
    corr_t = []
    for x in range(vec1.shape[1]): 
        for y in range(vec1.shape[2]): 
            mask = ~np.isnan(vec1[:,x,y]) * ~np.isnan(vec2[:,x,y]) #a mask of positions where both vectors have values that are not nan
            corr = np.dot(vec1[mask, x, y] ,vec2[mask, x, y])/len(mask)#the dot product of the selected positions divided by the length
            #corr = np.dot(vec1[mask, x, y] ,vec2[mask, x, y])
            corr_t.append(corr)
    return corr_t

In [None]:
def int_rate_map_corr(session1, session2, data, thresholds):
    #A function that takes in two 3d_arrays and the thresholds and returns a vector of Rate_map correlations 
    #this fuction is intelligent meaning that it can distinguish which rate_map correlations to include and which not
    #the knowledge of this comes from thresholds, which is a dictionary 
    #thresholds = dictionary    Key1 = Session name, Key2 = cell, value = threshold_classification 
    
    
    result = []                             #list of rate_map_corr but only the ones we are interested in 
    
    counters = {'above_above': 0, 'above_below': 0, 'above_zero': 0, 
                'zero_zero': 0, 'below_zero': 0, 'below_below':0,
                'total': 0}
    
    
    for cell in range(len(data[session1])):      #for each cell... 
        counters['total'] = counters['total']+1  #count each cell, total number of observations 
        rm1 = data[session1][cell]               #get the rate map in session 1
        rm2 = data[session2][cell]               #get the rate map in session 2
        th1 = thresholds[session1][str(cell)]    #get the threshold_description in session 1
        th2 = thresholds[session2][str(cell)]    #get the threshold_description in session 1
        num_bins = rm1.shape[0]                  #should be 20 bins
        rm1 = np.reshape(rm1, num_bins**2)       #reshape the rate maps 
        rm2 = np.reshape(rm2, num_bins**2)
    
    #############################################################################
    # there are 6 thresholding cases and we treat some of them differently.     #
    #############################################################################
        
        #These are the cases that we use: 
        
        
        #Case 1: the cell spiked in both sessions above threshold 
        if th1 == True and th2 == True:    
            result.append(pearsonr(rm1,rm2)[0])    #get the Pearsonr
            counters['above_above'] = counters['above_above']+1
        
        
        #Case 2: the cell spiked above threshold in one session and below threshold in the other 
        if th1 == True and th2 == False:
            result.append(pearsonr(rm1,rm2)[0])    #get the Pearsonr
            counters['above_below'] = counters['above_below']+1
        if th1 == False and th2 == True:
            result.append(pearsonr(rm1,rm2)[0])    #get the Pearsonr
            counters['above_below'] = counters['above_below']+1
            
        #Case 3: The cell spiked above threshold in one session and was silent in the other
        if th1 == True and th2 == 'silent':
            result.append(0)                       #input 0 as corr_coeff
            counters['above_zero'] = counters['above_zero']+1
        if th1 == 'silent' and th2 == True:
            result.append(0)                       #input 0 as corr_coeff
            counters['above_zero'] = counters['above_zero']+1
        
        #Case 4: cell was silent in both sessions 
        if th1 == 'silent' and th2 == 'silent':
            counters['zero_zero'] = counters['zero_zero']+1
        
        
        #Case 5: cell was silent in one and below th in the other 
        if th1 == 'silent' and th2 == False:
            counters['below_zero'] = counters['below_zero']+1
        if th1 == False and th2 == 'silent':
            counters['below_zero'] = counters['below_zero']+1
        
        #Case 6: cell was below threshold in both sessions 
        if th1 == False and th2 == False:
            counters['below_below'] = counters['below_below']+1
        
        
        #Case 4: cell was silent in both session, 
        #Case 5: cell was below threshold in both sessions, 
        #Case 6: cell was below threshold in one session and silent in the other 
       
    
    return result,counters



def int_rate_map_corr_2(session1, session2, data, thresholds):
    #A function that takes in two 3d_arrays and the thresholds and returns a vector of Rate_map correlations 
    #this fuction is intelligent meaning that it can distinguish which rate_map correlations to include and which not
    #the knowledge of this comes from thresholds, which is a dictionary 
    #thresholds = dictionary    Key1 = Session name, Key2 = cell, value = threshold_classification 
    #Returns also case. Description of corresponding case 
    ###USED FOR HISTOGRAMS###
    
    result = []                             #list of rate_map_corr but only the ones we are interested in 
    case   = []
    counter = 0                             #counts all cases 
    
    for cell in range(len(data[session1])):             #for each cell... 
        counter = counter + 1
        rm1 = data[session1][cell]               #get the rate map in session 1
        rm2 = data[session2][cell]               #get the rate map in session 2
        th1 = thresholds[session1][str(cell)]    #get the threshold_description in session 1
        th2 = thresholds[session2][str(cell)]    #get the threshold_description in session 1
        num_bins = rm1.shape[0]                  #should be 20 bins
        rm1 = np.reshape(rm1, num_bins**2)       #reshape the rate maps 
        rm2 = np.reshape(rm2, num_bins**2)
    
    #############################################################################
    # there are 6 thresholding cases and we treat some of them differently.     #
    #############################################################################
        
        #These are the cases that we use: 
        
        
        #Case 1: the cell spiked in both sessions above threshold 
        
        if th1 == True and th2 == True:    
            result.append(pearsonr(rm1,rm2)[0])    #get the Pearsonr
            case.append('above_above')
        #Case 2: the cell spiked above threshold in one session and below threshold in the other 
        
        if th1 == True and th2 == False:
            result.append(pearsonr(rm1,rm2)[0])    #get the Pearsonr
            case.append('above_below')
        
        if th1 == False and th2 == True:
            result.append(pearsonr(rm1,rm2)[0])    #get the Pearsonr
            case.append('above_below')
            
        #Case 3: The cell spiked above threshold in one session and was silent in the other
        if th1 == True and th2 == 'silent':
            result.append(0)                       #input 0 as corr_coeff
            case.append('above_zero')
        if th1 == 'silent' and th2 == True:
            result.append(0)                       #input 0 as corr_coeff
            case.append('above_zero')
    
        #Case 4: cell was silent in both session, 
        #Case 5: cell was below threshold in both sessions, 
        #Case 6: cell was below threshold in one session and silent in the other 
        #-----> we ignore these cases
    #print(counter)
    excluded = counter - len(result)

    return result,case,excluded


    
def get_corr_coeff(comparisons, data): 
    # a function that takes in a dictionary and data and returns the correct correlation coefficients in a list
    # Data = 28 * 28 matrix with each cell representing a correlation coefficient between two rooms 
     
        dic = {'F1':  0,                      #conversions of sessions to indeces 
            'N1':  1, 'N1!': 2,
            'N2':  3, 'N2!': 4,
            'N3':  5, 'N3!': 6, 
            'N4':  7, 'N4!': 8,
            'N5':  9, 'N5!': 10, 
            'N1*': 11, 'N1*!':12,
            'F1*': 13, 
            'F2':  14, 
            'N6':  15, 'N6!': 16,
            'N7':  17, 'N7!': 18, 
            'N8':  19, 'N8!': 20, 
            'N9':  21, 'N9!': 22,
            'N10': 23, 'N10!':24,  
            'N6*': 25, 'N6*!':26,
            'F2*': 27}

        corr_coeff_all = []
        
        for session_1 in list(comparisons.keys()):
            for session_2 in comparisons[session_1]:
                corr_coeff = data[dic[session_1]][dic[session_2]]
                corr_coeff_all.append(corr_coeff)
                
        return corr_coeff_all
                
                
       

# Averaging between animals 

In [1]:
def corr_average(mat_1, mat_2, mat_3, mat_4, mat_5): 
    #A function that takes in 5 2D matrices and builds an average matrix of the values involved.
    #the matrices must all be the same x and y dimensions. x and y are the sessions that were correlated. 
    #This funtions is used to average the correlation plots of Part 6 in Part 7
    #the five animals. The function returns the average martix. 
    
    y = mat_1.shape[0]                     #extract the length of the y-axis
    x = mat_1.shape[1]                     #extract the length of the x-axis 
    cum_mat = np.ndarray(shape=(5, y, x))  #this is a data holder matrix where the z dimension will correspond to 
                                           #the matrices that we input and want to average 
    average = np.zeros(shape=(y,x))        #this 2D matrix will hold our average data
    
    cum_mat[0] = mat_1 #Put the input matrixes into the cum_mat data holder 
    cum_mat[1] = mat_2
    cum_mat[2] = mat_3
    cum_mat[3] = mat_4
    cum_mat[4] = mat_5
    for y_ in range(y):                                   #for each row... 
        for x_ in range(x):                               #for each element in that row...
            average[y_,x_] = np.nanmean(cum_mat[:,y_,x_]) #the average value is computed along the z axis. 
                                                          #nans are not taken into account.       
    return average     #Return the average 2D matrix 

In [None]:
def cohen_d(x,y):
    nx = len(x)
    ny = len(y)
    dof = nx + ny - 2
    s =  np.sqrt(((nx-1)*np.std(x, ddof=1) ** 2 + (ny-1)*np.std(y, ddof=1) ** 2) / dof)
    return (np.mean(x) - np.mean(y)) / s



def get_value(animal,sessions, data): 
    # a function that takes in a list of sessions and data and returns the correct diff in a list
    # Data = dictionary of dictionaries, key 1 = animal, key 2 = session, value = behavioral metric
    result = []
    
    for s in sessions: 
        result.append(data[s])
        
    return np.mean(result)


def get_diff(comparisons, data): 
    # a function that takes in a dictionary and data and returns the correct diff in a list
    # Data = 28 * 28 matrix with each cell representing a difference in a behavioral metric between two rooms 
     
        dic = {'F1':  0,                      #conversions of sessions to indeces 
               'N1':  1, 'N1!': 2,
               'N2':  3, 'N2!': 4,
               'N3':  5, 'N3!': 6, 
               'N4':  7, 'N4!': 8,
               'N5':  9, 'N5!': 10, 
               'N1*': 11, 'N1*!':12,
               'F1*': 13, 
               'F2':  14, 
               'N6':  15, 'N6!': 16,
               'N7':  17, 'N7!': 18, 
               'N8':  19, 'N8!': 20, 
               'N9':  21, 'N9!': 22,
               'N10': 23, 'N10!':24,  
               'N6*': 25, 'N6*!':26,
               'F2*': 27}

        diff = []
        
        for session_1 in list(comparisons.keys()):
            for session_2 in comparisons[session_1]:
                diff_sess = data[dic[session_1]][dic[session_2]]
                diff.append(diff_sess)
                
        return diff

In [None]:
def threshold_counter(session1, session2,animal, thresholds):
    #A function that takes in two session names, the name of an and  animal  a dictionary of thresholds 
    #and counts how many cells are above threshold etc 
    #thresholds = dictionary    Key1 = animal name, Key2 = session name , Key3 = cell name 
    #value = threshold_classification 
     
    case   = []
    counter = 0                             #counts all cases 
    
    num_cell = len(thresholds[animal][session1])
    
    cells_above_above = 0
    cells_above_below = 0
    for cell in range(len(thresholds[animal][session1])):             #for each cell... 
        counter = counter + 1
        th1 = thresholds[animal][session1][str(cell)]    #get the threshold_description in session 1
        th2 = thresholds[animal][session2][str(cell)]    #get the threshold_description in session 1  
    #############################################################################
    # there are 6 thresholding cases and we treat some of them differently.     #
    #############################################################################
        
        #These are the cases that we use: 
        
        
        #Case 1: the cell spiked in both sessions above threshold 
        
        if th1 == True and th2 == True:    
            case.append('above_above')
            cells_above_above = cells_above_above+1
        #Case 2: the cell spiked above threshold in one session and below threshold in the other 
        
        if th1 == True and th2 == False:
            case.append('above_below')
            cells_above_below = cells_above_below +1
        
        if th1 == False and th2 == True:
            case.append('above_below')
            cells_above_below = cells_above_below +1
            
        #Case 3: The cell spiked above threshold in one session and was silent in the other
        if th1 == True and th2 == 'silent':     
            case.append('above_zero')
            cells_above_below = cells_above_below +1
        
        if th1 == 'silent' and th2 == True:
            cells_above_below = cells_above_below +1
            case.append('above_zero')
    
        #Case 4: cell was silent in both session, 
        #Case 5: cell was below threshold in both sessions, 
        #Case 6: cell was below threshold in one session and silent in the other 
        #-----> we ignore these cases
    #print(counter)
    
    #excluded = counter - len(result)
    fraction_above = cells_above_above/num_cell * 100
    fraction_above_below = cells_above_below/num_cell * 100
    fraction_below_below = 100- fraction_above_below - fraction_above

    return case, fraction_above, fraction_above_below, fraction_below_below