# loadRat_documentation
## This notebook generates a bench of attributes for each experiment (session) using python classes and functions 

### This notebook is at the core of the pipeline of data processing. Do not play with it lightly inside the master folder (load_preprocess_rat)

#### 1. Only modifiy if you are sure of what you are doing and that you are solving a bug. Consult with David
#### 2. If you do modify you MUST commit this modification using bitbucket
#### 3. If you want to play whis notebook (to understand it better) copy it on a toy folder distinct from the master folder
#### 4. If you want to modify this code (fix bug, improve, add attributes ...) it is recommanded  to first duplicate in a draft folder. Try to keep track of your change.
#### 5. When you are ready to commit : # clear all output, clean everything between hashtag 


In [None]:
from IPython.core.display import display, HTML

import platform
import os
import glob
import scipy.io as spio
import pickle
import xmltodict 
import datetime
import six
from collections import Counter
import numpy as np
import pandas as pd 
from scipy.interpolate import interp1d
from scipy.ndimage.filters import gaussian_filter as smooth
import matplotlib.pyplot as plt
from IPython.display import clear_output
%matplotlib inline


ThisNoteBookPath=os.path.dirname(os.path.realpath("__file__"))
CommunNoteBookesPath=os.path.join(os.path.split(ThisNoteBookPath)[0],"load_preprocess_rat")
CWD=os.getcwd()
if "__file__" not in dir():
    os.chdir(CommunNoteBookesPath) 
    %run loadRat_documentation.ipynb
    %run Animal_Tags.ipynb
    os.chdir(CWD)

#if available, loadRawSpike_documentation.ipynb must be imported to the scope.
try:
    os.chdir(CommunNoteBookesPath) 
    %run loadRawSpike_documentation.ipynb
except:
    pass
finally:
    os.chdir(CWD)   
    
if "__file__" not in dir():

    if platform.system()=='Linux':
        root="/data"
    elif platform.system()=='Windows':
        root="C:\\data\\"
    else:
        root="/Users/davidrobbe/Documents/Data/"

    print("The path to data is %s"%root)
    
#utility function to check if something is None, or empty, or np.nan
def isNone(value):
    if value is None:
        return True
    elif isinstance(value,(float,int)):
        if np.isnan(value):
            return True
        return False
    else:
        return (not value)

## Loading raw behavior data for rats

**experiment** is a synonym for **session**  
one experiment= one recording session= one folder = one animal at one time
  
#### Paths
  
- **rootFolder**: "/data"
- **animal**: "RatXXX" (name of the rat folder)
- **experiment**: "RatXXX\_20XX\_XX\_..." (name of the session)
- **sessionPath**: Folder of the session   
    "/data/RatXXX/Experiments/Rat2010_04_10_18_30" 
- **fullPath**: Folder of session + basename   
    "/data/RatXXX/Experiments/Rat_2010_04_10_18_30/Rat_2010_04_10_18_30"     
    fullPath +".dat" is the path from the root to the dat file

#### Integers

- **cameraSamplingRate**: number of frame per seconds on the video (Hz)
    other name: FrameRate  
    read from .samplingrate
- **nTrial**: total number of trial in the session (tracked and not tracked)

#### Numpy 1D array (horizontal, one row), length=nTrial
One value per trial, parameters

- **goalTime**: time after which the rat should cross the front of the treadmill (in seconds after the treadmill started)
    read from .goaltime or manually provided
    
- **treadmillSpeed**: fixed speed of the treadmill (cm/s) for each trial

- **cameraStartTime**: time where the camera start (in seconds from the beginning of the session)  
    other name: CameraEvents  
- **treadmillStartTime**: time where the treadmill start (in seconds from the beginning of the session)   
    other name: TreadmillEvents  
    
- **maxTrialDuration**: maximum duration of the trial, in seconds
- **interTrialDuation**: duration of the intertrial, in seconds (1s of camera off+the rest camera on in new setup)

One value per trial, rat behavior

- **entranceTime**: time when the rat crossed the front of the treadmill (in seconds after the treadmill started)
    entrance>=goal: trial sucessfull  
    read from .entrancetimes
    
#### Dictionary  key(trial number-1): value

 - **rawPosition**: position of the animal in cm (~15: front of treadmill, ~90: back of treadmill)  
     one numpy.array per trial correctly tracked (trial not in trialNotTracked)  
     read from .pos files
     
 - **rawTime**: time matching raw position (0=camera start)

     
The ".position" files in pavel's data are actually positions for the paw.  
 The ".position" and ".paw" files in teresa's data are extract from the ".avi" video. The ".position" are the position of the body (detected from the white color of the rat), the ".paw" the position of the paw (detected from the colored tape).
 
 In the new setup, the camera never stops. Position, time and trial/intertrial numbers are written as columns in one file ".position". Position is the position of the animal. In older setups, the time has to be deduced from the number of positions and the camera sampling rate.
 
#### Trials

 - **trialNotTracked**: list of trials to remove from analysis because position was not tracked, or other issue   
     Those trial are removed from rawPosition and rawPawPosition (dictionary)  
     They are kept in the 1d numpy array, to not disrupt the order.
     
 - **trials**: list of trials, trialNotTracked removed
 
 - **realTrials**: real indexes (trials+1)
 
 - **goodTrials**: trials where entrance time <= goalTime
 
#### Other
    
 - **emptyAnalysisFiles**: set of files that are empty in RatXXX/Experiments/RatXXX_2014.../Analysis  
    other name: FileTextTags
    
 - **experimenter**: name of the experimenter
    
 - **cameraToTreadmillDelay**: time between the camera start and the treadmill start (default=2sec).  
     Not implemented currently, but could be useful
 
 - **treadmillRange**: size of the treadmill, list of two integers [minPos, maxPos]
 
#### Booleans

 - **hasBehavior**: whether the session has behavioral data, and it was loaded successfully 
 
 - **hasEEG**: whether there is a downsample dat file (.eeg,.low.kwd)
 


In [None]:
class BaseRawBehaviorData:
    '''
    Base class, not meant to used directly
    Classes should be implemented for each data type, and should inherits this base class
    saveAsPickle=True: save data as a dictionnary in a pickle file "/Analysis/rawbehaviordata.p"
    parameters: dictionnary with data not in raw text files (ex: treadmill range).
    '''
    def __init__(self,root,rat,experiment,parameters={},saveAsPickle=True,PrintWarning=False):
        #dictionnary with treadmill speed, goal time...(things not in text files)
        self.parameters=parameters.copy() 

        #check if the experiment path exists, compute full path to the folder
        if not self.compute_paths(root,rat,experiment):
            if PrintWarning:
                print ("path error")
            self.hasBehavior=False
            return
        
        
        if len(self.read_entranceTime()[0])==1:
            self.hasBehavior=False
            return
        
        
        #rename files if needed
        self.rename_files()
             
        #xml, eeg
        self.xmlDict=self.read_xml()
        self.hasEEG=self.has_eeg()
        
        #behavior files
        positionFiles=glob.glob(os.path.join(self.sessionPath,"*.position"))
        entranceTimeFile=glob.glob(os.path.join(self.sessionPath,"*.entrancetimes"))
        if len(positionFiles)==0:
            self.hasBehavior=False
            if PrintWarning:
                print("position file missing")
            return
        
        elif len(entranceTimeFile)==0:
            self.hasBehavior=False
            if PrintWarning:
                print("entrancetimes file missing (reward habituation or locomotion test ?)")
            return
            
        else:
            self.hasBehavior=True
        if PrintWarning:
            print("ready to read everything")   
        self.read_everything()
        
        self.dataType=self.data_type()
        
        if saveAsPickle:
            self.save_as_pickle()
            
    def data_type(self):
        return "Base class"
            
    def read_everything(self):
        '''
        Calls a different methods for each attribute to reados.stat(os.path.join(root,name)).st_size<=5
        When inheriting the base class, keep every method, but change the order if needed.
        '''   
        self.cameraSamplingRate=self.read_cameraSamplingRate()
        self.experimenter=self.read_experimenter()
        self.emptyAnalysisFiles=self.find_emptyAnalysisFiles()
        self.treadmillRange=self.read_treadmill_range()
        
        #1d numpy array from text files
        self.entranceTime=self.read_entranceTime()
    
            
        self.cameraStartTime=self.read_cameraStartTime()
        self.treadmillStartTime=self.read_treadmillStartTime()
        
        #dictionnary key(trial number-1): value from position files
        self.rawPosition=self.read_rawPosition()
        
        self.trialNotTracked=self.get_trialNotTracked()
        
        #we need nTrial to read goalTime and treadmill speed
        #this could be removed when inheriting
        self.nTrial=len(self.rawPosition)
        
        # 1D numpy arrays 
        self.goalTime=self.read_goalTime()
        self.treadmillSpeed=self.read_treadmillSpeed()
        self.maxTrialDuration=self.read_maxTrialDuration()
        self.interTrialDuration=self.read_interTrialDuration()
        
        #trialNotTracked, trials, realTrials, goodTrials, nTrial
        # remove trial not tracked from any dict
        self.check_validity()
        self.get_trial_info()
        
        #compute rawTime from position and camera sampling rate
        self.rawTime=self.read_rawTime()
        
        #%%%%%%%%%
        #read lick time file
        self.lickTime=self.read_licktime_file ()
        #%%%%%%%%%
           
    #---------------------------------------------------------     
    def compute_paths(self,root,animal,experiment,PrintWarning=False):
        #clean name of folders (remove unnecessary slash or backslash)
        self.root=os.sep+root.strip(os.sep)
        self.animal=animal.strip(os.sep)
        self.experiment=experiment.strip(os.sep)
        #paths
        self.sessionPath=os.path.join(self.root,self.animal,"Experiments",self.experiment)
        self.fullPath=os.path.join(self.root,self.animal,"Experiments",self.experiment,self.experiment)  
        
        #Check if the path is correct
        if self.animal not in self.experiment:
            if PrintWarning:
                print("WARNING: session name (%s) does not contain animal name (%s)"%(self.animal,self.experiment))
        if not os.path.exists(self.sessionPath):
            if PrintWarning:
                print("STOP Loading - Path does no exists: %s"%self.sessionPath)
            return False
        return True
     
    def has_eeg(self):
        if glob.glob(self.fullPath+".eeg"):
            return True
        elif glob.glob(self.fullPath+".low.kwd"):
            return True
        else:
            return False
        
    def get_dict(self):
        return self.__dict__
    
    def save_as_pickle(self,folder="Analysis",name="rawbehaviordata.p"):
        folderPath=os.path.join(self.sessionPath,folder)
        if not os.path.exists(folderPath):
            os.mkdir(folderPath)
        filePath=os.path.join(folderPath,name)
        with open(filePath, "wb" ) as f:
            pickle.dump(self.__dict__, f)
            f.close()
    #---------------------------------------------------------   
    def rename_files(self):
        '''
        If the files inside the folder have not the same basename, rename them.
        ex: "MOU087_2015_08_12-16_09" -> "MOU087_2015_08_12_16_09"
        Only for files that starts with the animal name
        '''
        files=glob.glob(os.path.join(self.sessionPath,"*"))
        for path in files:
            if not os.path.isfile(path):
                #it's a folder, skipp
                continue
            filename=os.path.basename(path)
            if self.experiment in filename:
                #it's already named correctly, skipp
                continue
            #if it starts with "AnimalXXX" (it just has a wrong date)
            if filename.startswith(self.animal):
                # extract full extension ".behav_param", ".raw.kwd"
                extension=filename[filename.find("."):]
                #rename
                newName=self.experiment+extension
                newPath=os.path.join(self.sessionPath,newName)
                os.rename(path,newPath)
                print("Renamed %s to %s"%(filename,newName))
        
    def read_xml(self):
        d={}
        if os.path.exists(self.fullPath+".xml"):
            with open(self.fullPath+'.xml', "rb") as f:
                d = xmltodict.parse(f, xml_attribs=True)
                f.close()
        return d
    
    def look_param_or_ask_user(self,name,valueType=str,sentence=""):
        '''
        Look for a key 'name' in self.parameters, returns the value.
        If not found, ask the user to input the value
        '''
        try:
            return self.parameters[name]
        except KeyError:
            if sentence=="":
                sentence="Enter %s (type: %s):"%(name,valueType)
            return valueType(input(sentence))
        
    #---------------------------------------------------------------------------    
    def get_trialNotTracked(self):
        trialNotTracked=[]
        for trial in self.rawPosition:
            if len(self.rawPosition[trial])==0:
                trialNotTracked.append(trial)
        return trialNotTracked
    
    def get_trial_info(self):
        self.trials=[trial for trial in self.rawPosition if trial not in self.trialNotTracked]
        self.realTrials=[trial+1 for trial in self.trials]
        self.goodTrials=[t for t in self.trials if (self.maxTrialDuration[t]>self.entranceTime[t]>=self.goalTime[t])]
        self.nTrial=len(self.trials)+len(set(self.trialNotTracked))
        
        for trial in self.rawPosition:
            if trial in self.trialNotTracked:
                del self.rawPosition[trial]
        
    def check_validity(self):
        for l in [self.entranceTime,self.treadmillStartTime,self.cameraStartTime]:
            assert len(l)>=self.nTrial, "Wrong length of array (%s)"%len(l)
                
    #---------------------------------------------------------------------------
    # Read datas in the dictionnary "parameters" provided
    # If the data is not there, ask for user input
    
    def read_treadmill_range(self):
        try:
            return self.parameters["treadmillRange"]
        except KeyError:
            mini=0#input('Enter treadmill min position:')
            maxi=90#input('Enter treadmill max position:')
            return [int(mini),int(maxi)]
          
    def read_cameraSamplingRate(self):
        return self.look_param_or_ask_user("cameraSamplingRate",valueType=float,
                                      sentence="Enter camera sampling rate:")
        
    def read_experimenter(self):
        if "experimenters" in self.parameters:
            return self.parameters["experimenters"]
        else:
            return "unknown"
        
    def read_goalTime(self):
        goalTime=self.look_param_or_ask_user("goalTime",valueType=float,sentence="Enter goal time (seconds):")
        npGoalTime=np.full(self.nTrial,goalTime,dtype=np.float64)
        return npGoalTime
    
    def read_treadmillSpeed(self):
        speed=self.look_param_or_ask_user("treadmillSpeed",valueType=float,
                                         sentence='Enter treadmill speed (cm/sec):')
        npSpeed=np.full(self.nTrial,speed,dtype=np.float64)
        return npSpeed
    
    def read_maxTrialDuration(self):
        maxDuration=self.look_param_or_ask_user("maxTrialDuration",valueType=float,
                                               sentence="Enter maximum duration of trial (sec):")
        npMax=np.full(self.nTrial,maxDuration,dtype=np.float64)
        return npMax
    
    def read_interTrialDuration(self):
        inter=self.look_param_or_ask_user("interTrialDuration",valueType=float,
                                         sentence="Enter duration of intertrial (sec):")
        npInter=np.full(self.nTrial,inter,dtype=np.float64)
        return npInter
    
    #---------------------------------------------------------------------------
    #read data in text files with pandas (faster than numpy.loadtxt)    
    
    def read_csv_pandas(self,path,oneCol=False,header=None):
        if not os.path.exists(path):
            print("No file %s"%path)
            #self.hasBehavior=False
            return []
        try:
            csvData=pd.read_csv(path,header=header,delim_whitespace=True)
        except ValueError:
            print("%s not valid (usually empty)"%path)
            #self.hasBehavior=False
            return []
        if oneCol:
            return csvData.values[:,0]
        else:
            return csvData
    
    def read_entranceTime(self):     
        return self.read_csv_pandas(self.fullPath+'.entrancetimes',oneCol=True)
    
    def read_cameraStartTime(self):
        time_ms=self.read_csv_pandas(self.fullPath+".evt.cam",oneCol=True)
        time_sec=time_ms/1000.0
        return time_sec
    
    def read_treadmillStartTime(self):
        time_ms=self.read_csv_pandas(self.fullPath+".evt.tre",oneCol=True)
        time_sec=time_ms/1000.0
        return time_sec
    
    def read_rawPosition(self,extension="*.position"):
        rawPosition={}
        positionFiles = sorted(glob.glob(self.fullPath+extension))
        for index,posFile in enumerate(positionFiles):
            rawPosition[index]=self.read_csv_pandas(posFile,oneCol=True)
        return rawPosition
                
    def read_rawTime(self):
        rawTime={}
        for trial in self.rawPosition:
            nbPos=len(self.rawPosition[trial])
            rawTime[trial]=np.arange(nbPos)/float(self.cameraSamplingRate)
        return rawTime

    #---------------------------------------------------------------------------
    def find_emptyAnalysisFiles(self):
        result = []
        path=self.sessionPath
        if os.path.exists(path):
            for root, dirs, files in os.walk(path):
                for name in files:
                    if os.stat(os.path.join(root,name)).st_size<=5:
                        result.append(name)
        return result
    
    def read_licktime_file (self,extension=".lickbreaktime",PrintWarning=False):
        trial=0                                              #Trial number
        trialLickTimes=[]                                    #Lick times of a single trial
        lickTimes=[ np.array([]) for i in range(self.nTrial) ]           #Lick times of the session 
        try:
            with open(self.fullPath+extension,'r') as f:
                for line in f:
                    res=line.split()
                    if  res[0] == 'Trial':
                        trial = int(float(res [-1]))
                        trialLickTimes=[]
                    else:
                        trialLickTimes.append(float(res[-1]))
                        lickTimes[trial-1]=np.array(trialLickTimes)
                f.close()
                
            if len(trialLickTimes)==0:
                if PrintWarning:
                    print("lickbreaktime file is empty")
        except:
            if PrintWarning:
                print("No *.lickbreaktime file found!")
        return np.array(lickTimes)


### New setup with .behav_param
  
Files:
  - **behav_param**: parameters for each trial, like max trial duration, treadmill speed...
  
  - **entrancetimes**: for each trial, entrance time in second or "timeout".
  
  - **position**: in columns, trial/intertrial numbers, time and position (pixels)
  
      - `rawPosition[trial]`= positions for the trial and  for  the next intertrial, converted in cm (the camera is on before the treadmill start)
      
      - `rawTime[trial]`= time for the trial and time for the next intertrial , divided by camera sampling rate to get secondds
          (p.s.there is 1s between trial and intertrial where there camera is off)
      - `cameraStartTime[trial]=rawTime[trial][0]`

Treadmill start time is Camera start time + "computed start-up delay"  
Camera sampling rate is usually 25.0

For some trial in some session, no value are written in .entrancetimes, and only half are written in .behav_param

In [None]:
class NewRawBehaviorData(BaseRawBehaviorData):     
    
    def data_type(self):
        return "behav_param"
   
    def read_everything(self):  
        self.cameraSamplingRate=self.read_cameraSamplingRate()  
        
        #read entrance time and maxTrialDuration
        #check for missing values: put np.nan instead
        self.entranceTime,self.maxTrialDuration,missingTrial=self.read_entranceTime()
        if self.entranceTime is False:
            self.hasBehavior=False
            return
        
        if len(self.entranceTime)==1:
            self.hasBehavior=False
            return   
        
        
        self.trialNotTracked=[trial-1 for trial in missingTrial]
        self.realTrials=self.read_in_file("Trial #",valueType=int)
        
        self.realTrials=[trial for trial in self.realTrials if trial not in missingTrial]
        
        # only read trial in self.realTrials (skip trialNotTracked)
        self.rawPosition,self.stopFrame=self.read_position_file()
        self.rawTime=self.read_rawTime()
         
        # 1D numpy arrays (parameters) 
        # the size is the same as self.entranceTime, with np.nan in trial not tracked if needed
        self.goalTime=self.check_length(self.read_in_file("goal time 1",valueType=float))
        self.treadmillSpeed=self.check_length(self.read_in_file("computed treadmill speed",valueType=float))
        self.interTrialDuration=self.check_length(self.read_in_file("inter-trial duration",valueType=float))
        
        #other
        self.experimenter="unknown"
        self.emptyAnalysisFiles=self.find_emptyAnalysisFiles()
        self.treadmillRange=self.read_treadmill_range()
     
        #nTrial, trials, realTrials, goodTrials
        self.get_trial_info()
        
        #read camera start time from rawTime
        self.cameraStartTime=self.read_cameraStartTime()
        #compute treadmill start time from camera start and delay
        self.treadmillStartTime=self.read_treadmillStartTime()
        
        #read lick time file
        self.lickTime=self.read_licktime_file()
        
        #read reward dispensing mode Ok/kO:[A,B] | OK:B | KO:-A
        self.deliveredReward,self.deliveredRewardRatio=self.compute_delivered_reward_ratio()
        
    def read_cameraSamplingRate(self):
        return 25.0
        
    def get_trial_info(self):
        '''
        The video can froze, and there is less trials in rawPosition then in behav_param
        Ex: position 1,2,3,4 and behav_param 1,2,3,4,5,6 -> trials=[1,2,3,4]
        '''
        nTrialBehav=len(self.realTrials)
        
        nTrialPos=len(self.rawPosition)
        
        if nTrialPos<nTrialBehav:
            print("WARNING: Only %s trial in .position for %s trial in .behav_param"%(nTrialPos,nTrialBehav))
        elif nTrialPos>nTrialBehav:
            print("ERROR: %s trial in .position for only %s trial in .behav_param"%(nTrialPos,nTrialBehav))
        
        self.trialNotTracked.extend([trial-1 for trial in self.realTrials if trial-1 not in self.rawPosition])
        self.trialNotTracked=list(set(self.trialNotTracked))
        self.trials=[trial for trial in self.rawPosition if trial not in self.trialNotTracked]
        self.realTrials=[trial+1 for trial in self.trials]
        self.goodTrials=[t for t in self.trials if (self.maxTrialDuration[t]>self.entranceTime[t]>=self.goalTime[t])]
        self.nTrial=len(self.trials)+len(set(self.trialNotTracked))
    def check_length(self,array):
        #if length<nTrial, add nan where trial not tracked
        if len(array)!=len(self.entranceTime):
            for trial in self.trialNotTracked:
                array=np.insert(array,trial,np.nan)
        return array
    #-----------------------------------------------------------------------------------------------    
    def read_entranceTime(self):
        maxTrialDuration=self.read_in_file("maximum trial duration",extension=".behav_param",valueType=float)
        entranceTimeStr=list(self.read_in_file("time",extension=".entrancetimes",valueType=str))
        #fix possible issue
        trialNotTracked=self.detect_missing_value_in_entranceTime()
        if trialNotTracked:
            print("No entrance time for trials: "+str(trialNotTracked)+" ,they will be skipped")
            for badTrial in trialNotTracked:
                entranceTimeStr.insert(badTrial,"No value")
        #check if goalTime and entranceTime are of same size
#        if len(maxTrialDuration)!=len(entranceTimeStr):
#            print("Error: %s values for entranceTime, %s values for goalTime"%(len(entranceTimeStr),len(maxTrialDuration)))
#            return False,False,False
        #convert to float
        entranceTime=[]
        for e,duration in zip(entranceTimeStr,maxTrialDuration):
            if e=="timeout":
                entranceTime.append(duration)
            elif e=="No value":
                entranceTime.append(np.nan)
            else:
                entranceTime.append(float(e))
        entranceTime=np.asarray(entranceTime)
        return entranceTime,maxTrialDuration,trialNotTracked
    
    def detect_missing_value_in_entranceTime(self):
        '''
        Some "Trial #: X" are not followed by "Treadmill2 beam interruption time: value"
        those trial needs to be skipped 
        '''
        missingTrials=[]
        with open(self.fullPath+".entrancetimes") as f:
            timeLine=False
            for line in f:
                if timeLine:
                    if "beam interruption time" not in line:
                        # there's not time, there should be
                        missingTrials.append(trial)
                    timeLine=False
                elif line.startswith("Trial #"):
                    trial=int(float(line.split()[-1]))
                    timeLine=True  #next line, there should be a time
            f.close()
        #special case if file ends with "Trial #"
        if timeLine:
            missingTrials.append(trial)
        return missingTrials
    
    def compute_delivered_reward_ratio(self):
        def trig_func(x,lim):
            tmp=(-x/lim)+1
            if np.isnan(tmp):
                return 0
            if tmp<0:
                tmp=0
            elif tmp>1 and lim<0:
                tmp=1
            elif tmp>1 and lim>0:
                tmp=0
            return tmp
        
        rdm=self.read_rewardDispenseMode()
        tdr=[]   #total delivered reward
        for trial,_ in enumerate(rdm):
            if isinstance(rdm[trial],list):
                t=self.entranceTime[trial]-self.goalTime[trial]
                tdr.append(trig_func(t,-rdm[trial][0]) if t<0 else trig_func(t,rdm[trial][1]))
            elif rdm[trial]=='fixed':
                t=self.entranceTime[trial]-self.goalTime[trial]
                tdr.append(1 if t>0 else 0)
            else:
                t=self.entranceTime[trial]-self.goalTime[trial]
                tdr.append(trig_func(t,rdm[trial]))
        #delivered reward ration
        drr=sum(tdr)/(self.nTrial-max(sum(np.isnan(tdr)),sum(np.isnan(self.entranceTime))))
        return tdr,drr
    
    #-----------------------------------------------------------------------------------------------    
    def read_position_file(self):
        #column 0=trial 1=time from session start  3=xPosition
        data = self.read_csv_pandas(self.fullPath+'.position',oneCol=False)
        if len(data)==0:
            return {},{}
         
        self.correct_position(data[3].values)  
        #read number of pixel per cm
        nbPixel=self.read_in_file("Number of pixels per cm:",valueType=float)[0]
        rawPosition={}  
        stopFrame={}
        for trial in self.realTrials:
            #trial position
            key=trial-1
            trialAsFloat=float(trial) 
            trialPos=(data[3][data[0]==trialAsFloat].values)/float(nbPixel)
            #intertrial position
            trialAsFloat+=0.5
            interPos=(data[3][data[0]==trialAsFloat].values)/float(nbPixel)  
            rawPosition[key]=np.append(trialPos,interPos)
            stopFrame[key]=len(trialPos)
        return rawPosition,stopFrame

    def correct_position(self,positions):
        '''
        Raw position is in number of pixel
        This correction has to be done before converting to cm
        It removes integers found often (artefacts) and interpolate the data to fill gaps
        '''
        #find if an integer is present often
        integerValues=positions[np.equal(np.mod(positions,1),0)]
        if len(integerValues)>1:
            count=Counter(integerValues)#dictionary {integer: number in list}
            frequentIntegers=count.most_common()#list [(integer,number)] sorted by decreasing number
            number=frequentIntegers[0][1] #number of the most common integer
            if number>1:
                badTrackingValue=frequentIntegers[0][0] #most common integer
                #interpolate the bad values
                badIndex=np.where(positions==badTrackingValue)[0]
                goodIndex=np.where(positions!=badTrackingValue)[0]
                positions[badIndex]=np.interp(badIndex,goodIndex,positions[goodIndex])
                       
    def read_cameraStartTime(self):
        """
        Get the first time of every trial.
        """
        #column 0=trial 1=time from session start  3=xPosition
        data = self.read_csv_pandas(self.fullPath+'.position',oneCol=False)
        if len(data)==0:
            return []
      
        cameraStartTime=[]
        for trialIndex in range(self.nTrial):
            #trial time
            trialAsFloat=float(trialIndex+1) 
            trialTime=(data[1][data[0]==trialAsFloat].values)/float(self.cameraSamplingRate)
            if len(trialTime)>0:
                cameraStartTime.append(trialTime[0])
            else:
                cameraStartTime.append(np.nan)
        return np.asarray(cameraStartTime)
            
    def read_treadmillStartTime(self):
        """
        Treadmill starts after beginning of trial: "start-up delay"
        cameraToTreadmillDelay is used in preprocessing
        """
        delay=self.read_in_file("computed start-up delay",valueType=float)

        self.cameraToTreadmillDelay=np.nanmean(delay)
        if len(self.cameraStartTime) != len(delay):
            return []
        #return self.cameraStartTime+read_treadmillStartTime
        return self.cameraStartTime+delay
    
    def read_rewardDispenseMode(self):
        rewardType =self.read_in_file('reward dispensing mode')
        rewardStart=self.read_in_file('if interruption before',valueType=float)
        rewardStop =self.read_in_file('if interruption after',valueType=float)
        out=[]
        for trial,reward in enumerate(rewardType):
            if reward=='OK/KO':
                out.append([rewardStart[trial],rewardStop[trial]])
            elif reward=='OK':
                out.append(rewardStop[trial])
            elif reward=='KO':
                out.append(-rewardStart[trial])
            else:
                out.append('fixed')
        return out
            

    #-----------------------------------------------------------------------------------------------   
    def read_in_file(self,paramName,extension=".behav_param",exclude=None,valueType=str):
        '''
        Use to read from .behav_param or .entrancetimes
        Look for lines containing "paramName" and not containing "exclude"
        Split them by white spaces 
        example: "treadmill speed:     30.00" becomes ["treadmill","speed:","30.00"])
        Return a list of their last element, in the specified valueType (in example: "30.00")
        '''
        behav=self.fullPath+extension
        if not os.path.exists(behav):
            print("No file %s"%behav)
            self.hasBehavior=False
            return []
        result=[]
        trials=[0]
        with open(behav,"r") as f:
            for line in f:
                if "Trial #" in line:
                    trials.append(int(float(line.split()[-1]))-1)
                if paramName in line:
                    if (exclude is not None) and (exclude in line):
                        continue
                    res=line.split()[-1]
                    #integer or float: replace comma by dots
                    if valueType in [int,float]:
                        res=res.replace(",",".")                 
                    #integer: convert first to float ("0.00" -> 0.00 -> 0)
                    if valueType is int:
                        res=int(float(res))
                    #boolean "TRUE" "FALSE"
                    elif valueType is bool:
                        res=(res.lower()=="true")
                    else:
                        res=valueType(res)
                    result.append( (trials[-1],res) )
            f.close()
        
        out=[np.nan]*(trials[-1]+1)
        for item in result:
            out[item[0]]=item[1]
        return np.asarray(out)


In [None]:
#run only if inside this notebook (is not executed if "%run this_notebook")
if "__file__" not in dir():
    param={
        "cameraSamplingRate":25.0,
        "treadmillRange":[0,90]
    }
    #Rat041_2015_10_08_09_55 Rat124_2017_02_24_18_40 Rat124_2017_04_12_17_22 Rat106_2017_04_03_17_27
    experiment="Rat170_2018_01_05_09_55"
    animal=experiment[0:6]
    test=NewRawBehaviorData(root,animal,experiment,param,PrintWarning=False)
    plt.plot(test.entranceTime,test.deliveredReward,'.')
    print(test.deliveredRewardRatio)
    print(sum(test.entranceTime>7)/test.nTrial)

# Pavel Data

Camera sampling rate is not in files, but always equals to 60  
Experimeter can be read in .xml  
Treadmillspeed can be in different places, or not in text files  
Some position files are empty: trial not tracked

In [None]:
class PavelRawBehaviorData(BaseRawBehaviorData):
    def data_type(self):
        return "pavel"
    
    def read_cameraSamplingRate(self):
        if "cameraSamplingRate" in self.parameters:
            return float(self.parameters["cameraSamplingRate"])
        return 60.0
    
    def read_experimenter(self):
        experimenters = str(self.xmlDict['parameters']['generalInfo']['experimenters'])
        return experimenters
    
    def read_treadmillSpeed(self):
        path1=os.path.join(self.sessionPath,"Analysis","SpeedsCode.txt")
        path2=os.path.join(self.sessionPath,"Analysis","TreadmillSpeed.txt")
        path3=os.path.join(self.sessionPath,self.experiment+".treadmillspeed")
        if os.path.exists(path1):
            speed=1.0*(pd.read_csv(path1,delim_whitespace=True,header=None)).values[:,0]
        elif os.path.exists(path2):
            speed=1.0*(pd.read_csv(path2,delim_whitespace=True,header=None)).values[:,0]
        elif os.path.exists(path3):
            speed=1.0*(pd.read_csv(path3,delim_whitespace=True,header=None)).values[:,0]
        else:
            speed=np.full(self.nTrial,35.0,dtype=np.float64)
        return speed
    
    def get_trialNotTracked(self):
        positionFiles = sorted(glob.glob(self.fullPath+"*.position"))
        trialNotTracked = [] 
        for index,posFile in enumerate(positionFiles):
            if os.stat(posFile).st_size==0:
                trialNotTracked.append(index)
            elif (index+1)>len(self.entranceTime): 
                    trialNotTracked.append(index)
                    print("warning : position file ",index+1," without entrance time")
            else:
                rawPosition=self.read_csv_pandas(posFile,oneCol=True)
                if len(rawPosition)==0 or sum(rawPosition)==0:
                    trialNotTracked.append(index)
            
        return trialNotTracked  
    def check_validity(self):
        for l in [self.entranceTime,self.treadmillStartTime,self.cameraStartTime]:
            assert len(l)>=(self.nTrial-1), "Wrong length of array (%s)"%len(l)
            #in PavelData sometimes there is an extra position file it is considered as non tracked trial
    def get_trial_info(self):
        for trial in self.rawPosition:
            if self.entranceTime[trial]==0:
                self.trialNotTracked=np.append(self.trialNotTracked,trial)
        self.trials=[trial for trial in self.rawPosition if trial not in self.trialNotTracked]
        self.realTrials=[trial+1 for trial in self.trials]
        self.goodTrials=[t for t in self.trials if (self.maxTrialDuration[t]>self.entranceTime[t]>=self.goalTime[t])]
        self.nTrial=len(self.trials)+len(set(self.trialNotTracked))
        for trial in range(self.nTrial):
            if trial in self.trialNotTracked :
                del self.rawPosition[trial]
                
    def read_rawPosition(self,extension="*.position"):
        # if there is an empty position file no need to make hasbehavior to False"
        hasbehavior = self.hasBehavior
        rawPosition={}
        positionFiles = sorted(glob.glob(self.fullPath+extension))
        for index,posFile in enumerate(positionFiles):
            rawPosition[index]=self.read_csv_pandas(posFile,oneCol=True)
            if len(rawPosition[index])==0:
                self.hasBehavior = hasbehavior
        return rawPosition

In [None]:
#run only if inside this notebook (is not executed if "%run this_notebook")
if "__file__" not in dir():
    param={
        "goalTime":7,
        "treadmillRange":[0,80],
        "maxTrialDuration":20,
        "interTrialDuration":None,
        "sigmaSmoothPosition":0.33,#0.18
        "sigmaSmoothSpeed":0.5,
    }
    
    experiment="Rat103_2016_12_12_18_48"
    animal=experiment[0:6]
    test=PavelRawBehaviorData("/data/PavelData",animal,experiment,param)
    for key in sorted(test.__dict__):
        print(key)
        if key.startswith("raw"):
            print(test.__dict__[key].keys())
        elif key.startswith("xml"):
            continue
        else:
            print(test.__dict__[key])

# Teresa data

Class for data like Rat034, with .avi, .pos, .paw
   - camera sampling rate is in text file
   - treadmillSpeed and goaltime are in text files

In [None]:
class TeresaRawBehaviorData(BaseRawBehaviorData):
    def data_type(self):
        return "teresa"
    
    def read_treadmillSpeed(self):
        speed=self.read_csv_pandas(self.fullPath+".treadmillspeed",oneCol=True)
        return speed
        
    def read_cameraSamplingRate(self):
        csCsv=self.read_csv_pandas(self.fullPath+".samplingrate",oneCol=False)
        cs=csCsv.iat[0,0]
        return float(cs)
    
    def read_goalTime(self):
        return self.read_csv_pandas(self.fullPath+".goaltime", oneCol=True) 

In [None]:
#run only if inside this notebook (do not execute if "%run this_notebook")
if "__file__" not in dir():
    param={
        "treadmillRange":[0,90],
        "maxTrialDuration":20,
        "interTrialDuration":10,
        
    }
    root="/data"
    animal="Rat034"
    experiment="Rat034_2015_03_01_13_26"
    test=TeresaRawBehaviorData(root, animal,experiment,param)
           
    if test.hasBehavior:
        print(len(test.rawPosition[0]))
        print(len(test.rawTime[0]))
        

# View Point data

In [None]:
class ViewPointRawBehaviorData(NewRawBehaviorData):     
    
    def data_type(self):
        return "ViewPoint"
    
    def read_everything(self):
                
        #read entrance time and maxTrialDuration
        #check for missing values: put np.nan instead
        self.entranceTime,self.maxTrialDuration,missingTrial=self.read_entranceTime()
        if self.entranceTime is False:
            self.hasBehavior=False
            return
        
        #Read the .mat file
        mat_dict=self.read_mat(self.fullPath+'.mat')
        
        self.rawPosition,self.stopFrame,trialEmptyPosition=self.read_position_file(mat_dict)
        self.rawIntertrialPosition,_,_=self.read_intertrial_position(mat_dict)
        self.cameraSamplingRateInterTrial=np.nanmean(self.read_cameraSamplingRate_for_IT(mat_dict))

        missingTrial=np.append(missingTrial,trialEmptyPosition)
        self.trialNotTracked=[trial-1 for trial in missingTrial]
        self.realTrials=self.read_in_file("Trial #",valueType=int)
        self.realTrials=[trial for trial in self.realTrials if trial not in missingTrial]
        
        self.cameraSamplingRate=np.nanmean(self.read_cameraSamplingRate(mat_dict))
        self.rawMarkerPosition=self.read_marker_position_file(mat_dict)
        self.rawTime=self.read_rawTime()
         
        # 1D numpy arrays (parameters) 
        # the size is the same as self.entranceTime, with np.nan in trial not tracked if needed
        self.goalTime=self.check_length(self.read_in_file("goal time 1",valueType=float))
        self.treadmillSpeed=self.check_length(self.read_in_file("computed treadmill speed",valueType=float))
        self.interTrialDuration=self.check_length(self.read_in_file("inter-trial duration",valueType=float))
        
        #other
        self.experimenter="unknown"
        self.emptyAnalysisFiles=self.find_emptyAnalysisFiles()
        self.treadmillRange=self.read_treadmill_range()
     
        #nTrial, trials, realTrials, goodTrials
        self.get_trial_info()
        
        #read camera start time from rawTime
        self.cameraStartTime=self.read_cameraStartTime_eventFile()
        if self.cameraStartTime==[]:
            self.cameraStartTime=self.read_cameraStartTime(mat_dict)
        #compute treadmill start time from camera start and delay
        self.treadmillStartTime=self.read_treadmillStartTime_eventFile()
        if self.treadmillStartTime==[]:
            self.treadmillStartTime=self.read_treadmillStartTime()
        
        delay=self.read_in_file("computed start-up delay",valueType=float)
        self.cameraToTreadmillDelay=np.nanmean(delay)

        #read lick time file
        self.lickTime=self.read_licktime_file()

        #read reward dispensing mode Ok/kO:[A,B] | OK:B | KO:-A
        self.deliveredReward,self.deliveredRewardRatio=self.compute_delivered_reward_ratio()
        
    def read_mat(self,filePath):
        '''
        this function should be called instead of direct spio.loadmat
        as it cures the problem of not properly recovering python dictionaries
        from mat files. It calls the function check keys to cure all entries
        which are still mat-objects
        '''
        if not os.path.exists(filePath):
            print("No mat file for viewpoint data system:",filePath)
            return False
        data = spio.loadmat(filePath, struct_as_record=False, squeeze_me=True)
        matdict=self._check_keys(data)
        return matdict
    
    def read_cameraSamplingRate(self,mat_dict):
        if mat_dict is False:
            return super().read_cameraSamplingRate()
        try:
            fr=[mat_dict["Session"]["Fast"][trial]["FrameRate"] for trial in range(len(mat_dict["Session"]["Fast"]))]
        except:
            fr=[200]*len(mat_dict["Session"]["Fast"])
        return fr

    def read_cameraSamplingRate_for_IT(self,mat_dict):
        if mat_dict is False:
            return super().read_cameraSamplingRate()
        try:
            fr=[mat_dict["Session"]["Slow"][trial]["FrameRate"] for trial in range(len(mat_dict["Session"]["Fast"]))]
        except:
            fr=[25]*len(mat_dict["Session"]["Fast"])
        return fr

    
    def read_position_file(self,mat_dict):
        if mat_dict is False:
            return [],[],[]
        nbPixel=(mat_dict["Session"]["FileInfo"]['Scale'])*10
        nTrial=len(self.entranceTime)#len(mat_dict["Session"]["Fast"])
        rawPosition={}  
        stopFrame={}
        trialEmptyPosition=[]
        for key in range(nTrial):
            trialPos=np.empty(shape=(0,0))
            try:
                trialPos=(np.asarray(mat_dict["Session"]["Fast"][key]["Data"]["Body"]["Smooth"])[:,0])/float(nbPixel)
            except:
                trialEmptyPosition=np.append(trialEmptyPosition,key+1)
            finally:
                rawPosition[key]=trialPos
                stopFrame[key]=len(trialPos)
        if len(trialEmptyPosition) >=1:
            print("No position data for trial #:",trialEmptyPosition)
        return rawPosition,stopFrame,trialEmptyPosition

    def read_intertrial_position(self,mat_dict):
        if mat_dict is False:
            return [],[],[]
        nbPixel=(mat_dict["Session"]["FileInfo"]['Scale'])*10
        nTrial=len(self.entranceTime)#len(mat_dict["Session"]["Fast"])
        rawPosition={}  
        stopFrame={}
        trialEmptyPosition=[]
        for key in self.rawPosition:
            trialPos=np.empty(shape=(0,0))
            try:
                trialPos=(np.asarray(mat_dict["Session"]["Slow"][key]["Data"]["Body"]["Smooth"])[:,0])/float(nbPixel)
            except:
                trialEmptyPosition=np.append(trialEmptyPosition,key+1)
            finally:
                rawPosition[key]=trialPos
                stopFrame[key]=len(trialPos)
        if len(trialEmptyPosition) >=1:
            print("No position data for trial #:",trialEmptyPosition)
        return rawPosition,stopFrame,trialEmptyPosition
    
    def read_marker_position_file(self,mat_dict):
        if mat_dict is False:
            return []
        nbPixel=(mat_dict["Session"]["FileInfo"]['Scale'])*10
        nTrial=len(mat_dict["Session"]["Fast"])
        rawMarkerPosition=dict(
        Head     = {},
        ForeLimb = {},
        HindLimb = {},
        )
        markerRealName={'Head':'TracksHE', 'ForeLimb':'TracksFL', 'HindLimb':'TracksHL'}
        for key in range(nTrial):
            for marker in rawMarkerPosition.keys():
                trialPos=np.empty(shape=(0,0))
                try:
                    trialPos=(np.asarray(mat_dict["Session"]["Fast"][key]["Data"][markerRealName[marker]]["XYRepCam"])[:,0])/float(nbPixel)                
                except:
                    continue
                finally:
                    rawMarkerPosition[marker][key]=trialPos
        return rawMarkerPosition
    
    def read_cameraStartTime(self,mat_dict):
        """
        Get the first time of every trial.
        """
        if mat_dict is False:
            return np.ones(self.nTrial)*np.nan
        cameraStartTime=np.ones(self.nTrial)*np.nan
        for key in range(self.nTrial):
            try:
                trialTime=mat_dict["Session"]["Fast"][key]["TimeStampsAbs"]
                cameraStartTime[key]=trialTime[0]
            except:
                cameraStartTime[key]=np.nan
        return cameraStartTime
    
    def read_cameraStartTime_eventFile(self):
        time_ms=self.read_csv_pandas(self.fullPath+".evt.cam",oneCol=True)
        if time_ms==[]:
            return time_ms
        time_sec=time_ms/1000.0
        return time_sec
    
    def read_treadmillStartTime_eventFile(self):
        time_ms=self.read_csv_pandas(self.fullPath+".evt.tre",oneCol=True)
        if time_ms==[]:
            return time_ms        
        time_sec=time_ms/1000.0
        return time_sec

            
    def get_trial_info(self):

        nTrialGood=len(self.realTrials)
        
        nTrialPos=len(self.rawPosition)
        
        if nTrialPos != nTrialGood:
            print("%d trials in position file but %d trials acceptable!"%(nTrialPos,nTrialGood))
        
        self.trialNotTracked.extend([trial-1 for trial in self.realTrials if trial-1 not in self.rawPosition])
        self.trialNotTracked=list(set(self.trialNotTracked))
        self.trials=[trial for trial in self.rawPosition if trial not in self.trialNotTracked]
        self.realTrials=[trial+1 for trial in self.trials]
        self.goodTrials=[t for t in self.trials if (self.maxTrialDuration[t]>self.entranceTime[t]>=self.goalTime[t])]
        self.nTrial=len(self.trials)+len(set(self.trialNotTracked))
        

    # some local functions to read matlab .mat files
    def _check_keys(self,dict):
        '''
        checks if entries in dictionary are mat-objects. If yes
        todict is called to change them to nested dictionaries
        '''
        for key in dict:
             if isinstance(dict[key], spio.matlab.mio5_params.mat_struct):
                dict[key] = self._todict(dict[key])
        return dict        

    def _todict(self,matobj):
        '''
        A recursive function which constructs from matobjects nested dictionaries
        '''
        dict = {}
        for strg in matobj._fieldnames:
            elem = matobj.__dict__[strg]
            if isinstance(elem, spio.matlab.mio5_params.mat_struct):
                dict[strg] = self._todict(elem)
            elif isinstance(elem,np.ndarray):
                dict[strg] = self._tolist(elem)
            else:
                dict[strg] = elem
        return dict

    def _tolist(self,ndarray):
        '''
        A recursive function which constructs lists from cellarrays 
        (which are loaded as numpy ndarrays), recursing into the elements
        if they contain matobjects.
        '''
        elem_list = []            
        for sub_elem in ndarray:
            if isinstance(sub_elem, spio.matlab.mio5_params.mat_struct):
                elem_list.append(self._todict(sub_elem))
            elif isinstance(sub_elem,np.ndarray):
                elem_list.append(self._tolist(sub_elem))
            else:
                elem_list.append(sub_elem)
        return elem_list


In [None]:
#run only if inside this notebook (do not execute if "%run this_notebook")
if "__file__" not in dir():
    root="/data"
    experiment="Rat356_2019_08_07_11_49"
    
    animal=experiment[:6]
    param={
    "goalTime":7,#needed for pavel data only
    "treadmillRange":[0,90],#pavel error conversion "treadmillRange":[0,80]
    "maxTrialDuration":15,
    "interTrialDuration":10,#None pavel
    "endTrial_frontPos":30,
    "endTrial_backPos":55, 
    "endTrial_minTimeSec":4,
    "cameraSamplingRate":200, #needed for new setup    

    "sigmaSmoothPosition":0.1,#0.33, 0.18 pavel
    "sigmaSmoothSpeed":0.3,#0.3, 0.5 pavel
     "nbJumpMax":100,#200 pavel
    "binSize":0.25,
    }
    
    data=ViewPointRawBehaviorData(root,animal,experiment,parameters=param)
    print(data.rawPosition[14])

## Preprocessing the behavioral data
Given root, animal, session, detect the format and call the appropriate rawBehavior class:
  
  - ".behav_param" -> it's the new setup
  - No event file (.evt.cam) -> it's not valid data (hasBehavior=False)
  - None above and there is a .samplingrate -> it's Teresa data
  - None above and there is .position files -> it's Pavel data
  - None above: hasBehavior=False

Then, correct the raw position (jump, artefacts) and bin it, so that every trial can be plot and analysed with the same time vector.

Also compute speed, acceleration, median position...

#### Default parameters

 - **"binSize"**: size of the bin in seconds (0.25)
 - **"trialOffset"**: maximum of **maxTrialDuration**
 - **"sigmaSmoothPosition"**: standard deviation for gaussian smooth on position (0.33)
 - **"sigmaSmoothSpeed"**: standard deviation gaussian smooth on speed (0.5)
 - **"positionDiffRange"**: min and max differences between two consecutives position ([2.,5.])
   - min is to correct the start of trial (detect when treadmill actually start, pavel data)  
   - max is to detect and correct jumps (max difference allowed)
     
 - **"pawFrequencyRange"**: for pavel data, not implemented ([2.,10.])
 - **"startAnalysisParams"**: for pavel data, not implemented ([10,0.2,0.5])
 - **"cameraToTreadmillDelay"**: usual time between the camera start and the treadmill start, in seconds (2)
 - **"nbJumpMax"** : maximal number of jump (100). If jumps>nbJumpMax, trial goes in `trialBadlyTracked`
 
 - Parameters to detect end of trial, the first position minima following the conditions:
   - **"endTrial_backPos"** (60): minima is after the animal went once to the back (after first time position>backPos) 
   - **"endTrial_frontPos"** (40): minima's position is in front of treadmill (position[end]\<frontPos)
   - **"endTrial_minTimeSec"** (4): minima is after minTimeSec seconds (time[end]>minTimeSec)
  
  
#### Positions

  - **position**: corrected, smoothed position (match **timeTreadmill**, **timeCamera**)  
      dictionary {trial: [list of position]}  
      trials are skipped if badly tracked
      
  - **positionBin**: binned **position** (same size for every trial), match **timeBin**
  - **speedBin**: binned speed on position (match **timeBin**)
  - **speedSmoothBin**:binned smoothed speed (match **timeBin**)
  - **accelerationOnSpeedBin**:binned acceleration on speed (match **timeBin**)
  - **accelerationOnSpeedSmoothBin**:binned acceleration on smoothed speed (match **timeBin**)
  - **jumpFrame**: dictionary {trial: [list of index]}, frames (index) where the position jumped
  - **transientFrame** : idem for transient
  - **medianPosition**: median of the binned position (match **timeBin**)

#### Times
  - **timeTreadmill** (dict): match rawPosition and position, aligned on treadmill start
  - **timeCamera** (dict):  match rawPosition and position, aligned on camera start
  - **timeBin** (list): match positionBin
  - **timeSpeed** (list): match speed and speedSmooth
  - **timeAcceleration** (list): match accelerationOnSpeed and accelerationOnSpeedSmooth
  
#### Detect end of trial
  Position are aligned on detected end instead of treadmill start.  
  Position are then interpolated and binned to match the same time axis.

  - **timeEndTrial**: time detected for the each trial end, `None` if not detected (list of length `nTrial`)  
      detection is done on `position` if it exist (trial not skipped), or on `rawPosition` 
  - **indexEndTrial**: idem but with index
      
  - **timeAlignEnd**
  - **positionAlignEnd**: last 5 seconds of positions, binned
  - **medianPositionAlignEnd**
 
#### Trials
  - **trialBadlyTracked**: trial skipped (append the trial where position data is not good enough) 
  - **trials**: list of all the trial in dictionary `position` (trial not skipped)
  - **goodTrials**: trial in trials if entrance time > goal time
  - **realTrials**: real index (`trials=0->nTrial-1` and `realTrials=1->nTrial`)
  
#### Other
  - **startFrame**: index where the treadmill start
  - **stopFrame**: index where the treadmill stop in new setup data
  

In [None]:
#NB: startRunningFrame and startAnalysisFrame not implemented
class PreprocessTreadmillOn:
    def __init__(self,root,rat,experiment,param={},saveAsPickle=True,redo=False):
        self.hasBehavior=False
        #path
        
        if not self.compute_paths(root,rat,experiment):
            return        
        
        #Load OR create raw behavior with param
        behaviorDict=self.get_behavior_data_dict(root,rat,experiment,param,saveAsPickle,redo)        
        
        #update the default parameters with the ones provided
        defaultParam={
            "binSize":0.25,
            "sigmaSmoothPosition":0.1,
            "sigmaSmoothSpeed":0.3,
            "positionDiffRange": [2.,5.],
            "pawFrequencyRange":[2.,10.],
            "startAnalysisParams":[10,0.2,0.5],
            "cameraToTreadmillDelay":2.,#sec
            "nbJumpMax" : 100.,
            
            "endTrial_backPos":55, 
            "endTrial_frontPos":30,
            "endTrial_minTimeSec":4,
        }
        try:
            if behaviorDict["dataType"]=='pavel':
                defaultParam["sigmaSmoothPosition"]=0.33
        except:
            print("")
        #update default parameter by the one provided by the user   
        defaultParam.update(param)      
        #update the parameters with the ones used to read raw behavior data
        defaultParam.update(behaviorDict)        
        #update classe attributes
        self.__dict__.update(defaultParam)
        
        if not self.hasBehavior:
            if saveAsPickle:
                name = "preprocesseddata_binsize"+np.str(int(self.binSize*1000))+"ms_.p"
                self.save_as_pickle(name=name)
            return
        
        self.trialOffset=max(self.maxTrialDuration)
        
        #preprocess 
        self.preprocess_behavior()  
        #save according to bin size
        self.align_Behavior_OnBinCenters()  
        if saveAsPickle:
            name = "preprocesseddata_binsize"+np.str(int(self.binSize*1000))+"ms_.p"
            self.save_as_pickle(name=name)
    #------------------------------------------------------------------------------------------ 
    def get_dict(self):
        return self.__dict__

    def save_as_pickle(self,folder="Analysis",name="preprocessTreadmillOn.p"):
        folderPath=os.path.join(self.sessionPath,folder)
        if not os.path.exists(folderPath):
            os.mkdir(folderPath)
        filePath=os.path.join(folderPath,name)
        with open(filePath,"wb") as f:
            pickle.dump(self.__dict__, f)
            f.close()
     
    def get_behavior_data_dict(self,rootFolder,rat,experiment,param,saveAsPickle,redo,PrintWarning=False):
        rawBehaviorPath=os.path.join(self.analysisPath,"rawbehaviordata.p")
        self.hasBehavior=True
        if os.path.exists(rawBehaviorPath) and (not redo):
            try:
                with open(rawBehaviorPath,"rb") as f:
                    behaviorDict=pickle.load(f)
                    f.close()
                if PrintWarning:
                    print("behavior data loaded from %s"%rawBehaviorPath)
                return behaviorDict
            except:
                pass
        
        sampleFile=glob.glob(self.fullPath+".samplingrate")
        behavFile=glob.glob(self.fullPath+".behav_param")
        position=glob.glob(self.fullPath+"*.position")
        eventFile=glob.glob(self.fullPath+".evt.cam")
        matFile=glob.glob(self.fullPath+".mat")
        entrancetimeFiles=glob.glob(self.fullPath+".entrancetimes")
        if matFile or os.stat(position[0]).st_size<=1:
            behaviorDict=ViewPointRawBehaviorData(rootFolder,rat,experiment,
                                           param,saveAsPickle).get_dict()
            if PrintWarning:
                print("Behavior data loaded from text files: view point setup (.mat)")
        elif behavFile and entrancetimeFiles:
            behaviorDict=NewRawBehaviorData(rootFolder,rat,experiment,
                                           param,saveAsPickle).get_dict()
            if PrintWarning:
                print("Behavior data loaded from text files: new setup (.behav_param)")
            
        elif behavFile and not entrancetimeFiles:
            if PrintWarning:
                print("Stop Loading. reward habituation or locomotion no entrancetimes files")
            behaviorDict={}
            self.hasBehavior=False
            
        elif not eventFile:
            if PrintWarning:
                print("Stop Loading: No .evt or .behav_param file")
            self.hasBehavior=False
            behaviorDict={}
        elif sampleFile:
            behaviorDict=TeresaRawBehaviorData(rootFolder,rat,experiment,
                                                            param,saveAsPickle).get_dict()
            if PrintWarning:
                print("Behavior data loaded from text files: Teresa data (.samplingrate)")
        elif position:
            behaviorDict=PavelRawBehaviorData(rootFolder,rat,experiment,
                                                          param,saveAsPickle).get_dict()
            if PrintWarning:
                print("Behavior data loaded from text files: Pavel data (.position with no .samplingrate)")
        else:
            if PrintWarning:
                print("Stop Loading: no .position file(s).")
            behaviorDict={}
            self.hasBehavior=False
        return behaviorDict    
    #------------------------------------------------------------------------------------------ 
    def compute_paths(self,root,animal,experiment,PrintWarning=False):
        #clean name of folders (remove unnecessary slash or backslash)
        self.root=os.sep+root.strip(os.sep)
        self.animal=animal.strip(os.sep)
        self.experiment=experiment.strip(os.sep)
        #paths
        self.sessionPath=os.path.join(self.root,self.animal,"Experiments",self.experiment)
        self.fullPath=os.path.join(self.root,self.animal,"Experiments",self.experiment,self.experiment)  
        self.analysisPath=os.path.join(self.root,self.animal,"Experiments",self.experiment,"Analysis")       
        #Check if the path is correct
        if self.animal not in self.experiment:
            if PrintWarning:
                print("WARNING: session name (%s) does not contain animal name (%s)"%(self.animal,self.experiment))
        if not os.path.exists(self.sessionPath):
            if PrintWarning:
                print("STOP Loading - Path does no exists: %s"%self.sessionPath)
            return False      
        #create analysis folder if needed
        if not os.path.exists(self.analysisPath):
            os.mkdir(self.analysisPath)
        return True                
    #------------------------------------------------------------------------------------------ 
    def preprocess_behavior(self):
        '''
        Main function called in init
        '''
        #new:start frame (index of position where the treadmill start)
        self.startFrame=self.get_start_frame()
               
        #time matching position, aligned to treadmill or camera start (no bin)
        self.timeTreadmill={trial: self.rawTime[trial]+self.cameraStartTime[trial]-self.treadmillStartTime[trial] for trial in self.rawTime}
        #Correct the positions (no smooth, no bin)
        
        #Creates self.position={}, self.jumpFrame,self.transientFrame
        self.preprocess_positions() 
        if self.dataType=="pavel" or self.dataType=="teresa":
            self.stopFrame = self.get_stop_frame()          
        #smooth and bin position
        self.timeBin,self.positionBin=self.get_bin_time_and_position()
            
        #median position from the binned positions
        self.medianPosition=self.get_median_positionDict(self.positionBin)
        
        #update trial list
        self.trials=[trial for trial in self.trials if trial not in self.trialNotTracked]
        self.realTrials=[trial+1 for trial in self.trials]
        self.goodTrials=[trial for trial in self.goodTrials if trial not in self.trialNotTracked]

        #detect end of trials
        self.indexEndTrial=[]
        self.timeEndTrial=[]
        
        for trial in range(self.nTrial):
            if trial in self.trialNotTracked:
                index,time=None,None
            else:
                try:
                    position=self.position[trial]
                    timePos=self.timeTreadmill[trial]
                except KeyError:
                    position=self.rawPosition[trial]
                    timePos=self.rawTime[trial]
                #if self.dataType=="behav_param":
                position=position[:self.stopFrame[trial]]
                timePos=timePos[:self.stopFrame[trial]]
                index, time = self.get_index_time_end_of_trial(position, timePos, trial)
            self.indexEndTrial.append(index)
            self.timeEndTrial.append(time)
            
        #time and position aligned to end + median    
        self.timeAlignEnd,self.positionAlignEnd=self.get_time_and_position_align_end(self.position,
                                                                                     self.timeTreadmill,
                                                                                     self.timeEndTrial)
        self.positionAlignEnd,self.timeAlignEnd=self.remove_allNan_slice(self.positionAlignEnd,self.timeAlignEnd)
        self.medianPositionAlignEnd=self.get_median_positionDict(self.positionAlignEnd)
        
        #speed, acceleration
        self.compute_bin_speed_and_acceleration()

    #------------------------------------------------------------------------------------------ 
    def preprocess_positions(self):   
        '''
        Correct artefacts on positions
        positionTreadmillLightOnTreadmillON: the camera is on process_one_trial_position_continuousFilebefore treadmill on
        positionTreadmillAfterTreadmillON: the camera is on in the trial and the intertrial for the new setup
        '''    
        self.position,positionTreadmillLightOnTreadmillON,positionTreadmillAfterTreadmillON={},{},{}
        self.jumpFrame,self.transientFrame={},{}
        cs=self.cameraSamplingRate
        for trial in self.trials:        
            pos=self.rawPosition[trial]
            startFrame=int(self.startFrame[trial])
            #correct the position vector
            #----------------------------------------------------
            if self.dataType=="behav_param" or self.dataType=="ViewPoint" :
                pos=self.process_one_trial_position_continuousFile(pos,trial)
            else:
                pos=self.process_one_trial_position(pos,trial)   
            if pos is False:
                self.trialNotTracked.append(trial)
                continue  
            #correct positions out of the treadmill range
            posCorrectRange = self.correct_outofrange(pos,trial)
            if len(posCorrectRange)==0:
                self.trialNotTracked.append(trial)
                continue
                
            #following fixes should only be apply during treadmill on
            #keep one before start frame, to avoid jump (except if startFrame is 0)
            if startFrame>=1:
                startFrame=startFrame-1
            posCutted=posCorrectRange[startFrame:] 
            #find jump frames and transient frames
            jumpFrame,transientFrame=self.detect_jump_transient(posCutted,self.positionDiffRange[1])            
            #remove trial if badly tracked (too much jumps)
            nJump=len(jumpFrame)
            if nJump>self.nbJumpMax:
                print("trial %s bad video quality, number of jumps: %s"%(trial,nJump))
                self.trialNotTracked.append(trial)
                continue            
            #correct lost tracking
            borderCorrection=[posCutted[0],posCutted[-1]]
            posFixed=self.correct_trackingbreak(posCutted,borderCorrection)
            #----------------
            if len(posFixed)==0:
                print("trial %s bad tracking (many plateau)"%trial)
                self.trialNotTracked.append(trial)
                continue
            #correct jumps
            border=[posFixed[0],posFixed[-1]]
            posFixed2=self.correct_jumps(posFixed,self.positionDiffRange[1],border)    
            
            #smooth
            posSmooth=self.correct_outofrange(smooth(posFixed2,self.sigmaSmoothPosition*cs,mode="nearest"),trial)
            self.jumpFrame[trial]=jumpFrame
            self.transientFrame[trial]=transientFrame
            positionTreadmillLightOnTreadmillON[trial] = posCorrectRange[:startFrame]
            positionTreadmillAfterTreadmillON[trial] = posSmooth
            #re-attached the beginning (pos between camera and treadmill starts)
        if self.dataType=="pavel":
                if self.parameters["treadmillRange"][1]>=80:#long_treadmill
                        print("scaling ..")
                        positionTreadmillAfterTreadmillON = (self.correct_pixelConversion_Pavel_longTreadmill(positionTreadmillAfterTreadmillON)).copy()
                else:
                    print("You may have a scaling problem, ",
                          "conversion from pixel to cm was not always correct")
        for trial in self.trials:  
            if trial in  self.trialNotTracked: continue
            posAll=np.append(positionTreadmillLightOnTreadmillON[trial],positionTreadmillAfterTreadmillON[trial])
            self.position[trial]=posAll

    
    def process_one_trial_position(self,position,trial):
        '''process_one_trial_position
        Correct the position vector for one trial
        Return false is the trial was badly tracked/ bad quality
        Otherwise returns corrected position
        '''
        #correct ending for good trials
        if self.entranceTime[trial]>=self.goalTime[trial]:
            posCorrectEnd,lastGoodValue = self.correct_ending(position,trial) 
        else:
            posCorrectEnd=position  
        #correct starting artefacts (pavel only)
        if self.dataType=="pavel":
            posCorrectStart,firstGoodValue=self.correct_starting(posCorrectEnd,trial,mindiff=self.positionDiffRange[0])
        else:
            posCorrectStart=posCorrectEnd 
        return posCorrectStart
    
    def process_one_trial_position_continuousFile(self,position,trial):
        '''
        Correct the position vector for one trial
        For the new setup: one position file with continuous position and time
        Skip trial if nb unique position <5
        '''   
        #Number of unique positions 
        nbPosition=len(np.unique(np.around(position)))
        if nbPosition<5 and self.treadmillSpeed[trial] != 0:
            print("trial %s: only %s unique positions, skip"%(trial,nbPosition))
            return False             
        return position
    #------------------------------------------------------------------------------------------ 
    def get_start_frame(self):
        '''
        index of the position and time vector where the treadmill start
        some value in treadmillStart or cameraStart can be None (in new setup)
        '''
        startFrame,indexes=[],[]
        trial=0
        for tr,cm in zip(self.treadmillStartTime,self.cameraStartTime):
            if isNone(tr) or isNone(cm):
                start=0
            else:
                start=max(np.floor((tr-cm)*self.cameraSamplingRate),0)
            startFrame.append(int(start))
            #check the delay between camera and treadmill, print warning if needed
            if (tr-cm)>(self.cameraToTreadmillDelay+0.5):
                indexes.append(trial)
            trial+=1
        if indexes:
            print("WARNING: some trials with difference between Camera and Treadmill start times> %s"
                  %(self.cameraToTreadmillDelay+1))
            print("Trial indexes: %s"%indexes)
        return startFrame
    #------------------------------------------------------------------------------------------ 
    def correct_jumps(self,position,maxdiff=5,border_correction=[0,0]): 
        ''' 
        A function to correct for big jumps
            position: 1d array of position values in cm
            borderCorrection : default values for first and last positions (nans in borders alter interpolation)
            maxdiff : maximum difference between two successives positions (Treadmill On), set by default to 5.
            fixed: 1d array of position values corrected to recover jumps
        '''
        fixed = position.copy()
        differenceBetweenpositions = np.abs(np.diff(fixed))
        fixed[np.where(np.abs(differenceBetweenpositions)>maxdiff)[0]+1] = np.nan
        if np.isnan(fixed[0]): 
            fixed[0] = border_correction[0]
        if np.isnan(fixed[-1]): 
            fixed[-1] = border_correction[1]
        goodIndex = np.where(np.isnan(fixed)==False)[0]
        correctJump = interp1d(goodIndex,fixed[goodIndex])(np.arange(len(position)))
        return correctJump
            
    def correct_trackingbreak(self,position,border_correction=[0,0]):
        ''' 
        A function to correct for lost tracking : 
            constant position values (plateau) are explained by lost tracking or treadmill no moving
            Position: 1d array of position values in cm
            BorderCorrection : default values for first and last positions (nans in borders alter interpolation)
            RecoveredPositions: 1d array of position values corrected to recover tracking breaks
        '''
        fixed = position.copy()
        differenceBetweenpositions = np.diff(fixed)
        fixed[np.where(differenceBetweenpositions==0)[0]+1] = np.nan
        if np.isnan(fixed[0]): 
            fixed[0] = border_correction[0]
        if np.isnan(fixed[-1]): 
            fixed[-1] = border_correction[1]
        goodIndex = np.where(np.isnan(fixed)==False)[0]
        
        if len(goodIndex)<10:
            #trial will be add to trialNotTracked
            return []

        recoveredPositions = interp1d(goodIndex,fixed[goodIndex])(np.arange(len(position)))
        return recoveredPositions
            
    def correct_starting(self,rawPosition,trial,mindiff=2):
        ''' 
        A function to correct the starting artefacts in the position values
        position: 1d array of position values in cm for one trial
        treadmillSpeed:  treadmill speed in cm/s (one float)
        mindiff: minimum difference between two successives positions (Treadmill On), set by default to 2.
        firstGoodValue: index of first position in the given treadmill range and having a difference
                        with the previous position >=minidiff
        position: 1d array where the position values before the firstGoodValue were replaced by theorical 
            values corresponding to the Rat immobile and the treadmill moving at TreadmillSpeed,
            the resulting values <range[0] were set to range[0] (range[0]: minimum position on the treadmill)  
        '''
        #print("hhhhh",self.startFrame[trial])
        position=rawPosition[self.startFrame[trial]:]   #discard before treadmill start
        conditionDiff=np.abs(np.diff(position))>=mindiff
        conditionMax=position[1:]<=self.treadmillRange[1]
        if self.dataType=="pavel":
            conditionMax=position[1:]<=90#although the real treadmill size is 80
        conditionMin=position[1:]>=self.treadmillRange[0]
        try:
            firstGoodValue=((conditionDiff & conditionMax & conditionMin).nonzero()[0][0])+1 
        except IndexError:
            #if no firstGoodValue found
            firstGoodValue=0            
        maxFrame=2.0*self.cameraSamplingRate #firstGoodValue should be between 0 and 2 sec
        if firstGoodValue>maxFrame:
            firstGoodValue=0            
        ts=float(self.treadmillSpeed[trial])
        cs=float(self.cameraSamplingRate)
        position[0:firstGoodValue] = position[firstGoodValue]-(ts*firstGoodValue/cs)
        position[np.where(position<self.treadmillRange[0])[0]] = self.treadmillRange[0]
        
        pos=np.append(rawPosition[:self.startFrame[trial]],position) #
        return pos, firstGoodValue
    
    def correct_ending(self,position,trial):
        '''
        If the treadmill do not stop after entrance time: create a plateau at the end
        '''
        position_ = position.copy()
        time = self.timeTreadmill[trial]#np.arange(len(position))/float(self.cameraSamplingRate)
        try:
            lastGoodValue =(((np.abs(np.diff(position))==0)&(time[1:]>self.entranceTime[trial])).nonzero()[0][0])+1 
            position_[lastGoodValue:] = position_[lastGoodValue]
        except IndexError:
            lastGoodValue = len(position)-1
        return position_, lastGoodValue
    
    def correct_outofrange(self,position,trial):
        ''' 
        A function to correct the out of range values
            position: 1d array of position values in cm for one trial
            inRangePositions: 1d array of position values corrected to be in the good treadmill range 
        '''
        correctedPosition = position.copy()
        if self.dataType=="pavel":
            correctedPosition[position>90] = np.nan #although the real treadmill size is 80
        else:
            correctedPosition[position>self.treadmillRange[1]] = np.nan
        
        correctedPosition[position<self.treadmillRange[0]] = np.nan
        goodIndex = np.where(np.isnan(correctedPosition)==False)[0]
        #=======Mostafa: to avoid error when no point is within normal range===============
        if len(goodIndex)==0:
            print("trial %{0} skipped, position out of range!".format(str(trial)))
            return np.array([])
        #=================================================================================
        if np.isnan(correctedPosition[0]):        
            correctedPosition[0:goodIndex[0]] = correctedPosition[goodIndex[0]]
        if np.isnan(correctedPosition[-1]): 
            correctedPosition[goodIndex[-1]:] =  correctedPosition[goodIndex[-1]]
        goodIndex = np.where(np.isnan(correctedPosition)==False)[0]
        inRangePositions = interp1d(goodIndex,correctedPosition[goodIndex])(np.arange(len(position)))
        #Mostafa: to fix the start artifact
        ind=np.where(abs(np.diff(inRangePositions[:50]))>10)[0]
        if len(ind)<5:
            for i in ind:
                inRangePositions[i]=inRangePositions[max([i-1,0])]
        return inRangePositions
    
    def detect_jump_transient(self,position,maxdiff=3,ratio=0.8):
        ''' 
        A function to detect the frames corresponding to jumps  or transients 
            position: 1d array of position values in cm
            maxdiff : maximum difference between two successives positions (Treadmill On), set by default to 3.
            ratio : ratio between successive jumps in a transient
            jumps: list of frames corresponding to jumps (diff>maxdiff)
            transients = list of frames corresponding to transients (~successive opposit jumps)
        '''
        differenceBetweenpositions = np.diff(position)
        jumps = (np.abs(differenceBetweenpositions[:-1])>maxdiff).nonzero()[0]
        jumps = sorted(jumps) #simple ascend sorting
        transients = []
        for zz in jumps: # make sure the jumps detected are not transiant
            if differenceBetweenpositions[zz+1]/differenceBetweenpositions[zz]<-ratio:
                transients.append(zz)
        return jumps,transients      
    
    #------------------------------------------------------------------------------------------ 
    def get_bin_time_and_position(self,removeNanSlice=True,binSize=None,trialOffset=None):
        '''
        #(Bin and smooth the corrected position,) Bin the smoothed corrected position
        so every positions match one unique time vector
        binSize in second
        trialOffset: trial duration to consider, in seconds (default is max(maxTrialDuration))
        OUTPUT: np.array for time, dictionary {trial:np.array} for positions
        '''
        if binSize is None:
            binSize=self.binSize
        if trialOffset is None:
            trialOffset=self.trialOffset
        #time, same for every trial
        timeBin = np.arange(0,trialOffset+binSize-trialOffset%binSize,binSize)
        #bins for raw position
        rawBinSize = 1.0/self.cameraSamplingRate
        rawTimeBinEdges  = np.arange(0,trialOffset+rawBinSize-trialOffset%rawBinSize,rawBinSize)
        nbFramesMax = len(rawTimeBinEdges)
        cs=self.cameraSamplingRate 
        positionBin={}
        for trial in self.position:
            #cut at treadmill start
            pos=self.position[trial][int(self.startFrame[trial]):]
            #if self.dataType=="behav_param":
            pos=self.position[trial][int(self.startFrame[trial]):self.stopFrame[trial]-1]
               
            #number of frames to keep
            nbFrames = np.min([len(pos),nbFramesMax])   
            posSmooth=np.full(nbFramesMax,np.nan)
            #smooth position w: position is already smoothed
            #sigma=self.sigmaSmoothPosition
            posSmooth[:nbFrames]=pos[:nbFrames]#self.correct_outofrange(smooth(pos,sigma*cs,mode="nearest"))[:nbFrames]
            #bin
            positionBin[trial]=interp1d(rawTimeBinEdges,posSmooth,bounds_error=False,fill_value=np.nan)(timeBin)
        if removeNanSlice:
            positionBin,timeBin=self.remove_allNan_slice(positionBin,timeBin)    
        return timeBin,positionBin
    
    def compute_bin_speed_and_acceleration(self):
        '''
        From a smoothed corrected position (not the binned one), compute the speed and acceleration
        Also compute a smoothed speed, and an acceleration on smoothed speed
        '''
        time = np.arange(0,self.trialOffset+self.binSize-self.trialOffset%self.binSize,self.binSize)
        self.timeSpeed = time[:-1]+self.binSize/2.
        self.timeAcceleration = time[1:-1] 
        
        cs=self.cameraSamplingRate 
        rawBinSize = 1.0/cs
        rawTimeBinEdges  = np.arange(0,self.trialOffset+rawBinSize-self.trialOffset%rawBinSize,rawBinSize)
        rawTimeBinCenters = rawTimeBinEdges[:-1]+rawBinSize/2.
        nbFramesMax = len(rawTimeBinEdges)
        
        self.speedBin={}
        self.speedSmoothBin={}
        self.accelerationOnSpeedBin={}
        self.accelerationOnSpeedSmoothBin={}
        for trial in self.position:
            #cut at treadmill start
            pos=self.position[trial][int(self.startFrame[trial]):]
            #if self.dataType=="behav_param":
            pos=self.position[trial][int(self.startFrame[trial]):self.stopFrame[trial]-1]
            #number of frames to keep
            nbFrames = np.min([len(pos),nbFramesMax])   
            #empty arrays of the right size
            ##posSmooth=np.full(nbFramesMax,np.nan)
            speed=np.full(nbFramesMax-1,np.nan)
            speedSmooth=np.full(nbFramesMax-1,np.nan)
            acc=np.full(nbFramesMax-2,np.nan)
            accS=np.full(nbFramesMax-2,np.nan)
            #position
            ##sigma=self.sigmaSmoothPosition
            ##posSmooth[:nbFrames]=self.correct_outofrange(smooth(pos,sigma*cs,mode="nearest"))[:nbFrames]
            #speed
            #speed[:nbFrames-1]=self.treadmillSpeed[trial]-cs*np.diff(posSmooth[:nbFrames])
            speed[:nbFrames-1]=self.treadmillSpeed[trial]-cs*np.diff(pos[:nbFrames])
            #speed smooth      
            sigma=self.sigmaSmoothSpeed
            speedSmooth[:nbFrames-1]=smooth(speed[:nbFrames-1],sigma*cs,mode="constant",cval=0)
            #acceleration on speed
            acc[:nbFrames-2]=cs*np.diff(speed[:nbFrames-1])
            #acceleration on speed smooth
            accS[:nbFrames-2]=cs*np.diff(speedSmooth[:nbFrames-1])
                 
            #Bin everything
            self.speedBin[trial]=interp1d(rawTimeBinCenters,speed,bounds_error=False,
                                             fill_value=np.nan)(self.timeSpeed)
            self.speedSmoothBin[trial]=interp1d(rawTimeBinCenters,speedSmooth,bounds_error=False,
                                                   fill_value=np.nan)(self.timeSpeed)
            self.accelerationOnSpeedBin[trial]=interp1d(rawTimeBinEdges[1:-1],acc,bounds_error=False,
                                                           fill_value=np.nan)(self.timeAcceleration)
            self.accelerationOnSpeedSmoothBin[trial]=interp1d(rawTimeBinEdges[1:-1],accS,bounds_error=False,
                                                                 fill_value=np.nan)(self.timeAcceleration)      
        self.speedBin, self.timeSpeed = self.remove_allNan_slice(self.speedBin, self.timeSpeed)
        self.speedSmoothBin,t=self.remove_allNan_slice(self.speedSmoothBin)
        self.accelerationOnSpeedBin,self.timeAcceleration=self.remove_allNan_slice(self.accelerationOnSpeedBin,
                                                                                self.timeAcceleration)
        self.accelerationOnSpeedSmoothBin,t=self.remove_allNan_slice(self.accelerationOnSpeedSmoothBin)
    #------------------------------------------------------------------------------------------ 
    def get_index_time_end_of_trial(self,position,time,trial):
        '''
        Detect the end of the trajectories
        Use corrected position or raw position
        '''
        backPos=self.endTrial_backPos 
        frontPos=self.endTrial_frontPos
        minTimeSec=self.endTrial_minTimeSec          
        if np.sum(position>=backPos)==0:
             #skip trial if the animal never goes above backPos centimeters (=back of the treadmill)
            return None,None
        elif len(position)!=len(time):
            #error
            print("WARNING: position and time of different length, can't detect end")
            return None,None
        else:
            #find first occurence of "being above backPos" (back of the treadmill)
            firstTimeAboveX=time[np.where(position>=backPos)[0][0]]
        #for all position, find if it's a minima or not (True/False)
        # <= or won't catch minima if there is a plateau at the end
        isMinima=np.r_[False,position[1:] < position[:-1]] & np.r_[position[:-1]<=position[1:],True]
        #condition1: position<=frontPos (position is near front of the treadmill)
        positionIsLow=position<=frontPos
        #condition2: time>minTimeSec (don't keep early minimas)
        timeIsAboveMin=time>=minTimeSec
        #condition3: animal has been on the back of the treadmill once (above backPos)
        timeIsAfterReachBack=time>firstTimeAboveX
        
        #apply the conditions
        isMinima=isMinima & positionIsLow & timeIsAboveMin & timeIsAfterReachBack
        try:
            minimaIndex=np.where(isMinima==True)[0][0]
        except IndexError:
            return None, None

        #look for a minima before, in case of a nearly flat plateau
        minPos=position[minimaIndex]+0.5 #the position for the minima, in cm
        while (position[minimaIndex]<=minPos) and (minimaIndex>=0):
            minimaIndex=minimaIndex-1
        return minimaIndex,time[minimaIndex]
    
        
    def get_time_and_position_align_end(self,positionDict,timeDict,timeEnd,minEnd=-5,maxEnd=5,binSize=None):
        '''
        Given position and time for each trial, aligned them to the end
        Interpolate to match one time vector (between minEnd and maxEnd, with binSize)
        timeEnd= list of end time (relative to zero of timeDict)
        '''
        if binSize is None:
            binSize=self.binSize
        timeAlignEnd=np.arange(minEnd,maxEnd,self.binSize)
        
        positionAlignEnd={}
        for trial in positionDict:
            end=timeEnd[trial]
            if isNone(end):
                continue
            filledPos=interp1d(timeDict[trial]-end,positionDict[trial],
                               bounds_error=False,fill_value=np.nan)(timeAlignEnd)
            positionAlignEnd[trial] = filledPos
        return timeAlignEnd, positionAlignEnd
    #------------------------------------------------------------------------------------------ 
    def get_median_positionDict(self,positionDict):
        '''
        Given a dictionary {trial: [np.array]} with all array of same length,
        compute the median. Returns one array.
        '''
        try:
            posArray=np.asarray(list(positionDict.values()),dtype=np.float64)
        except ValueError:
            print("Can't compute median, array are not all the same sizes"%positionDict)
            return []
        if posArray.shape[0]==0:
            return []
        medianPosition=np.nanmedian(posArray,axis=0)
        if np.all(np.isnan(medianPosition)):
            medianPosition=[]
        return medianPosition
    
    def remove_allNan_slice(self,dictionary,time=None):
        try:
            posArray=np.asarray(list(dictionary.values()),dtype=np.float64)
        except ValueError:
            #the array are not all the same size
            return dictionary,time
        if posArray.shape[0]==0:
            return dictionary,time
        indexToRemove=[]
        for i in range(posArray.shape[1]):
            sli=posArray[:,i]
            if np.all(np.isnan(sli)):
                indexToRemove.append(i)               
        if indexToRemove:
            for trial in dictionary:
                dictionary[trial]=np.delete(dictionary[trial],indexToRemove)
            if time is not None:
                time=np.delete(time,indexToRemove)
        return dictionary,time
    def correct_pixelConversion_Pavel_longTreadmill(self,positionSmoothBinned):
        lowerTreadmillSize=60
        upperTreadmillSize=90
        treadmillSize = 80
        maxPos = []
        pos = {}
        for trial in positionSmoothBinned:
            startFrame=int(self.startFrame[trial])
            if startFrame>=1:
                startFrame=startFrame-1
            cuttedTime = self.timeTreadmill[trial][startFrame:] 
            trialOffset = np.min([self.entranceTime[trial],self.trialOffset])
            if trialOffset<self.goalTime[trial]: trialOffset = self.trialOffset
            indexSelect= (cuttedTime<trialOffset)
            #print("i",indexSelect)
            p = positionSmoothBinned[trial].copy()
            pos [trial] = p[indexSelect].copy()
            maxPos = np.append(maxPos,np.max(pos[trial]))
        scaling = 1.
        if np.sum(maxPos>treadmillSize)>len(self.entranceTime)/10.:
            scaling = 1.*treadmillSize/upperTreadmillSize
        if np.sum(maxPos<lowerTreadmillSize)>len(self.entranceTime)/2.:
            scaling = 1.*treadmillSize/lowerTreadmillSize
        if not np.isnan(scaling):
            scaledPosition = {}
            for trial in positionSmoothBinned:
                scaledPosition[trial] = positionSmoothBinned[trial]*scaling
                scaledPosition[trial][scaledPosition[trial]>treadmillSize] = treadmillSize-1 
        return  scaledPosition#
    def align_Behavior_OnBinCenters(self):
        
        for behaviorAttribute in ["positionBin","speedBin","speedSmoothBin","accelerationOnSpeedBin","accelerationOnSpeedSmoothBin"]:
            if behaviorAttribute in["positionBin"]: 
                t = self.timeBin
            if behaviorAttribute in ["speedBin","speedSmoothBin"]: 
                t = self.timeSpeed
            if behaviorAttribute in ["accelerationOnSpeedBin","accelerationOnSpeedSmoothBin"]: 
                t = self.timeAcceleration
            for trial in self.speedBin:
                    self.__dict__[behaviorAttribute][trial] = interp1d(t,self.__dict__[behaviorAttribute][trial],bounds_error=False, fill_value=np.nan)(self.timeSpeed[1:-1])
        self.timeBin = self.timeSpeed[1:-1]
        for key in ["timeSpeed","timeAcceleration"]:
            delattr(self, key)      
        self.medianPosition=self.get_median_positionDict(self.positionBin)
    #------------------------------------------------------------------------------------------ 
    def get_stop_frame(self):
        if self.dataType=="pavel" or self.dataType=="teresa":
            stopFrame={}
            for trial in self.position:
                if self.entranceTime[trial]>=self.goalTime[trial]:
                    stopFrame[trial]=np.where(self.timeTreadmill[trial]>=self.entranceTime[trial])[0][0]
                else:
                    stopFrame[trial]=np.where(self.timeTreadmill[trial]>=self.maxTrialDuration[trial])[0][0]
            return stopFrame
        else:
            return None


In [None]:
if "__file__" not in dir():
    root="/data"
    animal="Rat262"
    experiment="Rat262_2018_06_04_15_38"
    param={
        "goalTime":7,
        "treadmillRange":[0,90],
        "maxTrialDuration":60,
        "interTrialDuration":10,
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "binSize":0.25,
    }  
    data=PreprocessTreadmillOn(root,animal,experiment,param,redo=True)
    print(data.hasBehavior)
    print(data.experiment)
    if data.hasBehavior:
        print(data.cameraSamplingRate)
        print(len(data.goalTime))

### class Data: data for one session (rootFolder, rat, experiment)
  - Load preprocess behavior and spike data (read pickles if they exist, or load raw data and preprocess it)
  - Optionnal arguments:  
    - **param**: dictionary of parameters for loading/preprocessing
    - **saveAsPickle**: whether to save .p files if they don't exist already (default True)
    - **redoPreprocess**: whether to do the preprocessing even if a pickle already exist (default False) 
  - Functions to save any plot to html/png
  - Functions to select a cluster group and to add a new one
  - Function **describe()** to display all attributes
  
  
  - **date**: datetime.datetime for the session (python date module)
  - **hasSpike**: whether there is sorted spike data (.kwik or .clu)

In [None]:
class Data:
    header1="<!DOCTYPE html>\n<html>\n<head>\n<meta charset='utf-8' />\n<title>\n"
    header2="</title>\n</head><body><p>\n"
    bottom="</p></body></html>"
    
    def __init__(self,rootFolder,animal,experiment,param={},saveAsPickle=True,redoPreprocess=False,PrintWarning=False):
        self.hasBehavior=False
        #path
        if not self.compute_path(rootFolder,animal,experiment):
            return
             
        if "binSize" in param:
            binSize=param["binSize"]
        else:
            binSize=0.25
            
        #load (or do) preprocess behavior for given bin size
        
        dicBehavior=self.get_preprocess_behavior_dict(binSize,rootFolder,animal,experiment,
                                                      param,saveAsPickle,redoPreprocess,PrintWarning=PrintWarning)
        self.__dict__.update(dicBehavior)
        if not self.hasBehavior:
            return

       
        self.date=self.get_date()
        self.daySinceStart=self.get_session_day()

        
        #load raw spike 
        try:
            self.hasSpike=False
            dicSpike=self.get_spike_data_dict(rootFolder,animal,experiment,param,saveAsPickle,redoPreprocess)
            self.__dict__.update(dicSpike)
#             if self.hasSpike:
#                 self.spikeCountSmooth,self.spikeCountTime=self.get_smoothSpikeCounts([0,self.trialOffset])
        except NameError as e:
            if 'Klusta_RawSpikeData' in globals():
                print(repr(e))
            else:
                print("you must: %run loadRawSpike_documentation.ipynb for spike data")
    
        AnimalTagPath=os.path.join(self.root,self.animal,animal+".tag")
        if os.path.isfile(AnimalTagPath): 
            #function is called from Animal_Tags.ipynb
            self.tag=get_session_profile(self.root,self.animal,os.path.basename(self.sessionPath))['Tag']

#===========Mostafa: to manually fix the position artifact when needed========================================
    def position_correction(self):
        for trial in self.trials:
            ind=np.where(abs(np.diff(self.position[trial][:50]))>10)[0]
            if len(ind)<5:
                for i in ind:
                    self.position[trial][i]=self.position[trial][max([i-1,0])]
        return self.position

#--------------------------------------------------------------------------------  
    def compute_path(self,root,animal,experiment,PrintWarning=False):
        #clean name of folders (remove unnecessary slash or backslash)
        self.root=os.sep+root.strip(os.sep)
        self.animal=animal.strip(os.sep)
        self.experiment=experiment.strip(os.sep)
        #paths
        self.sessionPath=os.path.join(self.root,self.animal,"Experiments",self.experiment)
        self.fullPath=os.path.join(self.root,self.animal,"Experiments",self.experiment,self.experiment)  
        self.analysisPath=os.path.join(self.root,self.animal,"Experiments",self.experiment,"Analysis")
        #Check if the path is correct
        if self.animal not in self.experiment:
            if PrintWarning:
                print("WARNING: session name (%s) does not contain animal name (%s)"%(self.animal,self.experiment))
        if not os.path.exists(self.sessionPath):
            if PrintWarning:
                print("STOP Loading - Path does no exists: %s"%self.sessionPath)
            return False
        return True
    
    def get_date(self,experimentName=None):
        if experimentName is None:
            experimentName=self.experiment
        else:
            experimentName=os.path.basename(experimentName).strip(os.sep)
        goodFormat="%Y_%m_%d_%H_%M"
        dateFormats=[goodFormat,"%Y-%m-%d_%H-%M-%S","%Y_%m_%d_%H_%M_%S","%Y_%m_%d-%H_%M","%Y_%m_%d-%H_%M_%S"]
        for dateFormat in dateFormats:
            fullFormat=self.animal+"_"+dateFormat
            try:
                date=datetime.datetime.strptime(experimentName,fullFormat)
            except ValueError:
                continue #try next format
                
#            if dateFormat!=goodFormat:
#                #rename the folder
#                newName=date.strftime(goodFormat)
#                newPath=os.path.join(self.root,self.animal,"Experiments",newName)
#                os.rename(self.sessionPath,newPath)
#                print("Renamed %s in %s"%(experimentName,newName))
#                #if it was this session, change the attribute
#                if experimentName==self.experiment:
#                    self.experiment=newName
            return date

        print("WARNING: session %s does not match any date formats"%experimentName)
        return None
    
    def get_session_day(self):
        #this session is the Xe day since the first recording of this animal
        expList=glob.glob(os.path.join(self.root,self.animal,"Experiments",self.animal+"*"))
        firstSession=sorted(expList)[0]
        firstDate=self.get_date(experimentName=firstSession)
        if (firstDate is None) or (self.date is None):
            return None
        firstDate=firstDate.replace(hour=0,minute=0,second=0)
        date=self.date.replace(hour=0,minute=0,second=0)
        X=(date-firstDate).days +1
        return X

    #--------------------------------------------------------------------------------  
    def get_preprocess_behavior_dict(self,binSize,rootFolder,rat,experiment,param,saveAsPickle,redo,PrintWarning=False):
        name="preprocesseddata_binsize"+np.str(int(binSize*1000))+"ms_.p"
        preprocessPath=os.path.join(self.analysisPath,name)
        if os.path.exists(preprocessPath) and (not redo):
            try:
                with open(preprocessPath,"rb") as f:
                    dicBehavior=pickle.load(f)
                    f.close()
                if PrintWarning:
                    print("Preprocess behavior data loaded from %s"%preprocessPath)
                return dicBehavior
            except:
                pass
        if PrintWarning:
            print("Preprocessing behavior data...")
        dicBehavior=PreprocessTreadmillOn(rootFolder,rat,experiment,param,saveAsPickle,redo).get_dict()
        if PrintWarning:
            print("Preprocessing done")
            
        
        return dicBehavior
            
    def get_spike_data_dict(self,rootFolder,rat,experiment,param,saveAsPickle,redo,PrintWarning=False):
        self.hasSpike=True
        rawSpikePath=os.path.join(self.analysisPath,"rawspikedata.p")
        if os.path.exists(rawSpikePath) and (not redo):
            try:
                with open(rawSpikePath,"rb") as f:
                    spikeDict=pickle.load(f)
                    f.close()
                print("Spike data loaded from %s"%rawSpikePath)
                return spikeDict
            except:
                pass
        if glob.glob(self.sessionPath+os.sep+"*.kwik"):
            spikeDict=Klusta_RawSpikeData(rootFolder,experiment,parameters=param,saveAsPickle=saveAsPickle).get_dict()  
        elif glob.glob(self.sessionPath+os.sep+"*.clu*"):
            spikeDict=Kluster_RawSpikeData(rootFolder,rat,experiment,param,saveAsPickle=saveAsPickle).get_dict()
        else:
            if PrintWarning:
                print("No spike data")
            self.hasSpike=False
            return {}
        if PrintWarning:
            print("Spike data loaded from raw files")
        return spikeDict

    #-------------------------------------------------------------------------------- 
    def create_empty_html(self,path,name):
        if not path.endswith(".html"):
            path=path+".html"
        with open(path,"w") as f:
            f.write(self.header1+name+self.header2)
            f.write(self.bottom)
            f.close()
            
    def insert_in_html(self,path,insertList,name=None):
        if not os.path.exists(path):
            if name is None:
                name=os.path.basename(path)
            self.create_empty_html(path,name)
        if isinstance(insertList,str):
            insertList=insertList.split("\n")
            
        with open(path,"r+") as f:
            contents=[line.rstrip("\n") for line in f if line!="\n"]
            f.seek(0)
            f.truncate() 
            contents=contents[:-1]+insertList+[self.bottom]
            f.write("\n".join(contents))
            f.close()
            
    def remove_lines_in_html(self,path,contents):
        if not os.path.exists(path):
            return
        newLines=[]
        with open(path,"r+") as f:
            for line in f:
                if (contents not in line) and (line!="\n"):
                    newLines.append(line.strip("\n"))
            f.seek(0)
            f.truncate()
            f.write("\n".join(newLines))
            f.close()
    
    #--------------------------------------------------------------------------------  
    #--------------------------------------------------------------------------------
    def plot_session_png_html(self,plotFunctionList,name=None,override=False,**kwargs):
        if not isinstance(plotFunctionList,list):
            plotFunctionList=[plotFunctionList]
        
        if name is None:
            name=str(plotFunctionList[0].__name__)
        
        #html for the animal
        generalName="all_"+name
        generalFolder=os.path.join(self.root,self.animal,"Analysis")
        if not os.path.exists(generalFolder):
            os.mkdir(generalFolder)
        generalPath=os.path.join(generalFolder,generalName+".html")
        
        #save the plots as png, create html image tag
        images=[]
        for plotFunction in plotFunctionList:
            name=plotFunction.__name__+".png"
            path=os.path.join(self.sessionPath,name)
            #override or not
            if  override or not os.path.exists(path):
                hasPlot=plotFunction(self,**kwargs)
                if hasPlot is False:
                    continue
                try:
                    plt.savefig(path,bbox_inches='tight')
                except:
                    try:
                        plt.savefig(path)
                    except:pass
                finally:
                    plt.close()
            
            images.append("<a href=#%s><img src='%s' alt='%s' title='%s'/></a>"%(self.experiment,path,name,name))
            self.remove_lines_in_html(generalPath,path)
                
        #insert images in general html
        self.insert_in_html(generalPath,images,generalName)
        #print("Html updated: %s"%generalPath)
            
    def plot_all_clusters_png_html(self,plotFunctionList,name=None,override=False,groupList=None,**kwargs):
        #create folder for plots
        folderPath=os.path.join(self.sessionPath,"plots")
        if not os.path.exists(folderPath):
            os.mkdir(folderPath)

        if not isinstance(plotFunctionList,list):
            plotFunctionList=[plotFunctionList]
        if not isinstance(groupList,list):
            groupList=[groupList]
            
        #html for the session
        if name is None:
            name=str(plotFunctionList[0].__name__)
        htmlName=name
        htmlPath=os.path.join(self.sessionPath,htmlName+".html")
                
        #html for the animal
        generalName="all_"+htmlName
        generalFolder=os.path.join(self.root,self.animal,"Analysis")
        if not os.path.exists(generalFolder):
            os.mkdir(generalFolder)
        generalPath=os.path.join(generalFolder,generalName+".html")
           
        #override if needed
        if os.path.exists(htmlPath):
            if override:
                print("Override html %s"%htmlPath)
                os.remove(htmlPath)
                #remove links in general html
                self.remove_lines_in_html(generalPath,self.experiment)
            else:
                print("Html already exists: %s"%htmlPath)
                return
    
        #save the plots as png, create html image tag
        images=[]
        for shank in sorted(self.clusterGroup):
            print("Shank %s"%shank)
            for group in self.clusterGroup[shank]:
                if (groupList is not None) and (group not in groupList):
                    continue
                for cluster in sorted(self.clusterGroup[shank][group]):
                    six.print_(cluster,end=" ")
                    for plotFunction in plotFunctionList:
                        name="shank%s_cluster%s_%s.png"%(shank,cluster,plotFunction.__name__)
                        path=os.path.join(folderPath,name)
                        if override or (not os.path.exists(path)):
                            hasPlot=plotFunction(self,shank,cluster,group,**kwargs)
                            if hasPlot is False:
                                continue
                            plt.savefig(path)
                            plt.close()
                        images.append("<a href=#%s-%s-%s><img src='%s' alt='%s' title='%s'/></a>"%(self.experiment,shank,cluster,path,name,name))
                print("")
                
        #insert all images in session html
        self.insert_in_html(htmlPath,images,htmlName)
        #insert a link in general html
        link=["<a href='"+htmlPath+"'>"+self.experiment+"</a><br>"] 
        self.insert_in_html(generalPath,link,generalName)
        
        print("Html updated: %s"%htmlPath)
        print("Html updated: %s"%generalPath)
     
    #--------------------------------------------------------------------------------
    def select_cluster_group(self,selectedGroups):
        for shank in self.channelGroupList:
            selectedClusters=[]
            for group in selectedGroups:
                selectedClusters+=self.clusterGroup[shank][group]
            self.spikeSample[shank]={cluID:self.spikeSample[shank][cluID] for cluID in selectedClusters}
            self.spikeTime[shank]={cluID:self.spikeTime[shank][cluID] for cluID in selectedClusters}
            self.clusterGroup[shank]={group:self.clusterGroup[shank][group] for group in self.selectedGroups}
            
    def add_cluster_group(self,name,shankCluList):
        for (shank,clu) in shankCluList:
            if name not in self.clusterGroup[shank]:
                self.clusterGroup[shank][name]=[clu]
            elif clu not in self.clusterGroup[shank][name]:
                self.clusterGroup[shank][name].append(clu)
     
    #-----------------------------add 22/3/16--------------------------------------------------
    def get_smoothSpikeCounts(self,TimeRange):
        '''
        A function that returns smoothed spike counts from a dictionnary of spiketimes and given a binsize 
        Input:
            Spiketimes: a dictionnary where each key is a trial index and each 
                        value is an array of spike times corresponding to that trial
            TimeRange:  Time axis min and max, ex: [0,20] in sec
            BinSize: Time bin size, ex = 0.25
        Output:
            SmoothSpikeCount: a 2d (Ntrials, Nbins) array containing the smoothed spike counts for each trial 
            SpikingTimeBinCenters: 1d array of TimeBins centers values (Time Axis)
        '''
        SpikeTime = self.spikeTime
        TrialStartTime = self.treadmillStartTime
        SpikingTimeBinEdges = np.arange(TimeRange[0],TimeRange[1]+self.binSize-TimeRange[1]%self.binSize,self.binSize)
        SpikingTimeBinCenters = SpikingTimeBinEdges[:-1]+self.binSize/2.
        SmoothSpikeCount = {}
        for shank in self.spikeTime:
            #if shank not in self.channelGroupSubList:
            #    continue
            SmoothSpikeCount[shank] ={}
            for clu in SpikeTime[shank]:
                #consider_clu=False
                #for clusterGroupName in self.clusterGroupSelection: 
                #    if clu in self.clusterGroup[shank][clusterGroupName]:
                #        consider_clu=True
                #if not consider_clu:
                #    continue
                SmoothSpikeCount[shank][clu] = np.zeros((len(TrialStartTime),len(SpikingTimeBinCenters)))
                for trial in range(len(TrialStartTime)):  
                        SpikeTime_ = SpikeTime[shank][clu]-TrialStartTime[trial]
                        SpikeTime_= SpikeTime_[ (SpikeTime_<=TimeRange[1])& (SpikeTime_>=TimeRange[0]) ]
                        SpikeCount,a = np.histogram(SpikeTime_,SpikingTimeBinEdges)
                        SmoothSpikeCount[shank][clu][trial,:] = smooth(1.*SpikeCount,0.25/self.binSize)        
        return SmoothSpikeCount,SpikingTimeBinCenters
    def get_firingTreadmillOnOff(self):
        '''
        A function that returns smoothed spike counts from a dictionnary of spiketimes and given a binsize 
        Input:
            Spiketimes: a dictionnary where each key is a trial index and each 
                        value is an array of spike times corresponding to that trial
            TimeRange:  Time axis min and max, ex: [0,20] in sec
            BinSize: Time bin size, ex = 0.25
        Output:
            SmoothSpikeCount: a 2d (Ntrials, Nbins) array containing the smoothed spike counts for each trial 
            SpikingTimeBinCenters: 1d array of TimeBins centers values (Time Axis)
        '''
        SpikeTime = self.spikeTime
        TrialStartTime = self.treadmillStartTime
        #SpikingTimeBinEdges = np.arange(TimeRange[0],TimeRange[1]+self.binSize-TimeRange[1]%self.binSize,self.binSize)
        #SpikingTimeBinCenters = SpikingTimeBinEdges[:-1]+self.binSize/2.
        SmoothSpikeCountON = {}
        SmoothSpikeCountOFF = {}
        for shank in self.spikeTime:
            #if shank not in self.channelGroupSubList:
            #    continue
            SmoothSpikeCountON[shank] ={}
            SmoothSpikeCountOFF[shank] ={}
            for clu in SpikeTime[shank]:
                #consider_clu=False
                # for clusterGroupName in self.clusterGroupSelection: 
                #    if clu in self.clusterGroup[shank][clusterGroupName]:
                #        consider_clu=True
                #if not consider_clu:
                #    continue
                SmoothSpikeCountON[shank][clu] ={}
                SmoothSpikeCountOFF[shank][clu] ={}
                SpikeTime_ = SpikeTime[shank][clu]
                startSessionTime = SpikeTime[shank][clu][0]
                endSessionTime = SpikeTime[shank][clu][-1]
                SpikingTimeBinEdges = np.arange(startSessionTime,endSessionTime+self.binSize,self.binSize)
                SpikeCount,a = np.histogram(SpikeTime_,SpikingTimeBinEdges)
                SmoothSpikeCount= smooth(1.*SpikeCount,0.25/self.binSize) 
                for trial in range(len(TrialStartTime)-1):  
                       
                        offset = 20
                        if self.entranceTime[trial]>=self.goalTime[trial]:
                            offset=self.entranceTime[trial]
                        start = TrialStartTime[trial]
                        stop = TrialStartTime[trial]+offset
                        end = TrialStartTime[trial+1]
                        on = (SpikingTimeBinEdges[:-1]>=start) & (SpikingTimeBinEdges[:-1]<stop+self.binSize)
                        #print(on)
                        off = (SpikingTimeBinEdges[:-1]>=stop) & (SpikingTimeBinEdges[:-1]<=end)
                        
                        SmoothSpikeCountON[shank][clu][trial] = SmoothSpikeCount[on]
                        SmoothSpikeCountOFF[shank][clu][trial] = SmoothSpikeCount[off]
                        #print(SmoothSpikeCount[on][-1],SmoothSpikeCount[off][0])
                        #print(SmoothSpikeCount[SpikingTimeBinEdges[:-1]==],SmoothSpikeCount[off][-1])
        return SmoothSpikeCountON,SmoothSpikeCountOFF
    
    #----------------------------------------------------------------------
    def describe(self):
        dic=self.__dict__
        print("Session: %s"%self.experiment)
        print("Full Path: %s"%self.fullPath)
        print("Number of trials: %s"%self.nTrial)
        sep="-"*(28+10+30+20+5)
        print(sep)
        print("{: <28} {: <10} {: <30} {: <30}".format("**Name**","**Type**","**Content**","**Extract**"))
        print(sep)
        for (key,value) in sorted(dic.items()):
            t=type(value).__name__
            glance=""
            if isinstance(value,list) or isinstance(value,np.ndarray):
                v="length="+str(len(value))    
                glance="["+" ".join(["%.2f"%x if isinstance(x,float) else str(x) for x in value])
                if len(glance)<30:
                    glance+="]"
                else:
                    glance=glance[:27]+"..."
            elif isinstance(value,dict):
                keys=list(value.keys())
                v="nKeys="+str(len(keys))
                glance="keys: "+str(keys)
                if len(glance)>30:
                    glance=glance[:27]+"..."
            else:
                v=value
            row=[key,t,str(v),glance]
            print("{: <28} {: <10} {: <30} {:<30}".format(*row))
            


In [None]:
if "__file__" not in dir():
    root="/data"
    experiment="Rat284_2019_01_29_14_32"
    animal=experiment[:6]
    #Rat041_2015_10_08_09_55 Rat124_2017_02_24_18_40 Rat124_2017_04_12_17_22 Rat106_2017_04_03_17_27

    param={
        "goalTime":7,
        "treadmillRange":[0,90],
        "maxTrialDuration":20,
        "interTrialDuration":10,
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "binSize":0.25,
    }  
    data=Data(root,animal,experiment,param=param,saveAsPickle=False,redoPreprocess=False,PrintWarning=True)
    print(data.lickTime)
    #--------------------