In [None]:
def isClock(sig,fs):
    sigM=np.mean(sig)
    sig2=sig-sigM
    sig2=np.sign(sig2)#clock is a perfect binary signal now
    if len(sig) > 5*60*fs: #5min
        L=int(5*60*fs)
    else:
        L=len(sig)-1
    if np.corrcoef(sig[:L],sig2[:L])[0,1] > 0.98:
        return True
    else:
        return False

def prm_reader(prmFile):
    CWD=os.getcwd()
    try:
        os.chdir(os.path.dirname(prmFile))
        prmName=os.path.basename(prmFile)
        %run $prmName
    finally:
        os.chdir(CWD)
    return globals()


### Reading entire excel file

In [None]:
class ReadExcelFile:
    def __init__(self,root,animal,fileName=None):
        self.animal=animal
        self.excelPath=''
        if fileName is None:
            path=os.path.join(root,animal,animal)+'*.xls*'
            excelfiles=glob.glob(path)
            assert len(excelfiles)!=0, "No Excel files"+path
            assert len(excelfiles) ==1, "Too many Excel files"+str(excelfiles)
            self.excelPath=excelfiles[0]
        else:
            self.excelPath=fileName
            assert os.path.isfile(self.excelPath), "Bad Excel file path"
        
        self.read_excel_file()
    
    def __repr__(self):
        return " ".join(['Excel file at:',self.excelPath])

    
    def read_excel_file(self):
        safeExcel=safe_copy_from_nas(self.excelPath)
        path=safeExcel.start()
        with pd.ExcelFile(path) as file:
            sheets=file.sheet_names
            self.excelData={sheet:pd.read_excel(file,sheet) for sheet in sheets}
        safeExcel.stop(fileTypes=[])

## find files based on extension

In [None]:
def find_file(path, extension=['.raw.kwd']):
    """
    This function finds all the file types specified by 'extension' (ex: *.dat) in the 'path' directory
    and all its subdirectories and their sub-subdirectories etc., 
    and returns a list of all file paths
    'extension' is a list of desired file extensions: ['.dat','.prm']
    """
    if type(extension) is str:
        extension=extension.split()   #turning extension into a list with a single element
    return [os.path.join(walking[0],goodfile) for walking in list(os.walk(path)) 
         for goodfile in walking[2] for ext in extension if goodfile.endswith(ext)]

## Safe working with NAS

In [None]:
class safe_copy_from_nas:
    
    def __init__(self,path):
        assert isinstance(path,str)
        self.filePath=path
    
    def start(self,tempName='.NASqwrytdvn2u1r'):
        """
        safely copying the file to a temp local dir
        """
        if "NAS" in self.filePath or "NETDATA" in self.filePath:
            logging.info("Copying from NAS... "+self.filePath)
            defaultPath=os.path.expanduser("~") #home folder
            self.tempdir=os.path.join(defaultPath,tempName)
            try:
                os.mkdir(self.tempdir)
            except FileExistsError: #in case of multiple files being copied the directory already exists
                pass
            try:
                self.newPath=copy(self.filePath,self.tempdir)
            except Exception as e:
                logging.warning("could not copy from NAS to local drive!"+self.filePath)
                logging.info(repr(e))
                self.newPath=""
        else:
            self.newPath=self.filePath
        return self.newPath
    
    def stop(self,fileTypes=['.prm','.dat','.eeg']):
        """
        uploading all files with fileTypes found in the temp directory back to
        the remote server, and removing everything from local drive
        """
        if "NAS" in self.filePath or "NETDATA" in self.filePath:
            files=find_file(self.tempdir,fileTypes)             
            try:    
                for newfile in files:
                    copy(newfile,os.path.dirname(self.filePath))
                    logging.info("Uploaded to NAS! "+newfile)
            except Exception as e:  #Exceptions should not be raised in this level!
                logging.warning("upload to NAS failed! "+newfile)
                logging.info(repr(e))
            finally:
                #remove everything, ignore errors
                rmtree(self.tempdir,ignore_errors=True)


### Permutation Hypothesis Testing

In [None]:
class permtest_output:
    def __init__(self,D0,shuffledD=None,p_val=None,band=None, pairwise_CI=None,sig_signal=None):
        self.statistic=D0
        self.shuffled_data=shuffledD
        self.p_val=p_val
        self.significant=sig_signal
        self.boundary=band
        self.pairwise_alpha=pairwise_CI
        self.significant=sig_signal

        
def perm_statistic(x,y,Nx,Ny,sigma=0):
    
    if len(x) <2:
        return (x/Nx)-(y/Ny)
    
    #the gaussian kernel
    x_smooth=scipy.ndimage.filters.gaussian_filter1d(x, sigma=sigma, order=0, mode='constant', cval=0, truncate=4.0)
    y_smooth=scipy.ndimage.filters.gaussian_filter1d(y, sigma=sigma, order=0, mode='constant', cval=0, truncate=4.0)
    
    return (x_smooth/Nx)-(y_smooth/Ny)
    
def permtest(x,y, iterN=1000,sigma=0.05):
    """
    Permutation test as to whether x>y or not.
    x,y:
    represent the data. they could be eitherr one dimentional(several realizations)
    or 2-D (several realizaions through out the time/space/... course)
        EX: x.shape==(15,500) means 15 trials/samples over 500 time bins
    
    iterN:
    number of iterations used to shuffle. max(iterN)=(len(x)+len(y))!/len(x)!len(y)!
    
    sigma:
    the standard deviation of the gaussian kernel used for smoothing when there are multiple data points,
    based on the Fujisawa 2008 paper, default value: 0.05
    """
    
    #input check
    if x.ndim>2 or y.ndim>2:
        raise ValueError('bad input dimentions')
    elif x.ndim==1 or y.ndim==1:
        x=np.reshape(x,(len(x),1))
        y=np.reshape(y,(len(y),1))
    
    #computing the tset statistic
    xTrial,yTrial=x.shape[0],y.shape[0]
    
    x_superimpos=np.nansum(x,axis=0)
    y_superimpos=np.nansum(y,axis=0)
    
    D0=perm_statistic(x_superimpos,y_superimpos,x.shape[0],y.shape[0])
    
    # shuffling the data
    Dshuffled=np.ones((iterN,len(x_superimpos)))*np.nan
    for i in range(iterN):
        tmpShuffle=np.concatenate((x,y),axis=0)
        np.random.shuffle(tmpShuffle)  #works in-plcae
        xNew,yNew=tmpShuffle[:xTrial,:],tmpShuffle[xTrial:,:]
        
        xNew_superimpos=np.nansum(xNew,axis=0)
        yNew_superimpos=np.nansum(yNew,axis=0)
        
        Dshuffled[i,:]=perm_statistic(xNew_superimpos,yNew_superimpos,xNew.shape[0],yNew.shape[0],sigma)
    
    if len(D0)<2:  #single point comparison
        p_val0=np.sum(Dshuffled>=D0,axis=0)/(iterN+1)
        return permtest_output(D0=D0,p_val=p_val0,shuffledD=Dshuffled,sig_signal=bool(p_val0<=0.05))
    
    #global bands
    alpha=100
    CI=5  #global confidance interval
    pairwise_high_band=np.percentile(a=Dshuffled,q=100-CI,axis=0)
    
    while alpha>=5:
        high_band=np.percentile(a=Dshuffled,q=100-CI,axis=0)
        breaks=np.sum([np.sum(Dshuffled[i,:]>high_band)>1 for i in range(iterN)])
        alpha=(breaks/iterN)*100
        CI=0.95*CI
        logging.info("Global Confidence interval at "+CI+'\nComputing again...\n')
    
    #finding significant bins
    global_sig=D0>high_band
    pairwise_sig=D0>pairwise_high_band
    sigIndex=np.where(global_sig)[0]
    
    for i in sigIndex:
        if i==0 or i==len(global_sig):
            continue
        global_sig[np.min((np.where(pairwise_sig[:i])[0][-1],i)):np.max((np.where(pairwise_sig[i:])[0][-1],i))]=True
    
    return permtest_output(D0=D0,shuffledD=Dshuffled,sig_signal=global_sig,p_val=CI,band=high_band)

### Sample size control by random selection

In [None]:
class sample_size_control:
    def __init__(self,func,animalList,NbAnimal,n,**kwargs):
        """
        func: function to be applied to the randomly chosen animals
        NbAnimal: number of animals to be considered in each iteration of func
        n: max number of iterations
        kwargs: give any input necessary to run "func" comma-seperated, like normal function arguments
        """
        if NbAnimal>len(animalList):
            raise ("NbAnimal must be smaller than animal list")
#         if n==0:
#             tmp=scipy.special.comb(len(animalList), NbAnimal, exact=True, repetition=False)
#             n=max([1000,tmp])
        
        self.iterN=n
        self.animalList=animalList
        self.func=func
        self.subsetSize=NbAnimal
        self.kwargs=kwargs
        
        self.animalRepeat=np.ones(len(self.animalList))

        self.Results=self.run_function()
        
        
    def random_animal_subset(self):
        prob=np.sum(self.animalRepeat)*(1/self.animalRepeat)
        prob=prob/np.sum(prob)
        animalListSubset=np.random.choice(a=self.animalList,size=self.subsetSize,replace=False,p=prob)
        self.animalRepeat+=[animal in animalListSubset for animal in self.animalList]

        return animalListSubset
            
    def run_function(self):
        #inputArgs=inspect.getargspec(self.func)[0]
        i=0
        result=[]
        Args=self.kwargs

        #calculating estimated time needed to process
        t0=time.perf_counter()
        Args.update({'animalList':self.random_animal_subset()})
        result.append(self.func(**Args))
        Args.update({'animalList':self.random_animal_subset()})
        result.append(self.func(**Args))
        i=2
        t_elapsed=(time.perf_counter()-t0)/2
        logging.info("Estimated time to run sample size control<"+str(t_elapsed*self.iterN)+'s')

        while i<self.iterN:
            Args.update({'animalList':self.random_animal_subset()})
            result.append(self.func(**Args))
            i+=1

        return result

### session RMSE

In [None]:
def plot_rmse(data,onlyGood=False,maxTreadmillLength=90,raw=False):
    '''
    Compute the rmse of the position between trial start and trial stop (treadmill stop)
    plots and returns the RMSE matrix
    '''
    allTraj=get_positions_array_beginning(data,onlyGood=onlyGood,raw=raw)
    NbTrial=allTraj.shape[0]
    
    if NbTrial<3:
        title="Not enough trials"
        med=np.nan
        return med

    rmse=np.ones((NbTrial,NbTrial))*(-1)

    for i in range(NbTrial):
        rmse[i,i]=0
        for j in range(i+1,NbTrial):
            maxL=min([np.sum(np.logical_not(np.isnan(allTraj[i,:]))),np.sum(np.logical_not(np.isnan(allTraj[j,:])))])
            rmse[i,j]=np.sqrt(np.sum((allTraj[i,:maxL]-allTraj[j,:maxL])**2)/maxL)
    
    RMSEmatrix=np.triu(rmse,k=0)+np.triu(rmse,k=0).T #symetrical
    RMSEmatrix/=maxTreadmillLength
    pp=plt.pcolor(RMSEmatrix,vmin=0,vmax=1,cmap="Reds")
    plt.colorbar(pp)
    plt.xlim([0,RMSEmatrix.shape[0]])
    plt.ylim([0,RMSEmatrix.shape[1]])
    
    #median of upper triangle of matrix
    coef=RMSEmatrix[np.tril_indices(RMSEmatrix.shape[0],-1)]
    #print(len(coef))
    med=np.nanmedian(coef)
    maxSecond=allTraj.shape[1]/float(data.cameraSamplingRate)
    #title of the plot
    title=""
    if onlyGood:
        title="Good Trials"
    else:
        title="All Trials"

    title+=', trajectory median  r= %.2f'%med     
    plt.title(title)

    
    return med,RMSEmatrix

### Trajectory PDF

In [None]:
def plot_trajectory_PDF(data,TimeRes=.5,PosRes=5,onlyGood=False,**kargs):
    """
    calculates and plots the joint PDF of trajectories.
    time resolution in seconds
    Position resolution in cm
    """
    
    #data=Data(root,session[:6],session,defaultParam,redoPreprocess=False)
    allTraj=get_positions_array_beginning(data,onlyGood).T
    trialDuration=scipy.stats.mode(data.maxTrialDuration)[0]

    posSize =len(np.arange(data.treadmillRange[0],data.treadmillRange[1],PosRes))
    timeSize=len(np.arange(0,trialDuration,TimeRes))
    trajDis=np.zeros([timeSize,posSize])
    
    #replacing nans w/ the last position
    allTraj=allTraj//PosRes

    for t in range(allTraj.shape[0]-1):
        timeIndex=int((t/data.cameraSamplingRate)//TimeRes)
        trajDis[timeIndex,:]=[np.sum(allTraj[t,:]==x) for x in range(posSize)]

    trajDis=scipy.ndimage.filters.gaussian_filter(trajDis, sigma=[1,1],
                                                  order=0, mode='nearest', truncate=3)
    #normalizing as a PDF
    trajDis/=np.sum(trajDis)

#     plt.figure();
    plt.pcolor(trajDis.T, cmap=cm.hot,**kargs);
    ax=plt.gca();
    ax.set_xticks     (np.linspace(0,timeSize,5));
    ax.set_xticklabels(np.linspace(0,trialDuration,5));
    ax.set_yticks     (np.linspace(0,posSize,10));
    ax.set_yticklabels(np.linspace(data.treadmillRange[0],data.treadmillRange[1],10));
    
    return trajDis

def twoD_entropy(trajDist):
    H=0
    for i in range(trajDist.shape[0]):
        for j in range(trajDist.shape[1]):
            try:
                H+=trajDist[i,j]*math.log(float(trajDist[i,j]),2)
            except:
                pass
    H=-H
    return H

### read session files

In [None]:
def read_file(data,paramName,extension=".behav_param",exclude=None,valueType=str):
    '''
    Use to read from .behav_param or .entrancetimes
    Look for lines containing "paramName" and not containing "exclude"
    Split them by white spaces 
    example: "treadmill speed:     30.00" becomes ["treadmill","speed:","30.00"])
    Return a list of their last element, in the specified valueType (in example: "30.00")
    '''
    behav=data.fullPath+extension
    if not os.path.exists(behav):
        print("No file %s"%behav)
        data.hasBehavior=False
        return []
    result=[]
    trials=[0]
    with open(behav,"r") as f:
        for line in f:
            if "Trial #" in line:
                trials.append(int(float(line.split()[-1]))-1)
            if paramName in line:
                if (exclude is not None) and (exclude in line):
                    continue
                res=line.split()[-1]
                #integer or float: replace comma by dots
                if valueType in [int,float]:
                    res=res.replace(",",".")                 
                #integer: convert first to float ("0.00" -> 0.00 -> 0)
                if valueType is int:
                    res=int(float(res))
                #boolean "TRUE" "FALSE"
                elif valueType is bool:
                    res=(res.lower()=="true")
                else:
                    res=valueType(res)
                result.append( (trials[-1],res) )
    out=[np.nan]*(trials[-1]+1)
    for item in result:
        out[item[0]]=item[1]
    return np.asarray(out)

### Calculate rate

In [None]:
def compute_average_rate(timePoints,minDis=0,maxDis=np.inf):
    """
    timePoints: list of times of occuring of events (in sec)
    minDis= minimum distance between events to be considered valid
    """
    tDiff=np.diff(timePoints)
    tDiff=tDiff[np.logical_and(tDiff>minDis,tDiff<maxDis)]
    return 1/np.nanmean(tDiff)

def compute_rate (x,winLen,overlap=0.5,zero=0,end=None):
    """
    x: list like data with times of event, in sec
    winLen: length of window in sec
    overlap: normalized overlap: (0,1)
    zero: begining of the time axis
    end: maximum of time axis
    ??window: window param of scipy.signal.get_window
    """
    assert overlap<1 and overlap>0, "bad overlap value"
    x=np.array(x)
    if end is None:
        end=x[-1]
#     if window is None:
#         window='boxcar'
#     win=scipy.signal.get_window(window,winLen)
    Range=np.arange(zero,end,(1-overlap)*winLen)
    out=[]
    for i,_ in enumerate(Range):
        a=x[np.logical_and(x>=Range[i],x<Range[i]+winLen)]
        out.append(len(a)/winLen)
    return np.array(out),Range

### Copy PRB file

In [None]:
def prb_copy (prbfile, animalFolder):
    """
    prbfile='/home/david/Mostafa/info/prb-config files/8tetrode_8channelgroup.prb'
    animalFolder='/NETDATA/Rat172/Experiments/'
    
    """
    for dat in find_file(animalFolder, extension=['.dat']):
        prb2=copy(prbfile,os.path.dirname(dat))
        os.rename(prb2,dat[:-3]+'prb')


### Read a single channel from a _*.dat_ file

In [None]:
def read_ephy_epoch(filename, fs, Nch, wantedCh, t0, t1):
    assert filename.endswith(('.dat','.DAT','.Dat')), "bad file type: Not *.DAT"
    
    sampleSize=np.dtype(np.int16).itemsize
    systembyte=sys.byteorder
    n0=int(t0*fs*Nch*sampleSize)
    
    signal=[]

    with open(filename,'rb') as f:
        if t1>t0:
            n1=int(t1*fs*Nch*sampleSize)
        elif t1==-1:
            f.seek(0,2)
            n1=f.tell()
        else:
            raise ValueError("t1 must be greater than t0, or -1")

        f.seek( n0 + ((wantedCh-1) *sampleSize))
        n=n0
        step=(Nch-1)*sampleSize
        while n < n1:
            data=f.read(sampleSize)
            signal.append(int.from_bytes(data,systembyte,signed=True))
            n+=step+sampleSize
            f.seek(step,1)
        
        f.close()
        
    return np.array(signal)

## Merging several .dat files together 
(MUST have the same number of channels and sampling frequency)

In [None]:
def dat_merger(files: list, nCh: int):
    """
    files: a list of all the dat file paths you wish to merge (a list of strings)
    nCh: number of channels (int)
    """
    dat=[]
    for file in files:
        data=np.fromfile(file)
        data=np.reshape(data,(-1,nCh))
        dat.append(data)

    out=np.concatenate([array for array in dat],axis=0)
    del dat
    path=f'{os.path.dirname(files[0])}{os.sep}MERGED.dat'
    out.tofile(path)
    print(f'saved in {path}')

# Data Fetcher

In [None]:
def data_fetch(root: str, animal: str, profile: dict, PerfParam: list, NbSession: slice =5):
    """
    returns the data requested by PerfParam
    PerfParam: a list of known performance parameters or functions recieving data objest as input
    """
    if not isinstance(PerfParam,list):
        PerfParam=[PerfParam]
    
    perf=[]
    func=[]
    for item in PerfParam:
        if isinstance(item,types.FunctionType):
            func.append(item)
        elif isinstance(item,str):
            perf.append(item)

    if not isinstance(NbSession,slice):
        if NbSession >0:
            NbSession=slice(NbSession)
        else:
            NbSession=slice(NbSession,None)

    sessions=batch_get_session_list(root,[animal],profile=profile)['Sessions'][NbSession]
#     assert < len(sessions), "not enough sessions with this profile"
    
    res=dict((param,[]) for param in perf)
    res.update((param.__name__,[]) for param in func)
    for session in sessions:
        data=Data(root,session[:6],session,redoPreprocess=False)
        
        p1=compute_or_read_stats(data, perf, 
                                 saveAsPickle=False, redo=False)            
        for param in perf:
            res[param].append(p1[param])
            
        for fun in func:
            res[fun.__name__].append(fun(data))
    
    return res

## Find Times when trials end (Treadmill stops)

In [None]:
def int_(x):
    try:
        return int(x)
    except:
        return x

def punishment_duration(data,trial,minDuration=1,maxDuration=10):
    beamIgnore=read_file(data,paramName="consider beam state after (s)",valueType=float)[trial]
    punishTime=maxDuration-maxDuration*(data.entranceTime[trial]-beamIgnore)/(data.goalTime[trial]-beamIgnore)
    punishTime=max((punishTime,minDuration))
    return punishTime

def detect_trial_end(data, trials=None):
    if trials is None:
        trials=data.trials
        
    for trial in trials:
        if data.timeEndTrial[trial] is not None:
            if trial in data.goodTrials or data.timeEndTrial[trial] >= data.entranceTime[trial]+.99:
                continue
        if data.entranceTime[trial] > data.goalTime[trial]:
            data.timeEndTrial[trial]=data.entranceTime[trial]
            data.indexEndTrial[trial]=int_(data.timeEndTrial[trial]*data.cameraSamplingRate)
            continue
        else: #implementing the punishment rule
            data.timeEndTrial[trial]=data.entranceTime[trial]+punishment_duration(data,trial)
            data.indexEndTrial[trial]=int_(data.timeEndTrial[trial]*data.cameraSamplingRate)
            
    return data.timeEndTrial

## Get Ordered colors

In [None]:
def get_colors(n, colormap='plasma'):
    colors = []
    cmap = plt.cm.get_cmap(colormap)
    for colorVal in np.linspace(0, 1, n):
        colors.append(cmap(colorVal))
    return colors