## This notebook generates a bench of plot to show different parameters of task performance, effect of opto stim and relationship with electrophysiology

### This notebook is at the core of the pipeline of data processing. Do not play with it lightly inside the master folder (load_preprocess_mouse)

#### 1. Only modifiy if you are sure of what you are doing and that you are solving a bug
#### 2. If you do modify you MUST commit this modification using bitbucket
#### 3. If you want to play whis notebook (to understand it better) copy it on a toy folder distinct from the master folder
#### 4. If you want to modify this code (fix bug, improve, add attributes ...) it is recommanded  to first duplicate in a draft folder. Try to keep track of your change.
#### 5. When you are ready to commit : # clear all output, clean everything between hashtag 


# 0. Import packages and define a few basic functions

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import platform
import glob
import pickle
import itertools


import numpy as np
import pandas as pd
from scipy.ndimage.filters import gaussian_filter as smooth
from scipy.signal import argrelextrema
from scipy import stats

import warnings

import matplotlib.pyplot as plt
from matplotlib import collections  as mc
from matplotlib.patches import Rectangle
import matplotlib.gridspec as gridspec

%matplotlib inline

## The lines below allow to run required notebook from the master folder
if "__file__" not in dir():
    
    ThisNoteBookPath=os.path.dirname(os.path.realpath("__file__"))
    CommunNoteBookesPath=os.path.join(os.path.split(ThisNoteBookPath)[0],"load_preprocess_mouse")
    os.chdir(CommunNoteBookesPath)   
    
    %run loadMouse_documentation.ipynb
    %run loadRawSpike_documentation.ipynb
    %run plotMouse_documentation.ipynb
    
def has_tag(root, animal, session, tagList):
    """ Test if the session has at least one of the tag in tagList
    tag = empty file with a specific name, in a session folder
    """
    fullPath = os.path.join(root, animal, "Experiments", session)
    fileList = os.listdir(fullPath)
    for tag in tagList:
        if tag in fileList:
            return True
    return False

def cm2inch(value):
    return value/2.54

if "__file__" not in dir():

    if platform.system()=='Linux':
        root="/data"
    else:
        root="/Users/davidrobbe/Documents/Data/"

    print("The path to data is %s"%root)
    
    
def contiguous_regions(condition):
    """Finds contiguous True regions of the boolean array "condition". Returns
    a 2D array where the first column is the start index of the region and the
    second column is the end index."""

    # Find the indicies of changes in "condition"
    d = np.diff(condition)
    idx, = d.nonzero() 

    # We need to start things after the change in "condition". Therefore, 
    # we'll shift the index by 1 to the right.
    idx += 1

    if condition[0]:
        # If the start of condition is True prepend a 0
        idx = np.r_[0, idx]

    if condition[-1]:
        # If the end of condition is True, append the length of the array
        idx = np.r_[idx, condition.size] # Edit

    # Reshape the result into two columns
    idx.shape = (-1,2)
    return idx
    
# you can manually specify root data foder between hashtag line
##############################

##############################

# 1. Load and preprocess one session

In [None]:
if "__file__" not in dir():
    
    """
    Below is an example of data path information. DO NOT CHANGE THIS LINE BELOW
    BUT you can put your data path information between the 2 hashtag lines after the example
    (for git tracking issue)
    """
    
    SESSION="MOU073_2015_09_07_10_06/"
    
    
##############################
    
##############################


    """
    If you want to commit changes made in this NB please delete leave 1 empty line between the 2
    hashtag lines instead of your own data path
    """
    ANIMAL=SESSION[0:6]
    
    #Those parameters are overwritten if there is a .behav_param file
    paramCarola={
        "distanceToRun":100,
        "maxTrialDuration": 60,
        "valveONTime":50,
        "minInterTrialDuration":15,
        "immobilityDuration":2,
        #to read .eeg (put None to not read .eeg)
        "nChannelElectro":32, #32
        "channel_opto": -6, #not used
        "channel_lickBreak":-5,
        "channel_reward": -4, #not used currently
        "channel_sound": -3, #not used
        "channel_trialON": -2,
        "channel_beamBreak": -1,
    }
    data=Data(root,ANIMAL,SESSION,paramCarola,redoPreprocess=True)
    print("----------------")

    if data.hasBehavior and not data.isLickTraining:
        data.describe()

# 2. A serie of behavioral plots

### Create "fake" lick after reward if lick are missing 
#### obvously this is for approximative analysis purpose

In [None]:
def CreateMissingLicks(data,LickEstimatedDuration=4):
    
    licktimesintrialintetrital={}
    for trial in data.trials:
        licktimesintrialintetrital[trial]=[]
        licktimesintrialintetrital[trial+.5]=[]
    
    
    alllickstimes=[]
    for trial in data.trials:
        if np.round(data.durationTrial[trial],1)<data.maxTrialDuration[trial]: ## this is a good trial
#             roundedvalue=np.round(data.durationTrial[trial],1)
#             print("trial %s trial duration %s ,rounded value: %s" %(trial,data.durationTrial[trial],roundedvalue))
        
            
            ## generate alllickbreaktime
            endofgoodtrialtime=data.realStartTrial[trial]+data.durationTrial[trial]
            alllickstimes.extend(list(np.arange(endofgoodtrialtime,endofgoodtrialtime+LickEstimatedDuration,0.125)))
            
            ## generate lickbreaktime per (good) intetrial
            licktimesintrialintetrital[trial+.5]=list(np.arange(0,LickEstimatedDuration,0.125))
            
    return alllickstimes,licktimesintrialintetrital

#---------------------------------------------------------------------------------------------------------
if "__file__" not in dir():  
    alllickstimes,licktimesintrialintetrital=CreateMissingLicks(data)
 

### Plot beam breaks for every trial

In [None]:
def plot_break(data, legend=False, colorOpto="orange", xmax=60, ax=None, lick=True):  
    """
    Plot the beam breaks, lick breaks and indicates optogenetic stimulation.
    One line: one trial + following intertrial. 0=end of trial (reward)
    Input
      - legend: whether to put a legend, on the top right outside the plot
      - colorOpto: colors for the rectangular boxes indicating optogenetic stimulation
      - xmax: maximum on the xaxis
      - ax: matplotlib figure axis, usefull for complex subplots
      - lick: whether to plot the lick breaks
    """
    if ax is None:
        ax = plt.gca()
    minTime = - max(data.durationTrial)
    maxTime = max(data.durationInterTrial)
    maxTime = min(xmax, maxTime)
    distanceToRun = data.distanceToRun[0]
    darkLines = []
    greenLines = []
    #ticks where the distance to run has changed
    boldTicks = {1:distanceToRun}
    

    ## if there is no lick data we estimate them approximatively!
    if not data.allLickBreak:
        print("there was no lick so we create some")
        data.allLickBreak,data.lickBreakTime=CreateMissingLicks(data)
        #print("missing lick times %s" %data.allLickBreak)
        
    
    for trial in data.trials:
        y1 = trial + 1.1
        y2 = trial + 1.1
        y3 = trial + 1.7
        y4 = trial + 1.9       
        #look if distance to run changed
        newDistance = data.distanceToRun[trial]
        if newDistance != distanceToRun:
            boldTicks[trial+1] = newDistance
            distanceToRun = newDistance      
        #we want to align the plot on the trial end (=reward =start of intertrial)
        zero = data.durationTrial[trial]        
        #trial beam break time are relative to trial start        
        for breakTime in data.beamBreakTime[trial]:
            x = breakTime - zero
            darkLines.append([(x,y1), (x,y4)]) 

            
        #intertrial beam break time are already aligned correctly
        for x in data.beamBreakTime[trial+0.5]:
            darkLines.append([(x,y1), (x,y4)])
            
        #if there is lick break times
        if (len(data.lickBreakTime) > 0) and lick:
            for lickTime in data.lickBreakTime[trial]:
                x = lickTime-zero
                greenLines.append([(x,y2), (x,y3)])
            for x in data.lickBreakTime[trial+0.5]:
                greenLines.append([(x,y2), (x,y3)])          
        #color in grey the duration of trial and intertrial
        #rectangle= (x,y) lower left, width, height
        endInterTrial = data.durationInterTrial[trial]
        ax.add_patch(Rectangle((-zero,trial+1), zero, 1, facecolor="lightgrey", edgecolor="none"))
        ax.add_patch(Rectangle((0,trial+1), endInterTrial, 1, facecolor="lavender", edgecolor="none"))              
    
    #Plot all the lines at once (gain time)
    lc = mc.LineCollection(darkLines, colors="black", label="beam breaks")
    lc2 = mc.LineCollection(greenLines, colors="forestgreen", label="lick breaks")
    ax.add_collection(lc)
    ax.add_collection(lc2)        
    #color the optogenetic stimulation
    #start and stop are relative to trial start
    if data.hasOptogenetic:
        plt.plot([], [], linewidth=6, color=colorOpto, label="optogenetic stimulation")
        for trial in data.trials:
            zero = data.durationTrial[trial]
            start = data.startStimulation[trial]
            if start is not None:
                stop = data.stopStimulation[trial]
                ax.add_patch(Rectangle((start-zero, trial+1), (stop-start), 1, 
                                       edgecolor=colorOpto, fill=False, lw=2, zorder=10))
            #elif potentialstimstart is not None:
                
    #blue line at 0                
    plt.axvline(0, color="blue")    
    #bold ticks
    space = 1
    rangeList = list(set(range(1, data.nTrial+2, space)).union(boldTicks))
    ticksPosition = [y+0.5 for y in rangeList]
    ticksLabel = [str(boldTicks[y]) + "cm| " + str(y) if y in boldTicks else str(y) for y in rangeList]  
    plt.yticks(ticksPosition, ticksLabel)    
    #axis limits
    plt.xlim([minTime, maxTime])
    ax.invert_yaxis()
    plt.ylim([data.realTrials[-1]+1, data.realTrials[0]])   
    #axis labels and title
    plt.xlabel("time in seconds (0=intertrial start)", fontsize=14)
    plt.ylabel("trial number", fontsize=14)
    title = data.experiment + " (day " + str(data.daySinceStart) + ")\nBeam and lick break time"
    if data.hasOptogenetic:
        title += "\n %s" %data.stimulationNames
    plt.title(title, fontsize=14)   
    #legend with no duplicate
    if legend:
        plt.plot([], [], linewidth=6, color='lightgrey', label="trial time range")
        plt.plot([], [], linewidth=6, color='lavender', label="intertrial time range")
        plt.legend(loc='best', bbox_to_anchor=(1, 1))
    return title
            
#---------------------------------------------------------------------------------------------------------
if "__file__" not in dir():    
    plt.figure(figsize=(10,30))
    plot_break(data, legend=True,lick=False)

### Plot mean breaks

In [None]:
def plot_mean_breaks(data, binSize=0.25, minTime=-20, maxTime=40, align="trial end", separate="good trial",
                     lick=False, s=1, displayOnly=0, number=0):  
    '''
    Plots the mean running speed (or licking frequency), for two groups of trials,
      according to the argument 'separate'
    Input
      - s: sigma for smoothing the means
      - lick: True to plot lick frequency instead of running speed
      - separate: 
          "good trials"(/bad trials)
          "optogenetic"(/no optogenetic)
          "none"(mean for all trials, only one curve is plotted)
          "trial number" (trial < number / trial >= number) --- NB: first trial is 0
      - align: where to align the trials (where to consider 0)
         either "trial end" or "trial start"
         overriden by "stimulation start" in case of optogenetic stimulation after N ticks
      - minTime, maxTime: time range for xaxis, relative to 0 (see 'align')
      - displayOnly: 
          0 = nothing
          1 = plots only the separated group, in black (only the good trials, or only the optogenetic)
          2 = plots only the other group (bad trials, or trials without optogenetic)
      - number: see argument separate="trial number"
    '''
    #bins between minTime and maxTime
    timeBin = np.arange(minTime, maxTime+binSize-maxTime%binSize, binSize)
    centers = (timeBin[:-1]+timeBin[1:]) / 2.0
    
    
    ## if there is no lick data we estimate them approximatively!
    if not data.allLickBreak:
        print("there was no lick so we create some")
        data.allLickBreak,data.lickBreakTime=CreateMissingLicks(data)
        #print("missing lick times %s" %data.allLickBreak)    
    
    
    #lick or the beam breaks
    if lick:
        allBreak = np.asarray(data.allLickBreak)
    else:
        allBreak = np.asarray(data.allBeamBreak)
    #check we have data
    if len(allBreak) == 0:
        plt.title("Nothing to plot")
        return "nothing to plot"
    #colors depending on case
    if separate == "good trial":
        color = "green"
    elif separate == "optogenetic":
        color = "purple"
    elif separate == "trial number":
        color="blue"
    elif separate != "none":
        print("Unvalid value for separate. Choose between 'good trial', 'optogenetic', 'trial number' and 'none'")
        return
    if displayOnly:
        color = "black"
        if lick:
            color = "darkblue"
    #special case for "Stimulate after N ticks": trials have to be aligned to stimulation start
    Ncase = False
    if (data.hasOptogenetic) and (separate == "optogenetic"):
        if data.stimulationNames.startswith("Stimulate after"):
            Ncase = True
            align = "stimulation start"      
    InCase=False
    if (data.hasOptogenetic) and (separate == "optogenetic"):
        if data.stimulationNames.startswith("Stimulate in trial"):
            InCase = True
            align = "stimulation start"   
    InInterCase=False
    if (data.hasOptogenetic) and (separate == "optogenetic"):
        if data.stimulationNames.startswith("Stimulate in intertrial"):
            InInterCase = True
            align = "stimulation start"  
    #compute speeds
    separateBeamCount = [] #"good trials" or "optogenetic trials"
    allBeamCount = []      # all the other trials
    for trial in data.trials:
        #where is 0
        if Ncase:
            if (data.PutativeStimTimeAtNTicks[trial] is None) :
                continue
            zero = data.realStartTrial[trial] + data.PutativeStimTimeAtNTicks[trial]
        if InCase:
            if (data.PutativeStimTimeInTrial[trial] is None) :
                continue
            zero = data.realStartTrial[trial] + data.PutativeStimTimeInTrial[trial]
        if InInterCase:
            if (data.PutativeStimTimeInInterTrial[trial] is None) :
                continue
            zero = data.realStartTrial[trial] + data.PutativeStimTimeInInterTrial[trial]                
        elif align == "trial end":
            zero = data.realStartTrial[trial] + data.durationTrial[trial]
        elif align == "trial start":
            zero = data.realStartTrial[trial]
        #align on zero and compute speed
        alignedBreak = allBreak - zero
        hist, bins = np.histogram(alignedBreak, timeBin)
        #add to group
        if (separate == "good trial") and (trial in data.goodTrials):
            if displayOnly != 2:
                separateBeamCount.append(hist) 
            continue
        elif (separate == "optogenetic") and (data.hasOptogenetic):
            if data.stimulationOccured[trial] == True:
                if displayOnly != 2:
                    separateBeamCount.append(hist)
                continue
        elif (separate == "trial number") and (trial >= number):
            if displayOnly != 2:
                separateBeamCount.append(hist)
            continue
        if displayOnly != 1:
            allBeamCount.append(hist)   
    #compute means, avoiding "All Nan Slice" warning
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        meanBeamCount = np.nanmean(np.asarray(allBeamCount), axis=0)
        meanSeparate = np.nanmean(np.asarray(separateBeamCount), axis=0)
    if lick:
        meanBin = meanBeamCount/float(binSize)
        meanBinSeparate = meanSeparate/float(binSize)
        plt.ylabel("mean licking frequency (Hz)", fontsize=14)
    else:
        meanBin = meanBeamCount*data.tickDistance/float(binSize)
        meanBinSeparate = meanSeparate*data.tickDistance/float(binSize)
        plt.ylabel("mean running speed (cm/sec)", fontsize=14)
    #plot if not empty
    if len(allBeamCount) > 0:
        plt.plot(centers, smooth(meanBin,s), "k-")
    if len(separateBeamCount) > 0:
        plt.plot(centers, smooth(meanBinSeparate, s), "-", color=color, label=separate)
    plt.xlabel("time (sec), binSize=%ss, 0=%s"%(binSize, align), fontsize=14)
    plt.xlim(minTime, maxTime)
    #line at 0
    if lick:
        title = "Lick"          
        plt.axvline(0, color="orange", linestyle="--")    
    else:
        title = "Running"            
        plt.axvline(0, color="blue", linestyle="--")    
    #title
    if separate != "none":
        title = title + " (%s=%s %s, black=%s other)"%(color, len(separateBeamCount), separate, len(allBeamCount))
    if (data.hasOptogenetic) and (separate=="optogenetic"):
        title += "\n " + data.stimulationNames
    plt.title(title, fontsize=14)
    
    
    

#---------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    plt.figure(figsize=(20,10))
    plt.subplot(221)
    plot_mean_breaks(data, align="trial end", separate="good trial", minTime=-30, maxTime=30)

    plt.subplot(222)
    plot_mean_breaks(data, align="trial start", minTime=-10, maxTime=60, separate="good trial")

    plt.subplot(223)
    plot_mean_breaks(data, lick=True, align="trial end", minTime=-10, maxTime=20, displayOnly=1)
    
    plt.subplot(224)
    plot_mean_breaks(data, align="trial end", minTime=-20, maxTime=40, separate="optogenetic")
    
    plt.tight_layout()



### General behavior plot

In [None]:
def behavior_plot(data):
    """
    A general plot to display the behavior during one session
    """
    plt.figure(figsize=(20,25))
    plt.suptitle(data.experiment + " (day %s)" %data.daySinceStart, fontsize=20)    
    plt.subplot(121)
    plot_break(data, legend=False, xmax=40)

    plt.subplot(524)
    plot_mean_breaks(data, align="trial end", separate="good trial", minTime=-30, maxTime=30)

    plt.subplot(522)
    plot_mean_breaks(data, align="trial start", minTime=-10, maxTime=60, separate="good trial")
       
    if data.hasOptogenetic:
        stop = np.nanmean([d/1000.0 for d in data.opticalDuration if not isNone(d) and d>0])
        plt.subplot(526)
        if "," in data.stimulationNames:
            plt.title("Can't plot: two or more stimulation types")
            
            plt.subplot(528)
            plot_mean_breaks(data, lick=True, align="trial end", minTime=-10, maxTime=10)
            
        elif data.stimulationNames == "Stimulate at beginning of trial":
            plot_mean_breaks(data, align="trial start", separate="optogenetic", minTime=-10, maxTime=20)
            plt.axvspan(0, stop, color="lightgray", alpha=0.5)
            
            plt.subplot(528)
            plot_mean_breaks(data, lick=True, align="trial start", separate="optogenetic", minTime=-10, maxTime=10)
            plt.axvspan(0, stop, color="lightgray", alpha=0.5)
            
        elif data.stimulationNames.startswith("Stimulate before end of inter-trial"):
            t = np.nanmean([d for d in data.stimulateBeforeEnd_time if not isNone(d) and d>0])
            plot_mean_breaks(data, align="trial end", separate="optogenetic", minTime=-10, maxTime=20)
            plt.axvspan(t, t+stop, color="lightgray", alpha=0.5)
            
            plt.subplot(528)
            plot_mean_breaks(data, lick=True, align="trial end", separate="optogenetic", minTime=-10, maxTime=10)
            plt.axvspan(t, t+stop, color="lightgray", alpha=0.5)
            
        elif data.stimulationNames == "Stimulate  upon reward":
            plot_mean_breaks(data, align="trial end", separate="optogenetic", minTime=-10, maxTime=10)
            plt.axvspan(0, stop, color="lightgray", alpha=0.5)
            
            plt.subplot(528)
            plot_mean_breaks(data, lick=True, align="trial end", separate="optogenetic", minTime=-10, maxTime=10)
            plt.axvspan(0, stop, color="lightgray", alpha=0.5)
            
        elif data.stimulationNames.startswith("Stimulate after"):
            plot_mean_breaks(data, separate="optogenetic")
            plt.axvspan(0, stop, color="lightgray", alpha=0.5)
            
            plt.subplot(528)
            plot_mean_breaks(data, lick=True, separate="optogenetic", minTime=-10, maxTime=30)
            plt.axvspan(0, stop, color="lightgray", alpha=0.5)
        elif data.stimulationNames.startswith("Stimulate in trial"):
            plot_mean_breaks(data, separate="optogenetic", minTime=-10, maxTime=20)
            plt.axvspan(0, stop, color="lightgray", alpha=0.5)
            
            plt.subplot(528)
            plot_mean_breaks(data, lick=True, separate="optogenetic", minTime=-10, maxTime=30)
            plt.axvspan(0, stop, color="lightgray", alpha=0.5)
            
        elif data.stimulationNames.startswith("Stimulate in intertrial"):
            plot_mean_breaks(data, separate="optogenetic", minTime=-10, maxTime=20)
            plt.axvspan(0, stop, color="lightgray", alpha=0.5)
            
            plt.subplot(528)
            plot_mean_breaks(data, lick=True, separate="optogenetic", minTime=-10, maxTime=30)
            plt.axvspan(0, stop, color="lightgray", alpha=0.5)
        else:
            plt.title("No stimulation to plot", fontsize=14)
            
    if data.lickBreakTime:
        plt.subplot(5,2,10)
        plot_mean_breaks(data, lick=True, align="trial end", minTime=-10, maxTime=10)
        
    plt.subplots_adjust(top=0.93, hspace=0.3)
  

if "__file__" not in dir():
    behavior_plot(data)

#### the part below is an addition of wahiba/Loubna to save some plot. check with them 
#run only if inside this notebook (does not execute if "%run this_notebook")
# if "__file__" not in dir():  
#     if data.hasBehavior:
#         behavior_plot(data)
#         name = "behavior_plot" + ".png"
#         path = os.path.join(data.sessionPath, name)
#         plt.savefig(path)
        
#     else:
#         print("no behavior")
        
        

### Detect running periods and immobility periods

In [None]:
def detect_running_period(data, minDurationSecond=2, maxDurationSecond=None, runType="all", maxTimeBetweenBreak=None):
    """    
    detect periods where the animal runs without pause (time between consecutive break < maxTimeBetweenBreak)
    remove periods shorter than minDurationSeconds and longer than maxDurationSecond (put None for no maximum)
    
    runType can be "all", "intertrial", "trial good run", "trial bad run", or "unrewarded" (bad run and intertrial runs)
    
    """
    return detect_activity_period(data, minDurationSecond, maxDurationSecond, periodType=runType,
                                 lick=False, maxTimeBetweenBreak=maxTimeBetweenBreak)

def detect_licking_period(data, minDurationSecond=0, maxDurationSecond=None, lickType="all", maxTimeBetweenBreak=None):
    """
    same for licking
    """
    return detect_activity_period(data, minDurationSecond, maxDurationSecond, periodType=lickType, 
                                  lick=True, maxTimeBetweenBreak=maxTimeBetweenBreak)
    
def detect_activity_period(data, minDurationSecond=0, maxDurationSecond=None, periodType="all", lick=False,
                           maxTimeBetweenBreak = None, allActivity=False):
    """ detect periods of consecutive breaks
    
    - minDurationSecond: minimun duration of the period, in second
    - maxDurationSecond: maximum duration of the period, in second
    - periodType: period types to keep (see the "get_running/licking_period_type" methods)
       for running: "all", "intertrial", "trial good run", "trial bad run", "unrewarded"(=intertrial + trial bad run)
       for licking: "all", "rewarded", "unrewarded"
       NB - periodType can be a string, or a list of strings
    - lick: whether to use beambreaks (running, lick=False) or lickbreaks (lick=True)
    - maxTimeBetweenBreak: maximum time allowed between two consecutive breaks for them to be in the same period
    - allActivity: wheter to use beambreaks and lickbreaks together (as if it was one set of breaks)
        implies periodType="all"
    """
    if maxTimeBetweenBreak is None:
        maxTimeBetweenBreak = data.maxTimeBetweenBreak
        
    ## if there is no lick data we estimate them approximatively!
    if not data.allLickBreak:
        print("there was no lick so we create some")
        data.allLickBreak,data.lickBreakTime=CreateMissingLicks(data)
        #print("missing lick times %s" %data.allLickBreak)
    
    #detect group of breaks with no pause
    if allActivity:
        allBreaks = np.sort(np.append(data.allLickBreak, data.allBeamBreak))
        periodType = "all"
    elif lick:
        if not(data.allLickBreak):
            print("No lick Data")
            start=[]
            end=[]
            indexes=[]
            return start, end, indexes
        else:
            allBreaks = data.allLickBreak
    else:
        allBreaks = data.allBeamBreak
    
    #period type has to be a list of strings
    if not isinstance(periodType, list):
        periodType = [periodType]
    if "unrewarded" in periodType:
        periodType.append("trial bad run")
        periodType.append("intertrial")
        
    previousBreak = allBreaks[0]
    allPeriods = [[previousBreak]]
    indexStart = [0]
    indexStop = []
    i = 0
    for index, b in enumerate(allBreaks[1:]):
        if (b-previousBreak) >= maxTimeBetweenBreak:
            i += 1
            indexStop.append(index)    #index 0 is actually index 1, because of [1:]
            indexStart.append(index+1)
            allPeriods.append([])
        allPeriods[i].append(b)
        previousBreak = b
    indexStop.append(index)
    
    #get start and end of group
    startRunning=[]
    endRunning=[]
    indexes=[]
    for period, start, stop in zip(allPeriods, indexStart, indexStop):
        #remove period too short
        duration = period[-1] - period[0]
        if duration < minDurationSecond:
            continue
        #remove period too long (if specified)
        if maxDurationSecond is not None:
            if duration > maxDurationSecond:
                continue
        #period to keep
        indexes.append((start, stop))
        startRunning.append(period[0])
        endRunning.append(period[-1])
        
    #select running periods according to trial/intertrial, good/bad
    if "all" not in periodType:
        newStart = []
        newEnd = []
        newIndexes = []
        for start, end, ind in zip(startRunning, endRunning, indexes):
            if lick:
                tt = get_licking_period_type(data, start, end)
            else:
                tt = get_running_period_type(data, start, end) 
            if tt in periodType:
                newStart.append(start)
                newEnd.append(end)
                newIndexes.append(ind)
        return newStart, newEnd, newIndexes
    else:
        return startRunning, endRunning, indexes
#---------------------------------------------------------------------------------------------------------------
def get_running_period_type(data, start, stop):
    '''
    a running period can be:
     - during a trial (good run or bad run)
     - in between trial and intertrial (good run or intertrial)
     - in intertrial (intertrial run)
     
    Warning : during electrophy recordings using Carola original code there is
    # a bug that allows good runs [that will be rewarded] to start at the end of the previous trial [this is only a few case]
    a good run can therefore start in intertrial of he previous trial
    
    This is taken care of using the end of detected run (instead of start originally in Typhaine code) to detect the closest previous trial start time
    '''
    # first "trial start" before this running period
    trialIndex=np.where(data.realStartTrial<=stop)[0]
    if len(trialIndex)>0:
        trialIndex=trialIndex[-1]
    else:
        trialIndex=0
    trialTime=data.realStartTrial[trialIndex]
    # corresponding "intertrial start" 
    interTime=data.realStartInterTrial[trialIndex]
    #test if before or after intertrial start
    if start>=interTime:
        return "intertrial"
    else:
        #good run: trial is good, and end of run is on reward, or later
        #allows to be 1 second before reward, in case of not precise timing
        if trialIndex in data.goodTrials:
            if (stop+1)>=interTime:
                return "trial good run"
        else:
            if stop>=interTime:
                return "intertrial"
        return "trial bad run"
#---------------------------------------------------------------------------------------------------------------
def get_licking_period_type(data, start, stop):
    '''
    a licking period can be:
     - rewarded
     - not rewarded
    '''
    # first "trial start" before this running period
    trialIndex=np.where(data.realStartTrial<=start)[0]
    if len(trialIndex)>0:
        trialIndex=trialIndex[-1]
    else:
        trialIndex=0
    # corresponding "intertrial start" 
    interTime=data.realStartInterTrial[trialIndex]
    #test if before or after intertrial start
    if (start <= interTime + 0.5) and (stop >= interTime + 1):
        return "rewarded"
    else:
        return "unrewarded"
#---------------------------------------------------------------------------------------------------------------    
def detect_immobility_period(data, minDurationSecond=2, maxDurationSecond=None, immobilityType="all",
                             runMinDuration=0.1, allActivity=False, lick=False):
    '''
    detect periods where the animal is immobile:
      detect the runs (longer than runMinDuration) / or the licks if lick=True
      take what's in between the run
    remove periods shorter than minDurationSecond and longer than maxDurationSecond (None=no maximum duration)
    If runMinDuration=0: every tick is in a run. Immobility period has zero ticks
    if runMinDuration=0.1: a tick alone is not considered a run, and will be inside an immobility period
    '''   
    #detect runs, removed duration<minDurationSecond (strict)
    if allActivity:
        startRunning, endRunning, indexes = detect_activity_period(data, minDurationSecond=runMinDuration,
                                                                  maxDurationSecond=None, periodType="all",
                                                                  allActivity=True)
    elif lick:
        startRunning, endRunning, indexes = detect_licking_period(data, minDurationSecond=runMinDuration,
                                                                 maxDurationSecond=None, lickType="all")
    else:
        startRunning, endRunning, indexes = detect_running_period(data, minDurationSecond=runMinDuration,
                                                                maxDurationSecond=None, runType="all")
    #immobility= periods between runs
    startImmobile = endRunning[:-1]
    endImmobile = startRunning[1:]
    
    startIndexes = [index[1] for index in indexes[:-1]]
    endIndexes = [index[0] for index in indexes[1:]] 
    indexesImmobile = list(zip(startIndexes, endIndexes))
    
    #select periods
    newStart = []
    newEnd = []
    newIndexes = []
    for s, e, i in zip(startImmobile, endImmobile, indexesImmobile):
        #remove immobility period too long or too short
        duration = e - s
        if (duration > minDurationSecond):
            if (maxDurationSecond is None) or (duration < maxDurationSecond):
                #select type
                if (immobilityType == "all") or (get_immobility_period_type(data, s) == immobilityType):
                    newStart.append(s)
                    newEnd.append(e)
                    newIndexes.append(i)
                    
    return newStart, newEnd, newIndexes

#---------------------------------------------------------------------------------------------------------------   
def get_immobility_period_type(data,start):
    '''
    an immobility period can start during trial or intertrial
    '''
    # first "trial start" before this running period
    trialIndex=np.where(data.realStartTrial<=start)[0]
    if len(trialIndex)>0:
        trialIndex=trialIndex[-1]
    else:
        trialIndex=0
    # corresponding "intertrial start" 
    interTime=data.realStartInterTrial[trialIndex]
    #test if before or after intertrial start
    if start + 0.5 >= interTime:
        return "intertrial"
    else:
        return "trial"
    
#---------------------------------------------------------------------------------------------------------------
def check_period_detection(data, start, end):
    # in the orginal code of thyphaine there is a bug caused by the fact that the last trial has no intetrial
    #this caused wrong detection of a good trial as overlappping intetrial and trial and premature break out of the program
    #this corrected by using the time of the last wheel detection as time for the end of intertrial periode 
    plt.figure(figsize=(20,15))
    plot_break(data, legend=False, xmax=200)
    ax = plt.gca()
    
    trialStart = data.realStartTrial[0]
    trial = 0
    zero = data.durationTrial[0]
    trialEnd = trialStart + zero + data.durationInterTrial[trial]    
#     print("trialStart: %s, zero: %s, trialEnd: %s" %(trialStart, zero+trialStart,trialEnd))

    loop=0
    for s, e in zip(start, end):
        loop+=1
#         print("loop: %s" %loop)
#         print("s: %s, e: %s" %(s,e))
        if s <= trialStart:
#             print("first case is used")
#             print("trial number: %s" %trial)
#             print("trialStart: %s, zero: %s, trialEnd: %s" %(trialStart, zero+trialStart,trialEnd))         
             continue
        while s >= trialEnd:
            trial += 1
            #print("detected run number %s occurs during trial nber: %s" %(loop,trial+1))
            if trial+1 > data.nTrial:
                break
            
            zero = data.durationTrial[trial]
            trialStart = trialEnd
            if (trial+1==data.nTrial) and (data.durationInterTrial[trial]==0):
                #print("last trial!!!")
                trialEnd = data.allBeamBreak[-1]+1
            else:
                trialEnd += zero + data.durationInterTrial[trial]
            
        if s <= trialEnd and e >= trialEnd:
#             print("trialStart: %s, zero: %s, trialEnd: %s" %(trialStart, zero+trialStart,trialEnd)) 
#             print("third case is used")
#             print("duration intertrial: %s" %data.durationInterTrial[trial])
#             print("trial nber: %s" %(trial+1))
            while e >= trialEnd:
                #period is in beween two trials
                startRec = s - trialStart - zero
                length = data.durationInterTrial[trial] - startRec
                ax.add_patch(Rectangle((startRec, trial+1), length, 0.8, facecolor="lightcoral", edgecolor="none"))              
                if trial+2 > data.nTrial:
                    break                
                length2 = min(e-s-length, data.durationTrial[trial+1]+data.durationInterTrial[trial+1])                
                ax.add_patch(Rectangle((-data.durationTrial[trial+1], trial+2), length2, 0.8, facecolor="lightcoral", 
                                       edgecolor="none"))
                trial += 1
                zero = data.durationTrial[trial]
                trialStart = trialEnd
                #trialEnd += zero + data.durationInterTrial[trial]
                if (trial+1==data.nTrial) and (data.durationInterTrial[trial]==0):
                    #print("last trial!!!")
                    trialEnd = data.allBeamBreak[-1]+1
                else:
                    trialEnd += zero + data.durationInterTrial[trial]
                    
        else:
            ax.add_patch(Rectangle((s - trialStart - zero, trial+1), e-s, 0.8, facecolor="orange", edgecolor="none"))
#             print("last case is used")
#             print("trial number: %s" %(trial+1))
#             print("trialStart: %s, zero: %s, trialEnd: %s" %(trialStart, zero+trialStart,trialEnd))
            
            
            
            
        #print(s, e, "--", trialStart, trialEnd, "--", trial)
            
#---------------------------------------------------------------------------------------------------------------
if "__file__" not in dir(): 
    start, end, indexes = detect_running_period(data, minDurationSecond=3, runType="all",maxTimeBetweenBreak=1)
    #start, end, indexes = detect_immobility_period(data,lick=True, allActivity=True)
    #start, end, indexes = detect_activity_period(data, minDurationSecond=0)
    
    check_period_detection(data, start, end)
    


In [None]:
#---------------------------------------------------------------------------------------------------------------
if "__file__" not in dir(): 
    #start, end, indexes = detect_running_period(data, minDurationSecond=3, runType="unrewarded",maxTimeBetweenBreak=1)
    start, end, indexes = detect_immobility_period(data,lick=True, allActivity=True,immobilityType="intertrial",minDurationSecond=1)
    #start, end, indexes = detect_activity_period(data, minDurationSecond=0)    
    check_period_detection(data, start, end)

In [None]:
#---------------------------------------------------------------------------------------------------------------
if "__file__" not in dir(): 
    start, end, indexes = detect_licking_period(data, minDurationSecond=2, lickType="all")
    #start, end, indexes = detect_immobility_period(data)
    #start, end, indexes = detect_activity_period(data, minDurationSecond=0)
    if (data.allLickBreak):
        check_period_detection(data, start, end)

# 3. A serie of codes to visualize firing activity

## 3. 1 list the units in the present session or accross session

### List the all the Good cluster numbers of this session

In [None]:
def PrintClustersByShank(data):
    if not data.hasSpike:
        print ("no spikes in session %s" %data.experiment)
    else:
        for shank in data.clusterGroup:
            print ("Shank Nber %s"%shank)
            print ("Good Clusters: %s" %data.clusterGroup[shank]['Good'])
    #-------------------------------------------------------------------------------------------------------------
if "__file__" not in dir(): 
    PrintClustersByShank(data)

### List how many good and MUA clusters there are in this session

In [None]:
def CountCluPerSession(data):
    TotalNumberOfGoodClu=[]
    TotalNumberOfGoodCluPerShank=""
    TotalNumberOfMUAClu=[]
    TotalNumberOfMUACluPerShank=""
    for shank in data.clusterGroup:
        for Group in data.clusterGroup[shank]:        
            if Group=="Good":
                TotalNumberOfGoodClu.append(len(data.clusterGroup[shank][Group]))
                TotalNumberOfGoodCluPerShank=TotalNumberOfGoodCluPerShank+str(len(data.clusterGroup[shank][Group]))+"+"
            elif Group=="MUA":
                TotalNumberOfMUAClu.append(len(data.clusterGroup[shank][Group]))
                TotalNumberOfMUACluPerShank=TotalNumberOfMUACluPerShank+str(len(data.clusterGroup[shank][Group]))+"+"

    if sum(TotalNumberOfGoodClu)>0:
        GoodCluPerSession=[str(sum(TotalNumberOfGoodClu)) + " (" + TotalNumberOfGoodCluPerShank[:-1] + ")",sum(TotalNumberOfGoodClu)]
       
    else:
        GoodCluPerSession=["0",0]
        

    if sum(TotalNumberOfMUAClu)>0:
        MUACluPerSession=[str(sum(TotalNumberOfMUAClu)) + " (" + TotalNumberOfMUACluPerShank[:-1] + ")",sum(TotalNumberOfMUAClu)]
       
    else:
        MUACluPerSession=["0",0]

    print("session %s" %data.experiment)
    return GoodCluPerSession,MUACluPerSession

#-------------------------------------------------------------------------------------------------------------
if "__file__" not in dir(): 
    GoodCluPerSession,MUACluPerSession=CountCluPerSession(data)
    
    print("%s Good clusters and %s MUA clusters" %(GoodCluPerSession[0],MUACluPerSession[0]))
            
            

### List how many good and MUA clusters there are accross a group of analyzed sessions

In [None]:

def TableAllCluAccrosSesson(animalList=[],tagList = ["GoodPerfo"]):
    
    #animalList=["MOU025","MOU026","MOU027","MOU035","MOU074"]
    AllGoodCluAccrossSessions=0
    AllMUACluAccrossSessions=0
    CluPerSessions=[]
    if not animalList:
        animalList = [os.path.basename(path) for path in sorted(glob.glob(root+"/MOU*"))]
    

    #list of tags (tag = empty file in the session folder with a specific name)
    #leave empty for no tag
    tagList = tagList
    
    for animal in animalList:
        print("Animal", animal)
        
        ThisAnimalGoodCluAccrossSessions=0
        ThisAnimalMUACluAccrossSessions=0
        
        #Get the list of all session
        sessionList = [os.path.basename(expPath) for expPath in glob.glob(root+"/"+animal+"/Experiments/MOU*")]
        sessionList = sorted(sessionList)

        #loop through sessions
        for session in sessionList:  
    
            #if tag list is not emtpy
            if tagList:
                #check if the session has one of the tag
                if not has_tag(root, animal, session, tagList):
                    continue
    
                print(session)
                #load data for this session (add redoPreprocess=True to overwrite preprocess)
                data = Data(root, animal, session, paramCarola, redoPreprocess=False)
                
                if not data.hasSpike:
                    print("########")
                    print("%s has no spike" %data.experiment)
                    print("########")
                    continue
            
                
                
                GoodCluThisSession,MUACluThisSession=CountCluPerSession(data)
                CluPerSessions.append([data.experiment,GoodCluThisSession[0],MUACluThisSession[0]])
                
                AllGoodCluAccrossSessions+=GoodCluThisSession[1]
                AllMUACluAccrossSessions+=MUACluThisSession[1]
                
                ThisAnimalGoodCluAccrossSessions+=GoodCluThisSession[1]
                ThisAnimalMUACluAccrossSessions+=MUACluThisSession[1]
                
        if ThisAnimalGoodCluAccrossSessions>0:
            CluPerSessions.append(["   Total for "+animal,ThisAnimalGoodCluAccrossSessions,ThisAnimalMUACluAccrossSessions])
            CluPerSessions.append(["","",""])
        
    
    # last line is total and is spearated by empty line
    
#     CluPerSessions.append(["","",""])
    CluPerSessions.append(["TOTAL",AllGoodCluAccrossSessions,AllMUACluAccrossSessions])
    from IPython.display import clear_output
    clear_output()            
    
    return CluPerSessions,AllGoodCluAccrossSessions,AllMUACluAccrossSessions,tagList

if "__file__" not in dir(): 
    animalList=[]
    CluPerSessions,AllGoodCluAccrossSessions,AllMUACluAccrossSessions,tagList=TableAllCluAccrosSesson(animalList)
    
    from tabulate import tabulate
    headers = ["session", "Good clusters","MUA clusters"]
    print(tabulate(CluPerSessions,headers,tablefmt="fancy_grid"))

    
    
    PathForTableSaving=os.path.join(root,"ALLMOU_Analysis",tagList[0]+"_Table.txt")
    f = open(PathForTableSaving, 'w')
    f.write(tabulate(CluPerSessions,headers,tablefmt="grid"))
    f.close()


## 3.2 Plot wheel break with spikes for one cluster

In [None]:
def get_cluster_spikes(data, shank, cluster):
    """
    Returns list of spikes for a cluster, if the cluster exists
    """
    if not data.hasSpike:
        print("No spike data")
        return None
    try:
        cluSpike=data.spikeTime[shank][cluster]
    except KeyError:
        print("No shank %s cluster %s"%(shank,cluster))
        print("List of clusters for this session:")
        print(data.clusterGroup)
        return None
    return cluSpike

#-------------------------------------------------------------------------------------------------------------
def plot_break_cluster(data, shank, cluster, group="not defined", legend=False, colorOpto="yellow",
                       xmax=60, ax=None, lick=False):
    """
    Calls plot_break, and add the spikes for one cluster on top of it
    'group' is the cluster group ("good", "noise",...), only used for the title of the plot
    """
    #get the spikes for the cluster, if it exists
    cluSpike = get_cluster_spikes(data, shank, cluster)
    if cluSpike is None:
        return
    #plot the beam breaks, lick breaks and optogenetic
    if ax is None:
        ax = plt.gca()
    title=plot_break(data, legend=legend, colorOpto=colorOpto ,xmax=xmax, ax=ax, lick=lick)
    #add spikes for the given cluster
    lines=[]   
    for trial in data.trials:
        start=data.realStartTrial[trial]
        zero=start+data.durationTrial[trial]
        stop=zero+data.durationInterTrial[trial]
        trialSpikes=cluSpike[(cluSpike>start)&(cluSpike<stop)]-zero
        for spike in trialSpikes:
            lines.append([(spike,trial+1.1),(spike,trial+1.5)])
    lc= mc.LineCollection(lines, colors="red", label="spikes of shank %s cluster %s"%(shank, cluster))
    ax.add_collection(lc)
            
    if legend:
        plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
    plt.title(title+" (shank %s cluster %s, group %s)"%(shank,cluster,group))

#-------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():    
    if data.hasBehavior and data.hasSpike:
        plt.figure(figsize=(15,20))
        shank=1
        cluster=2
        
        """
        same than for the data path. if you want to look at other clusters name them between the 2 hash tage lines
        and do not forget to leave them empty when you commit
        """
        
        ##############################

        
        ##############################
        
        plot_break_cluster(data,shank,cluster,group="Good",legend=True,lick=True)

## 3.3 Mean Speed, Lick and Firing Rate, aligned to trial end

In [None]:
def plot_mean_breaks_firing_rate(data, shank, cluster, binSize=0.25, minTime=-60, maxTime=20, align="trial end",
                                 sigma=1, trialType="all", lick=False, s=1, ax=None):
    '''
    Calls plot_mean_breaks (mean running/licking on all trials) and add cluster spikes on top of it.
    Input
      -s: sigma to smoothed the curve
      -trialType: consider only the "good" or "bad", or "all" trials
      -lick: True to replace running by lick frequency
    '''
    #get the spikes for the cluster, if it exists
    cluSpike = get_cluster_spikes(data, shank, cluster)
    if cluSpike is None:
        return
    #bins between minTime and maxTime
    timeBin = np.arange(minTime, maxTime + binSize - maxTime%binSize, binSize)
    centers = (timeBin[:-1]+timeBin[1:]) / 2.0
    #plot speed, remove title
    if ax is None:
        ax = plt.gca()
    if trialType == "good":
        plot_mean_breaks(data, binSize, minTime, maxTime, align, separate="good trial", lick=lick,
                         displayOnly=1, s=s)
    elif trialType == "bad":
        plot_mean_breaks(data, binSize, minTime, maxTime, align, separate="good trial", lick=lick,
                         displayOnly=2, s=s)
    else:
        plot_mean_breaks(data, binSize, minTime, maxTime, align, separate="none", lick=lick, s=s)
    plt.title("")

    #histogram
    allHist = []
    for trial in data.trials:
        if (trialType == "good") and (trial not in data.goodTrials):
            continue
        if (trialType == "bad") and (trial in data.goodTrials):
            continue
        if align=="trial end":
            zero = data.realStartTrial[trial] + data.durationTrial[trial]
        elif align=="trial start":
            zero = data.realStartTrial[trial]
            
        alignedTime = cluSpike-zero
        hist,bins = np.histogram(alignedTime, timeBin)
        allHist.append(hist)
        
    #mean firing rate, smoothed
    meanFiring = np.nanmean(np.asarray(allHist),axis=0) / float(binSize)
    meanFiring = smooth(meanFiring,sigma)
    
    #plot firing rate on a new y axis
    ax2 = ax.twinx()
    plt.plot(centers, smooth(meanFiring, s), color="red")
    plt.ylabel("firing rate (smooth sigma=%s)"%sigma, color="red", fontsize=14)
    plt.ylim([0, max(meanFiring)])
    ax2.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
    for tl in ax2.get_yticklabels():
        tl.set_color('red')
    #title
    if lick:
        title = "Lick"
    else:
        title = "Running"
    title = title+", Shank %s Cluster %s, %s trials (%i)" %(shank, cluster, trialType, len(allHist))
    plt.title(title, fontsize=14)
    return ax2

#-----------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():      
    plt.figure(figsize=(20,10))
    plt.subplot(221)
    plot_mean_breaks_firing_rate(data, shank, cluster, trialType="good")
    plt.subplot(222)
    plot_mean_breaks_firing_rate(data, shank, cluster)
    plt.subplot(223)
    plot_mean_breaks_firing_rate(data, shank, cluster, lick=True, minTime=-10, maxTime=20, trialType="good")

## 3.4 Firing rate aligned relative to start or end of behavioral epochs

In [None]:
def plot_period_firing_rate(data, shank, cluster, binSize=0.25, minTime=-20, maxTime=20,
                            align="start", minDurationSecond=2, maxDurationSecond=None, sigma=1, ax=None,
                            immobility=False, periodType="all", runMinDuration=0.1):
    """
    Plots running or immobility periods, aligned on start or end of period, along with the firing rate of one cluster
    """
    #get the spikes for the cluster, if it exists
    cluSpike = get_cluster_spikes(data, shank, cluster)
    if cluSpike is None:
        return    
    if ax is None:
        ax=plt.gca()        
    #detect periods
    if immobility:
        startRunning,endRunning,periods=detect_immobility_period(data,minDurationSecond,maxDurationSecond,periodType,
                                                                runMinDuration)
        runOrImmobility="immobility"
    else:
        #start and end of running
        startRunning,endRunning,periods=detect_running_period(data,minDurationSecond,maxDurationSecond,periodType)
        runOrImmobility="run"
    #where to align
    if align=="start":
        zeroes=startRunning
    else:
        zeroes=endRunning        
    #breaks
    allBreaks=np.asarray(data.allBeamBreak)
    #histogram
    timeBin=np.arange(minTime,maxTime+binSize-maxTime%binSize,binSize)
    centers=(timeBin[:-1]+timeBin[1:])/2.0   
    allHist=[]
    spikeHist=[]
    for zero in zeroes:
        #breaks
        alignBreak=allBreaks-zero
        hist,bins=np.histogram(alignBreak,timeBin)
        allHist.append(hist)
        #spikes
        alignedSpikeTimes=cluSpike-zero
        hist,bins=np.histogram(alignedSpikeTimes,timeBin)
        spikeHist.append(hist)      
    #speed
    meanSpeed=np.nanmean(np.asarray(allHist),axis=0)*data.tickDistance/float(binSize)
    plt.plot(centers,meanSpeed,color="black")
    plt.ylabel("mean running speed (cm/s)",fontsize=14)    
    plt.xlabel("time (sec), binSize=%ss, 0=%s"%(binSize,align),fontsize=14)
    #firing rate on second axis
    ax2=ax.twinx()
    meanFiring=np.nanmean(np.asarray(spikeHist),axis=0)/float(binSize)
    meanFiring=smooth(meanFiring,sigma)
    plt.plot(centers,meanFiring,color="red") 
    plt.ylim([0,max(meanFiring)])
    plt.ylabel("firing rate (smooth sigma=%s)"%sigma,color="red",fontsize=14)
    ax2.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
    for tl in ax2.get_yticklabels():
        tl.set_color('red')          
    #title
    title="All %s periods, Shank %s Cluster %s"%(runOrImmobility,shank,cluster)
    plt.title(title,fontsize=14)
    plt.xlim([minTime,maxTime])
    return ax2
    
#---------------------------------------------------------------------------------------------------------------
if "__file__" not in dir(): 
    plt.figure(figsize=(15,10))
    plt.subplot(221)
    plot_period_firing_rate(data,shank,cluster,align="start")
    
    plt.subplot(222)
    plot_period_firing_rate(data,shank,cluster,align="end")

    plt.subplot(223)
    plot_period_firing_rate(data,shank,cluster,align="start",immobility=True,runMinDuration=0)
    
    plt.subplot(224)
    plot_period_firing_rate(data,shank,cluster,align="end",immobility=True,runMinDuration=0)
    
    plt.subplots_adjust(hspace=0.6)
    

## 3.5 Firing rates during behavioral epochs whose lengthes have been normalized in this session

In [None]:
def plot_normalized_running_periods_firing_rate(data, shank, cluster, binSize=0.25, minDuration=2, 
                                                maxDuration=15, SideLength=3, sigma=1, runType="all", ax=None):
    return plot_normalized_periods_firing_rate(data, shank, cluster, binSize=binSize, minDuration=minDuration, 
                                               maxDuration=maxDuration, SideLength=SideLength, sigma=sigma,
                                               periodType=runType, ax=ax, immobility=False, runMinDuration=None)

#---------------------------------------------------------------------------------------------------------------
def plot_normalized_immobility_periods_firing_rate(data, shank, cluster, binSize=0.25, minDuration=2,
                                                   maxDuration=15, SideLength=3, sigma=1, immobilityType="all",
                                                   runMinDuration=0.1, ax=None):
    return plot_normalized_periods_firing_rate(data, shank, cluster, binSize=binSize, minDuration=minDuration, 
                                    maxDuration=maxDuration, SideLength=SideLength, sigma=sigma, periodType=immobilityType,
                                    ax=ax, immobility=True, runMinDuration=runMinDuration)

#---------------------------------------------------------------------------------------------------------------
def plot_normalized_periods_firing_rate(data, shank, cluster, binSize=0.25, minDuration=2, maxDuration=15, SideLength=3,
                                        sigma=1, periodType="all", immobility=False, runMinDuration=0.1, ax=None):
    """
    Takes all running/immobility periods between minDuration seconds and maxDuration seconds
    Computes the mean running/immobility period duration, and divide it by binSize to get the number of bins
    For each running/immobility period, compute the speed given the number of bins, as well as the firing rate
    Plots the mean of the speeds and the mean of the firing rates.
    Before and after the running/immobility periods, nSideBin are also plotted.
    Input:
      - data, shank, cluster 
      - binSize, in seconds
      - minDuration, minimum duration of a period
      - maxDuration, maximum duration of a period
      - nSideBin, number of bins to consider before and after the period
      - sigma, parameter for the gaussian smoothing of the firing rate 
      - periodType, to specify a type of period
           for the runs: "trial good run", "trial bad run", "intertrial"
           for the immobility: "trial", "intertrial"
      - immobility, whether to plot immobility periods or running periods
      - runMinDuration, parameter for immobility period detection 
        (a run of less than runMinDuration is considered to be an immobility)
      - ax, matplotlib figure axe where to plot, useful when doing complex subplots
    """
    if ax is None:
        ax = plt.gca()     
    #spikes
    cluSpike = data.spikeTime[shank][cluster]
    
    if immobility:
        #immobility periods
        startPeriod, endPeriod, indexes = detect_immobility_period(data, minDuration, maxDuration, periodType,
                                                                   runMinDuration)
        runOrImmobile = "immobility"
    else:
        #running periods
        startPeriod, endPeriod, indexes = detect_running_period(data, minDuration, maxDuration, periodType)
        runOrImmobile = "running"
    
    #Detect number of bins based on mean duration and binsize
    allDuration = np.asarray(endPeriod) - np.asarray(startPeriod)
    meanDuration = np.nanmean(allDuration)
    nBins = np.ceil(meanDuration/float(binSize))    
    
    nSideBin=int(np.round(SideLength/binSize)) 
    #mean speed and firing rate in the bins
    allSpeed = []
    spikeHist = []
    rDuration = []
    for start, stop, duration in zip(startPeriod, endPeriod, allDuration):
        _bin = duration/float(nBins)
        rStart = start - (nSideBin+1) * _bin
        rStop = stop+ (nSideBin+1)*_bin - stop%_bin
        timeBin = np.arange(rStart, rStop+_bin, _bin)        
        hist, bins = np.histogram(data.allBeamBreak, timeBin)
        allSpeed.append(hist*data.tickDistance/float(_bin))
        hist, bins = np.histogram(cluSpike, timeBin)
        spikeHist.append(hist/float(_bin))
        rDuration.append(rStop-rStart)
        
    #compute the means and smooth
    meanDuration = np.nanmean(allDuration)
    meanSpeed = np.nanmean(np.asarray(allSpeed), axis=0)
    
    meanFiring = np.nanmean(np.asarray(spikeHist), axis=0)
    meanFiring = smooth(meanFiring, sigma)
    
    xmax = nBins + (nSideBin+1) * 2
    xaxis = np.arange(0.5, xmax, 1)
    #plot speed
    plt.plot(xaxis, meanSpeed, color="black")
    plt.ylabel("mean running speed (cm/s)", fontsize=14)
    plt.xlabel("%i bins, mean duration %.2fs with binsize %.2fs" %(xmax-1, meanDuration, binSize), fontsize=14)
    #plot firing rate on other y axis with different coor
    ax2 = ax.twinx()
    plt.plot(xaxis, meanFiring, color="red") 
    plt.ylim([0,max(meanFiring)])
    plt.ylabel("firing rate (smooth=%s)" %sigma, color="red", fontsize=14)
    ax2.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
    for tl in ax2.get_yticklabels():
        tl.set_color('red')
    #title
    title="'%s' normalized %s periods (%s) \n (%s-%s s)"%(periodType, runOrImmobile, len(startPeriod), 
                                                          minDuration, maxDuration)
    title=title+", Shank %s Cluster %s"%(shank, cluster)
    plt.title(title, fontsize=14)
    plt.xlim([0, xmax])

    return ax2, meanFiring, nSideBin, spikeHist, allSpeed, allDuration

#---------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    binSize=0.25
    plt.figure(figsize=(20,15))
    
    plt.subplot(521)
    ax2,meanFiring,nSideBin,spikeHist,allSpeed,allDuration=plot_normalized_running_periods_firing_rate(data, shank, cluster,binSize)
    
    plt.subplot(527)
    plot_normalized_running_periods_firing_rate(data, shank, cluster, binSize,runType="intertrial")
    
    plt.subplot(529)
    plot_normalized_running_periods_firing_rate(data, shank, cluster, binSize, runType="unrewarded")
    
    plt.subplot(523)
    plot_normalized_running_periods_firing_rate(data, shank, cluster, binSize,runType="trial good run")
    
    plt.subplot(525)
    plot_normalized_running_periods_firing_rate(data, shank, cluster,binSize, runType="trial bad run")
    
    plt.subplot(522)
    plot_normalized_immobility_periods_firing_rate(data, shank, cluster,binSize)
    
    plt.subplot(528)
    plot_normalized_immobility_periods_firing_rate(data, shank, cluster,binSize, immobilityType="intertrial")
    
    plt.subplot(524)
    plot_normalized_immobility_periods_firing_rate(data, shank, cluster,binSize, immobilityType="trial")
    
    plt.subplots_adjust(hspace=0.6)

# Firing rate versus Distance

In [None]:
def FiringVersusDistanceInRuns(data,shank,cluster,runType="all",binSize=0.025,plotIndex=False,checkrundetection=False):
    binSize=0.025
    tickDistance=data.tickDistance
    
    
    ## First detect the run
    start, end, indexes = detect_running_period(data, minDurationSecond=3, maxDurationSecond=30,runType="all",maxTimeBetweenBreak=1)
    RunDurations = np.asarray(end) - np.asarray(start)
    
    if checkrundetection:
        check_period_detection(data, start, end)
    

    ## Load spike time 
    cluSpike = data.spikeTime[shank][cluster]
    
    
    ### FIRST part : trial by trial tuning curves for firing rate versus distance
    
    InstantaneousRunDistancesConcatenatedAccrossRuns=[]
    allBeamBreakTimes=data.allBeamBreak
    TotalDistancePerRunAcrrossAllRuns=[]
    InstantaneousFRateConcatenatedAccrossRuns=[]
    TimeBinCentersConcatenatedAccrossTrials=[]
    
    MeanFiringRateVerusDistance=[]
    
    MeanFiringRateVersusTime=[]
    FRateVersusDistanceAcrossRuns=[]
    FRateVersusTimeAcrossRuns=[]
    AllTimesBinsAccrossAllTrials=[]
    AllDistancesAccrosAllTrials=[]
    
    for thisstart, thisend in zip(start, end):
        
        ## for each detected run finds the beam break times and realign them relative to the first one (zero)
        BeamBreakTimesInsideDetectedRun=[X for X in allBeamBreakTimes if X>=thisstart and X<=thisend]
        BeamBreakTimesInsideDetectedRun=[X-BeamBreakTimesInsideDetectedRun[0] for X in BeamBreakTimesInsideDetectedRun]
        minTime=BeamBreakTimesInsideDetectedRun[0]
        maxTime=BeamBreakTimesInsideDetectedRun[-1]
        timeBin=np.arange(minTime,maxTime+binSize-maxTime%binSize+binSize,binSize)
        centers=(timeBin[:-1]+timeBin[1:])/2.0 
        
        ## histogram of number of detected beambreak in time bins during run
        
        hist,bins=np.histogram(BeamBreakTimesInsideDetectedRun,timeBin)
#       print(np.cumsum(hist*tickDistance))
#       print("###")
        InstantaneousRunDistancesConcatenatedAccrossRuns.extend(np.cumsum(hist*tickDistance))
        TotalDistancePerRunAcrrossAllRuns.extend([np.sum(hist)*tickDistance])
        
        ## for the same time bin, get the spike count, then transform to rate
        alignedSpikeTimes=cluSpike-thisstart
        histspike, bins = np.histogram(alignedSpikeTimes, timeBin)        
        InstantaneousFRateConcatenatedAccrossRuns.extend(histspike/binSize)
        TimeBinCentersConcatenatedAccrossTrials.extend(centers)
        
        ### for a given run,firing rate versus distance tuning curve 
        FRateVersusDistancePerRun=[]
        InstantaneousDistancesInThisRun=np.cumsum(hist*tickDistance)
        Distances=np.unique(InstantaneousDistancesInThisRun)
        for distance in Distances:
            thisdistanceindexes=np.where(InstantaneousDistancesInThisRun==distance)[0].tolist()
            FRateThisDistanceValues=[histspike[i] for i in thisdistanceindexes]
            FRateVersusDistancePerRun.extend([np.mean(FRateThisDistanceValues)/binSize])
        
        ### Run by run firing rate versus distance tuning curve
        FRateVersusDistanceAcrossRuns.append(FRateVersusDistancePerRun)
        AllDistancesAccrosAllTrials.append(Distances.tolist())
        
        ### for a given run,firing rate versus distance tuning curve 
        binSizeTimeTuningCurve=0.125
        timeBinForTime=np.arange(thisstart,thisend+binSizeTimeTuningCurve-thisend%binSizeTimeTuningCurve+binSizeTimeTuningCurve,binSizeTimeTuningCurve)
        histSpikeCountVerusTimeInRun, bins = np.histogram(cluSpike, timeBinForTime)
        histFRateVersusTimeInRun=[float(i) for i in histSpikeCountVerusTimeInRun/binSizeTimeTuningCurve]
        FRateVersusTimeAcrossRuns.append(histFRateVersusTimeInRun)
        bins=np.asarray([X-bins[0] for X in bins])
        centerofbins=(bins[:-1]+bins[1:])/2.0
        AllTimesBinsAccrossAllTrials.append(centerofbins.tolist())
        

    #Compute average tuning curve from all instantaneous data 
    Distances=np.unique(InstantaneousRunDistancesConcatenatedAccrossRuns)
    for distance in Distances:
        thisdistanceindexes=np.where(InstantaneousRunDistancesConcatenatedAccrossRuns==distance)[0].tolist()
        FRateThisDistanceValues=[InstantaneousFRateConcatenatedAccrossRuns[i] for i in thisdistanceindexes]
        MeanFiringRateVerusDistance.extend([np.mean(FRateThisDistanceValues)])

    UniqueTimeCenters=np.unique(TimeBinCentersConcatenatedAccrossTrials)    
    for timecenter in UniqueTimeCenters:
        thistimecenterindexes=np.where(TimeBinCentersConcatenatedAccrossTrials==timecenter)[0].tolist()
        FRateThisTimeCenterValues=[InstantaneousFRateConcatenatedAccrossRuns[i] for i in thistimecenterindexes]
        MeanFiringRateVersusTime.extend([np.mean(FRateThisTimeCenterValues)])
        
    AllTimes=[item for sublist in AllTimesBinsAccrossAllTrials for item in sublist]
    AllFRateVersusTimeAcrossRuns=[item for sublist in FRateVersusTimeAcrossRuns for item in sublist]
    MeanFRatePerTimeBin=[]
    UniqueTimeCentersLargeBins=np.unique(AllTimes)
    for time in UniqueTimeCentersLargeBins:
        thistimeeindexes=np.where(AllTimes==time)[0].tolist()
        FRateThisTimeBin=[AllFRateVersusTimeAcrossRuns[i] for i in thistimeeindexes]
        MeanFRatePerTimeBin.extend([np.mean(FRateThisTimeBin)])

    
    ### Part II: some plotting and comparaison with firing rate versus time   
        
    if plotIndex:
        plt.figure(figsize=(15, 5))
        plt.subplot(131)
        plt.plot(UniqueTimeCenters,smooth(MeanFiringRateVersusTime,1))
        plt.xlabel('run time (s)')
        plt.ylabel("Firing Rate (Hz)")
        
        plt.subplot(132)
        plt.plot(UniqueTimeCentersLargeBins,smooth(MeanFRatePerTimeBin,1))
        plt.xlabel('run time (s)')
        plt.ylabel("Firing Rate (Hz)")
        
        plt.subplot(133)
        plt.plot(Distances,smooth(MeanFiringRateVerusDistance,1))
        plt.xlabel('run distance (cm)')
        plt.ylabel("Firing Rate (Hz)")
        
        plt.figure(figsize=(5, 5))
        plt.scatter(TotalDistancePerRunAcrrossAllRuns,RunDurations)
        plt.xlabel('run distance (cm)')
        plt.ylabel("run duration (s)")
        SpearManResults=stats.spearmanr(TotalDistancePerRunAcrrossAllRuns,RunDurations)
        
        rvalue=str(round(SpearManResults[0],2));
        if SpearManResults[1]<0.0001:
            pvalue="p<0.0001"
        else:
            pvalue="p="+ str(round(SpearManResults[1],4))

        title="r=%s, %s"%(rvalue,pvalue)
        plt.title(title)
        
        print(SpearManResults)
    
    return FRateVersusDistanceAcrossRuns,AllDistancesAccrosAllTrials
    
 


In [None]:
if "__file__" not in dir():
    ##########################
    
    
    ##########################
    
    FiringVersusDistanceInRuns(data,shank,cluster,runType="all",plotIndex=True,checkrundetection=True)

# Firing rate versus Distance Normalized

In [None]:
def FiringRateVersusDistanceInRunDistanceNormalized(data,shank,cluster,NBinsForNormalization=20,runType="all"):

    FRateVersusDistanceAcrossRuns,AllDistancesAccrosAllTrials=FiringVersusDistanceInRuns(data,shank=shank,cluster=cluster,runType=runType)

    NormalizeFiringRatesAccrossRuns=[]
    for FRateVersusDistance,Distance in zip(FRateVersusDistanceAcrossRuns,AllDistancesAccrosAllTrials):
        NormalizeFiringRateThisRun=[]
        DistanceBinBorders = np.linspace(0, Distance[-1], NBinsForNormalization+1)

        for low, hig in zip(DistanceBinBorders[0:-1], DistanceBinBorders[1:]):
            thisdistanceindexes=np.where((Distance>low) & (Distance<=hig))[0].tolist()
            thisFiringRate=[FRateVersusDistance[i] for i in thisdistanceindexes]
    #        if len(thisFiringRate)==0:
    #             print("Distance: %s" %Distance)
    #             print("")
    #             print("low: %s, hig: %s" %(low,hig))
    #             print("")
    #             print("thisdistanceindexes: %s" %thisdistanceindexes)
    #             print("")

    #             print("##############")

            if len(thisdistanceindexes)>0:
                NormalizeFiringRateThisRun.extend([np.nanmean(thisFiringRate)])
            else:
                NormalizeFiringRateThisRun.extend([float('nan')])
    #            print("nan value injected")


        ## check if there is nan value, if yes replace them using interp fonction
        NonNanIndexes = [i for i,x in enumerate(np.isnan(NormalizeFiringRateThisRun).tolist()) if x == False]
        if len(NonNanIndexes)>0:
            NonNanValues = [x for x in NormalizeFiringRateThisRun if not np.isnan(x)]
            NormalizeFiringRateThisRun=np.interp(np.arange(0,len(NormalizeFiringRateThisRun)),NonNanIndexes,NonNanValues).tolist()



        NormalizeFiringRatesAccrossRuns.append(NormalizeFiringRateThisRun)

    MeanFiringRateNormalizedDistance=np.mean(np.asarray(NormalizeFiringRatesAccrossRuns),0)
    BinEdges,BinSTEP=np.linspace(0,1,NBinsForNormalization+1, retstep=True)
    BinCenterS=BinEdges[:-1]+BinSTEP/2.0
    
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.annotate('run start', xy=(0, -0.1), xycoords='axes fraction',
                horizontalalignment='center', verticalalignment='center')
    ax.annotate('run stop', xy=(1, -0.1), xycoords='axes fraction',
                horizontalalignment='center', verticalalignment='center')
    plt.plot(BinCenterS,MeanFiringRateNormalizedDistance)
    plt.xlabel('run distance (normalized)')
    plt.ylabel("Firing Rate (Hz)")


In [None]:
if "__file__" not in dir():
    ##########################
    
    
    ##########################
    
    FiringRateVersusDistanceInRunDistanceNormalized(data,shank,cluster,NBinsForNormalization=20,runType="all")
    

## 3.5 Real and SHUFFLED firing rates during behavioral epochs whose lengthes have been normalized inside this session

In [None]:
def plot_shuffled_running_firing_rate(data, shank, cluster, binSize=0.25, minDuration=2, maxDuration=15, 
                                      SideLength=3, sigma=1, runType="all", ax=None, nShuffle=500,noPlot=True):
    return plot_shuffled_firing_rate(data, shank, cluster, binSize=binSize, minDuration=minDuration, 
                                     maxDuration=maxDuration, SideLength=SideLength, sigma=sigma,periodType=runType,
                                     ax=ax, immobility=False, runMinDuration=None, nShuffle=nShuffle,noPlot=noPlot)

#---------------------------------------------------------------------------------------------------------------
def plot_shuffled_immobility_firing_rate(data, shank, cluster, binSize=0.25, minDuration=2, maxDuration=15, 
                                         SideLength=3, sigma=1, immobilityType="all", runMinDuration=0.1, 
                                         ax=None, nShuffle=500,noPlot=True):
    return plot_shuffled_firing_rate(data, shank, cluster, binSize=binSize, minDuration=minDuration, 
                                     maxDuration=maxDuration, SideLength=SideLength, sigma=sigma, 
                                     periodType=immobilityType, ax=ax, immobility=True, runMinDuration=runMinDuration, 
                                     nShuffle = nShuffle,noPlot=noPlot)

#---------------------------------------------------------------------------------------------------------------
def plot_shuffled_firing_rate(data, shank, cluster, binSize=0.25, minDuration=2, maxDuration=15, SideLength=3, sigma=1,
                              periodType="all", immobility=False,runMinDuration=0.1, ax=None, nShuffle=200, 
                              noPlot=True):
    """
    Plot a confidence interval for plot_normalized_periods_firing_rate.  
    Same code and arguments as plot_normalized_periods_firing_rate, but the periods are randomly shifted
    The process is done nShuffle times. If noPlot=True, nothing will be plotted. 
    Returns the percentile 0% (min), 5%, 95%, 100% (max)
    """
    if (ax is None) and (not noPlot):
        ax = plt.gca()     
    nSideBin=int(np.round(SideLength/binSize))
    #spikes
    cluSpike = data.spikeTime[shank][cluster]
    
    if immobility:
        #immobility periods
        startPeriod, endPeriod, indexes = detect_immobility_period(data, minDuration, maxDuration, periodType,
                                                                   runMinDuration)
    else:
        #running periods
        startPeriod, endPeriod, indexes = detect_running_period(data, minDuration, maxDuration, periodType)
    
    #Detect number of bins base on mean duration and binsize
    allDuration = np.asarray(endPeriod) - np.asarray(startPeriod)
    meanDuration = np.nanmean(allDuration)
    nBins = np.ceil(meanDuration/float(binSize))
    
    xmax = nBins + (nSideBin+1) * 2
    xaxis = np.arange(0.5, xmax, 1)
    
    #shuffling    
    allMeanShuffledFiring=[]
    for i in range(nShuffle):
        spikeHist=[]    
        for start, stop, duration in zip(startPeriod, endPeriod, allDuration):
            _bin = duration/float(nBins)
            rStart = start - (nSideBin+1) * _bin
            rStop = stop + (nSideBin+1)*_bin - stop%_bin
            timeBin = np.arange(rStart, rStop+_bin, _bin)
            #shift by a random number between -duration and +duration
            hist, bins = np.histogram(cluSpike + duration*(np.random.random_sample(1)*2 - 1), timeBin)
            spikeHist.append(hist/float(_bin))  
        meanFiring = np.nanmean(np.asarray(spikeHist), axis=0)
        meanFiring = smooth(meanFiring, sigma)
        allMeanShuffledFiring.append(meanFiring)
        if not noPlot:
            plt.plot(xaxis, meanFiring, color="darkred", linestyle="--")   
        
    #get the min, 5%, 95% and max of the shuffled mean firing rates
    
    
    AllShufflesArray=np.asarray(allMeanShuffledFiring)
    globalconfidenceband=np.percentile(AllShufflesArray,[0.5,99.5])
    PointConfidenceBand=np.percentile(allMeanShuffledFiring,[2.5,97.5],axis=0)
    plt.fill_between(xaxis,PointConfidenceBand[0],PointConfidenceBand[1],facecolor='gray', alpha=0.2)
    plt.plot([xaxis[0],xaxis[-1]],[globalconfidenceband[0],globalconfidenceband[0]],'r',linestyle="--")
    plt.plot([xaxis[0],xaxis[-1]],[globalconfidenceband[1],globalconfidenceband[1]],'r',linestyle="--")
    
    
    
    percentile=np.percentile(allMeanShuffledFiring,[0.05,5,95,99.95])    
    return percentile, allMeanShuffledFiring

#---------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    plt.figure(figsize=(15, 5))
    ax2, meanFiring, nSideBin= plot_normalized_running_periods_firing_rate(data,shank,cluster,binSize=0.25,SideLength=3)[0:3]
    percentileShuffle,allMeanShuffledFiring = plot_shuffled_running_firing_rate(data,shank,cluster,binSize=0.25,SideLength=3,ax=ax2)
    
    plt.figure(figsize=(15, 5))
    ax2,mean,n = plot_normalized_immobility_periods_firing_rate(data, shank, cluster,binSize=0.25,SideLength=3)[0:3]
    percentile, allMeanShuffledFiring = plot_shuffled_immobility_firing_rate(data, shank, cluster,binSize=0.25,SideLength=3,ax=ax2)

## 3.6 General plot for one cluster
#### (could be improved by adding ACG and waveforme

In [None]:
def cluster_plot(data, shank, cluster, group):
    #check if cluster exists
    if get_cluster_spikes(data, shank, cluster) is None:
        return
    plt.figure(figsize=(20,20))

    # Left plot
    ax=plt.subplot(131)
    plot_break_cluster(data, shank, cluster, "good", lick=True)
   
    # Speed align on end/start trial
    ax1 = plt.subplot(432)
    ax1Bis = plot_mean_breaks_firing_rate(data, shank, cluster, ax=ax1, align="trial start", trialType="good",
                                          minTime=-10, maxTime=20)
    ax2 = plt.subplot(433, sharey=ax1)
    ax2Bis = plot_mean_breaks_firing_rate(data, shank, cluster, ax=ax2, align="trial end", trialType="good",
                                          minTime=-20)
    #make the same limits for firing rate axis
    maxY = max(ax1Bis.get_ylim()[1], ax2Bis.get_ylim()[1])
    ax1Bis.set_ylim(0, maxY)
    ax2Bis.set_ylim(0, maxY)
    
    #Lick break, if there is data
    if data.lickBreakTime:
        ax3 = plt.subplot(435)
        ax3Bis = plot_mean_breaks_firing_rate(data, shank, cluster, align="trial start", lick=True,
                                              minTime=0, maxTime=50)
        ax4 = plt.subplot(436, sharey=ax3)
        ax4Bis = plot_mean_breaks_firing_rate(data, shank, cluster, ax=ax4, align="trial end", lick=True,
                                              minTime=-20)
        #make the same limits for firing rate axis
        maxY = max(ax3Bis.get_ylim()[1], ax4Bis.get_ylim()[1])
        ax3Bis.set_ylim(0, maxY)
        ax4Bis.set_ylim(0, maxY)
    
    # Normalized runs
    ax3 = plt.subplot(438)
    ax3Bis = plot_normalized_running_periods_firing_rate(data, shank, cluster, runType="all")[0]
    
    ax4 = plt.subplot(439, sharey=ax3)
    ax4Bis = plot_normalized_running_periods_firing_rate(data, shank, cluster, runType="intertrial")[0]
    
    ax5 = plt.subplot(4, 3, 11, sharey=ax3)
    ax5Bis = plot_normalized_running_periods_firing_rate(data, shank, cluster, runType="trial good run")[0]
    
    ax6 = plt.subplot(4, 3, 12, sharey=ax3)
    ax6Bis = plot_normalized_running_periods_firing_rate(data, shank, cluster, runType="trial bad run")[0]
    
    #Normalierd runs: make the same limits for firing rate axis
    maxY = 0
    for axis in [ax3Bis, ax4Bis, ax5Bis, ax6Bis]:
        maxY = max(maxY, axis.get_ylim()[1])
    for axis in [ax3Bis, ax4Bis, ax5Bis, ax6Bis]:
        axis.set_ylim(0, maxY)
    
    # Title
    title = data.experiment+" (day %s), shank %s cluster %s, group %s" %(data.daySinceStart, shank, cluster, group)
    if not data.hasEEG:
        title+=" - read from .beambreaktime"
    plt.suptitle(title, fontsize=20)
    plt.subplots_adjust(wspace=0.3, hspace=0.3, top=0.94)
    
#------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():    
    if data.hasBehavior and data.hasSpike:
        plt.figure(figsize=(15,20))
        cluster_plot(data, shank, cluster, group="Good");

## 3.7 Batch of  3.5 (Real and shuffled firing rates during behavioral epochs with standardized lenght) to save data for all clusters of this session

In [None]:
def meanFiringRatevsRunandImmoNorm(data, groupList=["Good"], saveAsPickle=True, redo=False, binSize=0.25, 
                                   SideLength=3, nShuffle=500,showplot=False):
    """
    Computes the mean firing rate and its confidence interval during running and immobility
    Loads a pickle file if the analysis was already done
    Input
      -groupList: list of cluster groups (good, noise, mua..). if None, takes all clusters.
      -saveAsPickle: whether to save results in pickle file
      -redo: whether to redo analysis even if there is already a pickle file
      -nShuffle: parameter for confidence interval
    """
    #load and return the pickle if it exists (and redo=False)
    picklePath=os.path.join(data.sessionPath,"Analysis","AnalyzedSpikeData.p")
    if (not redo) and os.path.exists(picklePath):
        with open(picklePath, 'rb') as f:
            print("loaded pickle %s"%picklePath)
            return pickle.load(f)
    
    #check that groupList is a list
    if not isinstance(groupList,list):
        groupList=[groupList]
        
    #Get mean firing rates for each cluster
    analyzedSpikeData={
        "meanFiringRateDuringRunNorm":{},
        "meanFiringRateDuringImmoNorm":{},
        "confidenceLimitsRunning":{},
        "confidenceLimitsImmobility":{},
        }  
    nSideBin = None
    for shank in sorted(data.clusterGroup):
        print("Shank %s"%shank)        
        for key in analyzedSpikeData:
            analyzedSpikeData[key][shank]={}
        for group in data.clusterGroup[shank]:
            if (groupList is not None) and (group not in groupList):
                continue
            for cluster in sorted(data.clusterGroup[shank][group]):
                six.print_(cluster,end=" ")        
                
                plt.figure(figsize=(cm2inch(30),cm2inch(10)))
                plt.subplot(1,2,1)
                ax, meanFiring, nSideBin = plot_normalized_running_periods_firing_rate(data, shank, cluster, 
                                                           binSize=binSize, SideLength=SideLength)[0:3]
                percentileShuffle = plot_shuffled_running_firing_rate(data, shank, cluster, binSize=binSize,
                                                                      SideLength=SideLength, ax=ax, nShuffle=nShuffle)
                analyzedSpikeData["meanFiringRateDuringRunNorm"][shank][cluster] = meanFiring
                analyzedSpikeData["confidenceLimitsRunning"][shank][cluster] = percentileShuffle[0]
                
                
                plt.subplot(1,2,2)
                ax,meanFiringI,nSideBinI=plot_normalized_immobility_periods_firing_rate(data, shank, cluster, 
                                                           binSize=binSize, SideLength=SideLength)[0:3]
                percentileShuffleI = plot_shuffled_immobility_firing_rate(data, shank, cluster, binSize=binSize, 
                                                                          SideLength=SideLength, ax=ax, 
                                                                          nShuffle=nShuffle)
                analyzedSpikeData["meanFiringRateDuringImmoNorm"][shank][cluster]=meanFiringI
                analyzedSpikeData["confidenceLimitsImmobility"][shank][cluster]=percentileShuffleI[0]
                
                plt.tight_layout()
                
                if not showplot:
                    plt.close()
                
        print("")        
    analyzedSpikeData["nSideBinForNormPlots"]=nSideBin     
    analyzedSpikeData["binSizeForNormPlots"]=binSize
    
    #save as a pickle file
    if saveAsPickle:
        with open(picklePath, 'wb') as f:
            pickle.dump(analyzedSpikeData, f)
        print("Saved pickle: %s"%picklePath)
    
    return   

In [None]:
#------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():    
    meanFiringRatevsRunandImmoNorm(data, saveAsPickle=False, redo=False, binSize=0.25, SideLength=3,showplot=True)

## 3.8 Autocorrelograms

In [None]:
import phy.stats

def plot_autocorrelogram(data, shank, cluster, bin_ms=1, half_width_ms=25, ax=None):
    if ax is None:
        ax = plt.gca()
    
    bin_ms=np.clip(bin_ms, .1,1e3) #bin size in ms, rounded
    binsize=int(data.spikeSamplingRate * bin_ms * 0.001) #bin size in time samples
    
    half_width_ms = np.clip(half_width_ms, .1,1e3) #ms, rounded
    winsize_bins = 2 * int(half_width_ms/bin_ms) + 1 #number of bins in window

    sample = data.spikeSample[shank][cluster]
    clu = np.ones_like(sample, dtype="int64")
    
    pairwiseCorr = phy.stats.pairwise_correlograms(sample, clu, binsize, winsize_bins)

    autoCorr=pairwiseCorr[0,0,:]
    
    halfWinsize = winsize_bins // 2
    xaxis = np.arange(-halfWinsize-0.5, halfWinsize+1.5)
    xaxis = xaxis * binsize / data.spikeSamplingRate * 1000 #ms
    
    ax.bar(xaxis[:-1], autoCorr, width=bin_ms, color="darkred", edgecolor="darkred");
    ax.set_title("Cluster %s, Autocorrelogram"%cluster);
    ax.set_xlim([xaxis[0], xaxis[-1]]);
    ax.set_xlabel("time (ms), binsize=%s ms"%bin_ms)
    ax.set_ylabel("spike count")

#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    #print(data.clusterGroup)
    

    
###########################    


###########################    
    #bin_ms: bin size in ms
    bin_ms=1
    #half_width_ms: half width of the x axis (time), in ms
    half_width_ms=30 #1000

    plt.figure(figsize=(15,5))    
    plt.subplot(121)
    plot_autocorrelogram(data,shank,cluster,bin_ms,30)
    plt.subplot(122)
    plot_autocorrelogram(data,shank,cluster,bin_ms,1000)

## 3.9 Autocorrelogram by behavioral state  (running / lick/ true immobility [no run])

In [None]:
def plot_autocorrelogram_period(data, shank, cluster, bin_ms=1, half_width_ms=25, immobility=False, lick=False, ax=None):
    if ax is None:
        ax = plt.gca()
        
    if immobility:
            #immobility periods
        startPeriod, endPeriod, indexes = detect_immobility_period(data, minDurationSecond=0, maxDurationSecond=None,
                                                                       immobilityType="all", runMinDuration=0.1,lick=True, allActivity=True)

        runOrImmobile = "immobility"
    
    
    elif lick:
        #immobility periods
        startPeriod, endPeriod, indexes = detect_licking_period(data, minDurationSecond=2, lickType="all")
        
        runOrImmobile = "lick"
    else:
        #running periods
        startPeriod, endPeriod, indexes = detect_running_period(data, minDurationSecond=0, maxDurationSecond=None, 
                                                                runType="all")
        runOrImmobile = "running"
    
    bin_ms = np.clip(bin_ms, .1,1e3) #bin size in ms, rounded
    binsize = int(data.spikeSamplingRate * bin_ms * 0.001) #bin size in time samples
    
    half_width_ms = np.clip(half_width_ms, .1,1e3) #ms, rounded
    winsize_bins = 2 * int(half_width_ms/bin_ms) + 1 #number of bins in window

    sample = data.spikeSample[shank][cluster]
    spikeTime = data.spikeTime[shank][cluster]
    
    isInPeriod = np.full_like(sample,False)
    #select only spike during trials
    for s, e in zip(startPeriod, endPeriod):
        isInPeriod = np.logical_or(isInPeriod, (spikeTime > s) & (spikeTime < e))
    
    newSample = sample[isInPeriod]
    clu = np.ones_like(newSample, dtype="int64")

    pairwiseCorr = phy.stats.pairwise_correlograms(newSample, clu, binsize, winsize_bins)

    autoCorr = pairwiseCorr[0, 0]
    
    halfWinsize = winsize_bins // 2
    xaxis = np.arange(-halfWinsize-0.5, halfWinsize+1.5)
    xaxis = xaxis * binsize / data.spikeSamplingRate * 1000 #ms
    
    ax.bar(xaxis[:-1], autoCorr, width=bin_ms, color="darkred", edgecolor="darkred");
    ax.set_title("Cluster %s, Autocorrelogram during %s"%(cluster, runOrImmobile));
    ax.set_xlim([xaxis[0], xaxis[-1]]);
    ax.set_xlabel("time (ms), binsize=%s ms"%bin_ms)
    ax.set_ylabel("spike count")

#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    ###########

    
    #########
    
    bin_ms = 1
    plt.figure(figsize=(15,10))    
    plt.subplot(321)
    plot_autocorrelogram_period(data, shank, cluster, bin_ms, 30)
    plt.subplot(322)
    plot_autocorrelogram_period(data, shank, cluster, 20, 1000)
    plt.subplot(323)
    plot_autocorrelogram_period(data, shank, cluster, bin_ms, 30, immobility=True)
    plt.subplot(324)
    plot_autocorrelogram_period(data, shank, cluster, 20, 1000, immobility=True)
    plt.subplot(325)
    plot_autocorrelogram_period(data, shank, cluster, bin_ms, 30, lick=True)
    plt.subplot(326)
    plot_autocorrelogram_period(data, shank, cluster, 20, 1000, lick=True)
    plt.tight_layout()
    

## 3.10 Cross correlograms spike and behavior

In [None]:
def plot_crosscorrelogram_behaviour(data, shank, cluster, behaviorType="run", bin_ms=1,half_width_ms=30):
    bin_ms=np.clip(bin_ms,.1,1e3) #bin size in ms, rounded
    binsize=int(data.spikeSamplingRate*bin_ms*0.001) #bin size in time samples
    half_width_ms=np.clip(half_width_ms,.1,10e3) #ms, rounded
    winsize_bins= 2*int(half_width_ms/bin_ms) +1 #number of bins in window
    halfWinsize=winsize_bins//2
    xaxis=np.arange(-halfWinsize-0.5, halfWinsize+1.5)
    xaxis=xaxis*binsize/data.spikeSamplingRate*1000 #ms
    
    #spikes for the cluster
    sample = np.array(data.spikeSample[shank][cluster], dtype="uint64")
    clu = np.full_like(sample, 1, dtype="int64")
    
    #behavior photobeam breaks time
    print("behaviorType %s" %behaviorType)
    if behaviorType is "run":
        allbeambreak=data.allBeamBreak
    else:
        allbeambreak=data.allLickBreak
    
    
    sampleBreak=[round(X*data.spikeSamplingRate) for X in allbeambreak]
    sampleBreak = np.array(sampleBreak, dtype="uint64") 
    cluBreak=np.full_like(sampleBreak, 2, dtype="int64")
    
    #the phy.stats.pairwise_correlograms takes for input two lists:
    #  the spikes samples (old .res file)
    #  the corresponding cluster numbers (old .clu files)
    #the two lists need to be sorted by increasing spikes samples, or the crosscorrelogram won't work
    twoSample = np.append(sample, sampleBreak)
    twoClu = np.append(clu, cluBreak)
    sortingOrder = twoSample.argsort()
    print(binsize, winsize_bins)
    pairwiseCorr=phy.stats.pairwise_correlograms(twoSample[sortingOrder], twoClu[sortingOrder], binsize, winsize_bins)

    k = 1
    for i in range(2):
        for j in range(2):
            if i + j == 0:
                col = "darkred"
            elif i + j == 2:
                col = "black"
            else:
                col = "blue"
            ax = plt.subplot(2, 2, k)
            autoCorr = pairwiseCorr[i, j]
            ax.bar(xaxis[:-1], autoCorr, width=bin_ms, color=col, edgecolor=col);
            k += 1
    return sample, sampleBreak
            
#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    #bin_ms: bin size in ms
    bin_ms=50
    ############################

    ############################

    plt.figure(figsize=(15,5))    
    clu, cluBreak = plot_crosscorrelogram_behaviour(data, shank, cluster, behaviorType="lick", bin_ms=50, half_width_ms=8000)
    #print(k)

## 3.11 Waveform for a given unit (get mean waveform and some basic shape caractersitics)
  - Load .kwx with h5py / or extract from .dat
  - Choose randomly 150 spikes from a cluster
  
  in kwik/kwx format, spikeIndex are the index in the kwx file
  
  spikeSample are the index in dat file (once it's reshaped as `nSample*nChannel`)

In [None]:
import h5py

def read_dat_waveform(data, shank, cluster, subSample=200, extract=50):
    """ extract the waveform of one cluster from a dat file, return array [nSpikes, extract*2, nChannelsInShank]
    
    subSample: maximum number of spikes returned (nSpikes <= subSample) 
      If the cluster has more spikes, *subSample* spikes are chosen randomly
    extract: number of points to extract before and after the spike time
    """
    path = data.fullPath + '.dat'
    if not os.path.exists(path):
        print("No .dat file")
        return False
    #memory map to dat file
    dtype = np.int16
    size = os.stat(path).st_size
    row_size = data.nChannels * np.dtype(dtype).itemsize
    if size % row_size != 0:
        raise ValueError(("Shape error: the file {f} has S={s} bytes, "
                          "but there are C={c} channels. C should be a divisor of S."
                          "").format(f=filename, s=size, c=data.nChannels))
    nsamples = size // row_size
    shape = (nsamples, data.nChannels)
    datFile = np.memmap(data.fullPath + '.dat', dtype = dtype, mode = 'r', offset = 0, shape = shape)
    #indexes of spikes for this cluster
    spikeID = data.spikeSample[shank][cluster]
    nSpike = len(spikeID)
    if nSpike > subSample:
        spikeID = [int(s) for s in np.random.choice(spikeID, subSample, replace = False)]
        nSpike = subSample
    #get waveform for each index
    waveform = np.zeros(shape=(nSpike, extract * 2, len(data.channelGroupList[shank])), dtype=dtype )
    for index, spike in enumerate(spikeID):
        waveform[index, :, :] = datFile[spike-extract : spike+extract, data.channelGroupList[shank]]
    return waveform

#-------------------------------------------------------------------------------------------------------
def read_kwx_waveform(data, shank, cluster, subSample=200, filtered=False):
    """ read the waveform for one cluster in the .kwx file, return array [nSpikes, nDataPoints, nChannelsInShank]
    
    In the kwx we read an array [ spikes indexes, n data points, n channels]
    
    subSample: maximum number of spikes returned (nSpikes <= subSample) 
      If the cluster has more spikes, *subSample* spikes are chosen randomly
    filtered: whether to extract the filtered waveforms or the raw ones
    """
    if not os.path.exists(data.fullPath+".kwx"):
        print("No .kwx file")
        return False
    with h5py.File(data.fullPath+".kwx","r") as kwx:  
        if filtered:
            waveform = kwx.get('channel_groups/%s/waveforms_filtered' % shank)[()]
        else:
            waveform = kwx.get('channel_groups/%s/waveforms_raw' % shank)[()]
    #index of spikes where cluster==X
    spikeID=data.spikeIndex[shank][cluster]
    if len(spikeID) > subSample:
        spikeID = np.random.choice(spikeID, subSample, replace = False)
    return waveform[spikeID, :, :]

#-------------------------------------------------------------------------------------------------------
def get_mean_waveform_and_caracteristics(data, shank, cluster, sample=200, kwx=True, filtered=False,redo=True):
    
    #print("redo: %s" %redo )
    picklePath = os.path.join(data.analysisPath, "waveforms.p")
    #load spikes or previously saved file
    if (not redo) and (os.path.exists(picklePath)):
        print("load waveforms from %s"%(picklePath))
        with open(picklePath, 'rb') as f:
            res = pickle.load(f)
            res=res[shank][cluster]
            caracteristics=res[0] 
            peakIndex=res[1]
            valleyIndex=res[2]
            baseIndex=res[3]
            widthEnd=res[4]
            widthStart=res[5]
        
        return caracteristics, peakIndex, valleyIndex, baseIndex, widthEnd, widthStart
    
    #original from typhaine (but buggy, at least on MOU074_2015_07_17_12_54)
    elif os.path.exists(data.fullPath + '.kwx') and kwx:
        waveform = read_kwx_waveform(data, shank, cluster, sample, filtered)
    if os.path.exists(data.fullPath + '.dat'):
        waveform = read_dat_waveform(data, shank, cluster, sample)
    else:
        print('no spike')
        return {}
        
    #compute the mean waveform in each channel
    #find the best channel (=the one where we see the best waveform =the bigger y range)
    diffMinMax = 0
    meanWaveforms = []
    bestChannel = 0
    for channel in range(waveform.shape[2]):
        meanChannel = np.mean(waveform[:, :, channel], axis = 0)
        minMaxChannel = np.max(meanChannel) - np.min(meanChannel)
        meanWaveforms.append(meanChannel)
        if minMaxChannel > diffMinMax:
            diffMinMax = minMaxChannel
            bestChannel = channel
            
    #compute caracteristics of waveform, on the best channel
    wave = meanWaveforms[bestChannel]
    #l = len(meanWaveform)
    #interpolation to get more precise results (for the half width)
    #step = 1
    #wave = np.interp(np.arange(0, l, step), range(l), meanWaveform)    
    wave=np.interp(np.arange(0,len(wave),0.02),np.arange(0,len(wave)),wave)
    
    
    #peak: minimum, in the middle of the waveform
    middleIndex = int(len(wave) / 2)
    middleWave = wave[middleIndex-3 : middleIndex+3]
    peakValue = middleWave.min()
    peakIndex = middleIndex-3 + middleWave.argmin()
    #print(peakValue)
    
    if peakValue>0:
        wave=-wave
        peakValue=-peakValue

    ############    
    #return wave


    #valley: maximum, after the peak
    #valleyIndex = peakIndex + wave[peakIndex : -1000].argmax()   [Not use as first option because does not find the first max]

    try:
        valleyIndex = peakIndex +argrelextrema(wave[peakIndex : -1000], np.greater)[0][0]
    except IndexError:
        valleyIndex = peakIndex + wave[peakIndex : -1000].argmax()
    
    valleyValue = wave[valleyIndex]
    #print("valley value %s" %valleyValue)

    #baseline: maximum, before the peak
    baseIndex = wave[1 : peakIndex].argmax() 
    baseValue = wave[1 + baseIndex]
    
    # different way of getting baseline value, made by david : better :)
    baseValue=np.nanmean(wave[0:1000])
    #print("baseline %s" %baseValue)
    
    #amplitude: baseline - peak (values)
    amp = abs(wave[peakIndex]-baseValue)
    
    #middle width
    middleY = baseValue - amp/2.0 # Y value at half height
    widthStart = np.abs(wave[:peakIndex] - middleY).argmin() #left index of the half-width
    widthEnd = peakIndex + np.abs(wave[peakIndex:] - middleY).argmin() #right index of the half-width
    # voltage factor to return amplitudes in mV (x1000 at the end)
    VoltageFactor=data.voltageRange/(2**data.nBits)/data.amplification*1000
    
    #print(widthStart,widthEnd)
    
    peaktovalleytime=(valleyIndex - peakIndex)/(data.spikeSamplingRate*50)*1000
    #print("peaktovalleytime : %s ms" %peaktovalleytime)
    halfwidth=(widthEnd - widthStart)/(data.spikeSamplingRate*50)*1000
    #print("halfwidth : %s ms" %halfwidth)
    
    
    #spike assymetrie. Correlation between the 1000 point after peak vs 1000 points before peak
    RMat=np.corrcoef(wave[1501:2501],wave[3500:2500:-1])
    asymetrie=RMat[0,1]
    #save the data
    
    caracteristics = {
        "mean waveforms": meanWaveforms,
        "baseline-peak amplitude": amp,
        "peak to valley amplitude":  abs(peakValue-valleyValue),
        "HalfWidth": halfwidth,
        "best channel": bestChannel,
        "PeakToValley": peaktovalleytime,
        "asymetrie": asymetrie
    }
    
    return caracteristics, peakIndex, valleyIndex, baseIndex, widthEnd, widthStart



In [None]:
#-------------------------------------------------------------------------------------------------------
def plot_mean_waveform(data, shank, cluster, group="undefined", sample=200, kwx=True, filtered=False,redo=True,plotAllSpikes=False,ax1=None,ax2=None):
    """ plot the mean waveform and its caracteristics for one cluster """
    
    
    if ax1 is None:
        plt.figure(figsize = (10, 5))
        ax1=plt.subplot(1,2,1)
        
    
    caracteristics, p, v, b, e, s = get_mean_waveform_and_caracteristics(data, shank, cluster, sample, kwx, filtered,redo)
    best = caracteristics.pop("best channel")
    meanWaveforms = caracteristics.pop("mean waveforms")
    wave = meanWaveforms[best]
    wave=np.interp(np.arange(0,len(wave),0.02),np.arange(0,len(wave)),wave)
    
    if wave[p]>0:
        wave=-wave
    
    timeAxis=np.arange(0,len(wave)/(data.spikeSamplingRate/0.02),1/(data.spikeSamplingRate/0.02))
    timeAxis=timeAxis*1000-timeAxis[-1]*1000/2
    
    
       
    ax1.plot(timeAxis,wave)
    title = "Mean waveform - Shank %s Cluster %s channel %s" %(shank, cluster, best)
    title += "\n" + " - ".join(["%.2f"%(caracteristics[v]) for v in sorted(caracteristics)])
    ax1.set_title(title)
    

    ax1.axvspan(timeAxis[s], timeAxis[e], color="lightgrey", alpha=0.5)
    
    baseline = caracteristics["baseline-peak amplitude"] + wave[p]
    ax1.plot((timeAxis[0], timeAxis[p-1]), (baseline, baseline))
    
    ax1.plot((timeAxis[p], timeAxis[len(wave)-1]), (wave[v], wave[v]))
    #plt.xlim(-1, len(wave) +1)
    
    #plot half width
    ax1.plot((timeAxis[e],timeAxis[s]),(wave[e],wave[s]),'x-k')
    ax1.plot((timeAxis[p],timeAxis[v]),(wave[p],wave[p]),'x-k')
    ax1.set_xlim(-2,2)
    ax1.set_ylim(wave.min()-10, wave.max()+10)
    
    
    ###plot half width and peak to valley of the cluster versus all
    AllUnitspicklePath=os.path.join(root,"ALLMOU_Analysis","AllUnitsHalfWidthPeakToValley.p")
    if (os.path.isfile(AllUnitspicklePath)) and (plotAllSpikes):
#         plt.figure(figsize = (5, 5))
        if ax2 is None:
            ax2=plt.subplot(1,2,2)
        with open(AllUnitspicklePath, 'rb') as f:
            AllUnitsCharacteristics = pickle.load(f)
            
        AllUnitsCharacteristics
        AllUnitsCharacteristicsData=AllUnitsCharacteristics["Characteristics"]
        AllUnitsCharacteristicsData
        ax2.scatter(AllUnitsCharacteristicsData["HalfWidth"],AllUnitsCharacteristicsData["PeakToValley"])
        ax2.scatter(caracteristics["HalfWidth"],caracteristics["PeakToValley"],marker='x',s=200)
        ax2.set_xlim(0,0.5)
        
    
    return wave, caracteristics
        
#-------------------------------------------------------------------------------------------------------
def plot_waveforms(data, shank, cluster, group="not specified", sample=200, kwx=True, place=None,
                   filtered=False):
    """ plot the waveforms on every channel for one cluster """
    
    if os.path.exists(data.fullPath + '.kwx') and kwx:
        waveform = read_kwx_waveform(data, shank, cluster, sample, filtered)
    elif os.path.exists(data.fullPath + '.dat'):
        waveform = read_dat_waveform(data, shank, cluster, sample)
    else:
        print('no spikes')  
        return
    if place is None:
        place = [[0, 1], [1, 0], [2, 1], [3, 0], [4, 1], [5, 0], [6, 1], [7, 0]]           
    plt.figure(figsize = (3*2, 6))
    plt.suptitle("Cluster %s" % cluster, fontsize = 14)
    gs = gridspec.GridSpec(8, 3, hspace = -0.3, wspace = 0)    
    for channel in range(waveform.shape[2]):
        channelWaveform = waveform[:, :, channel] 
        x = place[channel][0]
        y = place[channel][1]
        ax = plt.subplot(gs[x,y])
        for spike in channelWaveform:
            ax.plot(spike, color = "darkred");
        ax.set_title("%s" % channel)
        ax.set_axis_off() 

In [None]:
#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():

# below you can change shank and clu
##############################


##############################
    #plot_waveforms(data, shank, cluster, kwx=False)
    plt.figure()
    wave,k= plot_mean_waveform(data, shank, cluster,redo=True,plotAllSpikes=True,kwx=False)
    #ax1.set_title("")
    #print(k)

## 3.12 Mean waveform and caracteristic for all the units of this session (And SAVE!)

In [None]:
def get_all_mean_waveforms(data, groupList=["Good"], sample=200, redo=True, saveAsPickle=True):
    picklePath = os.path.join(data.analysisPath, "waveforms.p")
    #load spikes or previously saved file
    if (not redo) and (os.path.exists(picklePath)):
        print("load waveforms from %s"%(picklePath))
        with open(picklePath, 'rb') as f:
            res = pickle.load(f)
        return res
        
    #dictionary to save results
    res = {}
    #loop on all cluster for this session
    for shank in data.clusterGroup:
        print("Shank", shank)
        res[shank] = {}
        for group in data.clusterGroup[shank]:
            if group not in groupList:
                continue
            for clu in data.clusterGroup[shank][group]:
                print(clu, end=" ")
                res[shank][clu] = get_mean_waveform_and_caracteristics(data, shank, clu, sample=sample)
                #res[shank][clu] = get_mean_waveform_and_caracteristics(data, shank, clu, sample=sample)[0]
        print()
        
    #save as a pickle file
    if saveAsPickle:
        with open(picklePath, 'wb') as f:
            pickle.dump(res, f)
        print("Saved pickle: %s"%picklePath)
        
    return res

## 3.13 Waveforms caracteristics and Firing rate characteristics for all units of this sessions (and save)  

In [None]:
def getUnitsMainCharacteristics(data, groupList=["Good"], sample=200, redofiringrate=False, redowaveform=True,saveAsPickle=True):
    FiringCharacteristics=["meanFRate","PropISI","runFRate","immoFRate"]
    WaveformCharacteristics=["HalfWidth","PeakToValley","asymetrie"]
    print("redofiring: %s" %(redofiringrate))
    print("redowaveform: %s" %(redowaveform))
    picklePath = os.path.join(data.analysisPath, "UnitsWaveFormAndFiringRateCharacteristics.p")
    
    if (not redofiringrate) and (not redowaveform) and (os.path.exists(picklePath)):
        print("load waveform and firing rate characteristics from %s"%(picklePath))
        with open(picklePath, 'rb') as f:
            allCharacteristics = pickle.load(f)
        return allCharacteristics
    
    
    
    if not data.hasSpike:
        print("no spike")
        return
    
    ## get firing rate caractersitics (mean Firing rate, proportion of long ISI, mean Firing rate during run, mean Firing rate during Immo)
    
    
    allFrate=[]
    allpropISI=[]
    AllFrateRun=[]
    AllFrateImmo=[]
    allCharacteristics={}
    
    resFiring = meanFiringRatevsRunandImmoNorm(data, groupList=groupList, redo=redofiringrate, saveAsPickle=saveAsPickle)
    
    for FiringCharacteristic in FiringCharacteristics:
        if (FiringCharacteristic=="meanFRate") or (FiringCharacteristic=="PropISI"):
            allCharacteristics[FiringCharacteristic]={}
            allrateData=[]
            for shankID in data.channelGroupList:
                for clusterID in data.clusterGroup[shankID]['Good']:
                    SpikeTimes=data.spikeTime[shankID][clusterID]
                    ISI=[j-i for i, j in zip(SpikeTimes[:-1], SpikeTimes[1:])] 
                    if FiringCharacteristic=="meanFRate":
                        Frate=1/np.nanmean(ISI)
                        allrateData.append(Frate)
                    if FiringCharacteristic=="PropISI":
                        PropISI=sum([x for x in ISI if x>5])/sum(ISI)
                        allrateData.append(PropISI)

            allCharacteristics[FiringCharacteristic]=allrateData

        
        if (FiringCharacteristic=="immoFRate") or (FiringCharacteristic=="runFRate"):  
            print('Run or Immo Firing rate is beeing computed or loaded')
            
            nSideBin = resFiring["nSideBinForNormPlots"]

            if FiringCharacteristic=="immoFRate":
                meanFiring = resFiring["meanFiringRateDuringImmoNorm"]
                for shank in meanFiring:
                    for clu in meanFiring[shank]:
                        FrateImmo = np.mean(meanFiring[shank][clu][nSideBin+1 : -nSideBin])
                        AllFrateImmo.append(FrateImmo)
                
                allCharacteristics[FiringCharacteristic]=AllFrateImmo
            if FiringCharacteristic=="runFRate":
                meanFiring = resFiring["meanFiringRateDuringRunNorm"]
                
                for shank in meanFiring:
                    for clu in meanFiring[shank]:
                        FrateRun = np.mean(meanFiring[shank][clu][nSideBin+1 : -nSideBin])
                        AllFrateRun.append(FrateRun)
                        
                allCharacteristics[FiringCharacteristic]=AllFrateRun

        
    ## get Waveform charecteristics
    resWave = get_all_mean_waveforms(data, groupList=groupList, sample=sample, redo=redowaveform, saveAsPickle=saveAsPickle)
    for caracteristic in WaveformCharacteristics:
        allCharacteristics[caracteristic]={}
        allsw=[]
        for shankID in data.channelGroupList:
            for cluID in data.clusterGroup[shankID]['Good']:
                    sw = resWave[shankID][cluID][0][caracteristic]
                    allsw.append(sw)
                    allCharacteristics[caracteristic]=allsw
                    

    if saveAsPickle:
        with open(picklePath, 'wb') as f:
            pickle.dump(allCharacteristics, f)
        print("Saved pickle: %s"%picklePath)
        
    return allCharacteristics

In [None]:
#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    #print(data.clusterGroup)
     allCharacteristics= getUnitsMainCharacteristics(data, groupList=["Good"], sample=200, redowaveform=True,redofiringrate=False)

## 3.14 Plot units carateristics (variable input from waveform and firing rate) across animals

In [None]:
## plot waveform characteristics 
def PlotUnitsCharacteristics(animalList=[],groupList=["Good"], sample=200, redowaveform=False,redofiringrate=False, 
                             CharacteristicsToPlot=["PropISI","HalfWidth","PeakToValley"],redoPreprocess=False,saveAsPickle=False):
    sessionindex=0
    allCharacteristicsAcrossSessions={}
    
    
    if not animalList:
        animalList = [os.path.basename(path) for path in sorted(glob.glob(root+"/MOU*"))]
    
    #animalList=["MOU025","MOU026","MOU027","MOU035","MOU074"]
    #animalList=["MOU074"]
    
    for CharacteristicToPlot in CharacteristicsToPlot:
        allCharacteristicsAcrossSessions[CharacteristicToPlot]=[]
    
    print(allCharacteristicsAcrossSessions)

    #list of tags (tag = empty file in the session folder with a specific name)
    #leave empty for no tag
    tagList = ["GoodPerfo"]
    
    fig = plt.figure()
    
    if len(CharacteristicsToPlot)>2:
        #import numpy as np
        from mpl_toolkits.mplot3d import Axes3D
        ax = fig.add_subplot(111, projection='3d')
    else:
        ax = fig.add_subplot(111)
        
            
    
    for animal in animalList:
        print("Animal", animal)
        #Get the list of all session
        sessionList = [os.path.basename(expPath) for expPath in glob.glob(root+"/"+animal+"/Experiments/MOU*")]
        sessionList = sorted(sessionList)
        if not sessionList:
            print("no sessions")
            return

        #loop through sessions
        for session in sessionList:  
    
            #if tag list is not emtpy
            if tagList:
                #check if the session has one of the tag
                if not has_tag(root, animal, session, tagList):
                    continue
    
                print(session)
                #load data for this session (add redoPreprocess=True to overwrite preprocess)
                sessionData = Data(root, animal, session, paramCarola, redoPreprocess=redoPreprocess)
                
                if not sessionData.hasSpike:
                    print("########")
                    print("%s has no spike" %sessionData.experiment)
                    print("########")
                    continue
            
                
                
                allCharacteristics=getUnitsMainCharacteristics(sessionData, groupList=groupList,sample=sample,redofiringrate=redofiringrate,redowaveform=redowaveform,saveAsPickle=saveAsPickle)

                for CharacteristicToPlot in CharacteristicsToPlot:
                    allCharacteristicsAcrossSessions[CharacteristicToPlot]=allCharacteristicsAcrossSessions[CharacteristicToPlot]+allCharacteristics[CharacteristicToPlot]
                
                
                sessionindex+=1


        
    #return allCharacteristicsAcrossSessions,sessionindex

    if len(CharacteristicsToPlot)>2:
        #import numpy as np

        ax.scatter(allCharacteristicsAcrossSessions[CharacteristicsToPlot[0]],allCharacteristicsAcrossSessions[CharacteristicsToPlot[1]],allCharacteristicsAcrossSessions[CharacteristicsToPlot[2]])

        ax.set_xlabel(CharacteristicsToPlot[0])
        ax.set_ylabel(CharacteristicsToPlot[1])
        ax.set_zlabel(CharacteristicsToPlot[2])


    else:

        ax.scatter(allCharacteristicsAcrossSessions[CharacteristicsToPlot[0]],allCharacteristicsAcrossSessions[CharacteristicsToPlot[1]])

        ax.set_xlabel(CharacteristicsToPlot[0])
        ax.set_ylabel(CharacteristicsToPlot[1])
        ax.set_xlim(0,0.5)


            #def randrange(n, vmin, vmax):
            #    return (vmax - vmin)*np.random.rand(n) + vmin



            #
            #ax.scatter(allFr, allhalfwidth, allpeaktovalley)



        
    
    if saveAsPickle:
        VariableNamesForPickelName=''.join(CharacteristicsToPlot)
        PickleName="AllUnits" + VariableNamesForPickelName + ".p"
        picklePath = os.path.join(root,"ALLMOU_Analysis",PickleName)
        AllUnitsCharacteristics={}
        AllUnitsCharacteristics["Characteristics"]=allCharacteristicsAcrossSessions
        AllUnitsCharacteristics["Animals"]=animalList
        with open(picklePath, 'wb') as f:
            pickle.dump(AllUnitsCharacteristics, f)
        print("Saved pickle: %s"%picklePath)
        
    
    return allCharacteristicsAcrossSessions

In [None]:
#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    #animalList=["MOU025","MOU026","MOU027","MOU035","MOU074","MOU075","MOU093"]
    
    ##############################

    ##############################
    
    allCharacteristicsAcrossSessions=PlotUnitsCharacteristics(groupList=["Good"], sample=200, redowaveform=True,redofiringrate=False,CharacteristicsToPlot=["HalfWidth","PeakToValley"],redoPreprocess=False,saveAsPickle=True)

## 3.17 Mean waveforms and autocorrelogram  for all units of this session 

For all the clusters of one session. Each spike halfwidth and peak to valley is plotted verus all spikes

In [None]:
def plotWaveformAndACGPerSession(data, groupList=["Good"], sample=200, redowaveform=False):
    
    
    ###get spike halfwidth and peak-to-valley for all the units already characterzied 
    AllUnitspicklePath=  os.path.join(root,"ALLMOU_Analysis","AllUnitsHalfWidthPeakToValley.p")
    if os.path.isfile(AllUnitspicklePath):
        with open(AllUnitspicklePath, 'rb') as f:
            AllUnitsCharacteristics = pickle.load(f)
            
        AllUnitsCharacteristics
        AllUnitsCharacteristicsData=AllUnitsCharacteristics["Characteristics"]
        
        
        
    else:
        print("no waveform carateristics saved in ALLMOU_Analysis folder")
        
    
    
    
    groupList=["Good"]
    TotalNClusters=0
    for shank in data.clusterGroup:
            #print("Shank", shank)
            for group in data.clusterGroup[shank]:
                if group not in groupList:
                    continue
                for clu in data.clusterGroup[shank][group]:
                    #print("Clu",clu)
                    TotalNClusters+=1

    plt.figure(figsize=(15,3*TotalNClusters))    
    ClusterN=0
    for shank in data.clusterGroup:
            #print("Shank", shank)
            for group in data.clusterGroup[shank]:
                if group not in groupList:
                    continue

                for clu in data.clusterGroup[shank][group]:
                    ClusterN+=1
                    plt.subplot(TotalNClusters,3,ClusterN*3-2)
                    plot_autocorrelogram(data,shank,clu,1,30)
                    thisax=plt.subplot(TotalNClusters,3,ClusterN*3-1)
                    caracteristics=plot_mean_waveform(data, shank, clu,redo=redowaveform,plotAllSpikes=False,ax1=thisax)[1]
                    plt.subplot(TotalNClusters,3,ClusterN*3)
                    plt.scatter(AllUnitsCharacteristicsData["HalfWidth"],AllUnitsCharacteristicsData["PeakToValley"])
                    plt.scatter(caracteristics["HalfWidth"],caracteristics["PeakToValley"],marker='x',s=200)
                    plt.xlim(0,0.5)
                    #print("clu",clu)
    plt.tight_layout()

In [None]:
#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    plotWaveformAndACGPerSession(data, groupList=["Good"], sample=200, redowaveform=False)
    

## 3.18 For a given cluster find singificant modulations during run 


In [None]:
def GetModulatedPortionsDuringRun(data,shank,cluster,runType="all",printoutput=True,plotoutput=True,ax=None):
 
    plotaxinfo,meanFiring,nSideBin,spikeHist,allSpeed,allDuration = plot_normalized_running_periods_firing_rate(data,shank,cluster,binSize=0.25,SideLength=3,runType=runType,ax=ax)
    percentileShuffle,allMeanShuffledFiring = plot_shuffled_running_firing_rate(data,shank,cluster,binSize=0.25,SideLength=3,ax=plotaxinfo,nShuffle=500,runType=runType)
    allMeanShuffledFiring = np.array(allMeanShuffledFiring)

    
    MeanFiringRateDuringRun=np.nanmean(meanFiring[nSideBin+1:-nSideBin])
    #print(MeanFiringRateDuringRun)
    if MeanFiringRateDuringRun<0.1:
        if printoutput:
            print("warning !!! there is not enough spike")
        ModulationSign="not enough spikes"
        ModulationResults={
            "AllModulatedIndexes":[],
            "BiggestModulationIndexes":[],
            "MeanFiringRateZscored":[],
            "AllSignificantModulationIndexInRunSameSignThanBiggest":[],
            "AllModulatedPortionsSameSignThatBiggest":[],
            "ModulatedFractionSameSignThatBiggest":[],
            "BiggestModulationSign":ModulationSign
            }
        
        
        InsideRunIndexes=[]
        if not plotoutput:
            plt.close()
        
    
    
    else:
        MeanFiringRateZscored=stats.zscore(meanFiring)
        #####
        # now find the biggest modulation
        RunBins=(np.arange(len(meanFiring))>=nSideBin) & (np.arange(len(meanFiring))<=(len(meanFiring)-(nSideBin+1)))
        InsideRunIndexes=[x for x,value in enumerate(RunBins.tolist()) if value]

        ## find all the beg and end of all periods of modulation beyond global bands of confidence during run (ModulationBeyondGlobalLimit)

        globalpercentile=np.percentile(allMeanShuffledFiring,[0.5,99.5])

        ModulatedFiringRateEpoch=(meanFiring>=globalpercentile[-1]) | (meanFiring<=globalpercentile[0])
        ModulatedIndexes=[i for i,value in enumerate(ModulatedFiringRateEpoch) if value]
        ModulationBeyondGlobalLimit=contiguous_regions(ModulatedFiringRateEpoch)

        #now the modulation beyond global bands are extended to pointwise limits 
        
        PointConfidenceBand=np.percentile(allMeanShuffledFiring,[2.5,97.5],axis=0)
        LowPointLimit=PointConfidenceBand[0].tolist()
        HighPointLimit=PointConfidenceBand[1].tolist()

        AllModulatedIndexes=[]
        Modulationmagnitude=[]
        for globallimits in ModulationBeyondGlobalLimit:
            if printoutput:
                print("significant portion beyond globallimits: %s"%globallimits)
            if np.nanmean(meanFiring[globallimits[0]:globallimits[1]]) >= globalpercentile[-1]: # case positive modulation
                ModulationBeyondPointWiseLimit=contiguous_regions(meanFiring>HighPointLimit)
                
            else:   # case negative modulation         
                ModulationBeyondPointWiseLimit=contiguous_regions(meanFiring<LowPointLimit)
                
            ModulationBeyondPointWiseLimit=ModulationBeyondPointWiseLimit.tolist()


            for pointwiselimits in ModulationBeyondPointWiseLimit:
                Intersection=np.intersect1d(np.arange(pointwiselimits[0],pointwiselimits[1]),np.arange(globallimits[0],globallimits[1]))
                Intersection=Intersection.tolist()
                if Intersection:
                    RunIntersection=np.intersect1d(np.arange(pointwiselimits[0],pointwiselimits[1]),InsideRunIndexes)
                    RunIntersection=RunIntersection.tolist()
                    if RunIntersection:
                        PointWiseSignificantModulation=list(range(pointwiselimits[0],pointwiselimits[1]))       
                        AllModulatedIndexes.append(PointWiseSignificantModulation)
                        ModulatedPortionDuringRunOnly=[i for i in PointWiseSignificantModulation if i in InsideRunIndexes]
                        if printoutput: 
                            print ("ModulatedPortionDuringRunOnly",ModulatedPortionDuringRunOnly,"PointWiseSignificantModulation",PointWiseSignificantModulation,"InsideRunIndexes",InsideRunIndexes)

                        #ThisModulationMagnitude=np.nanmean(meanFiring[ModulatedPortionDuringRunOnly]- np.asarray(HighPointLimit)[ModulatedPortionDuringRunOnly])
                        
                        #print("zscorevalues= ",MeanFiringRateZscored[ModulatedPortionDuringRunOnly])
                        ThisModulationMagnitude=np.nanmax(np.abs(MeanFiringRateZscored[ModulatedPortionDuringRunOnly]))
                        #Modulationmagnitude.append(ThisModulationMagnitude)
                        if np.nanmean(meanFiring[globallimits[0]:globallimits[1]]) >= globalpercentile[-1]: # case positive modulation
                            NameForPrint="This is a positive modulation"
                            Modulationmagnitude.append(ThisModulationMagnitude)

                        else:  # case negative modulation 
                            NameForPrint="This is a negative modulation"
                            Modulationmagnitude.append(ThisModulationMagnitude*-1)
                            
                        if printoutput:    
                            print("%s of inside run amplitude %s Zscored" %(NameForPrint,Modulationmagnitude[-1]))
                            print("###########")
                            print("")
                            
                            
                            
                            
        #remove duplicate
        AllModulatedIndexes=list(AllModulatedIndexes for AllModulatedIndexes,_ in itertools.groupby(AllModulatedIndexes))
        Modulationmagnitude=pd.unique(Modulationmagnitude)

        
        if printoutput:           
            print("")
            print("###")
            print ("Modulationmagnitude",Modulationmagnitude)
            print("AllModulatedIndexes",AllModulatedIndexes)
            print("")
        AllSignificantModulationIndexInRunSameSignThanBiggest=[]
        
        if len(Modulationmagnitude)>0:
            ## this finds the index of strongest modulation (if there are multiple points of modulation)
            ModulationmagnitudeAbs=list(np.abs(Modulationmagnitude))
            Index=ModulationmagnitudeAbs.index(max(np.abs(Modulationmagnitude)))

            ## now find all the modulation of the same sign than the same modulation max

            if Modulationmagnitude[Index]<0:
                ModulationSign="negative"                
            else:
                ModulationSign="positive"
                
                
            for count,value in enumerate(AllModulatedIndexes):
                if Modulationmagnitude[count] * Modulationmagnitude[Index] > 0 : # we only append modulation of the same sign than the biggest modulation
                    AllSignificantModulationIndexInRunSameSignThanBiggest=AllSignificantModulationIndexInRunSameSignThanBiggest+value
                    thisxaxis=[x+0.5 for x in value]                    
                    if count==Index:
                        plt.plot(thisxaxis,meanFiring[value],'go',markersize=12)
                        BiggestModulationIndexes=value

                    else:
                        plt.plot(thisxaxis,meanFiring[value],'ro',markersize=12)
                else:
                    thisxaxis=[x+0.5 for x in value]
                    plt.plot(thisxaxis,meanFiring[value],'r+',markersize=12)



            #remove duplicate
            #AllModulatedIndexes=list(AllModulatedIndexes for AllModulatedIndexes,_ in itertools.groupby(AllModulatedIndexes))

            AllSignificantModulationIndexInRunSameSignThanBiggest=set(AllSignificantModulationIndexInRunSameSignThanBiggest) ## return unique element (index) of the list only same signe modulation than biggestone
            AllSignificantModulationIndexInRunSameSignThanBiggest=list(AllSignificantModulationIndexInRunSameSignThanBiggest)
            #AllSignificantModulationIndexInRunSameSignThanBiggest=AllSignificantModulationIndexInRunSameSignThanBiggest.sort()
            AllSignificantModulationIndexInRunSameSignThanBiggest.sort()
            
            AllModulatedPortionsSameSignThatBiggest=[(x-InsideRunIndexes[0])/(len(InsideRunIndexes)-1) for x in AllSignificantModulationIndexInRunSameSignThanBiggest]
            AllModulatedPortionsSameSignThatBiggest.sort()

            ModulatedFractionSameSignThatBiggest=len(AllSignificantModulationIndexInRunSameSignThanBiggest)/len(InsideRunIndexes)
            if printoutput:
                print("Biggest Modulation Sign: %s" %ModulationSign)
                print("AllSignificantModulationIndexInRunSameSignThanBiggest", AllSignificantModulationIndexInRunSameSignThanBiggest)
                print("")
                print("Modulated Portions: %s"%AllModulatedPortionsSameSignThatBiggest)
                print("")
                print("Modulated Fraction (relative to run): %s"%ModulatedFractionSameSignThatBiggest)
            

            
            ModulationResults={
                "AllModulatedIndexes":AllModulatedIndexes,
                "BiggestModulationIndexes":BiggestModulationIndexes,
                "MeanFiringRateZscored":MeanFiringRateZscored,
                "AllSignificantModulationIndexInRunSameSignThanBiggest":AllSignificantModulationIndexInRunSameSignThanBiggest,
                "AllModulatedPortionsSameSignThatBiggest":AllModulatedPortionsSameSignThatBiggest,
                "ModulatedFractionSameSignThatBiggest":ModulatedFractionSameSignThatBiggest,
                "BiggestModulationSign":ModulationSign
                }
            
            
         

                
#             ModulationResults=[AllModulatedIndexes,
#                         AllModulatedPortionsSameSignThatBiggest,
#                         AllSignificantModulationIndexInRunSameSignThanBiggest,                     
#                         BiggestModulationIndexes,
#                         MeanFiringRateZscored,
#                         ModulatedFractionSameSignThatBiggest,
#                         ModulationSign]
                
            
            
        else: # case there is no modulation
            ModulationSign="not modulated"
            ModulationResults={
                "AllModulatedIndexes":[],
                "BiggestModulationIndexes":[],
                "MeanFiringRateZscored":[],
                "AllSignificantModulationIndexInRunSameSignThanBiggest":[],
                "AllModulatedPortionsSameSignThatBiggest":[],
                "ModulatedFractionSameSignThatBiggest":[],
                "BiggestModulationSign":ModulationSign
                }
            
            
            
            
            if printoutput:
                print("This unit is %s during run" %ModulationResults)
            
        if not plotoutput:
            plt.close()
#                 from IPython.display import clear_output
#                 clear_output()
        
    return ModulationResults,InsideRunIndexes,spikeHist,allSpeed,allDuration,nSideBin,plotaxinfo
        
    
    

In [None]:
#------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    
    ##############################

    
    ##############################
    #GetModulatedPortionsDuringRun(data,shank,cluster,printoutput=True,plotoutput=True)
    ThisModulationResults,InsideRunIndexes,spikeHist,allSpeed,allDuration,nSideBin,plotaxinfo=GetModulatedPortionsDuringRun(data,shank,cluster,printoutput=True,plotoutput=True)

## 3.21 For modulated portions correlate firing rate vs run kinematic trial by trial

In [None]:
def correlationFrateVsKinematic(data,shank,cluster,runType="all",redoModulation=False,showplot=True):
    
    FiringRateCorrelationsWith={
            "SpeedDuringModulatedEpoch":{},
            "AccelDuringModulatedEpoch":{},
            "RunDistance":{},
            "RunDuration":{}
            }
    
    
    fig = plt.figure(figsize=(10,15))
    gs1 = gridspec.GridSpec(1, 1)
    ax1=fig.add_subplot(gs1[0,0])



    ## load if exist and not redoModulation the saved modulation data
    picklePath=os.path.join(root,data.animal,"Experiments",data.experiment,"Analysis","ModulationDuring" + runType + "Runs.p")
    #print(picklePath)

    if os.path.exists(picklePath) and not redoModulation:
        ModulationResults=pickle.load(open(picklePath,"rb"))
        print("spiking modulation  data loaded from %s"%picklePath)

        if (ModulationResults["BiggestModulationSign"][shank][cluster]=="not modulated") or (ModulationResults["BiggestModulationSign"][shank][cluster]=="not enough spikes"):
            FiringRateCorrelationsWith={
            "SpeedDuringModulatedEpoch":[],
            "AccelDuringModulatedEpoch":[],
            "RunDistance":[],
            "RunDuration":[]
            }
            
            
            if not showplot:
                plt.close()
            return FiringRateCorrelationsWith

        # get and plot behavioral and neural data for all the behaviral epochs
        plotaxinfo,meanFiring,nSideBin,spikeHist,allSpeed,allDuration = plot_normalized_running_periods_firing_rate(data,shank,cluster,binSize=0.25,SideLength=3,runType=runType,ax=ax1)
        ModulatedBins=ModulationResults["BiggestModulationIndexes"][shank][cluster]
        plotaxinfo.plot([i+0.5 for i in ModulatedBins],meanFiring[ModulatedBins],'og')

    else:
        ModulationResults,InsideRunIndexes,spikeHist,allSpeed,allDuration,nSideBin,plotaxinfo=GetModulatedPortionsDuringRun(data,shank,cluster,runType="all",printoutput=False,plotoutput=True,ax=ax1)
        ModulatedBins=ModulationResults["BiggestModulationIndexes"]

    gs1.tight_layout(fig, rect=[0.25, 0.66, 0.75, 1],h_pad=0.33)



    FRateDuringModulatedEpoch=[]
    SpeedDuringModulatedEpoch=[]
    AccelDuringModulatedEpoch=[]
    DistanceRun=[]
    for x,r in enumerate(allDuration):
        if sum(spikeHist[x][ModulatedBins])>=0:

            #frate trial by trial
            FRateDuringModulatedEpoch.append(np.nanmean(spikeHist[x][ModulatedBins]))

            #running speed
            SpeedDuringModulatedEpoch.append(np.nanmean(allSpeed[x][ModulatedBins]))

            #acceleration
            BinDuration=r/(len(allSpeed[x])-2*nSideBin)
            SpeedInModulatedBins=allSpeed[x][ModulatedBins]
            Acceleration=(SpeedInModulatedBins[-1]-SpeedInModulatedBins[0])/((sum(ModulatedBins)-1)*BinDuration) 
            AccelDuringModulatedEpoch.append(Acceleration)

            #run distance

            TotalDistanceRun=sum(allSpeed[x][int(nSideBin):-int(nSideBin)]*BinDuration)
            DistanceRun.append(TotalDistanceRun)



    AllKinematicParameters={
            "SpeedDuringModulatedEpoch":SpeedDuringModulatedEpoch,
            "AccelDuringModulatedEpoch":AccelDuringModulatedEpoch,
            "RunDistance":DistanceRun,
            "RunDuration":allDuration
            }
   # AllKinematicParameters=sorted(AllKinematicParameters)
    AllYLabels=["Running accel $\mathregular{(cm/s^2)}$","Run distance (cm)","Run duration (s)","Running speed (cm/s)"]




    gs2 = gridspec.GridSpec(2, 2)
    subplotcoordinates=[[0,1],[1,0],[1,1],[0,0]]

    x=FRateDuringModulatedEpoch
    
    if np.median(x)<0.2: # case when firing rate during modulation is very low accross trials 
        FiringRateCorrelationsWith={
            "SpeedDuringModulatedEpoch":[],
            "AccelDuringModulatedEpoch":[],
            "RunDistance":[],
            "RunDuration":[]
            }
        
        print("")
        print("firing rate too low for correlation stats")
        print("")
        if not showplot:
            plt.close()
        return FiringRateCorrelationsWith

    




    for count,param in enumerate(sorted(AllKinematicParameters)):
        ax = fig.add_subplot(gs2[subplotcoordinates[count][0],subplotcoordinates[count][1]])
        print(param)
        y=AllKinematicParameters[param]
        
        fit = np.polyfit(x,y,1)
        fit_fn = np.poly1d(fit) 
        #plt.plot(FRateDuringModulatedEpoch,RSpeedDuringModulatedEpoch,'o')

        ax.plot(x,y, 'ro', x, fit_fn(x), '-k',linewidth=2)
        #ax.ylim(7,30)
        ax.set_xlabel("Firing rate (Hz)",fontsize=20,weight="bold")
        ax.set_ylabel(AllYLabels[count],fontsize=20,weight="bold")

        MinMaxForPlot=np.percentile(AllKinematicParameters[param],[2,98])
        ax.set_ylim(MinMaxForPlot)
        SpearManResults=stats.spearmanr(x,y)
        rvalue=str(round(SpearManResults[0],2));
        if SpearManResults[1]<0.0001:
            pvalue="p<0.0001"
        else:
            pvalue="p="+ str(round(SpearManResults[1],4))

        title="r=%s, %s"%(rvalue,pvalue)
        ax.set_title(title,fontsize=20,weight="bold")
        ax.tick_params(axis='both',which='major',labelsize=20,width=2) 
        FiringRateCorrelationsWith[param]=([SpearManResults[0],SpearManResults[1]])

    gs2.tight_layout(fig, rect=[0, 0, 1, 0.66],h_pad=0.33)

    if not showplot:
        plt.close()
    
    return FiringRateCorrelationsWith

In [None]:
#------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():   
    #x,y,AllKinematicParameters,meanFiring=correlationFrateVsKinematic(data,shank,cluster)
    #######################

    
    #######################
    FiringRateCorrelationsWith=correlationFrateVsKinematic(data,shank,cluster,showplot=True,redoModulation=True)

## 3.19 For all clusters of this session, finds singificant modulations during run AND SAVE

In [None]:
def GetAllSignificantModulation(data, groupList=["Good"], saveAsPickle=True, redo=False, binSize=0.25, 
                                   SideLength=3, nShuffle=500,runType="all",printoutput=False,plotoutput=False):
    """
    Loops through the units of a session and run GetModulatedPortionsDuringRun
    to detect signigicant modulation during distinct type of behavioral epochs
    then save the data in a pickles
    
    """
    #load and return the pickle if it exists (and redo=False)
    
    if not data.hasSpike:
        return
    
    
    picklePath=os.path.join(data.analysisPath,"ModulationDuring" + runType + "Runs.p")
    

    if (not redo) and os.path.exists(picklePath):
        with open(picklePath, 'rb') as f:
            print("loaded pickle %s"%picklePath)
            return pickle.load(f)
    
    #check that groupList is a list
    if not isinstance(groupList,list):
        groupList=[groupList]
        
    #Get mean firing rates for each cluster
    SavedInsideRunIdex=[]
    
    ModulationSavedData={
        "AllModulatedIndexes":{},
        "BiggestModulationIndexes":{},
        "MeanFiringRateZscored":{},
        "AllSignificantModulationIndexInRunSameSignThanBiggest":{},
        "AllModulatedPortionsSameSignThatBiggest":{},
        "ModulatedFractionSameSignThatBiggest":{},
        "BiggestModulationSign":{},
        }

    nSideBin = None
    for shank in sorted(data.clusterGroup):               
        for key in ModulationSavedData:
            ModulationSavedData[key][shank]={}
        for group in data.clusterGroup[shank]:
            if (groupList is not None) and (group not in groupList):
                continue
            for cluster in sorted(data.clusterGroup[shank][group]):
                print("Shank %s  Cluster %s" %(shank,cluster))
                ModulationResults,InsideRunIndexes=GetModulatedPortionsDuringRun(data,shank,cluster,printoutput=printoutput,plotoutput=plotoutput)[0:2]
                if ModulationResults:
                    for key in sorted(ModulationSavedData):
                        ModulationSavedData[key][shank][cluster]=ModulationResults[key]
                if InsideRunIndexes and not SavedInsideRunIdex:
                    SavedInsideRunIdex=InsideRunIndexes

    
    
    ModulationSavedData["InsideRunIndexese"]=InsideRunIndexes
    ## Create and save pickle
    if saveAsPickle:
        with open(picklePath, 'wb') as f:
            pickle.dump(ModulationSavedData, f)
        print("Saved pickle: %s"%picklePath)

    
    
    return ModulationSavedData
                    
                    

In [None]:
#------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():   

    ModulationSavedData=GetAllSignificantModulation(data,redo=True,plotoutput=True)

In [None]:
#------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():   

    runType="all"
    picklePath=os.path.join(data.analysisPath,"ModulationDuring" + runType + "Runs.p")
    print(picklePath)
    with open(picklePath, 'rb') as f:
        print("loaded pickle %s"%picklePath)
        test=pickle.load(f)

## 3.21 For modulated portions of all units in this session correlate firing rate vs run kinematic trial by trial AND SAVE

In [None]:
def GetAllCorrCoefFRateVsKinematic(data, groupList=["Good"], saveAsPickle=False, redo=False, redoModulation=False, binSize=0.25, 
                                  runType="all",showplot=False):
    """
    Loops through the units of a session and run correlationFrateVsKinematic
    to returm correlation beteen firing rate and kinematic, trial by trial, 
    at the time of the maximum modulation
    then save the data in a pickles
    
    """
    #load and return the pickle if it exists (and redo=False)
    
    if not data.hasSpike:
        return
    
    
    picklePath=os.path.join(data.analysisPath,"FRateVsKinematciCorrDuring" + runType + "Runs.p")
    

    if (not redo) and os.path.exists(picklePath):
        with open(picklePath, 'rb') as f:
            print("loaded pickle %s"%picklePath)
            return pickle.load(f)
    
    #check that groupList is a list
    if not isinstance(groupList,list):
        groupList=[groupList]
        
    #Get mean firing rates for each cluster
    SavedInsideRunIdex=[]
    
    FRateCorrelationWithSavedData={
            "SpeedDuringModulatedEpoch":{},
            "AccelDuringModulatedEpoch":{},
            "RunDistance":{},
            "RunDuration":{}
            }

    nSideBin = None
    for shank in sorted(data.clusterGroup):               
        for key in FRateCorrelationWithSavedData:
            FRateCorrelationWithSavedData[key][shank]={}
        for group in data.clusterGroup[shank]:
            if (groupList is not None) and (group not in groupList):
                continue
            for cluster in sorted(data.clusterGroup[shank][group]):
                print("Shank %s  Cluster %s" %(shank,cluster))
                FiringRateCorrelationsWith=correlationFrateVsKinematic(data,shank,cluster,runType="all",redoModulation=redoModulation,showplot=showplot)
                
                if FiringRateCorrelationsWith:
                    for key in sorted(FRateCorrelationWithSavedData):
                        #print(key)
                        ###
                        #return FRateCorrelationWithSavedData,FiringRateCorrelationsWith
                        ###
                        FRateCorrelationWithSavedData[key][shank][cluster]=FiringRateCorrelationsWith[key]
                else:
                    for key in sorted(FRateCorrelationWithSavedData):
                        FRateCorrelationWithSavedData[key][shank][cluster]=[]
                    
                

    
    
    
    ## Create and save pickle
    if saveAsPickle:
        with open(picklePath, 'wb') as f:
            pickle.dump(FRateCorrelationWithSavedData, f)
        print("Saved pickle: %s"%picklePath)

    
    
    return FRateCorrelationWithSavedData
                    
                    

In [None]:
#------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():   

    FRateCorrelationWithSavedData=GetAllCorrCoefFRateVsKinematic(data,redo=True)

## 3.22 Loop through saved modulation data accross animals and plot some population statistic about modulations: 

In [None]:
#------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():  
    
    animalList=[os.path.basename(path) for path in sorted(glob.glob(root+"/MOU*"))]

    # Redefine animalList if you don't want all the animals
    # animalList=["MOU015","MOU016","MOU017","MOU018","MOU019","MOU028","MOU029"]

    #print("List of animal to loop through: %s"%animalList)
    tagList = ["GoodPerfo"]


    #Whether to read the existing pickle files (redo=False) or to reload from raw text files (redo=True)
    redo=False


    runType="all"
    binmiddles=np.round(np.arange(-0.95,2.05,.1)*100)/100
    binedgesinput=np.round(np.arange(-1,2.1,.1)*10)/10
    AllModulatedPortion=[]
    AllSignsOfModulation=[]
    AllModulatedFraction=[]
    AllPeakModulationPositionRelativeToRun=[]
    ### loop throught the session and load the saned data on modulation to find the length of average firing rate for the different sessions


    #loop on animal
    for animal in animalList:


        #Get the list of all session
        sessionList=[os.path.basename(expPath) for expPath in glob.glob(root+"/"+animal+"/Experiments/MOU*")]
        sessionList=sorted(sessionList)
        nbSession=len(sessionList)   


        #loop through sessions
        #pdb.set_trace()
        for index,session in enumerate(sessionList):


            if not has_tag(root, animal, session, tagList):
                continue       

            picklePath=os.path.join(root,animal,"Experiments",session,"Analysis","ModulationDuring" + runType + "Runs.p")
            if os.path.exists(picklePath):
                ModulationResults=pickle.load(open(picklePath,"rb"))
    #             print("spiking modulation  data loaded from %s"%picklePath)
    #             print("")
                for shank in ModulationResults["AllModulatedPortionsSameSignThatBiggest"]:
                    for cluster in ModulationResults["AllModulatedPortionsSameSignThatBiggest"][shank]:
                        ThisClusterModulationsResults=ModulationResults["AllModulatedPortionsSameSignThatBiggest"][shank][cluster]

                        if ThisClusterModulationsResults:
                            histcount=np.clip(np.histogram(ThisClusterModulationsResults,binedgesinput)[0],0,1)
    #                         print(histcount)

                            AllModulatedPortion.append(histcount)

                            AllModulatedFraction.append(ModulationResults["ModulatedFractionSameSignThatBiggest"][shank][cluster])


                            SignifiantModulationAmplitude = np.abs(ModulationResults["MeanFiringRateZscored"][shank][cluster][ModulationResults['AllSignificantModulationIndexInRunSameSignThanBiggest'][shank][cluster]])
                            AllPeakModulationPositionRelativeToRun.append(ModulationResults["AllModulatedPortionsSameSignThatBiggest"][shank][cluster][SignifiantModulationAmplitude.argmax()])                


                        ThisClusterSignOfModulation=ModulationResults["BiggestModulationSign"][shank][cluster]
                        if ThisClusterSignOfModulation=="positive":
                            AllSignsOfModulation.append(3)
                        elif ThisClusterSignOfModulation=="negative":
                            AllSignsOfModulation.append(1)
                        else:
                            AllSignsOfModulation.append(2)

            else:            
                print("no pickle data at %s"%picklePath)
                print("")
                continue

    plt.figure()            
    AllModulatedPortion=np.vstack(AllModulatedPortion)
    #print(test.sum(0))
    plt.bar(binmiddles,AllModulatedPortion.sum(0),width=0.1)


    plt.figure()
    [histcount,binedges]=np.histogram(AllModulatedFraction,np.arange(0,1.5,0.1))
    plt.bar(np.arange(0.05,1.45,0.1),histcount,width=0.1)


    plt.figure()
    [histcount,binedges]=np.histogram(AllPeakModulationPositionRelativeToRun,binedgesinput)
    plt.bar(binmiddles,histcount,width=0.1)


    [histcount,binedges]=np.histogram(AllSignsOfModulation,[1,2,3,4])
    print("Nber of unit negatively/not/posively modulated during run %s/%s/%s"%(histcount[0],histcount[1],histcount[2]))


## 3.23 Loop through saved correlation data accross animals and plot some population statistics: 

In [None]:
#------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():  
    
    animalList=[os.path.basename(path) for path in sorted(glob.glob(root+"/MOU*"))]

    # Redefine animalList if you don't want all the animals
    #animalList=["MOU035","MOU026","MOU027","MOU025"]

    #print("List of animal to loop through: %s"%animalList)
    tagList = ["GoodPerfo"]


    #Whether to read the existing pickle files (redo=False) or to reload from raw text files (redo=True)
    redo=False


    runType="all"


    
    
    
    AllUNitsFRateCorrelation={
        "SpeedDuringModulatedEpoch":[],
        "AccelDuringModulatedEpoch":[],
        "RunDistance":[],
        "RunDuration":[]
        }
    
    AllUNitsFRateCorrelationPValue={
        "SpeedDuringModulatedEpoch":[],
        "AccelDuringModulatedEpoch":[],
        "RunDistance":[],
        "RunDuration":[]
        }

    
    
    
    
    
    ### loop throught the session and load the saned data on modulation to find the length of average firing rate for the different sessions


    #loop on animal
    for animal in animalList:


        #Get the list of all session
        sessionList=[os.path.basename(expPath) for expPath in glob.glob(root+"/"+animal+"/Experiments/MOU*")]
        sessionList=sorted(sessionList)
        nbSession=len(sessionList)   


        #loop through sessions
        #pdb.set_trace()
        for index,session in enumerate(sessionList):


            if not has_tag(root, animal, session, tagList):
                continue       

            picklePath=os.path.join(root,animal,"Experiments",session,"Analysis","FRateVsKinematciCorrDuring" + runType + "Runs.p")
            if os.path.exists(picklePath):
                CorrelationResults=pickle.load(open(picklePath,"rb"))
                print("spiking modulation  data loaded from %s"%picklePath)
    #             print("")
                for kinematicvariable in CorrelationResults:
                    for shank in CorrelationResults[kinematicvariable]:
                        for cluster in CorrelationResults[kinematicvariable][shank]:
                            ThisClusterCorrelationResult=CorrelationResults[kinematicvariable][shank][cluster]

                            if ThisClusterCorrelationResult:
                                AllUNitsFRateCorrelation[kinematicvariable].append(ThisClusterCorrelationResult[0])
                                AllUNitsFRateCorrelationPValue[kinematicvariable].append(ThisClusterCorrelationResult[1])
                   


                            

            else:            
                print("no pickle data at %s"%picklePath)
                print("")
                continue

    plt.figure(figsize=(15,15))
    binmiddles = np.round(np.arange(-0.95,1.05,.1)*100)/100
    binedgesinput = np.round(np.arange(-1,1.1,.1)*10)/10
    
    for count,kinematicvariable in enumerate(AllUNitsFRateCorrelation):
        [histcount,binutput]=np.histogram(AllUNitsFRateCorrelation[kinematicvariable],binedgesinput)
        plt.subplot(2,2,count+1)
        plt.bar(binmiddles,histcount,width=0.1)
        
        
        PValues=AllUNitsFRateCorrelationPValue[kinematicvariable]
        Pthreshold=0.01
        SignificantPValues=[x for x in PValues if x<Pthreshold]
        PercentSignificant=len(SignificantPValues)/len(PValues)*100
        plt.ylabel('count')
        plt.xlabel('cor coef')
        
        plt.title("%s, \n %3.2f percent units are singificantly correlated " %(kinematicvariable,PercentSignificant))
        
        

#     AllModulatedPortion=np.vstack(AllModulatedPortion)
#     #print(test.sum(0))
#     plt.bar(binmiddles,AllModulatedPortion.sum(0),width=0.1)


#     plt.figure()
#     [histcount,binedges]=np.histogram(AllModulatedFraction,np.arange(0,1.5,0.1))
#     plt.bar(np.arange(0.05,1.45,0.1),histcount,width=0.1)


#     plt.figure()
#     [histcount,binedges]=np.histogram(AllPeakModulationPositionRelativeToRun,binedgesinput)
#     plt.bar(binmiddles,histcount,width=0.1)


#     [histcount,binedges]=np.histogram(AllSignsOfModulation,[1,2,3,4])
#     print("Nber of unit negatively/not/posively modulated during run %s/%s/%s"%(histcount[0],histcount[1],histcount[2]))



## 3.24 Summary plot for a given cluster. Basic information + modulation + correlations

In [None]:
def cluster_ratevskinematicplot(data,shank,cluster,group,binSize=0.25,SideLength=3,printoutput=False):    
    
    fig=plt.figure(figsize=(25,30))
    
    ## trial by trial raster and behavior
    gs1 = gridspec.GridSpec(5, 1)
    
    # Left vertical plot: wheel/lick detection and spike rasters trial by trial
    ax1=fig.add_subplot(gs1[0:4,0])
    plot_break_cluster(data,shank,cluster,group="Good",legend=False,lick=True)
    xlimvalue=ax1.get_xlim()
    ax1.set_title("")
    
     # Below vertical plot : mean firing rate and running speed
    ax2 = fig.add_subplot(gs1[4,0])
    ax2=plot_mean_breaks_firing_rate(data,shank,cluster,trialType = "good")
    ax2.set_xlim(xlimvalue)
    ax2.set_title("")
    
    gs1.tight_layout(fig, rect=[0, 0, 0.33, 1])
    
    ##autocorrelagrams
    gs2 = gridspec.GridSpec(1, 3)
    
    ax3 = fig.add_subplot(gs2[0,0])
    plot_autocorrelogram(data,shank,cluster,1,30)
    ax3.set_title("")
    
    ax4 = fig.add_subplot(gs2[0,1])
    plot_autocorrelogram_period(data, shank, cluster, 5, 1000)
    
    ax5= fig.add_subplot(gs2[0,2])
    plot_autocorrelogram_period(data, shank, cluster, 5, 1000, immobility=True)
    
    gs2.tight_layout(fig, rect=[0.33, 0.85, 1, 1],h_pad=0.33)
    
    ## waveform
    gs4 = gridspec.GridSpec(1, 2)
    ax9=fig.add_subplot(gs4[0,0])
    ax10=fig.add_subplot(gs4[0,1])
    wave,k= plot_mean_waveform(data, shank, cluster,redo=False,plotAllSpikes=True,ax1=ax9,ax2=ax10)
    
    gs4.tight_layout(fig,rect=[0.45, 0.7, 0.9, 0.85],h_pad=0.33)
    
    
    ###mean firing rate during run and immo and correlation plots
    gs3 = gridspec.GridSpec(4, 2)
    
    #2nd raw 2nd plot : firing rate vs good run
    ax6 = fig.add_subplot(gs3[0,0])
    ax6=plot_normalized_running_periods_firing_rate(data,shank,cluster,binSize=binSize,runType="trial good run",ax=ax6)[0]
    ylimvalue=ax6.get_ylim()
    
    
    # 2nd raw 3nd plot : firing rate vs non rewarded runs
    ax7 = fig.add_subplot(gs3[0,1])
    ax7=plot_normalized_running_periods_firing_rate(data,shank,cluster,binSize=binSize,runType="unrewarded",ax=ax7)[0]
    if ylimvalue>ax7.get_ylim():
        ax7.set_ylim(ylimvalue)
    else:
        ylimvalue=ax7.get_ylim()
        ax6.set_ylim(ylimvalue)
   
    
    # 2n line left plot : firing rate vs all immo + shuffling + modulated bin
    ax7 = fig.add_subplot(gs3[1,0])
    ax7,meanFiring,nSideBin,spikeHist,allSpeed,allDuration = plot_normalized_immobility_periods_firing_rate(data,shank,cluster,binSize=binSize,SideLength=SideLength,ax=ax7)
    percentileShuffle,allMeanShuffledFiring = plot_shuffled_immobility_firing_rate(data,shank,cluster,binSize=binSize,SideLength=SideLength,ax=ax7)
    
    
    
    # 2n line right plot : firing rate vs all run + shuffling + modulated bin
    ax8 = fig.add_subplot(gs3[1,1])
    
    
    ModulationResults,InsideRunIndexes,spikeHist,allSpeed,allDuration,nSideBin,plotaxinfo=GetModulatedPortionsDuringRun(data,shank,cluster,runType="all",printoutput=printoutput,ax=ax8)
    
    
    #return ModulationResults
    if (ModulationResults["BiggestModulationSign"]=="not modulated") or (ModulationResults["BiggestModulationSign"]=="not enough spikes") or (len(ModulationResults["BiggestModulationIndexes"])<3):
        print("no correlation possible")
        gs3.tight_layout(fig, rect=[0.35, 0, 1, 0.7],h_pad=0.33)
        
        return
    # Generate Correlations Plot
    ModulatedBins=ModulationResults["BiggestModulationIndexes"]
    FRateDuringModulatedEpoch=[]
    RSpeedDuringModulatedEpoch=[]
    RAccelDuringModulatedEpoch=[]
    DistanceRun=[]
    for x,r in enumerate(allDuration):
        if sum(spikeHist[x][ModulatedBins])>=0:

            #frate trial by trial
            FRateDuringModulatedEpoch.append(np.nanmean(spikeHist[x][ModulatedBins]))

            #running speed
            RSpeedDuringModulatedEpoch.append(np.nanmean(allSpeed[x][ModulatedBins]))

            #acceleration
            BinDuration=r/(len(allSpeed[x])-2*nSideBin)
            SpeedInModulatedBins=allSpeed[x][ModulatedBins]
            Acceleration=(SpeedInModulatedBins[-1]-SpeedInModulatedBins[0])/((sum(ModulatedBins)-1)*BinDuration) 
            RAccelDuringModulatedEpoch.append(Acceleration)

            #run distance

            TotalDistanceRun=sum(allSpeed[x][int(nSideBin):-int(nSideBin)]*BinDuration)
            DistanceRun.append(TotalDistanceRun)



    AllKinematicParameters=[RSpeedDuringModulatedEpoch,RAccelDuringModulatedEpoch,DistanceRun,allDuration]
    AllYLabels=["Running speed (cm/s)","Running accel $\mathregular{(cm/s^2)}$","Run distance (cm)","Run duration (s)"]


    subplotcoordinates=[[2,0],[2,1],[3,0],[3,1]]

    x=FRateDuringModulatedEpoch

    RValues=[]

    for count,values in enumerate(AllKinematicParameters):
        ax = fig.add_subplot(gs3[subplotcoordinates[count][0],subplotcoordinates[count][1]])

        y=AllKinematicParameters[count]
        fit = np.polyfit(x,y,1)
        fit_fn = np.poly1d(fit) 
        #plt.plot(FRateDuringModulatedEpoch,RSpeedDuringModulatedEpoch,'o')

        ax.plot(x,y, 'ro', x, fit_fn(x), '-k',linewidth=2)
        #ax.ylim(7,30)
        ax.set_xlabel("Firing rate (Hz)",fontsize=20,weight="bold")
        ax.set_ylabel(AllYLabels[count],fontsize=20,weight="bold")

        MinMaxForPlot=np.percentile(AllKinematicParameters[count],[2,98])
        ax.set_ylim(MinMaxForPlot)
        SpearManResults=stats.spearmanr(x,y)
        rvalue=str(round(SpearManResults[0],2));
        if SpearManResults[1]<0.0001:
            pvalue="p<0.0001"
        else:
            pvalue="p="+ str(round(SpearManResults[1],4))

        title="r=%s, %s"%(rvalue,pvalue)
        ax.set_title(title,fontsize=20,weight="bold")
        ax.tick_params(axis='both',which='major',labelsize=20,width=2) 
        RValues.append([SpearManResults[0],SpearManResults[1]])

    
    
    gs3.tight_layout(fig, rect=[0.35, 0, 1, 0.7],h_pad=0.33)
        
    return #ModulationDuringRun
    
    #x2Bis=plot_mean_breaks_firing_rate(data,shank,cluster,ax=ax2,align="trial end",removeBadTrials=True,minTime=-20)
    

In [None]:
#------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():    
    if data.hasBehavior and data.hasSpike:
        # below you can change shank and clu
        ##############################

        
        ##############################
        ModulationDuringRun=cluster_ratevskinematicplot(data,shank,cluster,group="Good",printoutput=True);