In [None]:
def isClock(sig,fs):
    sigM=np.mean(sig)
    sig2=sig-sigM
    sig2=np.sign(sig2)#clock is a perfect binary signal now
    if len(sig) > 5*60*fs: #5min
        L=int(5*60*fs)
    else:
        L=len(sig)-1
    if np.corrcoef(sig[:L],sig2[:L])[0,1] > 0.98:
        return True
    else:
        return False

def prm_reader(prmFile):
    CWD=os.getcwd()
    try:
        os.chdir(os.path.dirname(prmFile))
        prmName=os.path.basename(prmFile)
        %run $prmName
    finally:
        os.chdir(CWD)
    return globals()


## find files based on extension

In [None]:
def find_file(path, extension=['.raw.kwd']):
    """
    This function finds all the file types specified by 'extension' (ex: *.dat) in the 'path' directory
    and all its subdirectories and their sub-subdirectories etc., 
    and returns a list of all file paths
    'extension' is a list of desired file extensions: ['.dat','.prm']
    """
    if type(extension) is str:
        extension=extension.split()   #turning extension into a list with a single element
    return [os.path.join(walking[0],goodfile) for walking in list(os.walk(path)) 
         for goodfile in walking[2] for ext in extension if goodfile.endswith(ext)]

## Safe working with NAS

In [None]:
class safe_copy_from_nas:
    
    def __init__(self,path):
        assert isinstance(path,str)
        self.filePath=path
    
    def start(self,tempName='.NASqwrytdvn2u1r'):
        """
        safely copying the file to a temp local dir
        """
        if "NAS" in self.filePath or "NETDATA" in self.filePath:
            logging.info("Copying from NAS... "+self.filePath)
            defaultPath=os.path.expanduser("~") #home folder
            self.tempdir=os.path.join(defaultPath,tempName)
            try:
                os.mkdir(self.tempdir)
            except FileExistsError: #in case of multiple files being copied the directory already exists
                pass
            try:
                self.newPath=copy(self.filePath,self.tempdir)
            except Exception as e:
                logging.warning("could not copy from NAS to local drive!"+self.filePath)
                logging.info(repr(e))
                self.newPath=""
        else:
            self.newPath=self.filePath
        return self.newPath
    
    def stop(self,fileTypes=['.prm','.dat','.eeg']):
        """
        uploading all files with fileTypes found in the temp directory back to
        the remote server, and removing everything from local drive
        """
        if "NAS" in self.filePath or "NETDATA" in self.filePath:
            files=find_file(self.tempdir,fileTypes)             
            try:    
                for newfile in files:
                    copy(newfile,os.path.dirname(self.filePath))
                    logging.info("Uploaded to NAS! "+newfile)
            except Exception as e:  #Exceptions should not be raised in this level!
                logging.warning("upload to NAS failed! "+newfile)
                logging.info(repr(e))
            finally:
                #remove everything, ignore errors
                rmtree(self.tempdir,ignore_errors=True)


### Permutation Hypothesis Testing

In [None]:
class TwoTailPermTest:
    """
    Permutation test as to whether there is significant difference between group one and two.
    
    group1, group2: Represent the data. they could be either one dimentional (several realizations)
        or 2-D (several realizaions through out the time/space/... course)
        EX: x.shape==(15,500) means 15 trials/samples over 500 time bins

    nIterations: Number of iterations used to shuffle. max(iterN)=(len(x)+len(y))!/len(x)!len(y)!

    initGlobConfInterval:
        Initial value for the global confidence band.

    smoothSigma: the standard deviation of the gaussian kernel used for smoothing when there are multiple data points,
        based on the Fujisawa 2008 paper, default value: 0.05

    Outputs:
        pVal: P-values
        highBand, lowBand: AKA boundary. Represents global bands.
        significantDiff: An array of True or False, indicating whether there is a difference.
    
    """  
    def __init__(self, group1, group2, nIterations=1000, initGlobConfInterval=5, smoothSigma=0.05, randomSeed=1):
        self.__group1, self.__group2 = self.__setGroupData(group1), self.__setGroupData(group2)
        self.__nIterations, self.__smoothSigma = nIterations, smoothSigma
        self.__initGlobConfInterval = initGlobConfInterval
        self.__randomSeed = randomSeed
        
        self.__checkGroups()

        # origGroupDiff is also known as D0 in the definition of permutation test.
        self.__origGroupDiff = self.__computeGroupDiff(group1, group2)

        # Generate surrogate groups, compute difference of mean for each group, and put in a matrix.
        self.__diffSurGroups = self.__setDiffSurrGroups()

        # Set statistics
        self.pVal = self.__setPVal()
        self.highBand, self.lowBand = self.__setBands()
        self.pairwiseHighBand = self.__setPairwiseHighBand()
        self.pairwiseLowBand = self.__setPairwiseLowBand()
        self.significantDiff = self.__setSignificantGroup()

    def __setGroupData(self, groupData):
        if not isinstance(groupData, dict):
            return groupData

        realizations = list(groupData.values())
        subgroups = list(groupData.keys())
                    
        dataMat = np.zeros((len(subgroups), len(realizations[0])))
        for index, realization in enumerate(realizations):
            if len(realization) != len(realizations[0]):
                raise Exception("The length of all realizations in the group dictionary must be the same")
            
            dataMat[index] = realization

        return dataMat

    def __checkGroups(self):
        # input check
        if not isinstance(self.__group1, np.ndarray) or not isinstance(self.__group2, np.ndarray):
            raise ValueError("In permutation test, \"group1\" and \"group2\" should be numpy arrays.")

        if self.__group1.ndim > 2 or self.__group2.ndim > 2:
            raise ValueError('In permutation test, the groups must be either vectors or matrices.')

        elif self.__group1.ndim == 1 or self.__group2.ndim == 1:
            self.__group1 = np.reshape(self.__group1, (len(self.__group1), 1))
            self.__group2 = np.reshape(self.__group2, (len(self.__group2), 1))

    def __computeGroupDiff(self, group1, group2):
        meanDiff = np.nanmean(group1, axis=0) - np.nanmean(group2, axis=0)
        
        if len(self.__group1[0]) == 1 and len(self.__group2[0]) == 1:
            return meanDiff
        
        return smooth(meanDiff, sigma=self.__smoothSigma)

    def __setDiffSurrGroups(self):
        # Fix seed 
        np.random.seed(seed=self.__randomSeed)
        # shuffling the data
        self.__concatenatedData = np.concatenate((self.__group1,  self.__group2), axis=0)
        
        diffSurrGroups = np.zeros((self.__nIterations, self.__group1.shape[1]))
        for iteration in range(self.__nIterations):
            # Generate surrogate groups
            # Shuffle every column.
            np.random.shuffle(self.__concatenatedData)  

            # Return surrogate groups of same size.            
            surrGroup1, surrGroup2 = self.__concatenatedData[:self.__group1.shape[0], :], self.__concatenatedData[self.__group1.shape[0]:, :]
            
            # Compute the difference between mean of surrogate groups
            surrGroupDiff = self.__computeGroupDiff(surrGroup1, surrGroup2)
            
            # Store individual differences in a matrix.
            diffSurrGroups[iteration, :] = surrGroupDiff

        return diffSurrGroups
 
    def __setPVal(self):
        positivePVals = np.sum(1*(self.__diffSurGroups > self.__origGroupDiff), axis=0) / self.__nIterations
        negativePVals = np.sum(1*(self.__diffSurGroups < self.__origGroupDiff), axis=0) / self.__nIterations
        return np.array([np.min([1, 2*pPos, 2*pNeg]) for pPos, pNeg in zip(positivePVals, negativePVals)])

    def __setBands(self):
        if not isinstance(self.__origGroupDiff, np.ndarray):  # single point comparison
            return None, None
        
        alpha = 100 # Global alpha value
        highGlobCI = self.__initGlobConfInterval  # global confidance interval
        lowGlobCI = self.__initGlobConfInterval  # global confidance interval
        while alpha >= 5:
            # highBand = np.percentile(a=self.__diffSurGroups, q=100-highGlobCI, axis=0)
            # lowBand = np.percentile(a=self.__diffSurGroups, q=lowGlobCI, axis=0)

            highBand = np.percentile(a=self.__diffSurGroups, q=100-highGlobCI)
            lowBand = np.percentile(a=self.__diffSurGroups, q=lowGlobCI)

            breaksPositive = np.sum(
                [np.sum(self.__diffSurGroups[i, :] > highBand) > 1 for i in range(self.__nIterations)]) 
            
            breaksNegative = np.sum(
                [np.sum(self.__diffSurGroups[i, :] < lowBand) > 1 for i in range(self.__nIterations)])
            
            alpha = ((breaksPositive + breaksNegative) / self.__nIterations) * 100
            highGlobCI = 0.95 * highGlobCI
            lowGlobCI = 0.95 * lowGlobCI
        return highBand, lowBand           

    def __setSignificantGroup(self):
        if not isinstance(self.__origGroupDiff, np.ndarray):  # single point comparison
            return self.pVal <= 0.05

        # finding significant bins
        globalSig = np.logical_or(self.__origGroupDiff > self.highBand, self.__origGroupDiff < self.lowBand)
        pairwiseSig = np.logical_or(self.__origGroupDiff > self.__setPairwiseHighBand(), self.__origGroupDiff < self.__setPairwiseLowBand())
        
        significantGroup = globalSig.copy()
        lastIndex = 0
        for currentIndex in range(len(pairwiseSig)):
            if (globalSig[currentIndex] == True):
                lastIndex = self.__setNeighborsToTrue(significantGroup, pairwiseSig, currentIndex, lastIndex)

        return significantGroup
    
    def __setPairwiseHighBand(self, localBandValue=0.5):        
        return np.percentile(a=self.__diffSurGroups, q=100 - localBandValue, axis=0)

    def __setPairwiseLowBand(self, localBandValue=0.5):        
        return np.percentile(a=self.__diffSurGroups, q=localBandValue, axis=0)

    def __setNeighborsToTrue(self, significantGroup, pairwiseSig, currentIndex, previousIndex):
        """
            While the neighbors of a global point pass the local band (consecutively), set the global band to true.
            Returns the last index which was set to True.
        """ 
        if (currentIndex < previousIndex):
            return previousIndex
        
        for index in range(currentIndex, previousIndex, -1):
            if (pairwiseSig[index] == True):
                significantGroup[index] = True
            else:
                break

        previousIndex = currentIndex
        for index in range(currentIndex + 1, len(significantGroup)):
            previousIndex = index
            if (pairwiseSig[index] == True):
                significantGroup[index] = True
            else:
                break
        
        return previousIndex
    
    def plotSignificant(self,ax: plt.Axes.axes,y: float,x=None,**kwargs):
        if x is None:
            x=np.arange(0,len(self.significantDiff))+1
        for x0,x1,p in zip(x[:-1],x[1:],self.significantDiff):
            if p:
                ax.plot([x0,x1],[y,y],zorder=-2,**kwargs)
                
    @staticmethod
    def plotSigPair(ax: plt.Axes.axes,y: float,x=None, s: str ='*',**kwargs):
        if x is None:
            x=(0,len(self.significantDiff))
        if 'color' not in kwargs:
            kwargs['color']='k'
        
        dy=.03*(ax.get_ylim()[1]-ax.get_ylim()[0])
        ax.plot(x,[y,y],**kwargs)
        ax.plot([x[0],x[0]],[y-dy,y],[x[1],x[1]],[y-dy,y],**kwargs)
        ax.text(np.mean(x),y,s=s,
                ha='center',va='center',color=kwargs['color'],
                size='xx-small',fontstyle='italic',backgroundcolor='w')

### Bootstrap
code from [ https://github.com/astroML/astroML/blob/master/astroML/resample.py ]

In [None]:
from sklearn.utils import check_random_state

def bootstrap(data, n_bootstraps, user_statistic=lambda x:np.mean(x,axis=1), kwargs=None,
              pass_indices=False, random_state=1):
    """Compute bootstraped statistics of a dataset.
    Parameters
    ----------
    data : array_like
        An n-dimensional data array of size n_samples by n_attributes
    n_bootstraps : integer
        the number of bootstrap samples to compute.  Note that internally,
        two arrays of size (n_bootstraps, n_samples) will be allocated.
        For very large numbers of bootstraps, this can cause memory issues.
    user_statistic : function
        The statistic to be computed.  This should take an array of data
        of size (n_bootstraps, n_samples) and return the row-wise statistics
        of the data.
    kwargs : dictionary (optional)
        A dictionary of keyword arguments to be passed to the
        user_statistic function.
    pass_indices : boolean (optional)
        if True, then the indices of the points rather than the points
        themselves are passed to `user_statistic`
    random_state: RandomState or an int seed (0 by default)
        A random number generator instance
    Returns
    -------
    distribution : ndarray
        the bootstrapped distribution of statistics (length = n_bootstraps)
    """
    # we don't set kwargs={} by default in the argument list, because using
    # a mutable type as a default argument can lead to strange results
    if kwargs is None:
        kwargs = {}

    rng = check_random_state(random_state)
    data = np.asarray(data)
    if data.ndim != 1:
        n_samples = data.shape[0]
        logging.warning("bootstrap data are n-dimensional: assuming ordered n_samples by n_attributes")
    else:
        n_samples = data.size

    # Generate random indices with repetition
    ind = rng.randint(n_samples, size=(n_bootstraps, n_samples))
    data = data[ind].reshape(-1, data[ind].shape[-1])
    # Call the function
    if pass_indices:
        stat_bootstrap = user_statistic(ind, **kwargs)
    else:
        stat_bootstrap = user_statistic(data, **kwargs)

    # compute the statistic on the data
    return stat_bootstrap

### Sample size control by random selection

In [None]:
class sample_size_control:
    def __init__(self,func,animalList,NbAnimal,n,**kwargs):
        """
        func: function to be applied to the randomly chosen animals
        NbAnimal: number of animals to be considered in each iteration of func
        n: max number of iterations
        kwargs: give any input necessary to run "func" comma-seperated, like normal function arguments
        """
        if NbAnimal>len(animalList):
            raise ("NbAnimal must be smaller than animal list")
#         if n==0:
#             tmp=scipy.special.comb(len(animalList), NbAnimal, exact=True, repetition=False)
#             n=max([1000,tmp])
        
        self.iterN=n
        self.animalList=animalList
        self.func=func
        self.subsetSize=NbAnimal
        self.kwargs=kwargs
        
        self.animalRepeat=np.ones(len(self.animalList))

        self.Results=self.run_function()
        
        
    def random_animal_subset(self):
        prob=np.sum(self.animalRepeat)*(1/self.animalRepeat)
        prob=prob/np.sum(prob)
        animalListSubset=np.random.choice(a=self.animalList,size=self.subsetSize,replace=False,p=prob)
        self.animalRepeat+=[animal in animalListSubset for animal in self.animalList]

        return animalListSubset
            
    def run_function(self):
        #inputArgs=inspect.getargspec(self.func)[0]
        i=0
        result=[]
        Args=self.kwargs

        #calculating estimated time needed to process
        t0=time.perf_counter()
        Args.update({'animalList':self.random_animal_subset()})
        result.append(self.func(**Args))
        Args.update({'animalList':self.random_animal_subset()})
        result.append(self.func(**Args))
        i=2
        t_elapsed=(time.perf_counter()-t0)/2
        logging.info("Estimated time to run sample size control<"+str(t_elapsed*self.iterN)+'s')

        while i<self.iterN:
            Args.update({'animalList':self.random_animal_subset()})
            result.append(self.func(**Args))
            i+=1

        return result

### Trajectory PDF

In [None]:
def plot_trajectory_PDF(data,TimeRes=.5,PosRes=5,onlyGood=False,**kargs):
    """
    calculates and plots the joint PDF of trajectories.
    time resolution in seconds
    Position resolution in cm
    """
    
    #data=Data(root,session[:6],session,defaultParam,redoPreprocess=False)
    allTraj=get_positions_array_beginning(data,onlyGood).T
    trialDuration=scipy.stats.mode(data.maxTrialDuration)[0]

    posSize =len(np.arange(data.treadmillRange[0],data.treadmillRange[1],PosRes))
    timeSize=len(np.arange(0,trialDuration,TimeRes))
    trajDis=np.zeros([timeSize,posSize])
    
    #replacing nans w/ the last position
    allTraj=allTraj//PosRes

    for t in range(allTraj.shape[0]-1):
        timeIndex=int((t/data.cameraSamplingRate)//TimeRes)
        trajDis[timeIndex,:]=[np.sum(allTraj[t,:]==x) for x in range(posSize)]

    trajDis=scipy.ndimage.filters.gaussian_filter(trajDis, sigma=[1,1],
                                                  order=0, mode='nearest', truncate=3)
    #normalizing as a PDF
    trajDis/=np.sum(trajDis)

#     plt.figure();
    plt.pcolor(trajDis.T, cmap=cm.hot,**kargs);
    ax=plt.gca();
    ax.set_xticks     (np.linspace(0,timeSize,5));
    ax.set_xticklabels(np.linspace(0,trialDuration,5));
    ax.set_yticks     (np.linspace(0,posSize,10));
    ax.set_yticklabels(np.linspace(data.treadmillRange[0],data.treadmillRange[1],10));
    
    return trajDis

def twoD_entropy(trajDist):
    H=0
    for i in range(trajDist.shape[0]):
        for j in range(trajDist.shape[1]):
            try:
                H+=trajDist[i,j]*math.log(float(trajDist[i,j]),2)
            except:
                pass
    H=-H
    return H

### read session files

In [None]:
def read_file(data,paramName,extension=".behav_param",exclude=None,valueType=str):
    '''
    Use to read from .behav_param or .entrancetimes
    Look for lines containing "paramName" and not containing "exclude"
    Split them by white spaces 
    example: "treadmill speed:     30.00" becomes ["treadmill","speed:","30.00"])
    Return a list of their last element, in the specified valueType (in example: "30.00")
    '''
    behav=data.fullPath+extension
    if not os.path.exists(behav):
        print("No file %s"%behav)
        data.hasBehavior=False
        return []
    result=[]
    trials=[0]
    with open(behav,"r") as f:
        for line in f:
            if "Trial #" in line:
                trials.append(int(float(line.split()[-1]))-1)
            if paramName in line:
                if (exclude is not None) and (exclude in line):
                    continue
                res=line.split()[-1]
                #integer or float: replace comma by dots
                if valueType in [int,float]:
                    res=res.replace(",",".")                 
                #integer: convert first to float ("0.00" -> 0.00 -> 0)
                if valueType is int:
                    res=int(float(res))
                #boolean "TRUE" "FALSE"
                elif valueType is bool:
                    res=(res.lower()=="true")
                else:
                    res=valueType(res)
                result.append( (trials[-1],res) )
    out=[np.nan]*(trials[-1]+1)
    for item in result:
        out[item[0]]=item[1]
    return np.asarray(out)

### Calculate rate

In [None]:
def compute_average_rate(timePoints,minDis=0,maxDis=np.inf):
    """
    timePoints: list of times of occuring of events (in sec)
    minDis= minimum distance between events to be considered valid
    """
    tDiff=np.diff(timePoints)
    tDiff=tDiff[np.logical_and(tDiff>minDis,tDiff<maxDis)]
    return 1/np.nanmean(tDiff)

def compute_rate (x,winLen,overlap=0.5,zero=0,end=None):
    """
    x: list like data with times of event, in sec
    winLen: length of window in sec
    overlap: normalized overlap: (0,1)
    zero: begining of the time axis
    end: maximum of time axis
    ??window: window param of scipy.signal.get_window
    """
    assert overlap<1 and overlap>0, "bad overlap value"
    x=np.array(x)
    if end is None:
        end=x[-1]
#     if window is None:
#         window='boxcar'
#     win=scipy.signal.get_window(window,winLen)
    Range=np.arange(zero,end,(1-overlap)*winLen)
    out=[]
    for i,_ in enumerate(Range):
        a=x[np.logical_and(x>=Range[i],x<Range[i]+winLen)]
        out.append(len(a)/winLen)
    return np.array(out),Range

# Data Fetcher

In [None]:
def data_fetch(root: str, animal: str, profile: dict, PerfParam: list, NbSession: slice =5):
    """
    returns the data requested by PerfParam
    PerfParam: a list of known performance parameters or functions recieving data objest as input
    """
    if not isinstance(PerfParam,list):
        PerfParam=[PerfParam]
    
    perf=[]
    func=[]
    for item in PerfParam:
        if isinstance(item,types.FunctionType):
            func.append(item)
        elif isinstance(item,str):
            perf.append(item)

    if not isinstance(NbSession,slice):
        if NbSession >0:
            NbSession=slice(NbSession)
        else:
            NbSession=slice(NbSession,None)

    sessions=batch_get_session_list(root,[animal],profile=profile)['Sessions'][NbSession]
#     assert < len(sessions), "not enough sessions with this profile"
    
    res=dict((param,[]) for param in perf)
    res.update((param.__name__,[]) for param in func)
    for session in sessions:
        data=Data(root,session[:6],session,redoPreprocess=False)
        
        p1=compute_or_read_stats(data, perf, 
                                 saveAsPickle=False, redo=False)            
        for param in perf:
            res[param].append(p1[param])
            
        for fun in func:
            res[fun.__name__].append(fun(data))
    
    return res

## Find Times when trials end (Treadmill stops)

In [None]:
def int_(x):
    try:
        return int(x)
    except:
        return x

def punishment_duration(data,trial,minDuration=1,maxDuration=10):
    beamIgnore=read_file(data,paramName="consider beam state after (s)",valueType=float)[trial]
    punishTime=maxDuration-maxDuration*(data.entranceTime[trial]-beamIgnore)/(data.goalTime[trial]-beamIgnore)
    punishTime=max((punishTime,minDuration))
    return punishTime

def detect_trial_end(data, trials=None):
    if trials is None:
        trials=data.trials
        
    for trial in trials:
        if data.timeEndTrial[trial] is not None:
            if trial in data.goodTrials or data.timeEndTrial[trial] >= data.entranceTime[trial]+.99:
                continue
        if data.entranceTime[trial] > data.goalTime[trial]:
            data.timeEndTrial[trial]=data.entranceTime[trial]
            data.indexEndTrial[trial]=int_(data.timeEndTrial[trial]*data.cameraSamplingRate)
            continue
        else: #implementing the punishment rule
            data.timeEndTrial[trial]=data.entranceTime[trial]+punishment_duration(data,trial)
            data.indexEndTrial[trial]=int_(data.timeEndTrial[trial]*data.cameraSamplingRate)
            
    return data.timeEndTrial