# This notebook loads camera and treadmill clock signals and saves start times in .evt.tre and .evt.cam files

### *.dat* file is needed to extract clock signals
### *.prm* file is needed to load the sampling frequency
### behavior data (*.behav_param* and *.entrancetimes*) needed to separate treadmill recording from homecage recording.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from IPython.display import clear_output
from pprint import pprint  #if the module is not available, comment this line out

#utility function to check if a given signal is a binary signal using
#correlation coeficient
def isClock(sig,fs):
    sigM=np.mean(sig)
    sig2=sig-sigM
    sig2=np.sign(sig2)#clock is a perfect binary signal now
    if len(sig) > 5*60*fs: #5min
        L=int(5*60*fs)
    else:
        L=len(sig)-1
    if np.corrcoef(sig[:L],sig2[:L])[0,1] > 0.98:
        return True
    else:
        return False

In [None]:
def clock_to_rate(clock,fs):
    '''
    this function recieves a clock signal and returns correspondant
    rate signals with 3 levels:
    0: no pulses
    1: pulses
    '''
    if isClock(clock,fs) == False:
        return False
    clockM=np.mean(clock)
    clock=clock-clockM
    clock=np.sign(clock)#clock is a perfect binary signal now
    frameShotTime=np.diff(clock)
    frameShotTime[frameShotTime<0]=0 #a single pulse on the rising edge
    shotTimeIndex=np.nonzero(frameShotTime>0)[0]
    
    shotDistance=np.diff(shotTimeIndex)
    interTrialOnset=np.nonzero(shotDistance>2*fs)[0]
    frameRate=np.ones(clock.shape)
    for i in range(len(interTrialOnset)):
        frameRate[shotTimeIndex[interTrialOnset[i]]:
                  shotTimeIndex[interTrialOnset[i]+1]]=0
    frameRate[0:shotTimeIndex[0]]=0
    frameRate[shotTimeIndex[-1]:]=0
    return frameRate


if "__file__" not in dir():
    #datfile="/data/Rat107/Experiments/Rat107_2017_03_17_11_01/Rat107_2017_03_17_11_01.dat"
    #--------------------------------------------------------------------------------
    #dat=np.fromfile(datfile,dtype=np.int16)
    #dat=np.reshape(dat,(-1,1),order='C')
    #dat=np.reshape(dat,(-1,13),order='C')
    #clock=dat[:,-3].copy()
    #del dat
    #a=clock_to_rate(clock,2e4)
    #print(a) if a is False else []

In [None]:
def camera_clock_to_rate(clock,fs):
    '''
    this function recieves a clock signal and returns correspondant
    rate signals with 3 levels:
    0: no pulses
    1: slow pulses
    2: fast pulses
    '''
    if isClock(clock,fs) == False:
        return False
    clockM=np.mean(clock)
    clock=clock-clockM
    clock=np.sign(clock)#clock is a perfect binary signal now
    frameShotTime=np.diff(clock)
    frameShotTime[frameShotTime<0]=0 #a single pulse on the rising edge
    frameRateIndex=np.reciprocal(np.diff(np.nonzero(frameShotTime>0.5)[0])/fs)
    frameRateIndex=frameRateIndex-np.mean(frameRateIndex)
    frameRateIndex=(np.sign(frameRateIndex)+3)/2 #to map to 1&2
    timeFrameRate=np.zeros(clock.shape)
    timeFrameRate[np.nonzero(frameShotTime>0.5)[0]]=np.append(frameRateIndex,frameRateIndex[-1])
    frameType={
        0: np.nonzero(timeFrameRate==1)[0], #slow
        1: np.nonzero(timeFrameRate==2)[0] #fast
        }
    index=np.min([frameType[0][0],frameType[1][0]])
    firstFrameType=np.argmin([frameType[0][0],frameType[1][0]])
    frameRate=np.zeros(timeFrameRate.shape)
    while True:  #the loop iterates over trials not data samples
        frameRate[index:]=firstFrameType+1 #convert to 1 or 2
        firstFrameType = not firstFrameType
        tmpIndex=np.nonzero(frameType[firstFrameType]>index)[0]
        if tmpIndex.size>0:
            index=frameType[firstFrameType][tmpIndex[0]]
        else:
            break
    index=np.max([frameType[0][-1],frameType[1][-1]])
    frameRate[index:]=0
    return frameRate


if "__file__" not in dir():
    #datfile="/data/Rat105/Experiments/Rat105_2016_12_14_16_53/Rat105_2016_12_14_16_53.dat"
    #dat=np.fromfile(datfile,dtype=np.int16)
    #dat=np.reshape(dat,(-1,1),order='C')
    #dat=np.reshape(dat,(-1,37),order='C')
    #clock=dat[:,-2].copy()
    #del dat
    #--------------------------------------------------------------------------------
    #a=camera_clock_to_rate(clock,2e4)
    

In [None]:
def write_evt_file(outputPath,sig,fs,overwrite):
    uniqueValues=np.unique(sig)
    fileType=''
    if len(uniqueValues)==2:
        fileType='.evt.tre'
    elif len(uniqueValues)==3:
        fileType='.evt.cam'
    experiment=os.path.basename(outputPath)
    output=os.path.join(outputPath,experiment+fileType)
    if os.path.exists(output) and not overwrite:
        return None
    if len(uniqueValues)>3 or len(uniqueValues)<2:
        #0,1 and 2 in case of camera
        print("bad clock signal, cannot write:",output)
        return False

    clock=sig.copy()
    clock[clock<uniqueValues[-1]]=0
    clock[np.append(0,np.diff(clock))<=0]=0   #each nonzero element is an event
    events=np.nonzero(clock)[0]
    eventList_ms=[events[i]/(fs/1000) for i,_ in enumerate(events)]
    try:
        with open(output,'w') as f:
            for event in eventList_ms:
                f.write(str(event))
                f.write('\n')   #new line
    except Exception as e:
        print(repr(e))
        return False
    print("wrote:",output)
    return True

#--------------------------------------------
if "__file__" not in dir():
    #outputPath="/data/Rat107/Experiments/Rat107_2017_03_10_10_33"
    
    #-------
    #b=write_evt_file(outputPath,a,int(2e4),True)

In [None]:
def extract_session_evt(sessionPath,nbTre,nbCam,overwrite=False):
    '''
    loads required information from sessionPath and writes
    the *.evt files for each session
    '''
    experiment=os.path.basename(sessionPath)
    behavFile=os.path.join(sessionPath,experiment+'.behav_param')
    etFile=os.path.join(sessionPath,experiment+'.entrancetimes')
    datFile=os.path.join(sessionPath,experiment+'.dat')
    prmFile=os.path.join(sessionPath,experiment+'.prm')
    
    if (not os.path.exists(behavFile)) or (not os.path.exists(etFile)):
        return []
    if (not os.path.exists(datFile)) or (not os.path.exists(prmFile)):
        return []
    if overwrite is False:
        cond1=os.path.exists(os.path.join(sessionPath,experiment,".evt.cam"))
        cond2=os.path.exists(os.path.join(sessionPath,experiment,".evt.tre"))
        if cond1 and cond2:
            return None            
        
    #reading the prm file to get the required info
    nameDic=prm_reader(prmFile)
    fs=nameDic['sample_rate']
    nchannels=nameDic['nchannels']
    
    #reading the dat file to get the required info
    dat=np.fromfile(datFile,dtype=np.int16)
    dat=np.reshape(dat,(-1,nchannels),order='C')
    treClock=dat[:,-abs(nbTre)].copy()
    camClock=dat[:,-abs(nbCam)].copy()
    del dat
    treClock=clock_to_rate(treClock,fs)
    camClock=camera_clock_to_rate(camClock,fs)
    treEventResult=write_evt_file(sessionPath,treClock,fs,overwrite)
    camEventResult=write_evt_file(sessionPath,camClock,fs,overwrite)
    return treEventResult and camEventResult

def prm_reader(prmFile):
    CWD=os.getcwd()
    try:
        os.chdir(os.path.dirname(prmFile))
        prmName=os.path.basename(prmFile)
        %run $prmName
    finally:
        os.chdir(CWD)
    return globals()
 
#------------------------------------------------------    
if "__file__" not in dir():
    #sessionPath="/data/Rat107/Experiments/Rat107_2017_03_16_10_52"
    nbTre=-3
    nbCam=-2
    #--------------------------------------------------------------------------------
    #a=extract_session_evt(sessionPath,nbTre,nbCam)
    #print(a)

In [None]:
def extract_evt_batch(root, animalList, nbTre, nbCam, overwrite):
    failedSessions={'Write':[],'Exist':[],'NoBehavior':[]}
    for animal in animalList:
        animalPath=os.path.join(root,animal)
        sessionList=[os.path.basename(expPath) for expPath in 
                     glob.glob(os.path.join(animalPath,"Experiments",animal+'*'))]
        sessionList=sorted(sessionList)
        for session in sessionList:
            sessionPath=os.path.join(animalPath,"Experiments",session)
            res=extract_session_evt(sessionPath,nbTre,nbCam,overwrite=overwrite)
            if res is None:
                failedSessions['Exist'].append(session)
            elif res==[]:
                failedSessions['NoBehavior'].append(session)
            elif res is False:
                failedSessions['Write'].append(session)
    clear_output()
    print('\nfailed sessions:')
    try:
        pprint(failedSessions)
    except:
        print(failedSessions)
    print('Done!')

## Script to Run the notebook as a Batch
### only sessions with *.dat*, *.prm*, *.behav_param* and *.entrancetimes* files will be included
#### root: path to main data folder
#### animalList: animals for which to generate the *.evt* files
#### nbTre: channel number for TREADMILL signal with reference to the last channel. ex: -1(last channel), -2(channel before the last one)
#### nbCam: channel number for CAMERA signal with reference to the last channel. ex: -1(last channel), -2(channel before the last one

In [None]:
if "__file__" not in dir():
    root="/data/"
    #in Windows paths must be like this: root="C:\\Data\\Recordings\\" (double backslash instead of single)
    
    animalList=["Rat105"]
    
    nbTre=-2
    
    nbCam=-1

    overwrite=False
  
    #--------------------------------------------------------------------------------
    extract_evt_batch(root, animalList, nbTre, nbCam, overwrite)