# Part 0:
## import everything
Run the cell below

In [None]:
# -*- coding: utf-8 -*-
import os
import glob
import numpy as np
from platform import system as OS
import pandas as pd
import scipy.stats
from scipy.ndimage.filters import gaussian_filter1d as smooth
import math
import datetime
from copy import deepcopy
import matplotlib.cm as cm
import warnings
warnings.filterwarnings("ignore")
import types
import inspect
import string
import sys, time
import pickle
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import mlab
from scipy import stats
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import matplotlib.backends.backend_pdf
from sklearn.decomposition import KernelPCA
import mpl_toolkits.axes_grid1.inset_locator as inset
from matplotlib.ticker import FormatStrFormatter, MaxNLocator
import imageio
from set_rc_params import set_rc_params
import ROOT


if "__file__" not in dir():
    %matplotlib inline
    %config InlineBackend.close_figures = False
    matplotlib.rcdefaults()
    
    root=ROOT.root
    
    ThisNoteBookPath=os.path.dirname(os.path.realpath("__file__"))
    CommonNoteBookesPath=os.path.join(os.path.split(ThisNoteBookPath)[0],"load_preprocess_rat")
    CWD=os.getcwd()
    os.chdir(CommonNoteBookesPath)
    %run UtilityTools.ipynb
    %run Animal_Tags.ipynb
    %run loadRat_documentation.ipynb
    %run plotRat_documentation_1_GeneralBehavior.ipynb
    %run plotRat_documentation_3_KinematicsInvestigation.ipynb
    %run RunBatchRat_3_CompareGroups.ipynb
    %run BatchRatBehavior.ipynb
    currentNbPath=os.path.join(os.path.split(ThisNoteBookPath)[0],'BehavioralPaper','NToTrd.ipynb')
    %run $currentNbPath
    os.chdir(CWD)

    logging.getLogger().setLevel(logging.ERROR)
    
    param={
        "goalTime":7,#needed for pavel data only
        "treadmillRange":[0,90],#pavel error conversion "treadmillRange":[0,80]
        "maxTrialDuration":15,
        "interTrialDuration":10,#None pavel
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "cameraSamplingRate":25, #needed for new setup    

        "sigmaSmoothPosition":0.1,#0.33, 0.18 pavel
        "sigmaSmoothSpeed":0.3,#0.3, 0.5 pavel
        "nbJumpMax":100,#200 pavel
        "binSize":0.25,
        #parameters used to preprocess (will override the default parameters)
           }  
    Y1,Y2=param['treadmillRange']

    print('os:',OS(),'\nroot:',root,'\nImport successful!')

---
---


# part 1:

# DEFINITIONS

### If you don't know what to do, move to part 2

In [None]:
def get_ordered_colors(colormap, n):
    colors = []
    cmap = plt.cm.get_cmap(colormap)
    for colorVal in np.linspace(0, 1, n):
        colors.append(cmap(colorVal))
    return colors

In [None]:
def add_panel_caption(axes: tuple, offsetX: tuple, offsetY: tuple, **kwargs):
    """
    This function adds letter captions (a,b,c,d) to Axes in axes
    at top left, with the specified offset, in RELATIVE figure coordinates
    """
    assert len(axes)==len(offsetX)==len(offsetY), 'Bad input!'
    
    fig=axes[0].get_figure()
    fbox=fig.bbox
    for ax,dx,dy,s in zip(axes,offsetX,offsetY,string.ascii_lowercase):
        axbox=ax.get_window_extent()
    
        ax.text(x=(axbox.x0/fbox.xmax)-dx, y=(axbox.y1/fbox.ymax)+dy,
                s=s,fontweight='extra bold', fontsize=10, ha='left', va='center',
               transform=fig.transFigure,**kwargs)


In [None]:
class TwoTailPermTest:
    """
    Permutation test as to whether there is significant difference between group one and two.
    
    group1, group2: Represent the data. they could be either one dimentional (several realizations)
        or 2-D (several realizaions through out the time/space/... course)
        EX: x.shape==(15,500) means 15 trials/samples over 500 time bins

    nIterations: Number of iterations used to shuffle. max(iterN)=(len(x)+len(y))!/len(x)!len(y)!

    initGlobConfInterval:
        Initial value for the global confidence band.

    smoothSigma: the standard deviation of the gaussian kernel used for smoothing when there are multiple data points,
        based on the Fujisawa 2008 paper, default value: 0.05

    Outputs:
        pVal: P-values
        highBand, lowBand: AKA boundary. Represents global bands.
        significantDiff: An array of True or False, indicating whether there is a difference.
    
    """  
    def __init__(self, group1, group2, nIterations=1000, initGlobConfInterval=5, smoothSigma=0.05, randomSeed=1):
        self.__group1, self.__group2 = self.__setGroupData(group1), self.__setGroupData(group2)
        self.__nIterations, self.__smoothSigma = nIterations, smoothSigma
        self.__initGlobConfInterval = initGlobConfInterval
        self.__randomSeed = randomSeed
        
        self.__checkGroups()

        # origGroupDiff is also known as D0 in the definition of permutation test.
        self.__origGroupDiff = self.__computeGroupDiff(group1, group2)

        # Generate surrogate groups, compute difference of mean for each group, and put in a matrix.
        self.__diffSurGroups = self.__setDiffSurrGroups()

        # Set statistics
        self.pVal = self.__setPVal()
        self.highBand, self.lowBand = self.__setBands()
        self.pairwiseHighBand = self.__setPairwiseHighBand()
        self.pairwiseLowBand = self.__setPairwiseLowBand()
        self.significantDiff = self.__setSignificantGroup()

    def __setGroupData(self, groupData):
        if not isinstance(groupData, dict):
            return groupData

        realizations = list(groupData.values())
        subgroups = list(groupData.keys())
                    
        dataMat = np.zeros((len(subgroups), len(realizations[0])))
        for index, realization in enumerate(realizations):
            if len(realization) != len(realizations[0]):
                raise Exception("The length of all realizations in the group dictionary must be the same")
            
            dataMat[index] = realization

        return dataMat

    def __checkGroups(self):
        # input check
        if not isinstance(self.__group1, np.ndarray) or not isinstance(self.__group2, np.ndarray):
            raise ValueError("In permutation test, \"group1\" and \"group2\" should be numpy arrays.")

        if self.__group1.ndim > 2 or self.__group2.ndim > 2:
            raise ValueError('In permutation test, the groups must be either vectors or matrices.')

        elif self.__group1.ndim == 1 or self.__group2.ndim == 1:
            self.__group1 = np.reshape(self.__group1, (len(self.__group1), 1))
            self.__group2 = np.reshape(self.__group2, (len(self.__group2), 1))

    def __computeGroupDiff(self, group1, group2):
        meanDiff = np.nanmean(group1, axis=0) - np.nanmean(group2, axis=0)
        
        if len(self.__group1[0]) == 1 and len(self.__group2[0]) == 1:
            return meanDiff
        
        return smooth(meanDiff, sigma=self.__smoothSigma)

    def __setDiffSurrGroups(self):
        # Fix seed 
        np.random.seed(seed=self.__randomSeed)
        # shuffling the data
        self.__concatenatedData = np.concatenate((self.__group1,  self.__group2), axis=0)
        
        diffSurrGroups = np.zeros((self.__nIterations, self.__group1.shape[1]))
        for iteration in range(self.__nIterations):
            # Generate surrogate groups
            # Shuffle every column.
            np.random.shuffle(self.__concatenatedData)  

            # Return surrogate groups of same size.            
            surrGroup1, surrGroup2 = self.__concatenatedData[:self.__group1.shape[0], :], self.__concatenatedData[self.__group1.shape[0]:, :]
            
            # Compute the difference between mean of surrogate groups
            surrGroupDiff = self.__computeGroupDiff(surrGroup1, surrGroup2)
            
            # Store individual differences in a matrix.
            diffSurrGroups[iteration, :] = surrGroupDiff

        return diffSurrGroups
 
    def __setPVal(self):
        positivePVals = np.sum(1*(self.__diffSurGroups > self.__origGroupDiff), axis=0) / self.__nIterations
        negativePVals = np.sum(1*(self.__diffSurGroups < self.__origGroupDiff), axis=0) / self.__nIterations
        return np.array([np.min([1, 2*pPos, 2*pNeg]) for pPos, pNeg in zip(positivePVals, negativePVals)])

    def __setBands(self):
        if not isinstance(self.__origGroupDiff, np.ndarray):  # single point comparison
            return None, None
        
        alpha = 100 # Global alpha value
        highGlobCI = self.__initGlobConfInterval  # global confidance interval
        lowGlobCI = self.__initGlobConfInterval  # global confidance interval
        while alpha >= 5:
            # highBand = np.percentile(a=self.__diffSurGroups, q=100-highGlobCI, axis=0)
            # lowBand = np.percentile(a=self.__diffSurGroups, q=lowGlobCI, axis=0)

            highBand = np.percentile(a=self.__diffSurGroups, q=100-highGlobCI)
            lowBand = np.percentile(a=self.__diffSurGroups, q=lowGlobCI)

            breaksPositive = np.sum(
                [np.sum(self.__diffSurGroups[i, :] > highBand) > 1 for i in range(self.__nIterations)]) 
            
            breaksNegative = np.sum(
                [np.sum(self.__diffSurGroups[i, :] < lowBand) > 1 for i in range(self.__nIterations)])
            
            alpha = ((breaksPositive + breaksNegative) / self.__nIterations) * 100
            highGlobCI = 0.95 * highGlobCI
            lowGlobCI = 0.95 * lowGlobCI
        return highBand, lowBand           

    def __setSignificantGroup(self):
        if not isinstance(self.__origGroupDiff, np.ndarray):  # single point comparison
            return self.pVal <= 0.05

        # finding significant bins
        globalSig = np.logical_or(self.__origGroupDiff > self.highBand, self.__origGroupDiff < self.lowBand)
        pairwiseSig = np.logical_or(self.__origGroupDiff > self.__setPairwiseHighBand(), self.__origGroupDiff < self.__setPairwiseLowBand())
        
        significantGroup = globalSig.copy()
        lastIndex = 0
        for currentIndex in range(len(pairwiseSig)):
            if (globalSig[currentIndex] == True):
                lastIndex = self.__setNeighborsToTrue(significantGroup, pairwiseSig, currentIndex, lastIndex)

        return significantGroup
    
    def __setPairwiseHighBand(self, localBandValue=0.5):        
        return np.percentile(a=self.__diffSurGroups, q=100 - localBandValue, axis=0)

    def __setPairwiseLowBand(self, localBandValue=0.5):        
        return np.percentile(a=self.__diffSurGroups, q=localBandValue, axis=0)

    def __setNeighborsToTrue(self, significantGroup, pairwiseSig, currentIndex, previousIndex):
        """
            While the neighbors of a global point pass the local band (consecutively), set the global band to true.
            Returns the last index which was set to True.
        """ 
        if (currentIndex < previousIndex):
            return previousIndex
        
        for index in range(currentIndex, previousIndex, -1):
            if (pairwiseSig[index] == True):
                significantGroup[index] = True
            else:
                break

        previousIndex = currentIndex
        for index in range(currentIndex + 1, len(significantGroup)):
            previousIndex = index
            if (pairwiseSig[index] == True):
                significantGroup[index] = True
            else:
                break
        
        return previousIndex
    
    def plotSignificant(self,ax: plt.Axes.axes,y: float,x=None,**kwargs):
        if x is None:
            x=np.arange(0,len(self.significantDiff))+1
        for x0,x1,p in zip(x[:-1],x[1:],self.significantDiff):
            if p:
                ax.plot([x0,x1],[y,y],zorder=-2,**kwargs)
                
    @staticmethod
    def plotSigPair(ax: plt.Axes.axes,y: float,x=None, s: str ='*',**kwargs):
        if x is None:
            x=(0,len(self.significantDiff))
        if 'color' not in kwargs:
            kwargs['color']='k'
        
        dy=.03*(ax.get_ylim()[1]-ax.get_ylim()[0])
        ax.plot(x,[y,y],**kwargs)
        ax.plot([x[0],x[0]],[y-dy,y],[x[1],x[1]],[y-dy,y],**kwargs)
        ax.text(np.mean(x),y,s=s,
                ha='center',va='center',color=kwargs['color'],
                size='xx-small',fontstyle='italic',backgroundcolor='w')

**group ET learning curve**

In [None]:
def plot_learning_curve(ax, root, animalList, profile, TaskParamToPlot, 
                        stop_dayPlot,color='gray'):
    Results,_=get_rat_group_statistic(root,
                                      animalList,
                                      profile,
                                      parameter=param,
                                      redo=False,
                                      stop_dayPlot=stop_dayPlot,
                                      TaskParamToPlot=[TaskParamToPlot])
    
    goalTime=data_fetch(root, animal=animalList[0], profile=profile,
                        PerfParam= [lambda data:data.goalTime[-1]],
                        NbSession=0).values()
    goalTime=list(goalTime)[-1]
    
    x=np.arange(stop_dayPlot)+1.1
    data=np.array( list( Results[TaskParamToPlot].values() ) )
    y=np.nanpercentile(data,50,axis=0)
    yerr=np.nanpercentile(data,(25,75),axis=0)
    
    ax.errorbar(x,y,yerr=abs(yerr-y), ecolor=color, fmt='-o',color=color,
                elinewidth=1, markersize=4, markerfacecolor='w',zorder=1)

    return data
        

def plot_dotted_learning_curve(ax, root, animalList, profile, TaskParamToPlot, 
                               stop_dayPlot,colors, seed=3):
    Results,_=get_rat_group_statistic(root,
                                      animalList,
                                      profile,
                                      parameter=param,
                                      redo=False,
                                      stop_dayPlot=stop_dayPlot,
                                      TaskParamToPlot=[TaskParamToPlot])
    
    goalTime=data_fetch(root, animal=animalList[0], profile=profile,
                        PerfParam= [lambda data:data.goalTime[-1]],
                        NbSession=0).values()
    goalTime=list(goalTime)[-1]
    
    x=np.arange(stop_dayPlot)+1
    data=np.array( list( Results[TaskParamToPlot].values() ) )
    y=np.nanpercentile(data,50,axis=0)
    yerr=np.nanpercentile(data,(25,75),axis=0)
    np.random.seed(seed=seed)
    sigma=.3
    
    ax.errorbar(x,y,yerr=abs(yerr-y), ecolor='k', fmt='k-o',elinewidth=1, markersize=4, markerfacecolor='w',zorder=3)
    
    for pts,day in zip(data.T,x):
        jitter=np.random.uniform(low=day-sigma, high=day+sigma, size=len(pts))
        ax.scatter(jitter,pts,s=.2,c=colors, marker='o',zorder=2)

    ax.set_xlim([x[0]-1,x[-1]+1])
    xtick=[1]
    for i in range(1,stop_dayPlot+1):
        if i%5==0:
            xtick.append(i)
    ax.set_xticks(xtick)
    ax.spines['bottom'].set_bounds(x[0],x[-1])
    ax.set_ylim([0,10])
    ax.set_yticks([0,7,10])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.hlines(y=goalTime, xmin=x[0], xmax=x[-1], linestyle='--', lw=1, color='m')
    ax.set_xlabel('Session#')
    ax.set_ylabel(TaskParamToPlot)
    
    return data
    
def add_legend_to_learning_curve(ax,dataType: str):
    l_data=matplotlib.lines.Line2D([], [], color='k',
                                   marker='o', markerfacecolor='w',
                                   markeredgecolor='k',markersize=4,
                                   label=dataType)
    
    color=plot_learning_curve.__defaults__[-1]
    l_ctrl=matplotlib.lines.Line2D([], [], color=color,
                                   marker='o', markerfacecolor='w',
                                   markeredgecolor=color,markersize=4,
                                   label='Control')

    l_stat=matplotlib.lines.Line2D([], [], color='goldenrod',label='$Significant$')
    handles=[l_stat]
    leg=ax.legend(handles=handles, title="",title_fontsize=6, handletextpad=.6,
                  bbox_to_anchor=(.99, 0),loc=4, ncol=1, fontsize=4, frameon=False)
    
    return leg

In [None]:
if "__file__" not in dir():

    profile={'Type':'Good',
         'rewardType':'Progressive',
#          'initialSpeed':['var'],
         'Speed':'var',
         'Tag':'Control-Early-var'
                  }
    animalList=batch_get_animal_list(root,profile)
    animalList=['Rat125', 'Rat126', 'Rat127', 'Rat128', 'Rat129', 'Rat130', 'Rat153', 'Rat154', 'Rat159', 'Rat160']
    
    profile0={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control']
             }
    animalList0=['Rat103','Rat104','Rat110','Rat113','Rat120','Rat137','Rat138','Rat139','Rat140','Rat149',
                 'Rat150','Rat151','Rat152','Rat161','Rat162','Rat163','Rat164','Rat165','Rat166','Rat215',
                 'Rat216','Rat217','Rat218','Rat219','Rat220','Rat221','Rat222','Rat223','Rat224','Rat225',
                 'Rat226','Rat227','Rat228','Rat229','Rat230','Rat231','Rat232','Rat246','Rat247','Rat248',
                 'Rat249','Rat250','Rat251','Rat252','Rat253','Rat254','Rat255','Rat256','Rat257','Rat258',
                 'Rat259','Rat260','Rat261','Rat262','Rat263','Rat264','Rat265','Rat297','Rat298','Rat299',
                 'Rat300','Rat305','Rat306','Rat307','Rat308']

    TaskParamToPlot="percentile entrance time"
    stop_dayPlot =30
    colors=get_ordered_colors(colormap='plasma', n=len(animalList)+1)[:-1]
    
    plt.close('all')
    fig=plt.figure(figsize=(10,5))
    ax=fig.add_subplot(111);
    
    plot_dotted_learning_curve(ax, root, animalList, profile, TaskParamToPlot, stop_dayPlot,colors)
    plot_learning_curve(ax, root, animalList0, profile0, TaskParamToPlot, stop_dayPlot)
    add_legend_to_learning_curve(ax,dataType=profile['Tag'])
    ax.set_ylabel('Entrance Times (s)')
    plt.show()

---

**plot the probablity of initial position for correct trials**

In [None]:
def initial_pos(root, profile, animalList, SessionRange, trdBins, ax):
    def et_and_initial_pos(data):
        et     =data.entranceTime
        initPos=np.array([data.position[i][data.startFrame[i]] for i in data.position])
        gt     =np.nanmedian(data.goalTime)

        if len(et)==len(initPos):
            return et,initPos,gt
        else:
            return None,None,None
    
    PerfParam=[et_and_initial_pos]
    initialPosB=[]
    initialPosG=[]
    countN=0
    for animal in  animalList:
        sessions=batch_get_session_list(root,[animal],profile=profile)['Sessions']
        data=data_fetch(root, animal=animal, profile=profile,
                        PerfParam= PerfParam,
                        NbSession=100)[PerfParam[0].__name__]
        data=data[SessionRange[0]:SessionRange[1]]

        for et,pos,gt in data:
            if et is None:
                countN+=1
                continue
            initialPosG.extend(list(pos[et>=gt]))
            initialPosB.extend(list(pos[et<gt]))
        
    logging.warning(f'{countN} sessions were removed, entranceTime != position')
    
    return np.array(initialPosG),np.array(initialPosB)
    
def plot_probablity_initial_pos(root,profile,animalList, SessionRange, trdBins, ax, pCum=True,):
        
    initialPosG,initialPosB= initial_pos(root, profile, animalList, SessionRange, trdBins, ax)

    Y1,Y2=param['treadmillRange']
    n,_,_=ax.hist(initialPosG,trdBins,density=True,edgecolor='None',color='lime'  ,alpha=.6, rwidth=1)
    m,_,_=ax.hist(initialPosB,trdBins,density=True,edgecolor='None',color='tomato',alpha=.6, rwidth=1)
    y0,y1=ax.get_ylim()
    
    if pCum:
        ax2=ax.twinx()
        X=np.array(trdBins[1:])-trdBins.step/2
        Y=np.cumsum(n)*trdBins.step
        ax2.plot(X,Y,linewidth=1,color='lime', linestyle='--')
        Ybad=np.cumsum(m)*trdBins.step
        ax2.plot(X,Ybad,linewidth=1,color='tomato', linestyle='--')


#         x=X[Y>=.8][0]
#         y=Y[Y>=.8][0]
#         ax2.hlines(y=y,xmin=x,xmax=trdBins[-1],alpha=.25,linewidth=.5,color='k',zorder=5)
#         ax2.vlines(x=x,ymin=0,ymax=y          ,alpha=.25,linewidth=.5,color='k',zorder=5)
#         ax.fill_betweenx(y=(y0,y1),x1=Y1,x2=x, facecolor=[.5,.5,.5,.2],edgecolor='None')

#         xbad=X[Ybad>=.8][0]
#         ybad=Ybad[Ybad>=.8][0]
#         ax2.vlines(x=xbad,ymin=0,ymax=ybad,alpha=.25,linewidth=.5,color='k',zorder=5)


        ax2.spines['top'].set_visible(False)
        ax2.set_ylim([-.02,1.02])
#         ax2.spines['right'].set_color('m')
#         ax2.tick_params(axis='y', colors='m')
        ax2.set_ylabel('Cumu. Prob.',color='k')
        ax2.set_yticks([0,1])
        ax2.spines['bottom'].set_bounds(Y1,Y2)
        ax2.spines['right'].set_bounds(0,1)

    ax.set_xlabel('Init. position (cm)')
    ax.set_ylabel('Probability')
    ax.set_ylim((y0-y1)/100,y1)
    ax.set_xlim([trdBins[0]-trdBins.step,trdBins[-1]+trdBins.step])
    ax.set_xticks([Y1,Y2/2,Y2])
    ax.set_xticklabels([Y1,'',Y2])
    ax.set_yticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_bounds(Y1,Y2)
    ax.spines['left'].set_bounds(0,y1)
    

In [None]:
if "__file__" not in dir():
    profile={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':['10'],
             'Tag':'Control'
                  }
    animalList=batch_get_animal_list(root,profile)
    animalList=animalList

    Y1,Y2=param['treadmillRange']
    SessionRange=[20,30]
    trdBins=range(Y1,Y2+1,1)
    
    plt.close('all')
    fig=plt.figure(figsize=(2,2),dpi=300)
    ax=fig.add_subplot(111)
    
    plot_probablity_initial_pos(root,profile,animalList, SessionRange, trdBins, ax)
    
    plt.show()


---

**Plot average trajectory**

In [None]:
def plot_session_median_trajectory(data,ax, **kwargs):
    posDict=data.position
    maxL=np.nanmax(list(data.stopFrame.values()))
    maxL=int(maxL)
    position=np.ones((maxL,len(posDict.keys())))*np.nan
    time=np.arange(-data.cameraToTreadmillDelay,
                   (maxL-data.cameraSamplingRate)/data.cameraSamplingRate,
                   1/data.cameraSamplingRate)
    
    
    for i,trial in enumerate(posDict):
        pos=posDict[trial][:data.stopFrame[trial]]
        position[:len(pos),i]=pos
    
    #keeping data where 70% of points exist
    nanSum=np.sum(np.isnan(position),axis=1)
    try:
        maxTraj=np.where(nanSum>.3*position.shape[1])[0][0]
    except IndexError:
        maxTraj=position.shape[1]
    
    
    ax.plot(time[:maxTraj], np.nanmedian(position,axis=1)[:maxTraj], lw=.5, **kwargs)    

    ax.set_ylabel("Position (cm)")
    ax.set_xlabel("Time (s) relative to camera start")
    
    return position
    

def plot_median_trajectory(root, ax, profile, animalList, sessionIdx,colors):
    
    for i,animal in enumerate(animalList):
        session=batch_get_session_list(root, animalList=[animal], profile=profile)['Sessions'][sessionIdx]

        data=Data(root,session[:6],session,
                  param=param,redoPreprocess=False, saveAsPickle=False);
        data.position_correction()

        plot_session_median_trajectory(data,ax,color=colors[i])

    ax.set_title(f'Session #{sessionIdx+1}',pad=0,fontsize='small')
    ax.set_ylim([Y1,Y2])
    ax.set_yticks([Y1,Y2/2,Y2])
    ax.set_yticklabels([Y1,'',Y2])
    ax.set_xlim([-1,8])
    ax.set_xticks([0,7])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_bounds(0,7)
    ax.set_xlabel('Trial time (s)')
    ax.set_ylabel('Position (cm)')
    ax.spines['left'].set_bounds(Y1,Y2)
        
    return

def plot_grand_average(root, ax, profile, animalList, sessionIdx, **kwargs):
    fig_=plt.figure()
    ax_=fig_.add_subplot(111)
    maxL=int(8*25) #8s * 25fps
    traj=np.empty((maxL,len(animalList)))
    time=np.linspace(-1,7,len(traj))
    for i,animal in enumerate(animalList):
        session=batch_get_session_list(root, animalList=[animal], profile=profile)['Sessions'][sessionIdx]

        data=Data(root,session[:6],session, param=param,redoPreprocess=False, saveAsPickle=False);
        data.position_correction()
        pos=plot_session_median_trajectory(data,ax_)
        traj[:,i]=np.nanmedian(pos,axis=1)[:maxL]
        
    ax.plot(time, np.nanmedian(traj,axis=1), **kwargs)    
    
    plt.close(fig_)
    
    return traj

In [None]:
if "__file__" not in dir():
    profile={'Type':'Good',
                 'rewardType':'Progressive',
#                  'initialSpeed':['var'],
                 'Speed':'var',
                 'Tag':'Control-Early-var'
                }
    animalList=['Rat125', 'Rat126', 'Rat127', 'Rat128', 'Rat129', 'Rat130', 'Rat153', 'Rat154', 'Rat159', 'Rat160']

    sessionIdx=29
    colors=get_ordered_colors(colormap='plasma', n=len(animalList)+1)[:-1]
    
    plt.close('all')
    fig=plt.figure(figsize=(8,4))
    ax=fig.add_subplot(111)

    
    
    plot_median_trajectory(root,ax, profile, animalList, sessionIdx,colors)
    plot_grand_average(root, ax, profile, animalList, sessionIdx, color=[.6,.6,.6,.6],lw=6)

---

**Draw the definition of NOT**

In [None]:
def draw_table_NTO(ax):
    try:
        Clr=inspect.signature(plot_learning_curve).parameters['color'].default
    except:
        Clr='gray'

    cellText=[['$0 \leq ET < 7$'],
              ['$1.5 \leq ET < 7$']
             ]
    rowLabels=['No-Timeout','Control']
    colLabels=['Error Trials:']

    table=ax.table(cellText=cellText, cellColours=None, cellLoc='center',
                                 rowLabels=rowLabels, rowColours=None, rowLoc='center',
                                 colLabels=colLabels, colColours=None, colLoc='center',colWidths=None,
                                 loc='center', bbox=[0.25, 0.25, 1, .5], edges='horizontal')
    table.properties()['celld'][(2, -1)].set_text_props(color=Clr)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    return table

In [None]:
def plot_NTO_definition_bars(ax):
    try:
        Clr=inspect.signature(plot_learning_curve).parameters['color'].default
    except:
        Clr='gray'

    ax.annotate(s='', xy=(0,0), xytext=(0,7), arrowprops=dict(arrowstyle='<->', color='k', shrinkA=0, shrinkB=0))
    ax.annotate(s='', xy=(1,1.5), xytext=(1,7), arrowprops=dict(arrowstyle='<->', color=Clr, shrinkA=0, shrinkB=0))

    ax.tick_params(axis='both', labelsize='x-small')
    ax.yaxis.tick_right()
    ax.yaxis.set_label_position('right')
    ax.xaxis.tick_top()
    ax.xaxis.set_label_position('bottom')
    ax.yaxis.set_major_formatter(FormatStrFormatter('%g'))
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['right'].set_bounds(0,7)
    ax.set_xlim([-.2,1.2])
    ax.set_xticks([0,1])
    ax.set_xticklabels(['No-\nTimeout','Control'])
    ax.tick_params('x',top=False, pad=-20)
#     ax.set_ylabel('$ET (s)$')
    ax.set_ylim([-1,10])
    ax.set_yticks([0,1.5,7])

    ax.get_xticklabels()[1].set_color(Clr)
    ax.set_xlabel('Error trials',fontsize='small')


In [None]:
if "__file__" not in dir():
    colors=get_ordered_colors(colormap='plasma', n=len(animalList)+1)[:-1]
    
    plt.close('all')
    fig=plt.figure(figsize=(2,2))
    ax=fig.add_subplot(111)

#     draw_table_NTO(ax)
    plot_NTO_definition_bars(ax)

---

**plotting the CTRL > NTO sketch**

In [None]:
def plot_ctrl_to_nto_sketch(ax):
    
    spds=[1.5,1.5,'. . .',1.5,  '. . .'  ,1.5,  0]
    idx= [1, 3 , 5,     7,   10,      13,      16]
    xtck=[1, 2, '. . .',30  ,'. . .', r'$\mathbf{Before}$',r'$\mathbf{After}$']

    ax.set_xlim([idx[0],idx[-1]])
    ax.set_xticks(idx)
    ax.set_xticklabels(xtck)
    for i,spd in enumerate(spds):
        ax.text(x=idx[i], y=.2, s=str(spd), fontsize='xx-small',
                zorder=5,ha='center',va='center',color='k')

    ax.get_xticklines()[4].set_alpha(0)
    ax.get_xticklines()[8].set_alpha(0)
    ax.set_ylim([0,.5])
    ax.set_yticks([])
    ax.set_title('$Beam$ start time (s)',pad=-10,fontsize='small')
    ax.set_xlabel('Session#',fontsize='small')


    for spine in ['left','top','right']:
        ax.spines[spine].set_visible(False)

    ax.spines['bottom'].set_bounds(idx[0],idx[-1])
    ax.tick_params(axis='y',left=False)
    ax.tick_params(axis='x',labelsize='xx-small')
    
    for t in ax.get_xticklabels()[-2:]:
        t.set_rotation('vertical')

In [None]:
if "__file__" not in dir():
    
    
    plt.close('all')
    fig=plt.figure(figsize=(5,1))
    ax=fig.add_subplot(111)
    
    plot_ctrl_to_nto_sketch(ax)

---

**plotting Effect of NTO**

In [None]:
def plot_event_1on1(root, Profiles, badAnimals, TaskParamToPlot, ax):
    """
    gs could be either gridSpec or a list of axes with 
    """
    #getting the data
    _,SessionDict=event_detect(root, Profiles[0], Profiles[1], badAnimals=badAnimals)

    Results,nSessionPre,nSessionPost=event_statistic(root,
                                                     SessionDict,
                                                     parameter=param,
                                                     redo=False,
                                                     TaskParamToPlot=[TaskParamToPlot])

    assert 1<=nSessionPre and 1<=nSessionPost,"fewer sessions available than requested:"

    data=np.array(list(Results[TaskParamToPlot].values()))
    data=data[:,
              nSessionPre-1 : nSessionPre+1
             ]

    y=np.nanpercentile(data,50,axis=0)
    yerr=np.nanpercentile(data,(25,75),axis=0)

    #plotting the errorbar
    props={'color':'k','linewidth':1}
    ax.boxplot(x=data,whis=[5,95],
                  positions=[-.5,.5], widths=.3,
                  showcaps=False, showfliers=False,
                  medianprops=props, boxprops=props, whiskerprops=props,
                  zorder=2)


    #plotting single animals
    for before,after in data:
        ax.plot([-.5,.5],[before,after],lw=.5,alpha=1,zorder=1)



    ax.set_xlim([-1,1])
    ax.set_xticks([-.5,.5])
    ax.set_xticklabels(['Before','After'])
    ax.spines['bottom'].set_bounds(-.5,.5)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    ax.set_ylabel(TaskParamToPlot)
#     ax.axvline(x=0, color='k',linestyle='--',lw=.8)

    return data

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'10',
              'Speed':'10',
              'Tag':['RandomSpeed-BackTo10','Control-Late-Sharp']
              }
    profile2={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'10',
              'Speed':'10',
              'Tag':['Control-Late-Sharp-Late-NoTimeout','Control-Late-NoTimeout']
              }

    #number of sessions to plot
    badAnimals=['RatBAD']
    TaskParamToPlot="standard deviation of entrance time"
    TaskParamToPlot="% good trials"
    wspace=0.05
    
    
    Profiles=(profile1,profile2,)
    plt.close('all')
    fig=plt.figure(figsize=(5,4))
    ax=fig.add_subplot(111)
    
    a=plot_event_1on1(root, Profiles, badAnimals, TaskParamToPlot, ax)

---

**plotting the trajectories of example sessions**

In [None]:
def plot_trajectories(data,ax):
    posDict=data.position
    time=data.timeTreadmill #align on camera
    Colors=[]
    for trial in posDict:
        color="lime"
        if trial not in data.goodTrials:
            color="tomato"
        Colors.append(color)
        ax.plot(time[trial][:data.stopFrame[trial]], posDict[trial][:data.stopFrame[trial]],
               color=color, lw=.5, )
        

    ax.set_ylabel("X Position (cm)")
    ax.set_xlabel("Time (s) relative to camera start")
    
#     ax.vlines(x=np.nanmedian(data.goalTime), 
#               ymin=param['treadmillRange'][0], ymax=param['treadmillRange'][1], 
#               colors='k',linestyle='--')
    
    return np.array(Colors)



def plot_trajectories_and_distributions(root, ax, session):
    data=Data(root,session[:6],session,redoPreprocess=False)
    y1,y2=treadmillRange=param['treadmillRange']
    
    color=plot_trajectories(data,ax=ax)
    
    position=get_positions_array_beginning(data,onlyGood=False,raw=False)
    position=position.T
    
    histT0,bins0=np.histogram(position[0,color=='lime'],30,range=(y1,y2), density=False)
    histT1,bins1=np.histogram(position[0,color=='tomato'],30,range=(y1,y2), density=False)
    
    maxBin=max([histT0.max(),histT1.max()])/3
    ax.barh(bins0[:-1],-histT0/maxBin,height=np.diff(bins0)[0],align='edge',left=-1.02*data.cameraToTreadmillDelay, color='lime', alpha=.6)
    ax.barh(bins1[:-1],-histT1/maxBin,height=np.diff(bins1)[0],align='edge',left=-1.02*data.cameraToTreadmillDelay, color='tomato', alpha=.6)
    
    props={'color':'k', 'linewidth':1}
    ax.boxplot(x=data.entranceTime,whis=[5,95],vert=False,
               positions=[10], widths=10,
               showcaps=False, showfliers=False,
               medianprops=props, boxprops=props, whiskerprops=props, zorder=5
              )
    
    ax.set_xlim([-4.2,10])
    ax.set_xticks((0,np.nanmedian(data.goalTime)))
    ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
    ax.set_ylim([y1,y2])
    ax.set_yticks([y1,y2/2,y2])
    ax.set_yticklabels([y1,'',y2])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_bounds(ax.get_xticks()[0],ax.get_xticks()[-1])
    ax.spines['left'].set_position(('data',-data.cameraToTreadmillDelay))

    ax.set_xlabel('Trial time (s)')
    ax.set_ylabel('Position (cm)')

    return ax

In [None]:
if "__file__" not in dir():
    #the inputs
    
    fig=plt.figure(figsize=(6,3),dpi=100)
    naiveAx=fig.add_subplot(121)
    trainedAx=fig.add_subplot(122)
    #plotting naive
    session='Rat335_2019_04_09_10_39'
    plot_trajectories_and_distributions(root, naiveAx, session)
    
    #plotting trained
    session='Rat335_2019_04_10_10_59'
    plot_trajectories_and_distributions(root, trainedAx, session)
    
    
#     fig.savefig('/home/david/Pictures/tst.pdf', format='pdf', bbox_inches='tight')
    plt.show()
    plt.close('all')

------



------

# part 2:

# **GENERATING THE FIGURE**

**Definition of Parameters**

In [None]:
if "__file__" not in dir():
    Y1,Y2=param['treadmillRange']
    
    #================================================
    # GRID 1 PARAMS
    profile1={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'10',
              'Speed':'10',
              'Tag':['RandomSpeed-BackTo10','Control-Late-Sharp']
              }
    profile1_={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'10',
              'Speed':'10',
              'Tag':['Control-Late-Sharp-Late-NoTimeout','Control-Late-NoTimeout']
              }

    TaskParamToPlot1="percentile entrance time"
#     TaskParamToPlot1_ ="standard deviation of entrance time"
    TaskParamToPlot1__ ="% good trials"  
    
    Profiles1=(profile1,profile1_)       
    wspace1=.7
    colorSig1='goldenrod'
    
   

    #================================================
    # GRID 2 PARAMS

    sessionPre2 ='Rat335_2019_04_09_10_39'    
    sessionPost2='Rat335_2019_04_10_10_59'



    #================================================
    # GRID 3 PARAMS
    
    profile3Ctrl={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control']
             }
    animalList3Ctrl=['Rat110','Rat113','Rat120','Rat161','Rat162','Rat163','Rat164','Rat165','Rat166','Rat215',
                     'Rat217','Rat218','Rat219','Rat220','Rat221','Rat222','Rat223','Rat224','Rat225','Rat226',
                     'Rat227','Rat228','Rat229','Rat230','Rat231','Rat232','Rat246','Rat247','Rat248','Rat249',
                     'Rat250','Rat251','Rat252','Rat253','Rat254','Rat255','Rat256','Rat257','Rat258','Rat259',
                     'Rat260','Rat261','Rat262','Rat263','Rat264','Rat265','Rat297','Rat298','Rat299','Rat300',
                     'Rat305','Rat306','Rat307','Rat308']
    
    TaskParamToPlot3="percentile entrance time"
    stop_dayPlot3 =30
    try:
        color3Ctrl=inspect.signature(plot_learning_curve).parameters['color'].default
    except:
        color3Ctrl='gray'

 
    
    #=================================================
    
    # GRID 4 PARAMS
    profile4NTO={'Type':'Good',
                 'rewardType':'Progressive',
                 'initialSpeed':['10'],
                 'Speed':'10',
                 'Tag':'Control-NoTimeout'
                }
    animalList4NTO=['Rat141', 'Rat142', 'Rat143', 'Rat144', 'Rat155', 'Rat156', 'Rat157', 'Rat158']
    colors4NTO=get_ordered_colors(colormap='plasma', n=len(animalList4NTO)+1)[:-1]
    label4NTO='No-Timeout'

    
    
    #=================================================
    
    #GRID 9: INIT POSITION DISTRIBUTION
    SessionRange9=[20,30]
    trdBins9=range(Y1,Y2+1,1)

    
        
    #=================================================
    # GENERAL
    param={
        "goalTime":7,#needed for pavel data only
        "treadmillRange":[0,90],#pavel error conversion "treadmillRange":[0,80]
        "maxTrialDuration":15,
        "interTrialDuration":10,#None pavel
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "cameraSamplingRate":25, #needed for new setup    

        "sigmaSmoothPosition":0.1,#0.33, 0.18 pavel
        "sigmaSmoothSpeed":0.3,#0.3, 0.5 pavel
        "nbJumpMax":100,#200 pavel
        "binSize":0.25,
        #parameters used to preprocess (will override the default parameters)
        }
    Y1,Y2=param['treadmillRange']

**Plotting the figure**

In [None]:
if "__file__" not in dir():
    plt.close('all')
    set_rc_params()
    figsize=(7,5)
    fig=plt.figure(figsize=figsize,dpi=600)
    
    
    ###########################################
    # 12: DEFINITION NTO
   
    gs12=fig.add_gridspec(nrows=1, ncols=1, left=0.0, bottom=0.77, right=0.09, top=1.02)
    ax12= fig.add_subplot(gs12[0])
    plot_NTO_definition_bars(ax12)

    
    ##########################################
    # 4: learning curve NTO
    gs4= fig.add_gridspec(nrows=1, ncols=1, left=.16, bottom=0.77, right=0.58, top=1.02)
    ax4= fig.add_subplot(gs4[0])
    D4_1=plot_dotted_learning_curve(ax4, root, animalList4NTO, profile4NTO, TaskParamToPlot3, stop_dayPlot3,colors4NTO)
    D4_2=plot_learning_curve(ax4, root, animalList3Ctrl, profile3Ctrl, TaskParamToPlot3, stop_dayPlot3)
    permTest4=TwoTailPermTest(group1=D4_1, group2=D4_2, nIterations=1000)
    permTest4.plotSignificant(ax=ax4,y=9.5,color=colorSig1,lw=2)
    s=f'{label4NTO}, $n={len(animalList4NTO)}$ rats'
    ax4.text(x=stop_dayPlot3, y=.2, s=s, fontsize='xx-small', zorder=5,ha='right',color='k')
    s=f'Control, $n={len(animalList3Ctrl)}$ rats'
    ax4.text(x=stop_dayPlot3, y=-.5, s=s, fontsize='xx-small', zorder=5,ha='right',color=color3Ctrl)


    ax4.set_ylabel('$ET$ (s)')
    ax4.set_ylim([-1,10])
    ax4.set_yticks([0,7])
    ax4.spines['left'].set_bounds(0,10)
    
    

    ###########################################
    # 8: average trajectory NTO
    gs8= fig.add_gridspec(nrows=1, ncols=1, left=0.66, bottom=0.77, right=0.81, top=1.02)
    ax8= fig.add_subplot(gs8[0])
    plot_median_trajectory(root,ax8, profile4NTO, animalList4NTO, stop_dayPlot3-1,colors4NTO)
    plot_grand_average(root, ax8, profile3Ctrl, animalList3Ctrl, stop_dayPlot3-1, color=color3Ctrl,lw=2,zorder=-5)

        
    
    
    
    ###########################################
    # 9: probabblity of initial pos
    gs9= fig.add_gridspec(nrows=1, ncols=1, left=0.85, bottom=0.77, right=0.99, top=1.02)
    ax9= fig.add_subplot(gs9[0])
    plot_probablity_initial_pos(root,profile4NTO,animalList4NTO, SessionRange9, trdBins9, ax9, pCum=False)
    
    
    
    ###########################################
    # 1: plot variable speed effect
    gs1  = fig.add_gridspec(nrows=1, ncols=2, left=0.72, bottom=0.38, right=1, top=0.63,wspace=wspace1)
    ax1  = fig.add_subplot(gs1[0])
    ax1__= fig.add_subplot(gs1[1])

    data1  =plot_event_1on1(root, Profiles1, [], TaskParamToPlot1,   ax1  )
    data1__=plot_event_1on1(root, Profiles1, [], TaskParamToPlot1__, ax1__)
    
    ax1.axhline(7 ,linestyle='--',color='m',lw=1, zorder=-5)
    ax1.set_ylabel('$ET$ (s)',labelpad=-5)
    ax1.set_yticks([0,7])
    ax1.spines['left'].set_bounds(0,7)
    ax1.set_ylim([-.5,8])
    ax1.tick_params('x',labelsize='xx-small')
    
    ax1__.set_ylabel('% Correct trials')
    ax1__.set_yticks([0,20,40,60])
    ax1__.spines['left'].set_bounds(0,60)
    ax1__.set_ylim([0,65])
    ax1__.set_xticklabels('')


    totAx1=fig.add_subplot(gs1[:],frameon=False)
#     totAx1.set_xlabel('Effect of variable speed')
    totAx1.yaxis.set_visible(False)
    totAx1.set_xlim([0,10])
    totAx1.set_xticks([10])
    totAx1.tick_params(color=(0, 0, 0, 0),labelcolor=(0, 0, 0, 0),zorder=-10)

    #PERM TESTS
    test1=TwoTailPermTest(group1=data1[:,0], group2=data1[:,1], nIterations=10000)
    s='*' if test1.significantDiff[0] else 'n.s.'
    test1.plotSigPair(ax1,y=7.9,x=(-.5,.5), s=s, color=colorSig1,lw=.8)
        
    test1__=TwoTailPermTest(group1=data1__[:,0], group2=data1__[:,1], nIterations=10000)
    s='*' if test1__.significantDiff[0] else 'n.s.'
    test1__.plotSigPair(ax1__,y=64,x=(-.5,.5), s=s, color=colorSig1,lw=.8)


    
    ###########################################
    # 6: CTRL>NTO sketch
    gs6= fig.add_gridspec(nrows=3, ncols=1, left=0.0, bottom=0.38, right=0.17, top=0.63,hspace=0,wspace=0)
    ax6= fig.add_subplot(gs6[0])
    plot_ctrl_to_nto_sketch(ax6)
    ax6.text(x=.5, y=-1.5, s=f'$n={7}$ rats',
             fontsize='xx-small',ha='center',transform=ax6.transAxes,zorder=5)
    
    
    
    ###########################################
    # 2: Pre andPost Trajectory plots
    gs2= fig.add_gridspec(nrows=1, ncols=2, left=0.23, bottom=0.38, right=0.65, top=0.63,wspace=.1)
    ax2= fig.add_subplot(gs2[0])
    ax2_= fig.add_subplot(gs2[1])

    plot_trajectories_and_distributions(root, ax2, sessionPre2)
    plot_trajectories_and_distributions(root, ax2_, sessionPost2)
    
    ax2_.set_ylabel('')
    ax2_.set_yticks([])
    ax2_.set_xlabel('')
    ax2.set_xlabel('')
    ax2.set_title('Before',fontsize='medium')
    ax2_.set_title('After',fontsize='medium')
    
    axTot=fig.add_subplot(gs2[:],frameon=False)
    axTot.set_xlabel('Trial Time (s)')
    axTot.tick_params('both',left=False,bottom=False,labelcolor=[0,0,0,0],labelleft=False)

    
    
    
    #############################################
    #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    AXES=(ax12,ax4,ax8,ax9,ax6,ax2,ax1)
    OFFX=np.array([.03]*len(AXES))
    OFFY=np.array([.03]*len(AXES))
    OFFX[[2]]=0.05
    add_panel_caption(axes=AXES, offsetX=OFFX, offsetY=OFFY)


    fig.savefig(os.path.join(os.path.dirname(os.getcwd()),'BehavioralPaper','Figures','NToTrd.pdf'),
                format='pdf', bbox_inches='tight')
    
    plt.show()
    plt.close('all')
    matplotlib.rcdefaults()