# Part 0:
## import everything
Run the cell below

In [None]:
# -*- coding: utf-8 -*-
import os
import glob
import numpy as np
from platform import system as OS
import pandas as pd
import scipy.stats
from scipy.ndimage.filters import gaussian_filter1d as smooth
import math
import datetime
from copy import deepcopy
import matplotlib.cm as cm
import warnings
warnings.filterwarnings("ignore")
import types
import inspect
import string
import sys, time
import pickle
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import mlab
from scipy import stats
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import matplotlib.backends.backend_pdf
from sklearn.decomposition import KernelPCA
import mpl_toolkits.axes_grid1.inset_locator as inset
from matplotlib.ticker import FormatStrFormatter
import imageio
from set_rc_params import set_rc_params


if "__file__" not in dir():
    %matplotlib inline
    %config InlineBackend.close_figures = False
    matplotlib.rcdefaults()
    
    if OS()=='Linux':
        root="/data"
    elif OS()=='Windows':
        root="C:\\DATA\\"
    else:
        root="/Users/davidrobbe/Documents/Data/"
            
    ThisNoteBookPath=os.path.dirname(os.path.realpath("__file__"))
    CommonNoteBookesPath=os.path.join(os.path.split(ThisNoteBookPath)[0],"load_preprocess_rat")
    CWD=os.getcwd()
    os.chdir(CommonNoteBookesPath)
    %run UtilityTools.ipynb
    %run Animal_Tags.ipynb
    %run loadRat_documentation.ipynb
    %run plotRat_documentation_1_GeneralBehavior.ipynb
    %run plotRat_documentation_3_KinematicsInvestigation.ipynb
    %run RunBatchRat_3_CompareGroups.ipynb
    %run BatchRatBehavior.ipynb
    currentNbPath=os.path.join(os.path.split(ThisNoteBookPath)[0],'BehavioralPaper','NoT-VarTrd.ipynb')
    %run $currentNbPath
    os.chdir(CWD)

    logging.getLogger().setLevel(logging.ERROR)
    
    param={
        "goalTime":7,#needed for pavel data only
        "treadmillRange":[0,90],#pavel error conversion "treadmillRange":[0,80]
        "maxTrialDuration":15,
        "interTrialDuration":10,#None pavel
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "cameraSamplingRate":25, #needed for new setup    

        "sigmaSmoothPosition":0.1,#0.33, 0.18 pavel
        "sigmaSmoothSpeed":0.3,#0.3, 0.5 pavel
        "nbJumpMax":100,#200 pavel
        "binSize":0.25,
        #parameters used to preprocess (will override the default parameters)
           }  
    Y1,Y2=param['treadmillRange']

    print('os:',OS(),'\nroot:',root,'\nImport successful!')

---
---


# part 1:

# DEFINITIONS

### If you don't know what to do, move to part 2

In [None]:
def get_ordered_colors(colormap, n):
    colors = []
    cmap = plt.cm.get_cmap(colormap)
    for colorVal in np.linspace(0, 1, n):
        colors.append(cmap(colorVal))
    return colors

In [None]:
def add_panel_caption(axes: tuple, offsetX: tuple, offsetY: tuple, **kwargs):
    """
    This function adds letter captions (a,b,c,d) to Axes in axes
    at top left, with the specified offset, in RELATIVE figure coordinates
    """
    assert len(axes)==len(offsetX)==len(offsetY), 'Bad input!'
    
    fig=axes[0].get_figure()
    fbox=fig.bbox
    for ax,dx,dy,s in zip(axes,offsetX,offsetY,string.ascii_lowercase):
        axbox=ax.get_window_extent()
    
        ax.text(x=(axbox.x0/fbox.xmax)-abs(dx), y=(axbox.y1/fbox.ymax)+abs(dy),
                s=s,fontweight='extra bold', fontsize=10, ha='left', va='center',
               transform=fig.transFigure,**kwargs)


In [None]:
class TwoTailPermTest:
    """
    Permutation test as to whether there is significant difference between group one and two.
    
    group1, group2: Represent the data. they could be either one dimentional (several realizations)
        or 2-D (several realizaions through out the time/space/... course)
        EX: x.shape==(15,500) means 15 trials/samples over 500 time bins

    nIterations: Number of iterations used to shuffle. max(iterN)=(len(x)+len(y))!/len(x)!len(y)!

    initGlobConfInterval:
        Initial value for the global confidence band.

    sigma: the standard deviation of the gaussian kernel used for smoothing when there are multiple data points,
        based on the Fujisawa 2008 paper, default value: 0.05

    Outputs:
        pVal: P-values
        highBand, lowBand: AKA boundary. Represents global bands.
        significantDiff: An array of True or False, indicating whether there is a difference.
    
    """  
    def __init__(self, group1, group2, nIterations=1000, initGlobConfInterval=5, smoothSigma=0.05):
        self.group1, self.group2 = group1, group2
        self.nIterations, self.smoothSigma = nIterations, smoothSigma
        self.initGlobConfInterval = initGlobConfInterval

        self.checkGroups()

        # origGroupDiff is also known as D0 in the definition of permutation test.
        self.origGroupDiff = self.computeGroupDiff(group1, group2)

        # Generate surrogate groups, compute difference of mean for each group, and put in a matrix.
        self.diffSurGroups = self.setDiffSurrGroups()

        # Set statistics
        self.pVal = self.setPVal()
        self.highBand, self.lowBand = self.setBands()
        self.significantDiff = self.setSignificantGroup()

    def checkGroups(self):
        # input check
        if not isinstance(self.group1, np.ndarray) or not isinstance(self.group2, np.ndarray):
            raise ValueError("In permutation test, \"group1\" and \"group2\" should be numpy arrays.")

        if self.group1.ndim > 2 or self.group2.ndim > 2:
            raise ValueError('In permutation test, the groups must be either vectors or matrices.')

        elif self.group1.ndim == 1 or self.group2.ndim == 1:
            self.group1 = np.reshape(self.group1, (len(self.group1), 1))
            self.group2 = np.reshape(self.group2, (len(self.group2), 1))

    def computeGroupDiff(self, group1, group2):
        meanDiff = np.nanmean(group1, axis=0) - np.nanmean(group2, axis=0)
        
        if len(self.group1[0]) == 1 and len(self.group2[0]) == 1:
            return [meanDiff]
        
        return smooth(meanDiff, sigma=self.smoothSigma, order=0, 
                    mode='constant', cval=0, truncate=4.0)

    def setDiffSurrGroups(self):
        # shuffling the data
        self.concatenatedData = np.concatenate((self.group1,  self.group2), axis=0)
        
        diffSurrGroups = np.zeros((self.nIterations, self.group1.shape[1]))
        for iteration in range(self.nIterations):
             # Generate surrogate groups
            surrGroup1, surrGroup2 = self.generateSurrGroup()
            
            # Compute the difference between mean of surrogate groups
            surrGroupDiff = self.computeSurrGroupDiff(surrGroup1, surrGroup2) 
            
            # Store individual differences in a matrix.
            diffSurrGroups[iteration, :] = surrGroupDiff

        return diffSurrGroups

    def generateSurrGroup(self):
        # Shuffle every column.
        np.random.shuffle(self.concatenatedData)  

         # Return surrogate groups of same size.            
        return self.concatenatedData[: self.group1.shape[0], :], self.concatenatedData[self.group1.shape[0]:, :]

    def computeSurrGroupDiff(self, surrGroup1, surrGroup2):
        return self.computeGroupDiff(surrGroup1, surrGroup2)
  
    def setPVal(self):
        positivePVals = np.sum(1*(self.diffSurGroups > self.origGroupDiff), axis=0) / self.nIterations
        negativePVals = np.sum(1*(self.diffSurGroups < self.origGroupDiff), axis=0) / self.nIterations
        return np.array([np.min([1, 2*pPos, 2*pNeg]) for pPos, pNeg in zip(positivePVals, negativePVals)])

    def setBands(self):
        if len(self.origGroupDiff) < 2:  # single point comparison
            return None, None
        
        alpha = 100 # Global alpha value
        highGlobCI = self.initGlobConfInterval  # global confidance interval
        lowGlobCI = self.initGlobConfInterval  # global confidance interval
        while alpha >= 5:
            highBand = np.percentile(a=self.diffSurGroups, q=100-highGlobCI, axis=0)
            lowBand = np.percentile(a=self.diffSurGroups, q=lowGlobCI, axis=0)
            
            breaksPositive = np.sum(
                [np.sum(self.diffSurGroups[i, :] > highBand) > 1 for i in range(self.nIterations)]) 
            
            breaksNegative = np.sum(
                [np.sum(self.diffSurGroups[i, :] < lowBand) > 1 for i in range(self.nIterations)])
            
            alpha = ((breaksPositive + breaksNegative) / self.nIterations) * 100
            highGlobCI = 0.95 * highGlobCI
            lowGlobCI = 0.95 * lowGlobCI
        return highBand, lowBand           

    def setSignificantGroup(self):
        if len(self.origGroupDiff) < 2:  # single point comparison
            return self.pVal <= 0.05

        # finding significant bins
        globalSig = np.logical_or(self.origGroupDiff > self.highBand, self.origGroupDiff < self.lowBand)
        pairwiseSig = np.logical_or(self.origGroupDiff > self.setPairwiseHighBand(), self.origGroupDiff < self.setPairwiseLowBand())
        
        significantGroup = globalSig.copy()
        lastIndex = 0
        for currentIndex in range(len(pairwiseSig)):
            if (globalSig[currentIndex] == True):
                lastIndex = self.setNeighborsToTrue(significantGroup, pairwiseSig, currentIndex, lastIndex)

        return significantGroup
    
    def setPairwiseHighBand(self):        
        return np.percentile(a=self.diffSurGroups, q=100 - self.initGlobConfInterval, axis=0)

    def setPairwiseLowBand(self):        
        return np.percentile(a=self.diffSurGroups, q=self.initGlobConfInterval, axis=0)


    def setNeighborsToTrue(self, significantGroup, pairwiseSig, currentIndex, previousIndex):
        """
            While the neighbors of a global point pass the local band (consecutively), set the global band to true.
            Returns the last index which was set to True.
        """ 
        if (currentIndex < previousIndex):
            return previousIndex
        
        for index in range(currentIndex, previousIndex, -1):
            if (pairwiseSig[index] == True):
                significantGroup[index] = True
            else:
                break

        previousIndex = currentIndex
        for index in range(currentIndex + 1, len(significantGroup)):
            previousIndex = index
            if (pairwiseSig[index] == True):
                significantGroup[index] = True
            else:
                break
        
        return previousIndex
    
    def plotSignificant(self,ax: plt.Axes.axes,y: float,x=None,**kwargs):
        if x is None:
            x=np.arange(0,len(self.significantDiff))+1
        for x0,x1,p in zip(x[:-1],x[1:],self.significantDiff):
            if p:
                ax.plot([x0,x1],[y,y],zorder=-2,**kwargs)
        

**group ET learning curve**

In [None]:
def plot_learning_curve(ax, root, animalList, profile, TaskParamToPlot, 
                        stop_dayPlot,color='gray'):
    Results,_=get_rat_group_statistic(root,
                                      animalList,
                                      profile,
                                      parameter=param,
                                      redo=False,
                                      stop_dayPlot=stop_dayPlot,
                                      TaskParamToPlot=[TaskParamToPlot])
    
    goalTime=data_fetch(root, animal=animalList[0], profile=profile,
                        PerfParam= [lambda data:data.goalTime[-1]],
                        NbSession=0).values()
    goalTime=list(goalTime)[-1]
    
    x=np.arange(stop_dayPlot)+1
    data=np.array( list( Results[TaskParamToPlot].values() ) )
    y=np.nanpercentile(data,50,axis=0)
    yerr=np.nanpercentile(data,(25,75),axis=0)
    
    ax.errorbar(x,y,yerr=abs(yerr-y), ecolor=color, fmt='-o',color=color,
                elinewidth=1, markersize=4, markerfacecolor='w',zorder=1)

    return data
        

def plot_dotted_learning_curve(ax, root, animalList, profile, TaskParamToPlot, 
                               stop_dayPlot,colors, seed=3):
    Results,_=get_rat_group_statistic(root,
                                      animalList,
                                      profile,
                                      parameter=param,
                                      redo=False,
                                      stop_dayPlot=stop_dayPlot,
                                      TaskParamToPlot=[TaskParamToPlot])
    
    goalTime=data_fetch(root, animal=animalList[0], profile=profile,
                        PerfParam= [lambda data:data.goalTime[-1]],
                        NbSession=0).values()
    goalTime=list(goalTime)[-1]
    
    x=np.arange(stop_dayPlot)+1
    data=np.array( list( Results[TaskParamToPlot].values() ) )
    y=np.nanpercentile(data,50,axis=0)
    yerr=np.nanpercentile(data,(25,75),axis=0)
    np.random.seed(seed=seed)
    sigma=.3
    
    ax.errorbar(x,y,yerr=abs(yerr-y), ecolor='k', fmt='k-o',elinewidth=1, markersize=4, markerfacecolor='w',zorder=3)
    
    for pts,day in zip(data.T,x):
        jitter=np.random.uniform(low=day-sigma, high=day+sigma, size=len(pts))
        ax.scatter(jitter,pts,s=.2,c=colors, marker='o',zorder=2)

    ax.set_xlim([x[0]-1,x[-1]+1])
    xtick=[1]
    for i in range(1,stop_dayPlot+1):
        if i%5==0:
            xtick.append(i)
    ax.set_xticks(xtick)
    ax.spines['bottom'].set_bounds(x[0],x[-1])
    ax.set_ylim([0,10])
    ax.set_yticks([0,7,10])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.hlines(y=goalTime, xmin=x[0], xmax=x[-1], linestyle='--', lw=1, color='m')
    ax.set_xlabel('Session#')
    ax.set_ylabel(TaskParamToPlot)
    
    return data
    
def add_legend_to_learning_curve(ax,dataType: str):
    l_data=matplotlib.lines.Line2D([], [], color='k',
                                   marker='o', markerfacecolor='w',
                                   markeredgecolor='k',markersize=4,
                                   label=dataType)
    
    color=plot_learning_curve.__defaults__[-1]
    l_ctrl=matplotlib.lines.Line2D([], [], color=color,
                                   marker='o', markerfacecolor='w',
                                   markeredgecolor=color,markersize=4,
                                   label='Control')

    l_stat=matplotlib.lines.Line2D([], [], color='goldenrod',label='$Significant$')
    handles=[l_stat]
    leg=ax.legend(handles=handles, title="",title_fontsize=6, handletextpad=.6,
                  bbox_to_anchor=(.99, 0),loc=4, ncol=1, fontsize=4, frameon=False)
    
    return leg

In [None]:
if "__file__" not in dir():

    profile={'Type':'Good',
         'rewardType':'Progressive',
#          'initialSpeed':['var'],
         'Speed':'var',
         'Tag':'Control-Early-var'
                  }
    animalList=batch_get_animal_list(root,profile)
    animalList=['Rat125', 'Rat126', 'Rat127', 'Rat128', 'Rat129', 'Rat130', 'Rat153', 'Rat154', 'Rat159', 'Rat160']
    
    profile0={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control']
             }
    animalList0=['Rat103','Rat104','Rat110','Rat113','Rat120','Rat137','Rat138','Rat139','Rat140','Rat149',
                 'Rat150','Rat151','Rat152','Rat161','Rat162','Rat163','Rat164','Rat165','Rat166','Rat215',
                 'Rat216','Rat217','Rat218','Rat219','Rat220','Rat221','Rat222','Rat223','Rat224','Rat225',
                 'Rat226','Rat227','Rat228','Rat229','Rat230','Rat231','Rat232','Rat246','Rat247','Rat248',
                 'Rat249','Rat250','Rat251','Rat252','Rat253','Rat254','Rat255','Rat256','Rat257','Rat258',
                 'Rat259','Rat260','Rat261','Rat262','Rat263','Rat264','Rat265','Rat297','Rat298','Rat299',
                 'Rat300','Rat305','Rat306','Rat307','Rat308']

    TaskParamToPlot="percentile entrance time"
    stop_dayPlot =30
    colors=get_ordered_colors(colormap='plasma', n=len(animalList)+1)[:-1]
    
    plt.close('all')
    fig=plt.figure(figsize=(10,5))
    ax=fig.add_subplot(111);
    
    plot_dotted_learning_curve(ax, root, animalList, profile, TaskParamToPlot, stop_dayPlot,colors)
    plot_learning_curve(ax, root, animalList0, profile0, TaskParamToPlot, stop_dayPlot)
    add_legend_to_learning_curve(ax,dataType=profile['Tag'])
    ax.set_ylabel('Entrance Times (s)')
    plt.show()

---

**plotting several consecutive trials**

In [None]:
def plot_consecutive_trajectories(root, session, trials, ax,**kwargs):
    delta=1 #last delta second of position is not recorded
    data=Data(root,session[:6],session,param=param,redoPreprocess=False, saveAsPickle=False);
    data.position_correction()
    detect_trial_end(data, trials)
    fps=data.cameraSamplingRate
    y1=param['treadmillRange'][0]
    y2=param['treadmillRange'][1]

    maxT=0
    for trial in trials:
        time=data.rawTime[trial]+maxT
        maxT=time[-1]+delta
        #plotting position
        ax.plot(time,data.position[trial],'k',**kwargs);
        #plotting entrance time
        ET=data.entranceTime[trial]+data.cameraToTreadmillDelay+time[0]
        if data.entranceTime[trial]!=data.maxTrialDuration[trial]:
            ax.plot(ET,data.position[trial][time[time<ET].argmax()+1],'rx')
        #plotting the goal time
        GT=data.goalTime[trial]+data.cameraToTreadmillDelay+time[0]
        ax.plot(GT,data.position[trial][time[time<GT].argmax()], marker='o',markeredgecolor='r', markerfacecolor='None')
        #plotting the highlight for trials
        x=(time[0], time[0]+data.timeEndTrial[trial]+data.cameraToTreadmillDelay)
        c='lime' if trial in data.goodTrials else 'tomato' if data.entranceTime[trial]<data.goalTime[trial] else 'cyan'
        ax.fill_betweenx(y=(y1,y2),x1=x[0],x2=x[1], color=c, alpha=.2)
    
    ax.set_xlim([-1,time[-1]+1])
    ax.set_ylim([y1-2,y2+2])
    ax.set_yticks([y1,y2/2,y2])
    ax.set_yticklabels([y1,'',y2])
#     ax.set_xlabel('Time (s)')
    ax.set_ylabel('Position (cm)')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
        
    return ax

In [None]:
if "__file__" not in dir():
    session='Rat156_2017_09_08_14_10'
    i=93
    trials=range(i,i+11)

    plt.close('all')
    ax=plt.figure(figsize=(20,3)).add_subplot(111);

    plot_consecutive_trajectories(root, session, trials, ax)
#     add_treadmill_to_axis(ax,path,Xextention=12)

In [None]:
if "__file__" not in dir():
    fig=plt.figure(figsize=(9,4),dpi=100)
    naiveAx=fig.add_subplot(211)
    trainedAx=fig.add_subplot(212)
    
    #plotting VAR
    session='Rat128_2017_06_02_12_26'
    trials=range(119,130)
    plot_consecutive_trajectories(root, session, trials, naiveAx)
    naiveAx.text(x=0, y=0, s=f'{session[:6]} trained')
    #plotting NTO
    session='Rat156_2017_09_08_14_10'
    trials=range(93,104)
    plot_consecutive_trajectories(root, session, trials, trainedAx)
    trainedAx.text(x=0, y=0, s=f'{session[:6]} trained')
    

    
#     fig.savefig('/home/david/Pictures/tst.pdf',format='pdf')
    plt.show()
    plt.close('all')

---

**plotting the trajectories of above example sessions**

In [None]:
def plot_trajectories(data,ax):
    posDict=data.position
    time=data.timeTreadmill #align on camera
    Colors=[]
    for trial in posDict:
        color="lime"
        if trial not in data.goodTrials:
            color="tomato"
        Colors.append(color)
        ax.plot(time[trial][:data.stopFrame[trial]], posDict[trial][:data.stopFrame[trial]],
               color=color, lw=.5, )
        

    ax.set_ylabel("X Position (cm)")
    ax.set_xlabel("Time (s) relative to camera start")
    
    ax.vlines(x=np.nanmedian(data.goalTime), 
              ymin=param['treadmillRange'][0], ymax=param['treadmillRange'][1], 
              colors='k',linestyle='--')
    
    return np.array(Colors)



def plot_trajectories_and_distributions(root, ax, session):
    data=Data(root,session[:6],session,redoPreprocess=False)
    y1,y2=treadmillRange=param['treadmillRange']
    
    color=plot_trajectories(data,ax=ax)
    
    position=get_positions_array_beginning(data,onlyGood=False,raw=False)
    position=position.T
    
    histT0,bins0=np.histogram(position[0,color=='lime'],30,range=(y1,y2), density=False)
    histT1,bins1=np.histogram(position[0,color=='tomato'],30,range=(y1,y2), density=False)
    
    maxBin=max([histT0.max(),histT1.max()])/3
    ax.barh(bins0[:-1],-histT0/maxBin,height=np.diff(bins0)[0],align='edge',left=-1.02*data.cameraToTreadmillDelay, color='lime', alpha=.6)
    ax.barh(bins1[:-1],-histT1/maxBin,height=np.diff(bins1)[0],align='edge',left=-1.02*data.cameraToTreadmillDelay, color='tomato', alpha=.6)

    
    ax.set_xlim([-4.2,data.maxTrialDuration[0]])
    ax.set_xticks((0,np.nanmedian(data.goalTime),np.nanmedian(data.maxTrialDuration)))
    ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
    ax.set_ylim([y1,y2])
    ax.set_yticks([y1,y2/2,y2])
    ax.set_yticklabels([y1,'',y2])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_bounds(ax.get_xticks()[0],ax.get_xticks()[-1])
    ax.spines['left'].set_position(('data',-data.cameraToTreadmillDelay))

    ax.set_xlabel('Trial time (s)')
    ax.set_ylabel('Position (cm)')

    return ax

In [None]:
if "__file__" not in dir():
    #the inputs
    
    fig=plt.figure(figsize=(3,6),dpi=100)
    naiveAx=fig.add_subplot(211)
    trainedAx=fig.add_subplot(212)
    #plotting VAR
    session='Rat128_2017_06_02_12_26'
    plot_trajectories_and_distributions(root, naiveAx, session)
    
    #plotting NTO
    session='Rat156_2017_09_08_14_10'
    plot_trajectories_and_distributions(root, trainedAx, session)
    
    
#     fig.savefig('/home/david/Pictures/tst.pdf', format='pdf', bbox_inches='tight')
    plt.show()
    plt.close('all')

---

**plot the probablity of initial position for correct trials**

In [None]:
def initial_pos(root, profile, animalList, SessionRange, trdBins, ax):
    def et_and_initial_pos(data):
        et     =data.entranceTime
        initPos=np.array([data.position[i][data.startFrame[i]] for i in data.position])
        gt     =np.nanmedian(data.goalTime)

        if len(et)==len(initPos):
            return et,initPos,gt
        else:
            return None,None,None
    
    PerfParam=[et_and_initial_pos]
    initialPosB=[]
    initialPosG=[]
    countN=0
    for animal in  animalList:
        sessions=batch_get_session_list(root,[animal],profile=profile)['Sessions']
        data=data_fetch(root, animal=animal, profile=profile,
                        PerfParam= PerfParam,
                        NbSession=100)[PerfParam[0].__name__]
        data=data[SessionRange[0]:SessionRange[1]]

        for et,pos,gt in data:
            if et is None:
                countN+=1
                continue
            initialPosG.extend(list(pos[et>=gt]))
            initialPosB.extend(list(pos[et<gt]))
        
    logging.warning(f'{countN} sessions were removed, entranceTime != position')
    
    return np.array(initialPosG),np.array(initialPosB)
    
def plot_probablity_initial_pos(root,profile,animalList, SessionRange, trdBins, ax, pCum=True,):
        
    initialPosG,initialPosB= initial_pos(root, profile, animalList, SessionRange, trdBins, ax)

    Y1,Y2=param['treadmillRange']
    n,_,_=ax.hist(initialPosG,trdBins,density=True,edgecolor='None',color='lime'  ,alpha=.6, rwidth=1)
    m,_,_=ax.hist(initialPosB,trdBins,density=True,edgecolor='None',color='tomato',alpha=.6, rwidth=1)
    y0,y1=ax.get_ylim()
    
    if pCum:
        ax2=ax.twinx()
        X=np.array(trdBins[1:])-trdBins.step/2
        Y=np.cumsum(n)*trdBins.step
        ax2.plot(X,Y,linewidth=1,color='lime', linestyle='--')
        Ybad=np.cumsum(m)*trdBins.step
        ax2.plot(X,Ybad,linewidth=1,color='tomato', linestyle='--')


#         x=X[Y>=.8][0]
#         y=Y[Y>=.8][0]
#         ax2.hlines(y=y,xmin=x,xmax=trdBins[-1],alpha=.25,linewidth=.5,color='k',zorder=5)
#         ax2.vlines(x=x,ymin=0,ymax=y          ,alpha=.25,linewidth=.5,color='k',zorder=5)
#         ax.fill_betweenx(y=(y0,y1),x1=Y1,x2=x, facecolor=[.5,.5,.5,.2],edgecolor='None')

#         xbad=X[Ybad>=.8][0]
#         ybad=Ybad[Ybad>=.8][0]
#         ax2.vlines(x=xbad,ymin=0,ymax=ybad,alpha=.25,linewidth=.5,color='k',zorder=5)


        ax2.spines['top'].set_visible(False)
        ax2.set_ylim([-.02,1.02])
#         ax2.spines['right'].set_color('m')
#         ax2.tick_params(axis='y', colors='m')
        ax2.set_ylabel('Cumu. Prob.',color='k')
        ax2.set_yticks([0,1])
        ax2.spines['bottom'].set_bounds(Y1,Y2)
        ax2.spines['right'].set_bounds(0,1)

    ax.set_xlabel('Init. position (cm)')
    ax.set_ylabel('Probability')
    ax.set_ylim((y0-y1)/100,y1)
    ax.set_xlim([trdBins[0]-trdBins.step,trdBins[-1]+trdBins.step])
    ax.set_xticks([Y1,Y2/2,Y2])
    ax.set_xticklabels([Y1,'',Y2])
    ax.set_yticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_bounds(Y1,Y2)
    ax.spines['left'].set_bounds(0,y1)
    

In [None]:
if "__file__" not in dir():
    profile={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':['10'],
             'Tag':'Control'
                  }
    animalList=batch_get_animal_list(root,profile)
    animalList=animalList

    Y1,Y2=param['treadmillRange']
    SessionRange=[20,30]
    trdBins=range(Y1,Y2+1,1)
    
    plt.close('all')
    fig=plt.figure(figsize=(2,2),dpi=300)
    ax=fig.add_subplot(111)
    
    plot_probablity_initial_pos(root,profile,animalList, SessionRange, trdBins, ax)
    
    plt.show()


---

**Plot probability of correct trial given Treadmill Speeds**

In [None]:
def prob_correct_given_distance(animalList, profile, SessionRange, GT, spdBins):
    def entrance_time(data):
        data.entranceTime[data.entranceTime==data.maxTrialDuration]=0
        return data.entranceTime
    def treadmill_speed(data):
        return data.treadmillSpeed

    param=[treadmill_speed,entrance_time]

    et=[]
    spd=[]
    for animal in animalList:
        data=data_fetch(root,animal,profile, param, NbSession=100)
        spd.append(data[param[0].__name__][SessionRange[0]:SessionRange[-1]])
        et.append(data[param[1].__name__][SessionRange[0]:SessionRange[-1]])
        

    corrData=dict.fromkeys(animalList,None)
    spdData=dict.fromkeys(animalList,None)
    for i,animal in enumerate(animalList):
        corrData[animal]=[]
        spdData[animal] =[]
        for j in range(len(et[i])):
            if len(et[i][j])==len(spd[i][j]):
                corrData[animal].extend(np.logical_and(et[i][j]>min(GT),et[i][j]<max(GT)))
                spdData[animal].extend(spd[i][j])

    Panimal=np.ones((len(spdBins)-1,len(animalList)))*np.nan
    for col,animal in enumerate(animalList):
        correct =np.array(corrData[animal])
        speed=np.array(spdData[animal])
        for i,(loBin,hiBin) in enumerate(zip(spdBins[:-1],spdBins[1:])):
            signal=correct[np.logical_and(speed>=loBin,speed<hiBin)]
            if len(signal)> 10:
                Panimal[i,col]= sum(signal)/len(signal)

    Ptotal=np.ones((len(spdBins)-1,))*np.nan
    correct=[]
    speed=[]
    for col,animal in enumerate(animalList):
        correct.extend(corrData[animal])
        speed.extend(spdData[animal])
    correct=np.array(correct)
    speed=np.array(speed)
    for i,(loBin,hiBin) in enumerate(zip(spdBins[:-1],spdBins[1:])):
        signal=correct[np.logical_and(speed>=loBin,speed<hiBin)]
        if len(signal)> 10:
            Ptotal[i]= sum(signal)/len(signal)
    
    return Ptotal, Panimal


def plot_cond_prob_correct(ax, animalList, profile, SessionRange, GT, spdBins, **kwargs):
    Ptotal, _=prob_correct_given_distance(animalList,
                                                profile,
                                                SessionRange,
                                                GT,
                                                spdBins)
    
    binSize=spdBins[1]-spdBins[0]    
    
    ax.plot(np.array(spdBins[:-1])+binSize/2, Ptotal, **kwargs)

    
    ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
    ax.yaxis.set_major_formatter(FormatStrFormatter('%g'))
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_bounds(spdBins[0],spdBins[-1])
    ax.spines['left'].set_bounds(0,1)
    ax.set_ylim([-.02,1])
    ax.set_yticks([0,.5,1])
    ax.set_yticklabels([0,'',1])
    ax.set_xlim([spdBins[0]-1,spdBins[-1]])
    ax.set_xticks([spdBins[0],10,20,spdBins[-1]])
    ax.set_xlabel('Speed ($cm.s^{-1}$)')
    ax.set_ylabel(f"Probability")
    ax.yaxis.set_label_coords(-0.04,0.5)
    
    return ax

def add_legend_to_cond_prob_plot(ax, colors, labels):
    lines=[]
    for color, label in zip(colors,labels):
        lines.append(matplotlib.lines.Line2D([], [], color=color,label=label))
    
    leg=ax.legend(handles=lines, title="",title_fontsize=6, handletextpad=.6,
                  bbox_to_anchor=(0, 1),loc=2, ncol=1, fontsize=4, frameon=False)
    

In [None]:
if "__file__" not in dir():
    profile={'Type':'Good',
             'rewardType':'Progressive',
#            'initialSpeed':['var'],
             'Speed':'var',
             'Tag':'Control-Early-var'
            }
    animalList=['Rat125', 'Rat126', 'Rat127', 'Rat128', 'Rat129', 'Rat130', 'Rat153', 'Rat154', 'Rat159', 'Rat160']

    Y1,Y2=param['treadmillRange']
    SessionRange=[20,200]
    spdBins=range(5,31,1)
    GTrange=(6,8)
    colors=get_ordered_colors(colormap='RdBu', n=1)
    labels=[f'$ P ({GTrange[0]} < ET < {GTrange[1]})$']
    plt.close('all')
    fig=plt.figure(figsize=(2,2),dpi=600)
    ax=fig.add_subplot(111)
    
    plot_cond_prob_correct(ax, animalList, profile, SessionRange, GTrange, spdBins, color=colors[0],label=labels[0])
    add_legend_to_cond_prob_plot(ax, colors, labels)
    plt.show()


---

**Plot average trajectory**

In [None]:
def plot_session_median_trajectory(data,ax, **kwargs):
    posDict=data.position
    maxL=np.nanmax(list(data.stopFrame.values()))
    maxL=int(maxL)
    position=np.ones((maxL,len(posDict.keys())))*np.nan
    time=np.arange(-data.cameraToTreadmillDelay,
                   (maxL-data.cameraSamplingRate)/data.cameraSamplingRate,
                   1/data.cameraSamplingRate)
    
    
    for i,trial in enumerate(posDict):
        pos=posDict[trial][:data.stopFrame[trial]]
        position[:len(pos),i]=pos
    
    #keeping data where 70% of points exist
    nanSum=np.sum(np.isnan(position),axis=1)
    try:
        maxTraj=np.where(nanSum>.3*position.shape[1])[0][0]
    except IndexError:
        maxTraj=position.shape[1]
    
    
    ax.plot(time[:maxTraj], np.nanmedian(position,axis=1)[:maxTraj], lw=.5, **kwargs)    

    ax.set_ylabel("Position (cm)")
    ax.set_xlabel("Time (s) relative to camera start")
    
    return position
    

def plot_median_trajectory(root, ax, profile, animalList, sessionIdx,colors):
    
    for i,animal in enumerate(animalList):
        session=batch_get_session_list(root, animalList=[animal], profile=profile)['Sessions'][sessionIdx]

        data=Data(root,session[:6],session,
                  param=param,redoPreprocess=False, saveAsPickle=False);
        data.position_correction()

        plot_session_median_trajectory(data,ax,color=colors[i])

    ax.set_title(f'Session #{sessionIdx+1}',pad=0,fontsize='small')
    ax.set_ylim([Y1,Y2])
    ax.set_yticks([Y1,Y2/2,Y2])
    ax.set_yticklabels([Y1,'',Y2])
    ax.set_xlim([-1,8])
    ax.set_xticks([0,7])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_bounds(0,7)
    ax.set_xlabel('Trial time (s)')
    ax.set_ylabel('Position (cm)')
    ax.spines['left'].set_bounds(Y1,Y2)
        
    return

def plot_grand_average(root, ax, profile, animalList, sessionIdx, **kwargs):
    fig_=plt.figure()
    ax_=fig_.add_subplot(111)
    maxL=int(8*25) #8s * 25fps
    traj=np.empty((maxL,len(animalList)))
    time=np.linspace(-1,7,len(traj))
    for i,animal in enumerate(animalList):
        session=batch_get_session_list(root, animalList=[animal], profile=profile)['Sessions'][sessionIdx]

        data=Data(root,session[:6],session, param=param,redoPreprocess=False, saveAsPickle=False);
        data.position_correction()
        pos=plot_session_median_trajectory(data,ax_)
        traj[:,i]=np.nanmedian(pos,axis=1)[:maxL]
        
    ax.plot(time, np.nanmedian(traj,axis=1), **kwargs)    
    
    plt.close(fig_)
    
    return traj

In [None]:
if "__file__" not in dir():
    profile={'Type':'Good',
                 'rewardType':'Progressive',
#                  'initialSpeed':['var'],
                 'Speed':'var',
                 'Tag':'Control-Early-var'
                }
    animalList=['Rat125', 'Rat126', 'Rat127', 'Rat128', 'Rat129', 'Rat130', 'Rat153', 'Rat154', 'Rat159', 'Rat160']

    sessionIdx=29
    colors=get_ordered_colors(colormap='plasma', n=len(animalList)+1)[:-1]
    
    plt.close('all')
    fig=plt.figure(figsize=(8,4))
    ax=fig.add_subplot(111)

    
    
    plot_median_trajectory(root,ax, profile, animalList, sessionIdx,colors)
    plot_grand_average(root, ax, profile, animalList, sessionIdx, color=[.6,.6,.6,.6],lw=6)

---

**Plot speed distributions**

In [None]:
def prob_var_speed(animalList, profile, spdBins):
    def treadmill_speed(data):
        return data.treadmillSpeed

    param=[treadmill_speed]

    spd=[]
    for animal in animalList:
        data=data_fetch(root,animal,profile, param, NbSession=100)
        spd.append(data[param[0].__name__])
        
    spdData=dict.fromkeys(animalList,None)
    for i,animal in enumerate(animalList):
        spdData[animal] =[]
        for j in range(len(spd[i])):
            spdData[animal].extend(spd[i][j])

    Ptotal=np.ones((len(spdBins)-1,))*np.nan
    speed=[]
    for col,animal in enumerate(animalList):
        speed.extend(spdData[animal])
    speed=np.array(speed)
    
    hist,_=np.histogram(a=speed, bins=spdBins, density=True)
    return hist


def plot_prob_var_speed(gs, animalList, profile, spdBins,
                        wspace,labels=('Variable\nspeed','Control'), **kwargs):
    try:
        Clr=inspect.signature(plot_learning_curve).parameters['color'].default
    except:
        Clr='gray'
    
    gssub = gs.subgridspec(1, 2,wspace=wspace)
    axL=fig.add_subplot(gssub[0,0])
    axR=fig.add_subplot(gssub[0,1])
    axes=[axL,axR]
    Ptotal=prob_var_speed(animalList,profile,spdBins)
    maxBar=max(Ptotal.tolist())
    
    binSize=spdBins[1]-spdBins[0]    
    axL.bar(x=spdBins[:-1], height=Ptotal, width=binSize*.9, align='edge',edgecolor='None',color='k',**kwargs)
    axR.bar(x=spdBins[:-1], height=0.01, width=binSize, align='edge',
            color=Clr,edgecolor='None')
    axR.bar(x=10-binSize/2, height=1/binSize, width=binSize, align='edge',
            color=Clr,edgecolor='None')

    
    for i,ax in enumerate(axes):
        ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
        ax.yaxis.set_major_formatter(FormatStrFormatter('%g'))
        ax.spines['top'].set_visible(False)
        ax.set_yticks([0])
        ax.set_yticklabels('')
        ax.set_ylabel('')
        ax.set_title(labels[i],pad=2,fontsize='x-small')
        ax.tick_params(axis='both', labelsize='x-small')

    axL.set_xlim([0,35])
    axL.set_xticks([5,30])
    axL.set_ylim([-.02*maxBar,1.1*maxBar])
    axL.spines['bottom'].set_bounds(5,30)
    axL.spines['right'].set_visible(False)
    axL.spines['left'].set_bounds(0,maxBar*1.1)
    axL.set_ylabel('Probability',fontsize='x-small')

    
    axR.spines['right'].set_visible(False)
    axR.spines['left'].set_visible(False)
    axR.yaxis.set_visible(False)
    axR.spines['bottom'].set_bounds(6,14)
    axR.set_xlim([5,15])
    axR.set_xticks([10])
    axR.set_ylim([-.02,1.1])
    axR.set_title(axR.get_title(),color=Clr,fontsize='x-small')
    
    totAx4=fig.add_subplot(gs,frameon=False)
    totAx4.set_xlabel('Speed ($cm.s^{-1}$)',fontsize='x-small')
    totAx4.yaxis.set_visible(False)
    totAx4.set_xlim([0,100])
    totAx4.tick_params(color=(0, 0, 0, 0),labelcolor=(0, 0, 0, 0),zorder=-10)
    

In [None]:
if "__file__" not in dir():
    profile={'Type':'Good',
             'rewardType':'Progressive',
#            'initialSpeed':['var'],
             'Speed':'var',
             'Tag':'Control-Early-var'
            }
    animalList=['Rat125', 'Rat126', 'Rat127', 'Rat128', 'Rat129', 'Rat130', 'Rat153', 'Rat154', 'Rat159', 'Rat160']
    spdBins=range(5,31,1)
    wspace=0.3
    
    plt.close('all')
    fig=plt.figure(figsize=(1,1))
    gs=fig.add_gridspec(1,1,left=0, right=1, top=1, bottom=0)[0]


    
    
    plot_prob_var_speed(gs, animalList, profile, spdBins,wspace)

---

**Draw the definition of NOT**

In [None]:
def draw_table_NTO(ax):
    try:
        Clr=inspect.signature(plot_learning_curve).parameters['color'].default
    except:
        Clr='gray'

    cellText=[['$0 \leq ET < 7$'],
              ['$1.5 \leq ET < 7$']
             ]
    rowLabels=['No-Timeout','Control']
    colLabels=['Error Trials:']

    table=ax.table(cellText=cellText, cellColours=None, cellLoc='center',
                                 rowLabels=rowLabels, rowColours=None, rowLoc='center',
                                 colLabels=colLabels, colColours=None, colLoc='center',colWidths=None,
                                 loc='center', bbox=[0.25, 0.25, 1, .5], edges='horizontal')
    table.properties()['celld'][(2, -1)].set_text_props(color=Clr)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    return table

In [None]:
def plot_NTO_definition_bars(ax):
    try:
        Clr=inspect.signature(plot_learning_curve).parameters['color'].default
    except:
        Clr='gray'

    ax.annotate(s='', xy=(0,0), xytext=(0,7), arrowprops=dict(arrowstyle='<->', color='k', shrinkA=0, shrinkB=0))
    ax.annotate(s='', xy=(1,1.5), xytext=(1,7), arrowprops=dict(arrowstyle='<->', color=Clr, shrinkA=0, shrinkB=0))

    ax.tick_params(axis='both', labelsize='x-small')
    ax.yaxis.tick_right()
    ax.yaxis.set_label_position('right')
    ax.xaxis.tick_top()
    ax.xaxis.set_label_position('bottom')
    ax.yaxis.set_major_formatter(FormatStrFormatter('%g'))
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['right'].set_bounds(0,7)
    ax.set_xlim([-.2,1.2])
    ax.set_xticks([0,1])
    ax.set_xticklabels(['No-\nTimeout','Control'])
    ax.tick_params('x',top=False, pad=-20)
#     ax.set_ylabel('$ET (s)$')
    ax.set_ylim([-1,10])
    ax.set_yticks([0,1.5,7])

    ax.get_xticklabels()[1].set_color(Clr)
    ax.set_xlabel('Error trials',fontsize='small')


In [None]:
if "__file__" not in dir():
    colors=get_ordered_colors(colormap='plasma', n=len(animalList)+1)[:-1]
    
    plt.close('all')
    fig=plt.figure(figsize=(2,2))
    ax=fig.add_subplot(111)

#     draw_table_NTO(ax)
    plot_NTO_definition_bars(ax)

------



------

# part 2:

# **GENERATING THE FIGURE**

**Definition of Parameters**

In [None]:
if "__file__" not in dir():
    Y1,Y2=param['treadmillRange']
    #================================================
    # GRID 1 PARAMS
    
    #plotting VAR
    session1VAR='Rat128_2017_06_02_12_26'
    trials1VAR=range(119,130)
    label1VAR='Variable Speed'
    #plotting NTO
    session1NTO='Rat156_2017_09_08_14_10'
    trials1NTO=range(93,104)
    label1NTO='No-Timeout'
    
    dayVAR, dayNTO=days=(30,30)
    

    
    #================================================
    # GRID 2: Trajectory examples
    #nothing
    

    #================================================
    # GRID 3 PARAMS
    profile3VAR={'Type':'Good',
                 'rewardType':'Progressive',
#                  'initialSpeed':['var'],
                 'Speed':'var',
                 'Tag':'Control-Early-var'
                }
    animalList3VAR=['Rat125', 'Rat126', 'Rat127', 'Rat128', 'Rat129', 'Rat130', 'Rat153', 'Rat154', 'Rat159', 'Rat160']
    
    profile3Ctrl={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control']
             }
    animalList3Ctrl=['Rat110','Rat113','Rat120','Rat161','Rat162','Rat163','Rat164','Rat165','Rat166','Rat215',
                     'Rat217','Rat218','Rat219','Rat220','Rat221','Rat222','Rat223','Rat224','Rat225','Rat226',
                     'Rat227','Rat228','Rat229','Rat230','Rat231','Rat232','Rat246','Rat247','Rat248','Rat249',
                     'Rat250','Rat251','Rat252','Rat253','Rat254','Rat255','Rat256','Rat257','Rat258','Rat259',
                     'Rat260','Rat261','Rat262','Rat263','Rat264','Rat265','Rat297','Rat298','Rat299','Rat300',
                     'Rat305','Rat306','Rat307','Rat308']

    TaskParamToPlot3="percentile entrance time"
    stop_dayPlot3 =30
    colors3VAR=get_ordered_colors(colormap='plasma', n=len(animalList3VAR)+1)[:-1]
    colorSig3='goldenrod'
    try:
        color3Ctrl=inspect.signature(plot_learning_curve).parameters['color'].default
    except:
        color3Ctrl='gray'

 
    
    #=================================================
    
    # GRID 4 PARAMS
    profile4NTO={'Type':'Good',
                 'rewardType':'Progressive',
                 'initialSpeed':['10'],
                 'Speed':'10',
                 'Tag':'Control-NoTimeout'
                }
    animalList4NTO=['Rat141', 'Rat142', 'Rat143', 'Rat144', 'Rat155', 'Rat156', 'Rat157', 'Rat158']
    colors4NTO=get_ordered_colors(colormap='plasma', n=len(animalList4NTO)+1)[:-1]

    
    
    #=================================================
    
    #GRID 9: INIT POSITION DISTRIBUTION
    profile={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':['10'],
             'Tag':'Control'
                  }
    animalList=batch_get_animal_list(root,profile)
    animalList=animalList

    SessionRange9=[20,30]
    trdBins9=range(Y1,Y2+1,1)

    
    
    #=================================================
    
    #GRID 10: Prob correct for speed
    

    SessionRange10=[20,200]
    spdBins10=range(5,31,1)
    GTrange10Correct=(7,15)
    GTrange10Perfect=(6,8)
    colors10=get_ordered_colors(colormap='RdBu', n=2)
    labels10=[
              f'$ P ({GTrange10Correct[0]} \leq ET < {GTrange10Correct[1]})$',
              f'$ P ({GTrange10Perfect[0]} <    ET < {GTrange10Perfect[1]})$'
             ]
    
 
    #=================================================
    
    #GRID 11: speed distribution
    
    spdBins11=range(5,31,1)
    wspace11=0.5
    
    #=================================================
    
    #GRID 12: NTO table

        
    
    
    
    #=================================================
    # GENERAL
    param={
        "goalTime":7,#needed for pavel data only
        "treadmillRange":[0,90],#pavel error conversion "treadmillRange":[0,80]
        "maxTrialDuration":15,
        "interTrialDuration":10,#None pavel
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "cameraSamplingRate":25, #needed for new setup    

        "sigmaSmoothPosition":0.1,#0.33, 0.18 pavel
        "sigmaSmoothSpeed":0.3,#0.3, 0.5 pavel
        "nbJumpMax":100,#200 pavel
        "binSize":0.25,
        #parameters used to preprocess (will override the default parameters)
        }
    Y1,Y2=param['treadmillRange']

**Plotting the figure**

In [None]:
if "__file__" not in dir():
    plt.close('all')
    set_rc_params()
    figsize=(7,3)
    fig=plt.figure(figsize=figsize,dpi=600)
    
    
    ###########################################
    # 1: consecutive trajectory VAR
    

    ###########################################
    # 11: speed distribution
   
    gs11=fig.add_gridspec(nrows=1, ncols=1, left=0.0, bottom=0.62, right=0.10, top=.94)
    plot_prob_var_speed(gs11[0], animalList3VAR, profile3VAR, spdBins11,wspace11)

    
    
    ###########################################
    # 2: plot trajectories VAR

        
        
    ###########################################
    # 5: average trajectory VAR
    gs5= fig.add_gridspec(nrows=1, ncols=1, left=0.66, bottom=0.62, right=0.81, top=0.98)
    ax5= fig.add_subplot(gs5[0])
    plot_median_trajectory(root,ax5, profile3VAR, animalList3VAR, stop_dayPlot3-1,colors3VAR)
    plot_grand_average(root, ax5, profile3Ctrl, animalList3Ctrl, stop_dayPlot3-1, color=[.6,.6,.6,.6],lw=2)
        
    

#     ##########################################
    # 3: learning curve VAR
    gs3= fig.add_gridspec(nrows=1, ncols=1, left=0.16, bottom=0.62, right=0.58, top=0.98)
    ax3= fig.add_subplot(gs3[0])
    
    D3_1=plot_dotted_learning_curve(ax3, root, animalList3VAR, profile3VAR, TaskParamToPlot3, stop_dayPlot3,colors3VAR)
    D3_2=plot_learning_curve(ax3, root, animalList3Ctrl, profile3Ctrl, TaskParamToPlot3, stop_dayPlot3)
    permTest3=TwoTailPermTest(group1=D3_1, group2=D3_2, nIterations=1000)
    permTest3.plotSignificant(ax=ax3,y=9.5,color=colorSig3,lw=2)
    ax3.text(x=1,y=9.2,s='Significant',color=colorSig3, ha='left',va='top',fontstyle='italic',fontsize='xx-small')
    s=f'{label1VAR}, $n={len(animalList3VAR)}$ rats'
    ax3.text(x=stop_dayPlot3, y=1, s=s, fontsize=4, zorder=5,ha='right',color='k')
    s=f'Control, $n={len(animalList3Ctrl)}$ rats'
    ax3.text(x=stop_dayPlot3, y=.3, s=s, fontsize=4, zorder=5,ha='right',color=color3Ctrl)

    
    ax3.set_ylabel('$ET$ (s)')
    ax3.set_ylim([0,10])
    ax3.set_yticks([0,7])
    ax3.spines['left'].set_bounds(0,10)
    

    ###########################################
    # 10: probabblity of correct for different speeds
    gs10= fig.add_gridspec(nrows=1, ncols=1, left=0.85, bottom=0.62, right=0.99, top=0.98)
    ax10= fig.add_subplot(gs10[0])
    
    plot_cond_prob_correct(ax10, animalList3VAR, profile3VAR, SessionRange10,
                           GTrange10Correct, spdBins10, color=colors10[0],label=labels10[0])
    plot_cond_prob_correct(ax10, animalList3VAR, profile3VAR, SessionRange10,
                       GTrange10Perfect, spdBins10, color=colors10[1],label=labels10[1])

    add_legend_to_cond_prob_plot(ax10, colors10, labels10)
    ax10.set_ylabel('')

    
    
    ###########################################
    # 6: consecutive trajectory NTO


    ###########################################
    # 7: plot trajectories NTO

    
    ###########################################
    # 12: DEFINITION NTO
   
    gs12=fig.add_gridspec(nrows=1, ncols=1, left=0.0, bottom=0.0, right=0.09, top=.38)
    ax12= fig.add_subplot(gs12[0])
    plot_NTO_definition_bars(ax12)

    
    
    ##########################################
    # 4: learning curve NTO
    gs4= fig.add_gridspec(nrows=1, ncols=1, left=.16, bottom=0.0, right=0.58, top=.38)
    ax4= fig.add_subplot(gs4[0])
    D4_1=plot_dotted_learning_curve(ax4, root, animalList4NTO, profile4NTO, TaskParamToPlot3, stop_dayPlot3,colors4NTO)
    D4_2=plot_learning_curve(ax4, root, animalList3Ctrl, profile3Ctrl, TaskParamToPlot3, stop_dayPlot3)
    permTest4=TwoTailPermTest(group1=D4_1, group2=D4_2, nIterations=1000)
    permTest4.plotSignificant(ax=ax4,y=9.5,color=colorSig3,lw=2)
    s=f'{label1NTO}, $n={len(animalList4NTO)}$ rats'
    ax4.text(x=stop_dayPlot3, y=0, s=s, fontsize=4, zorder=5,ha='right',color='k')

    ax4.set_ylabel('$ET$ (s)')
    ax4.set_ylim([-1,10])
    ax4.set_yticks([0,7])
    ax4.spines['left'].set_bounds(0,10)
    
    

    ###########################################
    # 8: average trajectory NTO
    gs8= fig.add_gridspec(nrows=1, ncols=1, left=0.66, bottom=0.0, right=0.81, top=0.38)
    ax8NTO= fig.add_subplot(gs8[0])
    plot_median_trajectory(root,ax8NTO, profile4NTO, animalList4NTO, stop_dayPlot3-1,colors4NTO)
    plot_grand_average(root, ax8NTO, profile3Ctrl, animalList3Ctrl, stop_dayPlot3-1, color=[.6,.6,.6,.6],lw=2)

        
    
    
    
    ###########################################
    # 9: probabblity of initial pos
    gs9= fig.add_gridspec(nrows=1, ncols=1, left=0.85, bottom=0, right=0.99, top=0.38)
    ax9= fig.add_subplot(gs9[0])
    plot_probablity_initial_pos(root,profile4NTO,animalList4NTO, SessionRange9, trdBins9, ax9, pCum=False)

    
    
    

#     #############################################
#     #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
#     AXES=(axes3[0],ax1,ax2,axes7[0],ax5,ax6)
#     OFFX=(.05,)*len(AXES)
#     OFFY=(.03,)*len(AXES)
#     add_panel_caption(axes=AXES, offsetX=OFFX, offsetY=OFFY)


    fig.savefig(os.path.join(os.path.dirname(os.getcwd()),'BehavioralPaper','Figures','NoT-VarTrd.pdf'),
                format='pdf', bbox_inches='tight')
    
    plt.show()
    plt.close('all')
    matplotlib.rcdefaults()