In [1]:
import numpy as np
import scipy
from scipy.interpolate import interp1d

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib import collections  as mc
%matplotlib inline

import glob
import os
import pickle
import datetime


if "__file__" not in dir():
    %run loadRat_documentation.ipynb
    %run loadRawSpike_documentation.ipynb


    
# INFO: all the default parameters for preprocessing
defaultParam={
    "binSize":0.25,
    "trialOffset":20., #max end of trial, in seconds (position will be cutted)
    "sigmaSmoothPosition":0.1,  #smooth the position
    #"sigmaSmoothPosition":0.33 for pavel dataType
    "sigmaSmoothSpeed":0.3, #smooth the speed
    "positionDiffRange": [2.,5.], #min and max differences allowed between two consecutive positions
                                  #min to correct start, max to correct jumps
    "pawFrequencyRange":[2.,10.],
    "startAnalysisParams":[10,0.2,0.5],
    "cameraToTreadmillDelay":2., #seconds, usual time between camera start and treadmill start
    "nbJumpMax" : 100., #if jumps>nbJumpMax, trial is badly tracked
    
    #parameter to detect end of trial (first position minima)
    "endTrial_backPos":55,  # minima is after the animal went once to the back (after first time position>backPos)
    "endTrial_frontPos":30, # minima's position is in front of treadmill (position[end]<frontPos)
    "endTrial_minTimeSec":4, # minima is after minTimeSec seconds (time[end]>minTimeSec)
    }

### Load Data
Load the preprocess data (corrected and binned position, speed, median position, ...)
 DEFAULT: load pickle if they exist, or create them
data=Data(root,animal,experiment,param=param)

 OPTION 1: do not save any new pickle file
data=Data(root,animal,experiment,param,saveAsPickle=False)

 OPTION 2: redo the preprocessing with param, even if the pickle already exist
 (to be sure everything is preprocess with the same parameters)

In [2]:
#run only if inside this notebook (do not execute if "%run this_notebook")
if "__file__" not in dir():
    root="/data/"
    animal="Rat034"
    experiment="Rat034_2015_04_28_10_08"
    
    param={
        "goalTime":7,
        "treadmillRange":[0,90],
        "maxTrialDuration":20,
        "interTrialDuration":10,
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "binSize":0.25,
    }  
    data=Data(root,animal,experiment,param,redoPreprocess=True)
    print(data.hasBehavior)
    #print(data.dataType)
    #print(data.treadmillSpeed)
    #print(data.experiment)
    #print(data.cameraSamplingRate)

Preprocessing behavior data...
No *.lickbreaktime file found!
Behavior data loaded from text files: Teresa data (.samplingrate)
Preprocessing done
Initializing features and masks: 100.0%.
Features and masks initialized.[K
Initializing features and masks: 100.0%.
Initializing statistics: 100.0%.
Statistics initialized.[K
Initializing statistics: 100.0%.
13:50:16 [I] Switched to channel group 2.
Features and masks initialized.[K
Initializing statistics: 100.0%.
Statistics initialized.[K
Initializing statistics: 100.0%.
13:50:16 [I] Switched to channel group 3.
Features and masks initialized.[K
Initializing statistics: 100.0%.
Statistics initialized.[K
Initializing statistics: 100.0%.
13:50:16 [I] Switched to channel group 4.
Features and masks initialized.[K
Initializing statistics: 100.0%.
Statistics initialized.[K
Initializing statistics: 100.0%.
Spike data loaded from raw files
True


In [None]:
#run only if inside this notebook (do not execute if "%run this_notebook")
if "__file__" not in dir():
    plt.plot(data.rawTime[5],data.rawPosition[5])
    #plt.plot(data.timeTreadmill[5],data.rawPosition[5])
    plt.plot(data.timeTreadmill[5],data.position[5])
    plt.plot(data.timeBin,data.positionBin[5])#centered
    plt.ylabel("position(cm)")
    plt.xlabel("time(s)")

### Display all attributes and their type  
Every attributes can be access with `data.attributeName`  

In [None]:
#run only if inside this notebook (do not execute if "%run this_notebook")
if "__file__" not in dir():
    data.describe()
    print(data.emptyAnalysisFiles)

### Shank and their clusters, per group

In [None]:
#run only if inside this notebook (do not execute if "%run this_notebook")
if "__file__" not in dir():
    # add a list of (shank,clu) in a group "testGroup"
    #data.add_cluster_group("testGroup",[(1,2),(1,3),(2,4)])
    if not data.hasSpike:
        print("no spike data")
    else:
        #display all groups for every shank
        for shank in data.channelGroupList:
            print("shank %s:"%shank)
            print(data.clusterGroup[shank])
        print(data.channelGroupList)

## Plot Behavioral data : TreadmillOn
###  Position, speed, acceleration binned

In [None]:
#run only if inside this notebook (do not execute if "%run this_notebook")
if "__file__" not in dir():
    trial=0
    plt.figure(figsize=(12,8))
    plt.subplot(322)
    plt.plot(data.timeBin,data.medianPosition);
    plt.title("Median of corrected position");

    plt.subplot(321)
    plt.plot(data.timeBin,data.positionBin[trial])
    plt.title("position smooth and binned for trial %s"%trial)

    plt.plot(data.timeBin,data.speedSmoothBin[trial],label="speed smooth")
    plt.plot(data.timeBin,data.accelerationOnSpeedSmoothBin[trial],label="acceleration")
    plt.legend();
    
    plt.subplot(323)
    for trial in data.positionAlignEnd:
        plt.plot(data.timeAlignEnd,data.positionAlignEnd[trial])
    plt.title("Position align on end (%s)"%len(data.positionAlignEnd))
    
    plt.subplot(324)
    for trial in data.trials:
        plt.plot(data.timeBin,data.positionBin[trial])
    plt.title("Position 0=treadmill start")
    
    plt.subplot(325)
    plt.plot(data.timeAlignEnd,data.medianPositionAlignEnd)
    plt.title("median align on end")
    
    plt.subplot(326)
    plt.plot(data.timeTreadmill[trial],data.position[trial])
    plt.title("position correct (smooth, no bin), aligned on treadmill start")

### Positions, percentile, correlation

  - `onlyGood (False/True)`: plot/compute only on good trials
  - `raw (False/True)`: use raw positions. If False, use the corrected smoothed position.

In [None]:
def get_positions_array_beginning(data,onlyGood=False,raw=False):
    '''
    Return all the position in one array, between startFrame (treadmill start) and mean goal time
    '''
    if raw:
        posDict=data.rawPosition
    else:
        posDict=data.position
    #the mean goal time, converted in number of frames    
    #size=np.ceil(np.nanmean(data.goalTime))*data.cameraSamplingRate 
    #put all positions in a 2D array
    size=np.ceil(np.nanmean(data.maxTrialDuration))*data.cameraSamplingRate
    allTraj=[]
    for trial in posDict:
        if onlyGood and (trial not in data.goodTrials):
            continue
        start=int(data.startFrame[trial])
        stop=int(data.startFrame[trial]+size)
        pos=posDict[trial][start:stop]
        #pad with nan at the end if too short (entrance time < mean goal time)
        if len(pos)<(size):
            pos=np.append(pos,[np.nan] * (size-len(pos)))
        allTraj.append(pos)
    allTraj=np.asarray(allTraj)
    return allTraj

#----------------------------------------------------------------------------------------------------------------------
def plot_positions(data,onlyGood=False,raw=False):
    '''
    Plot the positions (green=good trial, red=other), aligned on camera start
    Plot the percentile (25%, 50%, 75%) on the beginning (trial start to mean goal time)
    '''
    if raw:
        posDict=data.rawPosition
    else:
        posDict=data.position
    time=data.rawTime #align on camera
    for trial in posDict:
        color="green"
        if trial not in data.goodTrials:
            if onlyGood:
                continue
            color="red"
        plt.plot(time[trial],posDict[trial],color=color)
        
    #Get the positions between trial start and mean goal time
    allTraj=get_positions_array_beginning(data,onlyGood=onlyGood,raw=raw)
    #Check that the array is not empty
    if allTraj.shape[0]==0:
        if onlyGood:
            plt.title("0 Good trials")
        plt.title("No positions")
        return False

    #plot the percentiles
    
    trajP=np.nanpercentile(allTraj,[25, 50, 75],axis=0)
    time=(np.arange(allTraj.shape[1])/data.cameraSamplingRate)+data.cameraToTreadmillDelay
    plt.plot(time,trajP[0],"b--",lw=2)
    plt.plot(time,trajP[1],"b-",lw=2)
    plt.plot(time,trajP[2],"b--",lw=2)
    #title, labels, grid
    plt.ylabel("X Position (cm)")
    plt.xlabel("Time (s) relative to camera start")
    #plt.xlim([0,max(data.entranceTime)])
    plt.xlim([0,16])
    plt.grid()
    plt.axvspan(time[0],time[-1],alpha=0.2,color="grey")
    title=""
    if raw:
        title="Raw "
    if onlyGood:
        title+="Positions of %s good trials"%(len(data.goodTrials))
    else:
        title+="Positions, %s good / %s trials"%(len(data.goodTrials),len(data.trials))
    plt.title(title)
    return trajP

#----------------------------------------------------------------------------------------------------------------------
def plot_correlation_position(data,onlyGood=False,raw=False):
    '''
    Compute the correlation of the position between trial start and mean goal time
    Plot the matrix of correlation coefficients
    Returns the median of coefficients
    '''
    import pandas as pd
    allTraj=get_positions_array_beginning(data,onlyGood=onlyGood,raw=raw)
    df = pd.DataFrame(allTraj.transpose())

    if allTraj.shape[0]<3:
        title="Not enough trials"
        med=np.nan
    else:
        corMatrix=np.array(df.corr())
        pp=plt.pcolor(corMatrix,vmin=0,vmax=1,cmap="Reds")
        plt.colorbar(pp)
        plt.xlim([0,corMatrix.shape[0]])
        plt.ylim([0,corMatrix.shape[1]])
        #median of upper triangle of matrix
        coef=corMatrix[np.triu_indices(corMatrix.shape[0],1)]
        med=np.nanmedian(coef)
        maxSecond=allTraj.shape[1]/float(data.cameraSamplingRate)
        #title of the plot
        title=""
        if onlyGood:
            title="Good "
        if raw:
            title+="Raw "
        title+='position correlation up to %ss, median r= %.2f'%(maxSecond,med)        
    plt.title(title)
    return med
     
#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    plt.figure(figsize=(15,10))
    plt.subplot(221)
    plot_positions(data,onlyGood=False)
    plt.subplot(222)
    plot_correlation_position(data,onlyGood=False)
    plt.subplot(223)
    plot_positions(data,raw=True)
    plt.subplot(224)
    plot_correlation_position(data,onlyGood=False,raw=True)
    
    #path=os.path.join(root,"Analysis")
    #pathFig=os.path.join(path,"Borrar_Rat051_AFTERM1_LESION.svg")
    #plt.savefig(pathFig,format="svg")
    
    

### Position align on end and correlation

The end of each trial trajectory is detected during preprocessing with the parameters:

 - "endTrial_backPos":55,  minima is after the animal went once to the back (after first time position>backPos)
 - "endTrial_frontPos":30,  minima's position is in front of treadmill (position[end]\<frontPos)
 - "endTrial_minTimeSec":4, minima is after minTimeSec seconds (time[end]>minTimeSec)

The detection is done on the corrected smoothed position.  
`data.indexEndTrial` contains the detected index for each trial. If the end could not be detected, `indexEndTrial[trial]=None`.  
For more details see Preprocess_treadmillOn_detailed_documentation.ipnb

  - `minTime (int)`: time in seconds to consider, relative to detected end  
    `minTime= -5` means "keep 5 seconds before the end", for each trial
  - `onlyGood (True/False)`: whether to keep only good trials
  - `raw (True/False)`: whether to use raw positions

In [None]:
def get_positions_array_end(data,minTime=-9,onlyGood=False,raw=False):
    '''
    Returns array of position, align on detected end, between minTime and zero (0=end)
    The dectect end can be None, if so the trial is skipped
    If one position is too short (detected end is before minTime), 
      the position is pad at the beginning with a plateau 
    '''
    if raw:
        posDict=data.rawPosition
    else:
        posDict=data.position
    #number of frames to keep 
    cs=data.cameraSamplingRate
    size=int(abs(minTime*cs))
    #put all positions in a 2D array
    allTraj=[]
    for trial in posDict:
        if onlyGood and (trial not in data.goodTrials):
            continue
        #index where the trajectory ends
        endIndex=data.indexEndTrial[trial]
        if isNone(endIndex):
            #no end was detected, skipp trial
            continue 
        #Position is cut between (end -minTime) and end.
        #At minima, the cut should begin at startFrame
        startIndex=max(data.startFrame[trial],int(endIndex+(minTime*cs)))
        pos=posDict[trial][startIndex:endIndex]
        #if too short, pad the start with a plateau (Nan would cause trouble to compute speed later)
        if len(pos)<(size):
            pos=np.append([pos[0]] * (size-len(pos)),pos)
        allTraj.append(pos)
    allTraj=np.asarray(allTraj)
    return allTraj

#----------------------------------------------------------------------------------------------------------------------
def plot_position_align_end(data,minTime=-9,xUpLimit=5,onlyGood=False):
    '''
    [minTime to 0] is the time range to measure the correlation, with 0=detected end
    xUpLimit is the x axis max limit
    onlyGood is whether to consider only good trials
    nb: data.timeEndTrial is computed with data.position
    '''
    cs=data.cameraSamplingRate
    for trial in data.position:
        end=data.timeEndTrial[trial]
        indexStart=np.argmax(data.position[trial])
        indexEnd=data.indexEndTrial[trial]
        if isNone(end):
            continue
        color="green"
        if trial not in data.goodTrials:
            if onlyGood:
                continue
            color="red"
        plt.plot(data.timeTreadmill[trial]-end,data.position[trial],color=color)
        plt.plot(data.timeTreadmill[trial][indexStart:indexEnd]-end,data.position[trial][indexStart:indexEnd],color="b")
    plt.xlim([minTime,xUpLimit])
    plt.ylabel("Position (cm/s)")
    plt.xlabel("Time relative to detected end (s)")
    plt.axvline(0,color="b",ls="--")

    w="correct " if onlyGood else ""
    
    allTraj=get_positions_array_end(data,minTime=minTime,onlyGood=onlyGood,raw=False)
    if allTraj.shape[0]<3:
        plt.title("End of %strials"%w)
        return None
    corMatrix=np.corrcoef(allTraj)
    coeff=corMatrix[np.triu_indices(corMatrix.shape[0],1)]
    med=np.nanmedian(coeff)
    plt.title("End of %strials, median last %ss r=%.2f" %(w,abs(minTime),med))    
    return med

#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    plt.figure(figsize=(20,10))
    plt.subplot(221)
    plot_position_align_end(data,onlyGood=False)
    plt.subplot(222)
    plot_position_align_end(data,onlyGood=True)
    
    #path=os.path.join(root,"Analysis")
    #pathFig=os.path.join(path,"Borrar_Rat041_AFTERM1_LESION.svg")
    #plt.savefig(pathFig,format="svg")

### Speed

Treadmill speed has to taken into account to compute rat speed

` speed= treadmill speed - np.diff(position)*camera sampling rate `

positive speed= rat moving towards the front of the treadmill  
zero speed= rat not moving, dragged towards the end by the treadmill  
negative speed= rat moving towards the end, faster than the treadmill  

When the rat cross the beam, does the treadmill stop ? No it depends on its entrance time, the treadmill does not stop when the trial is bad (entrance time< goal time)

### Acceleration

` acceleration= np.diff(speed)*camera sampling rate `

In [None]:
#---------------------------------------------------
def get_position_treadmillOn(data,trial):
    '''
    Compute position between treadmill start and treadmill stop, for one trial
    '''
    pos=data.position[trial]
    time=data.timeTreadmill[trial]
    start=data.startFrame[trial]   
    stoptime=time[-1]
    cs=data.cameraSamplingRate
    if data.entranceTime[trial]>=data.goalTime[trial]:
        stoptime = data.entranceTime[trial]
    stopList=np.where(time>=stoptime)[0]
    try:
        stop=stopList[0]
    except IndexError:
        stop=-1
    if data.dataType=="behav_param":
        stop=data.stopFrame[trial] 
    posCutted=pos[start:stop]
    timeCutted=time[start:stop]
    return posCutted,timeCutted
#---------------------------------------------------
def speed_from_pos(data,position,trial,sigmaSpeed=0.3):
    '''
    speed= (diff pos)* camera sampling rate -  treadmill speed 
    Speed can't be inferior to "- treadmill speed" (or even 0 ?)
    positive speed= rat moving towards the front of the treadmill
    negative speed= rat moving towards the end, faster than the treadmill
    '''
    cs=data.cameraSamplingRate
    tr=data.treadmillSpeed[trial]
    speed= tr - np.diff(position)*cs 
    speed[speed<(-tr)]=-tr
    speed=smooth(speed,sigmaSpeed*cs)
    return speed
#----------------------------------------------------------------------------------------------------------------------
def get_speed_treadmillON(data,trial,sigmaSpeed=0.3):
    '''
    Compute speed between treadmill start and treadmill stop, for one trial
    '''
    cs=data.cameraSamplingRate
    posCutted,timeCutted =  get_position_treadmillOn(data,trial)
    timeSpeed=(timeCutted[1:]+timeCutted[:-1])/2.0
    speed=speed_from_pos(data,posCutted,trial,sigmaSpeed)
    return speed,timeSpeed
#----------------------------------------------------------------------------------------------------------------------
def get_speed_end(data,trial,sigmaSpeed=0.3,minTime=-9):
    '''
    speed in the last 9 seconds (minTime=-9) (zero=detected end), for one trial
    '''
    cs=data.cameraSamplingRate
    endIndex=data.indexEndTrial[trial]
    if isNone(endIndex):
        return False, False
    pos=data.position[trial]
    cs=data.cameraSamplingRate
    startIndex=int(max(data.startFrame[trial],endIndex+(minTime*cs)))
    posCutted=pos[startIndex:endIndex]
    time=data.timeTreadmill[trial]-data.timeEndTrial[trial]
    timeCutted=time[startIndex:endIndex]
    timeSpeed=(timeCutted[1:]+timeCutted[:-1])/2.0
    speed=speed_from_pos(data,posCutted,trial,sigmaSpeed)
    return speed, timeSpeed
#----------------------------------------------------------------------------------------------------------------------
def plot_position_speed_acceleration(data,trial,sigmaSpeed=0.3):
    #plot position
    cs=data.cameraSamplingRate
    pos,time=get_position_treadmillOn(data,trial)
    plt.plot(time,pos,"k-",label="position")
    plt.ylabel("position (cm)")
    plt.xlabel("time (s)")
    # plot speed
   
    speed,timeSpeed=get_speed_treadmillON(data,trial,sigmaSpeed) 
    ax2=plt.gca().twinx()
    ax2.plot(timeSpeed,speed,color="red",label="speed")
    ax2.yaxis.label.set_color('red')
    ax2.tick_params(axis='y', colors='red') 
    ax2.set_ylabel("speed")
    plt.xlim([time[0],time[-1]])
    #plot acceleration
    timeAcceleration=timeSpeed[1:]
    cs=data.cameraSamplingRate
    acceleration= np.diff(speed)*cs
    ax2.plot(timeAcceleration,acceleration,color='blue',label="acceleration")
    
    #vertical and horizontal lines (treadmill speed, entrance time, startFrame)
    ax2.axhline(data.treadmillSpeed[trial],ls="--",color="g")
    ax2.axhline(0,ls="--",c="blue")
    ax2.axhline(-data.treadmillSpeed[trial],ls="--",color="g")
    plt.axvline(data.entranceTime[trial],color="purple")
    plt.axvline(data.maxTrialDuration[trial])
    plt.axvline(0,color="purple")

#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    plt.figure(figsize=(10,5))
    trial =5
    plt.subplot(211)
    plt.title("Treadmill On, kinematics with raw time binning ")
    plot_position_speed_acceleration(data,trial,sigmaSpeed=0.3)
    plt.xlim(0,25)
    plt.subplot(212)
    plt.title("Treadmill On, Binned kinematics timeBin=0.25")
    plt.plot(data.timeBin,data.positionBin[trial],"k",label="position")
    plt.plot(data.timeBin,data.speedSmoothBin[trial],"r",label="speed")
    plt.plot(data.timeBin,data.accelerationOnSpeedSmoothBin[trial],"b",label="acceleration")
    plt.legend(loc="best")
    plt.xlim(0,25)

In [None]:
def plot_kinematic_histo(data,kinematic="speedSmoothBin",binSize=1,title="Distribution of speed",xlablel="speed(cm/s)",onlyGood=False):
    k=np.asarray([])
    for trial in data.trials:
        if (not onlyGood) or (trial in data.goodTrials):
            k=np.append(k,data.__dict__[kinematic][trial])          
    #check that the vector is not empty
    if len(k)==0:
        return np.nan   
    #histogram between min and max speed, with binSize
    mink=np.nanmin(k)
    maxk=np.nanmax(k)
    hist,bins=np.histogram(k,np.arange(mink,maxk+1,binSize),density=True)
    #plot histograme
    centers=(bins[:-1]+bins[1:])/2.0
    plt.bar(centers,hist,width=binSize,color='b')
    #title, labels
    if onlyGood:
        title+=" - %s good trials"%(len(data.goodTrials))
    plt.title(title)
    plt.xlabel(xlablel+", binSize=%s"%binSize)
    plt.ylabel("Density")
    
    #vertical lines for 0 and treadmill speed
    plt.axvline(0,color="r",ls="--")
    if kinematic=="speedSmoothBin":
        plt.axvline(np.nanmean(data.treadmillSpeed),color="c",ls="--")
        #plt.xlim([-20,70])
    plt.ylim([0,0.14])
    return hist
#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    hist=plot_kinematic_histo(data,kinematic="speedSmoothBin",title="Distribution of speed",xlablel="speed(cm/s)")
    plt.figure()
    hist=plot_kinematic_histo(data,kinematic="accelerationOnSpeedSmoothBin",title="Distribution of acceleration",xlablel="acceleration(cm/s²)")

In [None]:
def plot_entrance_time(data,smoothSigma=2):
    #Get the entrance time for the correctly tracked trials
    entrance=[data.entranceTime[trial] for trial in data.trials]
    #Colors for good/bad trials
    colors=["green" if t in data.goodTrials else "red" for t in data.trials]
    #Scatter plot entrance time/ trial
    plt.scatter(data.realTrials,entrance,color=colors,marker="o")
    #Smooth line entrane time/ trial
    smoothEntrance=smooth(entrance,smoothSigma)
    plt.plot(data.realTrials,smoothEntrance,"b--",linewidth=2)
    #Horizontal line: maxTrial duration
    plt.plot(data.realTrials,[data.maxTrialDuration[t] for t in data.trials],"k--")
    #mean goal time
    goal=np.nanmean(data.goalTime)
    plt.axhspan(0,goal,alpha=0.2,color="grey")
    #limits, labels, title
    maxEntrance=max(goal,max(entrance))
    plt.ylim([0,maxEntrance+1])
    plt.xlim([0,data.nTrial+1])
    plt.grid()
    plt.ylabel('Entrance Time (s), mean goal=%s'%goal)
    plt.xlabel('Trial Number')
    plt.title("Entrance times, %s good trials"%(len(data.goodTrials)))

#----------------------------------------------------------------------------------------------------------------------
def plot_correlation_entrance_time(data,removeTimeout=False):    
    #entrance time with or without timeout
    if removeTimeout:
        entrance=[e for e,m in zip(data.entranceTime,data.maxTrialDuration) if e<m]
    else:
        entrance=data.entranceTime
    #remove nan/None
    entrance=[e for e in entrance if not isNone(e)]
    #check if not empty
    if not entrance:
        plt.title("No entrance times")
        return np.nan
    #plot entrance time n/entrance time n+1
    plt.plot(entrance[:-1],entrance[1:],"ko")
    plt.xlabel("Entrance time trial n")
    plt.ylabel("Entrance time trial n+1")
    #limits
    goal=np.nanmean(data.goalTime)
    m=max(goal,max(entrance))
    plt.ylim([0,m])
    plt.xlim([0,m])
    #diagonal and xspan
    plt.plot([0,m],[0,m],"k--")
    plt.axhspan(0,goal,alpha=0.2,color="grey")
    plt.axvspan(0,goal,alpha=0.2,color="grey")
    #spearman
    spearMan=scipy.stats.spearmanr(entrance[:-1],entrance[1:])
    if spearMan[1]<0.001:
        pvalue='p<0.001'
    else:
        pvalue='p= %.3f'%spearMan[1]
    plt.title("Correlation entrance times, Spearman r= %.2f, %s"%(spearMan[0],pvalue));  
    return spearMan[0]
    
#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    plt.figure(figsize=(15,5))
    plt.subplot(121)
    plot_entrance_time(data)
    plt.subplot(122)
    plot_correlation_entrance_time(data)
    #plt.xlim([0,25]),plt.ylim([0,25])

### General plot behavior

In [None]:
def plot_session_behavior(data):
    plt.figure(figsize=(20,15))
    
    plt.subplot(331)
    plot_positions(data)
    #plt.xlim(0,12)
    
    plt.subplot(332)
    plot_position_align_end(data,onlyGood=True)
    
    plt.subplot(333)
    plot_entrance_time(data)
    
    plt.subplot(334)
    plot_correlation_position(data,onlyGood=False)

    plt.subplot(337)
    plot_correlation_entrance_time(data)
    
    plt.subplot(335)
    hist=plot_kinematic_histo(data,kinematic="speedSmoothBin",title="Distribution of speed",xlablel="speed(cm/s)")
    plt.subplot(338)
    hist=plot_kinematic_histo(data,kinematic="speedSmoothBin",title="Distribution of speed",xlablel="speed(cm/s)",onlyGood=True)
    
    
    plt.subplot(336)
    hist=plot_kinematic_histo(data,kinematic="accelerationOnSpeedSmoothBin",title="Distribution of acceleration",xlablel="acceleration(cm/s²)")
    
    plt.subplot(339)
    hist=plot_kinematic_histo(data,kinematic="accelerationOnSpeedSmoothBin",title="Distribution of acceleration",xlablel="acceleration(cm/s²)",onlyGood=True)
    
    title=data.experiment+", day %s, %s trials"%(data.daySinceStart,data.nTrial)
    plt.suptitle(title,fontsize=16)
    
#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():  
    plot_session_behavior(data)  
    
    #path=os.path.join(root,"Analysis")
    #pathFig=os.path.join(path,"Rat051_2016_05_06_16_04.png")
    #plt.savefig(pathFig)
    

# Speed investigation

In [None]:
def plot_ratio_speed(data,minTime=-9,binSize=1,smoothSpeed=0.3,onlyGood=True,highSpeed=None):
    '''
    Position aligned on detected end
    Last 9 seconds (minTime=-9)
    Speed= treadmill speed - difference(position) * camera sampling rate
    Plot histogram of speeds
    Returns a ratio
    '''
    if highSpeed==None:
        treadmillSpeed = np.nanmean(data.treadmillSpeed)
        highSpeed = 2*treadmillSpeed
    speed=np.asarray([])
    for trial in data.trials:
        end=data.timeEndTrial[trial]
        if isNone(end):
            continue
        if (not onlyGood) or (trial in data.goodTrials):
            s,t=get_speed_end(data,trial,sigmaSpeed=smoothSpeed,minTime=minTime)
            speed=np.append(speed,s)    
    #check that the vector is not empty
    if len(speed)==0:
        return np.nan   
    #histogram between min and max speed, with binSize
    minSpeed=np.floor(np.nanmin(speed))
    maxSpeed=np.ceil(np.nanmax(speed))
    hist,bins=np.histogram(speed,np.arange(minSpeed,maxSpeed,binSize),density=True)    
    #plot histograme
    centers=(bins[:-1]+bins[1:])/2.0
    plt.bar(centers,hist,width=binSize)  
    #compute ratio
    
    highSpeed=hist[np.logical_and(centers>=highSpeed*binSize,centers<=maxSpeed*binSize)].sum()
    TotalSpeed=hist[np.logical_and(centers>=minSpeed*binSize,centers<=maxSpeed*binSize)].sum()
    ratio=highSpeed/TotalSpeed

    #title, labels
    title="highSpeed/totalSpeed %.d sec trajectory= %.2f"%(-minTime,ratio)
    if onlyGood:
        title+=" - %s good trials"%(len(data.goodTrials))
    plt.title(title)
    plt.xlabel("End Speeds (cm/s), binSize=%s"%binSize)
    plt.ylabel("Density")
    #fix a range for x, because artefacts can give outliers values
    plt.xlim([-50,100])
    #vertical lines for 0 and treadmill speed
    plt.axvline(0,color="r",ls="--")
    plt.axvline(np.nanmean(data.treadmillSpeed),color="r",ls="--")
    return ratio

#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    plot_ratio_speed(data,binSize=1,smoothSpeed=0.3, onlyGood=True,highSpeed=None)
    print(data.treadmillSpeed[0])

In [None]:
def plot_proportionTimeRunningForward_TrajEnd(data,binSize=1,minSpeed=-20,maxSpeed=120,smoothSpeed=0.3,onlyGood=False,highSpeed=None):
    '''
    Position aligned on detected end
    Speed= treadmill speed - difference(position) * camera sampling rate
    Histograms are computed trial by trial then averaged
    Plot histogram of speeds
    Returns the cumulative of high speed occurencies (speed > treadmillspeed *2 )
    '''
    bins = np.arange(minSpeed,maxSpeed,binSize)
    if onlyGood==True:
        Hist = np.zeros((len(data.goodTrials),len(bins)-1))
    else:        
        Hist = np.zeros((len(data.trials),len(bins)-1))
    k = 0
    for trial in data.trials:
        if (not onlyGood) or (trial in data.goodTrials):
            s,t=get_speed_treadmillON(data,trial,sigmaSpeed=smoothSpeed)
            n,_= np.histogram(s,bins,density=True)   
            Hist[k,:] = n
            k+=1
    avgHist = np.mean(Hist,axis=0)
    stdHist = np.std(Hist,axis=0)
    #check that the vector is not empty
    if np.sum(avgHist)==0:
        return np.nan   
    centers=(bins[:-1]+bins[1:])/2.0
    
    #area under curve
    if highSpeed==None:
        treadmillSpeed = np.nanmean(data.treadmillSpeed)
        #highSpeed = 2*treadmillSpeed
        highSpeed = treadmillSpeed
    indexList = np.where(centers > (highSpeed))[0]
    index = indexList[0] #first index where speed > treadmillspeed *2
    areaUnderCurve =  np.sum(avgHist[index:]) * binSize
    print (areaUnderCurve)
    #plot
    plt.plot(centers,avgHist,"k-")
    plt.plot(centers,avgHist+stdHist,"k--")
    plt.plot(centers,avgHist-stdHist,"k--")
    plt.axvline(0,color="r",ls="--")
    plt.axvline(treadmillSpeed, color="r",ls="--")
    plt.ylim([-0.05,0.35])
    plt.ylabel("Mean Density over Trials")
    plt.xlabel("Speed(cm/s)")
    plt.fill_between(centers[index:],0*avgHist[index:], avgHist[index:], facecolor='green', interpolate=True)
    title="%.2f proportion time running forward"%(areaUnderCurve)
    if onlyGood:
        title+=" - %s good trials"%(len(data.goodTrials))
    plt.title(title)
    return areaUnderCurve 

#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    plot_proportionTimeRunningForward_TrajEnd(data,binSize=1,smoothSpeed=0.3, onlyGood=False)
    #save figure
    path=os.path.join(root,"Analysis")
    pathFig=os.path.join(path,"Borrar_Rat081_last_time.svg")
    plt.savefig(pathFig,format="svg")
    
    #path=os.path.join(root,"Analysis")
    #pathFig=os.path.join(path,"Rat051_2016_05_06_16_04.png")
    #plt.savefig(pathFig)

In [None]:
def plot_mean_forwardSpeed(data,smoothSpeed=0.3, onlyGood=True):
    '''
    plot mean forwardSpeed 
    '''
    ValuesSessionRatio=[]
    treadmillSpeed = np.nanmean(data.treadmillSpeed)
    print(treadmillSpeed)
    for trial in data.trials:
        if (not onlyGood) or (trial in data.goodTrials):
            speed,time=get_speed_treadmillON(data,trial,sigmaSpeed=smoothSpeed)
            MeanSpeed=np.nanmean(speed[speed>treadmillSpeed])
            ratio=MeanSpeed
            ValuesSessionRatio.append(ratio)

    SessionRatio=np.nanmean(ValuesSessionRatio)
    print(SessionRatio)
    #plot
    plt.plot(ValuesSessionRatio,"ko")
    plt.ylabel("Forward Speed")
    plt.xlabel("trials over session")
    plt.ylim([1,60])
    
    
    return SessionRatio

#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    x = plot_mean_forwardSpeed(data,smoothSpeed=0.3,onlyGood=False)
    #plt.close()

In [None]:
def plot_Tortuosity(data, onlyGood=True):
    
    
    cs=data.cameraSamplingRate
    ValuesSessionTortuosity=[]
    ValuesSessionStraightSpeed=[]
    
    for trial in data.position:
        end=data.timeEndTrial[trial]
        #indexStart=np.argmax(data.position[trial])
        indexEnd=data.indexEndTrial[trial]
        if isNone(end):
            continue
        indexStart=np.argmax(data.position[trial][0:indexEnd])
        PositionStart=data.position[trial][indexStart] 
        PositionEnd=data.position[trial][indexEnd]
        PositionPath=data.position[trial][indexStart:indexEnd]
        RunDistance=np.sum(np.abs(data.treadmillSpeed[trial]*data.binSize-np.diff(PositionPath)))
        straightDistance=np.abs(PositionStart-PositionEnd)
        tortuosity=RunDistance/straightDistance#((indexStart-indexEnd)/25)/(PositionStart-PositionEnd)
        straightSpeed=np.abs(straightDistance/((indexStart-indexEnd)/cs))
        #if trial==1:
            #print(straightDistance)
            #print(indexStart)
            #print(indexEnd)
            #plt.plot(data.position[trial][indexStart:indexEnd])
        ValuesSessionTortuosity.append(tortuosity)
        ValuesSessionStraightSpeed.append(straightSpeed)
        
    SessionTortuosity=np.nanmean(ValuesSessionTortuosity)
    SessionStraightSpeed=np.nanmean(ValuesSessionStraightSpeed)
    print(SessionTortuosity)
    
    plt.figure(figsize=(15,5))
    
    plt.subplot(121)
    plt.plot(ValuesSessionTortuosity,"ko")
    plt.ylabel("Tortuosity")
    plt.xlabel("trials over session")
    plt.ylim([0,10])
    
    print(SessionStraightSpeed)
    #plot
    plt.subplot(122)
    plt.plot(ValuesSessionStraightSpeed,"ko")
    plt.ylabel("Straight Speed (cm/sec)")
    plt.xlabel("trials over session")
    plt.ylim([0,50])
    
    
    return SessionTortuosity,SessionStraightSpeed

#----------------------------------------------------------------------------------------------------------------------

if "__file__" not in dir():
    plt.figure(figsize=(15,5))
    plt.subplot(121)
    plot_position_align_end(data,onlyGood=False)
    plt.subplot(122)
    x = plot_Tortuosity(data,onlyGood=False)
    #plt.xlim([0,25]),plt.ylim([0,25])


# Spiking Activity : Needs Revision

### Spike Raster

In [None]:
def plot_raster(data,shank,cluster,minTime=-5,maxTime=20,firstTrial=15,ax=None,legend=False,alignEnd=False):
    if ax is None:
        ax=plt.gca()
        
    ax.set_ylim(firstTrial,data.nTrial)
    ax.set_ylabel('trial number',fontsize=12)

    if alignEnd:
        ax.axvline(x=0,linestyle="--",color="k",label="Detected end (0)")
        ax.set_xlabel('time relative to detected end (s)',fontsize=12)
        ax.set_title('Shank %s Cluster %s, aligned on end'%(shank,cluster),fontsize=15)
    else:
        ax.axvline(x=0,linestyle="--",color="k",label="Treadmill start (0)")
        ax.set_xlabel('time relative to treadmill start (s)',fontsize=12)
        ax.set_title('Shank %s Cluster %s'%(shank,cluster),fontsize=15)

    cluSpikeTime=data.spikeTime[shank][cluster]
    lines=[]
    for trial in data.trials:
        zero=data.treadmillStartTime[trial]
        if alignEnd:
            end=data.timeEndTrial[trial]
            if isNone(end):
                continue
            zero=zero+end
        start=zero+minTime
        stop=zero+maxTime
        trialSpikeTime=cluSpikeTime[(cluSpikeTime>=start)&(cluSpikeTime<=stop)]
        for spikeTime in trialSpikeTime:
            alignTime=spikeTime-zero
            lines.append([(alignTime,trial+1),(alignTime,trial+1.9)])        
    lc= mc.LineCollection(lines,colors="blue",label="spikes")
    ax.add_collection(lc)
        
    ax.set_xlim([minTime,maxTime])
    
    ax2=ax.twinx()
    ax2.axis("off")
    
    #Plot median position and entrance time
    if alignEnd:
        trials=[trial for trial in data.trials if not isNone(data.timeEndTrial[trial])]
        entrance=[data.entranceTime[trial]-data.timeEndTrial[trial] for trial in trials]
        trialAxis=[trial+1 for trial in trials]
        ax2.plot(data.timeAlignEnd,data.medianPositionAlignEnd,'g-',linewidth=2,label="Median position")
        ax.plot(entrance,trialAxis,'rx',label="Entrance times");
    else:
        entrance=[data.entranceTime[trial] for trial in data.trials]
        trialAxis=np.arange(1,data.nTrial+1)+0.4
        ax2.plot(data.timeBin,data.medianPosition,'g-',linewidth=2,label="Median position");
        ax.plot(entrance,data.realTrials,'rx',label="Entrance times");
            
    if legend:
        ax.legend(loc='center left', bbox_to_anchor=(1.1, 0.5));
        
#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    print(data.clusterGroup)
    SHANK=1
    CLU=12
    plt.figure(figsize=(10,10))
    plt.subplot(211)
    plot_raster(data,SHANK,CLU,legend=True)
    plt.subplot(212)
    plot_raster(data,SHANK,CLU,legend=True,minTime=-10,maxTime=10,alignEnd=True)

### Raster and mean with firing rate

`pcolormesh(x,y,res)`: x and y are coordinates of rectangles. Each rectangle is colored according to the value in res

To display `res` as 3x3 cells, you need:

      res        x        y
               |0 1 2 3|  |0|
    |0 1 2|               |1|
    |4 0 0|               |2|
    |5 3 1|               |3|

In [None]:
def compute_firing_rate(data,shank,cluster,binSize,minTime,maxTime):
    timeBin=np.arange(minTime,maxTime+binSize-maxTime%binSize,binSize)
    firingRate={}
    spikeTime=data.spikeTime[shank][cluster]
    for trial in data.trials:
        zero=data.treadmillStartTime[trial]
        start=zero+minTime
        stop=zero+maxTime
        trialSpikeTime=spikeTime[(spikeTime>=start)&(spikeTime<=stop)]
        alignedTime=trialSpikeTime-zero
        hist,bins=np.histogram(alignedTime,timeBin)
        firingRate[trial]=hist/float(binSize)
    center=(timeBin[:-1]+timeBin[1:])/2
    return firingRate,center,timeBin

#----------------------------------------------------------------------------------------------------------------------
def plot_raster_firing_rate(data,shank,cluster,binSize=0.25,minTime=-5,maxTime=20,ax=None):
    if ax is None:
        ax=plt.gca()
        
    firingRate,center,timeBin=compute_firing_rate(data,shank,cluster,binSize,minTime,maxTime)
    res=np.asarray(list(firingRate.values()))
    
    y= [t+0.5 for t in range(1,len(data.trials)+2)]
    resColor=ax.pcolormesh(timeBin,y,res,cmap="Greys")

    box = ax.get_position()
    axColor = plt.axes([box.x0*1.02 + box.width * 1.02, box.y0, 0.01, box.height])
    plt.colorbar(resColor, cax=axColor, label="firing rate")

    ax2=ax.twinx()
    ax2.plot(data.timeBin,data.medianPosition,'g-',linewidth=2,label="Median position");
    ax2.axis("off")
    
    entrance=[data.entranceTime[trial-1] for trial in data.realTrials]
    ax.plot(entrance,data.realTrials,"rx")
    
    ax.set_xlim([timeBin[0],timeBin[-1]])
    ax.set_ylim([data.realTrials[0],data.realTrials[-1]])
    ax.set_xlabel("Time relative to treadmill start")
    ax.set_ylabel("Trial Number")
    ax.set_title("Cluster %s, firing rate, binSize=%s s"%(cluster,data.binSize))

#----------------------------------------------------------------------------------------------------------------------
def plot_mean_firing_rate(data,shank,cluster,binSize=0.25,minTime=-5,maxTime=20,sigma=1,ax=None):
    if ax is None:
        ax=plt.gca()
    
    firingRate,center,timeBin=compute_firing_rate(data,shank,cluster,binSize,minTime,maxTime)
    spikeCount=np.mean(list(firingRate.values()),axis=0)
    
    import scipy.ndimage as scImage
    smoothSpikeCount=scImage.filters.gaussian_filter1d(spikeCount, sigma) 
    
    ax.plot(center,smoothSpikeCount);
    ax.set_xlabel("Time relative to treadmill start")
    ax.set_ylabel("Mean firing rate")
    ax.set_title("Cluster %s, mean firing rate + gaussian (sigma=%s)"%(cluster,sigma))  

#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    SHANK=1
    CLU=3

    plt.figure(figsize=(8,10))
    plt.subplot(211)
    plot_raster_firing_rate(data,SHANK,CLU,binSize=0.1)

    plt.subplot(212)
    plot_mean_firing_rate(data,SHANK,CLU,binSize=0.1,sigma=1)

### Raster trial by trial

In [None]:
def plot_raster_one_trial(data,shank,cluster,trial,minTime=0,maxTime=20,ax=None):
    '''
    Position for one trial + spike times for one cluster
    Paw=True: plot paw position
    A trial time range is [treadmillStart+minTime, treadmillStart+maxTime]
    times are aligned on treadmillStart
    '''
    if ax is None:
        ax=plt.gca()
    if trial in data.trialNotTracked:
        print("Trial not tracked")
        return

    cluData=data.spikeTime[shank][cluster]

    zero=data.treadmillStartTime[trial]
    start=zero+minTime
    stop=zero+maxTime
    trialSpikeTime=cluData[(cluData>start)&(cluData<stop)]-zero
   
    ax2=ax.twinx()
    ax2.axis("off")
    lines=[]
    for spike in trialSpikeTime:
        lines.append([(spike,0),(spike,1)])
    lc= mc.LineCollection(lines,colors="blue",label="spike",linestyle="--",linewidth=0.5)
    ax2.add_collection(lc)
    nbSpike=len(trialSpikeTime)

    ax.plot(data.entranceTime[trial],5,'rx',zorder=10);
    ax.plot(data.timeTreadmill[trial],data.position[trial],'g-',linewidth=3,zorder=9);

    ax.set_ylim(data.treadmillRange)
    ax.set_xlim([minTime,maxTime])
    ax.set_title('Shank %s Cluster %s, trial %s (%s spikes)'%(shank,cluster,trial,nbSpike));

#----------------------------------------------------------------------------------------------------------------------
def plot_raster_trial_by_trial(data,shank,cluster,group="not specified",minTime=0,maxTime=20,paw=False):
    '''
    -Raster for all trials with median position (firing rate if nbSpike>20000)
    -One plot for every trial (plot_raster_one_trial: position + spike time)
    Paw=True: plot paw position
    A trial time range is [treadmillStart+minTime, treadmillStart+maxTime]
    times are aligned on treadmillStart
    '''
    nbcol=3 #nbcol=2 will result in Value Error (plot too big) 
    nbLines=int(np.ceil(data.nTrial*1.0/nbcol))
    gs=gridspec.GridSpec(nbLines+1,nbcol,hspace=0.5)

    plt.figure(figsize=(15,nbLines*(15.0/nbcol)))
    ax=plt.subplot(gs[0,0:2])
    if len(data.spikeTime[shank][cluster])<20000:
        plot_raster(data,shank,cluster)
    else:
        plot_raster_firing_rate(data,shank,cluster)
    ax.text(1.2,0.5,"Shank %s, Cluster %s \n(group %s)"%(shank,cluster,group),transform=ax.transAxes,fontsize=15)

    gs=gridspec.GridSpec(nbLines+1,nbcol,hspace=0.2)

    for trial in data.trials[:-1]:
        axT = plt.subplot(gs[1+trial//nbcol,trial%nbcol])
        plot_raster_one_trial(data,shank,cluster,trial,minTime,maxTime,ax=axT)
#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    animal="Rat025"
    experiment="Rat025_2014_08_21_16_18"
    data=Data(root,animal,experiment,param,redoPreprocess=False)
    SHANK=4
    CLU=37
    #print(data.trialNotTracked)
    #plot_raster_one_trial(data,SHANK,CLU,trial=4)

    plot_raster_trial_by_trial(data,SHANK,CLU)

In [None]:
#def plot_raster_trial_by_trial_new(data,shank,cluster,group="not specified",minTime=0,maxTime=20,paw=False):
def plot_raster_trial_by_trial_new(data,group="Good",minTime=0,maxTime=20,paw=False):
    '''
    -Raster for all trials with median position (firing rate if nbSpike>20000)
    -One plot for every trial (plot_raster_one_trial: position + spike time)
    Paw=True: plot paw position
    A trial time range is [treadmillStart+minTime, treadmillStart+maxTime]
    times are aligned on treadmillStart
    '''
    #nbcol=3 #nbcol=2 will result in Value Error (plot too big) 
    #nbLines=int(np.ceil(data.nTrial*1.0/nbcol))
    #gs=gridspec.GridSpec(nbLines+1,nbcol,hspace=0.5)

    #plt.figure(figsize=(15,nbLines*(15.0/nbcol)))
    #ax=plt.subplot(gs[0,0:2])
    #if len(data.spikeTime[shank][cluster])<20000:
        #plot_raster(data,shank,cluster)
    #else:
        #plot_raster_firing_rate(data,shank,cluster)
    #ax.text(1.2,0.5,"Shank %s, Cluster %s \n(group %s)"%(shank,cluster,group),transform=ax.transAxes,fontsize=15)
    
    rvalue={}
    for shank in data.spikeTime:
        rvalue[shank]={}
        for cluster in data.spikeTime[shank]:
            if cluster not in data.clusterGroup[shank][group]:continue
            firingRateSession=[]
            for trial in data.trials[:-1]:
                cluData=data.spikeTime[shank][cluster]
                zero=data.treadmillStartTime[trial]
                start=zero+minTime
                stop=zero+maxTime
                trialSpikeTime=cluData[(cluData>start)&(cluData<stop)]-zero
                nbSpike=len(trialSpikeTime)
                firingRateTrial=nbSpike/(stop-start)
                firingRateSession.append(firingRateTrial)

            percentile20 = np.percentile(firingRateSession, 20)
            percentile80 = np.percentile(firingRateSession, 80)
            median = np.percentile(firingRateSession, 50)
            r=(median-percentile20)/(percentile80-median)
            rvalue[shank][cluster]=r

    #return percentile20,percentile80,median,r
    return rvalue
#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    animal="Rat025"
    experiment="Rat025_2014_08_16_14_18"
    data=Data(root,animal,experiment,param,redoPreprocess=False)
    #SHANK=1
    #CLU=130
    #percentile20,percentile80,median,r = plot_raster_trial_by_trial_new(data,SHANK,CLU)
    #print(shank,clu,"r",rvalues)
    rvalue=plot_raster_trial_by_trial_new(data,group="Good")
    
    for shank in rvalue:
        for cluster in rvalue[shank]:
            print('shank:',shank,',','clu:',cluster,',',rvalue[shank][cluster])

### Raster align on the detected end
For some trials, detected end is None (not detected). Those trials are not plotted.

In [None]:
def compute_firing_rate_alignEnd(data,shank,cluster,binSize,minTime=-10,maxTime=10):
    timeBin=np.arange(minTime,maxTime+binSize-maxTime%binSize,binSize)
    firingRate={}
    spikeTime=data.spikeTime[shank][cluster]
    for trial in data.trials:
        end=data.timeEndTrial[trial]
        if isNone(end):
            continue
        zero=data.treadmillStartTime[trial]+end
        start=zero+minTime
        stop=zero+maxTime
        trialSpikeTime=spikeTime[(spikeTime>=start)&(spikeTime<=stop)]
        alignedTime=trialSpikeTime-zero
        hist,bins=np.histogram(alignedTime,timeBin)
        firingRate[trial]=hist/float(binSize)
    center=(timeBin[:-1]+timeBin[1:])/2
    return firingRate,center,timeBin

#----------------------------------------------------------------------------------------------------------------------
def plot_raster_firing_rate_alignEnd(data,shank,cluster,binSize=0.25,minTime=-10,maxTime=10,legend=False,ax=None):
    if ax is None:
        ax=plt.gca()
        
    firingRate,center,timeBin=compute_firing_rate_alignEnd(data,shank,cluster,binSize,minTime,maxTime)
    res=np.asarray(list(firingRate.values()))

    realTrials=[trial+1 for trial in data.trials if not isNone(data.timeEndTrial[trial])]
    
    y= [t+0.5 for t in range(1,len(realTrials)+2)]
    resColor=ax.pcolormesh(timeBin,y,res,cmap="Greys")

    box = ax.get_position()
    axColor = plt.axes([box.x0*1.02 + box.width * 1.02, box.y0, 0.01, box.height])
    plt.colorbar(resColor, cax=axColor, label="firing rate")

    yAxis=[y+0.5 for y in range(0,len(realTrials))]
    entrance=[data.entranceTime[trial-1]-data.timeEndTrial[trial-1] for trial in realTrials]
    ax.plot(entrance,yAxis,"x",color="red",label="Entrance times",linewidth=2)

    ax2=ax.twinx()
    ax2.axis("off")
    ax2.plot(data.timeAlignEnd,data.medianPositionAlignEnd,'g-',linewidth=2,label="Median position");
    
    #set correct ticks and limits
    ax.set_xlim([timeBin[0],timeBin[-1]])
    ax.set_yticks(yAxis)
    for i,label in enumerate(ax.get_yticklabels()):
        label.set_visible(not i%(len(yAxis)//10))
        
    ax.set_yticklabels([str(t) for t in realTrials])
    ax.set_ylim(0,len(realTrials))
    ax.set_xlabel("Time relative to detected end")
    ax.set_ylabel("Trial Number")
    ax.set_title("Cluster %s aligned end, firing rate, binSize=%s s"%(cluster,data.binSize))

#----------------------------------------------------------------------------------------------------------------------
def plot_mean_firing_rate_alignEnd(data,shank,cluster,binSize=0.25,minTime=-10,maxTime=10,sigma=1,ax=None):
    if ax is None:
        ax=plt.gca()
    
    firingRate,center,timeBin=compute_firing_rate_alignEnd(data,shank,cluster,binSize,minTime,maxTime)
    spikeCount=np.mean(list(firingRate.values()),axis=0)
    
    import scipy.ndimage as scImage
    smoothSpikeCount=scImage.filters.gaussian_filter1d(spikeCount, sigma) 
    
    ax.plot(center,smoothSpikeCount);
    ax.set_xlabel("Time relative to detected end")
    ax.set_ylabel("Mean firing rate")
    ax.set_title("Cluster %s aligned End, firing rate + gaussian (sigma=%s)"%(cluster,sigma))   

#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    SHANK=1
    CLU=4

    plt.figure(figsize=(8,15))
    plt.subplot(311)
    plot_raster_firing_rate(data,SHANK,CLU)

    plt.subplot(312)
    plot_raster_firing_rate_alignEnd(data,SHANK,CLU)

    plt.subplot(313)
    plot_mean_firing_rate_alignEnd(data,SHANK,CLU)

###  Autocorrelogram

In [None]:
import phy.stats

def plot_autocorrelogram(data, shank, cluster, bin_ms=1,half_width_ms=25,ax=None):
    if ax is None:
        ax=plt.gca()
    
    bin_ms=np.clip(bin_ms,.1,1e3) #bin size in ms, rounded
    binsize=int(data.spikeSamplingRate*bin_ms*0.001) #bin size in time samples
    
    half_width_ms=np.clip(half_width_ms,.1,1e3) #ms, rounded
    winsize_bins= 2*int(half_width_ms/bin_ms) +1 #number of bins in window

    sample=data.spikeSample[shank][cluster]
    clu=np.ones_like(sample,dtype="int64")
    
    pairwiseCorr=phy.stats.pairwise_correlograms(sample,clu,binsize,winsize_bins)

    autoCorr=pairwiseCorr[0,0,:]
    
    halfWinsize=winsize_bins//2
    xaxis=np.arange(-halfWinsize-0.5, halfWinsize+1.5)
    xaxis=xaxis*binsize/data.spikeSamplingRate*1000 #ms
    
    ax.bar(xaxis[:-1],autoCorr,width=bin_ms,color="blue",edgecolor="blue");
    ax.set_title("Cluster %s, Autocorrelogram"%cluster);
    ax.set_xlim([xaxis[0],xaxis[-1]]);
    ax.set_xlabel("time (ms), binsize=%s ms"%bin_ms)
    ax.set_ylabel("spike count")

#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    shank=1
    clusterID = 3

    #bin_ms: bin size in ms
    bin_ms=1
    #half_width_ms: half width of the x axis (time), in ms
    half_width_ms=30 #1000

    plt.figure(figsize=(15,5))    
    plt.subplot(121)
    plot_autocorrelogram(data,shank,clusterID,bin_ms,30)
    plt.subplot(122)
    plot_autocorrelogram(data,shank,clusterID,bin_ms,1000)

### Autocorrelogram for trial/intertrial

In [None]:
import phy.stats

def plot_autocorrelogram_trial(data,shank,cluster,bin_ms=1,half_width_ms=25,minTime=-5,maxTime=20,inTrial=True,ax=None):
    if ax is None:
        ax=plt.gca()
    
    bin_ms=np.clip(bin_ms,.1,1e3) #bin size in ms, rounded
    binsize=int(data.spikeSamplingRate*bin_ms*0.001) #bin size in time samples
    
    half_width_ms=np.clip(half_width_ms,.1,1e3) #ms, rounded
    winsize_bins= 2*int(half_width_ms/bin_ms) +1 #number of bins in window

    sample=data.spikeSample[shank][cluster]
    spikeTime=data.spikeTime[shank][cluster]
    
    isInTrial=np.full_like(sample,False)
    #select only spike during trials
    for trial in data.trials:
        zero=data.treadmillStartTime[trial]
        start=zero+minTime
        stop=zero+maxTime
        isInTrial=np.logical_or(isInTrial,(spikeTime>start)&(spikeTime<stop))
    
    if inTrial:
        newSample=sample[isInTrial]
        title="trial"
    else:
        newSample=sample[np.invert(isInTrial)]
        title="intertrial"
    clu=np.ones_like(newSample,dtype="int64")

    pairwiseCorr=phy.stats.pairwise_correlograms(newSample,clu,binsize,winsize_bins)

    autoCorr=pairwiseCorr[0,0,:]
    
    halfWinsize=winsize_bins//2
    xaxis=np.arange(-halfWinsize-0.5, halfWinsize+1.5)
    xaxis=xaxis*binsize/data.spikeSamplingRate*1000 #ms
    
    ax.bar(xaxis[:-1],autoCorr,width=bin_ms,color="blue",edgecolor="blue");
    ax.set_title("Cluster %s, Autocorrelogram during %s"%(cluster,title));
    ax.set_xlim([xaxis[0],xaxis[-1]]);
    ax.set_xlabel("time (ms), binsize=%s ms"%bin_ms)
    ax.set_ylabel("spike count")

#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    plt.figure(figsize=(15,10))    
    plt.subplot(221)
    plot_autocorrelogram_trial(data,shank,clusterID,bin_ms,30)
    plt.subplot(222)
    plot_autocorrelogram_trial(data,shank,clusterID,bin_ms,1000)
    plt.subplot(223)
    plot_autocorrelogram_trial(data,shank,clusterID,bin_ms,30,inTrial=False)
    plt.subplot(224)
    plot_autocorrelogram_trial(data,shank,clusterID,bin_ms,1000,inTrial=False)

### Waveform for teresa's data (.kwx)
  - Load .kwx with h5py and .kwik with phy  
  - Choose randomly 150 spikes from a cluster
  - Do not plot spike if mask=0
  
  in kwik/kwx format, spikeIndex are the index in the kwx file
  
  spikeSample are the index in dat file (once it's reshaped as `nSample*nChannel`)

In [None]:
import h5py
from phy.session import Session

def read_dat_waveform(data, shank, cluster, subSample = 150, extract = 16):
    #memory map to dat file
    dtype = np.int16
    size = os.stat(data.fullPath + '.dat').st_size
    row_size = data.nChannels * np.dtype(dtype).itemsize
    if size % row_size != 0:
        raise ValueError(("Shape error: the file {f} has S={s} bytes, "
                          "but there are C={c} channels. C should be a divisor of S."
                          "").format(f=filename, s=size, c=self.nchannels))
    nsamples = size // row_size
    shape = (nsamples, data.nChannels)
    datFile = np.memmap(data.fullPath + '.dat', dtype = dtype, mode = 'r', offset = 0, shape = shape)
    #indexes of spikes for this cluster
    spikeID = data.spikeSample[shank][cluster]
    nSpike = len(spikeID)
    if nSpike > subSample:
        spikeID = np.random.choice(spikeID, subSample, replace = False)
        nSpike = subSample
    #get waveform for each index
    waveform = np.zeros(shape=(nSpike, extract * 2, len(data.channelGroupList[shank])), dtype=dtype )
    for index, spike in enumerate(spikeID):
        waveform[index, :, :] = datFile[spike-extract : spike+extract, data.channelGroupList[shank]]
    return waveform

def read_kwx_waveform(data, shank, clusterID, sample = 150):
    """
    kwx array = [ spikes indexes, n data points, n channels]
    """
    if not os.path.exists(data.fullPath+".kwx"):
        return
    with h5py.File(data.fullPath+".kwx","r") as kwx:  
        waveform = kwx.get('channel_groups/%s/waveforms_raw' % shank)[()]
    print(waveform.shape)
    #index of spikes where cluster==X
    spikeID=data.spikeIndex[shank][clusterID]
    print(max(spikeID))
    if len(spikeID) > sample:
        spikeID = np.random.choice(spikeID, sample, replace = False)
    return waveform[spikeID, :, :]

In [None]:
#Mean waveform from .kwx
def plot_mean_waveform(data, shank, clusterID, noPlot = False, sample = 150, kwx = True):
    
    if kwx:
        waveform = read_kwx_waveform(data, shank, clusterID, sample)
    else:
        waveform = read_dat_waveform(data, shank, clusterID, sample)
    
    minMax = 0
    meanWaveform = []
    ch = 0
    for channel in range(waveform.shape[2]):
        meanChannel = np.mean(waveform[:, :, channel], axis = 0)
        minMaxChannel = np.max(meanChannel) - np.min(meanChannel)
        if minMaxChannel > minMax:
            minMax = minMaxChannel
            meanWaveform = meanChannel
            ch = channel
    if not noPlot:
        plt.plot(meanWaveform)
        plt.title("Mean waveform - Shank %s Cluster %s channel %s" %(shank, clusterID, ch))
    return meanWaveform
        
#Loading waveform from .kwx
def plot_waveforms(data, shank, clusterID, group="not specified", sample = 150, kwx = True):
    
    if kwx:
        waveform = read_kwx_waveform(data, shank, clusterID, sample)
    else:
        waveform = read_dat_waveform(data, shank, clusterID, sample)
    
    place = [[0, 1], [1, 0], [2, 1], [3, 0], [4, 1], [5, 0], [6, 1], [7, 0]]           
    plt.figure(figsize = (3*2, 6))
    plt.suptitle("Cluster %s" % clusterID, fontsize = 14)
    gs = gridspec.GridSpec(8, 3, hspace = -0.3, wspace = 0)    
    for channel in range(waveform.shape[2]):
        channelWaveform = waveform[:, :, channel] 
        x = place[channel][0]
        y = place[channel][1]
        ax = plt.subplot(gs[x,y])
        for spike in channelWaveform:
            ax.plot(spike, color = "blue");
        ax.set_title("%s" % channel)
        ax.set_axis_off() 
    return

#old name of the function
plot_waveforms_teresa = plot_waveforms

#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    shank = 1
    clu = 3
    plot_waveforms(data, shank, clu, kwx=False)
    plt.figure()
    plot_mean_waveform(data, shank, clu)

### General plot on each cluster

In [None]:
def plot_raster_correlogram(data,shank,cluster,group="not specify"):
    
    cluSpikeTime=data.spikeTime[shank][cluster]
    
    plt.figure(figsize=(15,10))
    plt.subplot(222)

    if len(cluSpikeTime)>20000:
        plot_raster_firing_rate(data,shank,cluster)
    else:
        plot_raster(data,shank,cluster)
        
    plt.subplot(224)
    plot_mean_firing_rate(data,shank,cluster)
    plt.subplot(221)
    plot_autocorrelogram(data,shank,cluster,1,30)
    plt.subplot(223)
    plot_autocorrelogram(data,shank,cluster,1,1000)
    
    exp=data.experiment
    plt.suptitle("%s, Shank %s, Cluster %s (group:%s), %s spikes"%(exp,shank,cluster,group,len(cluSpikeTime)),fontsize=16)

#----------------------------------------------------------------------------------------------------------------------
if "__file__" not in dir():
    shank=1
    cluster=3
    plot_raster_correlogram(data,shank,cluster,"Good")