# Part 0:
## import everything
Run the cell below

In [None]:
# -*- coding: utf-8 -*-
import os
import glob
import numpy as np
from platform import system as OS
import pandas as pd
import scipy.stats
import math
import datetime
from copy import deepcopy
import matplotlib.cm as cm
import warnings
warnings.filterwarnings("ignore")
import types
import string
import sys, time
import pickle
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import mlab
from scipy import stats
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import matplotlib.backends.backend_pdf
from sklearn.decomposition import KernelPCA
import mpl_toolkits.axes_grid1.inset_locator as inset
from matplotlib.ticker import FormatStrFormatter
import imageio
from set_rc_params import set_rc_params


if "__file__" not in dir():
    %matplotlib inline
    %config InlineBackend.close_figures = False
    if OS()=='Linux':
        root="/data"
    elif OS()=='Windows':
        root="C:\\DATA\\"
    else:
        root="/Users/davidrobbe/Documents/Data/"
            
    ThisNoteBookPath=os.path.dirname(os.path.realpath("__file__"))
    CommonNoteBookesPath=os.path.join(os.path.split(ThisNoteBookPath)[0],"load_preprocess_rat")
    CWD=os.getcwd()
    os.chdir(CommonNoteBookesPath)
    %run UtilityTools.ipynb
    %run Animal_Tags.ipynb
    %run loadRat_documentation.ipynb
    %run plotRat_documentation_1_GeneralBehavior.ipynb
    %run plotRat_documentation_3_KinematicsInvestigation.ipynb
    %run RunBatchRat_3_CompareGroups.ipynb
    %run BatchRatBehavior.ipynb
    currentNbPath=os.path.join(os.path.split(ThisNoteBookPath)[0],'BehavioralPaper','SUPP_Imm2Ctrl.ipynb')
    %run $currentNbPath

    os.chdir(CWD)

    logging.getLogger().setLevel(logging.ERROR)
    
    param={
        "goalTime":7,#needed for pavel data only
        "treadmillRange":[0,90],#pavel error conversion "treadmillRange":[0,80]
        "maxTrialDuration":15,
        "interTrialDuration":10,#None pavel
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "cameraSamplingRate":25, #needed for new setup    

        "sigmaSmoothPosition":0.1,#0.33, 0.18 pavel
        "sigmaSmoothSpeed":0.3,#0.3, 0.5 pavel
        "nbJumpMax":100,#200 pavel
        "binSize":0.25,
        #parameters used to preprocess (will override the default parameters)
    }
    Y1,Y2=param['treadmillRange']

    print('os:',OS(),'\nroot:',root,'\nImport successful!')

---
---


# part 1:

# DEFINITIONS

### If you don't know what to do, move to part 2

Utility Functions

In [None]:
def add_panel_caption(axes: tuple, offsetX: tuple, offsetY: tuple, **kwargs):
    """
    This function adds letter captions (a,b,c,d) to Axes in axes
    at top left, with the specified offset, in RELATIVE figure coordinates
    """
    assert len(axes)==len(offsetX)==len(offsetY), 'Bad input!'
    
    fig=axes[0].get_figure()
    fbox=fig.bbox
    for ax,dx,dy,s in zip(axes,offsetX,offsetY,string.ascii_lowercase):
        axbox=ax.get_window_extent()
    
        ax.text(x=(axbox.x0/fbox.xmax)-abs(dx), y=(axbox.y1/fbox.ymax)+abs(dy),
                s=s,fontweight='extra bold', fontsize=10, ha='left', va='center',
               transform=fig.transFigure,**kwargs)


In [None]:
def get_ordered_colors(colormap, n):
    colors = []
    cmap = plt.cm.get_cmap(colormap)
    for colorVal in np.linspace(0, 1, n):
        colors.append(cmap(colorVal))
    return colors

In [None]:
class TwoTailPermTest:
    """
    Permutation test as to whether there is significant difference between group one and two.
    
    group1, group2: Represent the data. they could be either one dimentional (several realizations)
        or 2-D (several realizaions through out the time/space/... course)
        EX: x.shape==(15,500) means 15 trials/samples over 500 time bins

    nIterations: Number of iterations used to shuffle. max(iterN)=(len(x)+len(y))!/len(x)!len(y)!

    initGlobConfInterval:
        Initial value for the global confidence band.

    sigma: the standard deviation of the gaussian kernel used for smoothing when there are multiple data points,
        based on the Fujisawa 2008 paper, default value: 0.05

    Outputs:
        pVal: P-values
        highBand, lowBand: AKA boundary. Represents global bands.
        significantDiff: An array of True or False, indicating whether there is a difference.
    
    """  
    def __init__(self, group1, group2, nIterations=1000, initGlobConfInterval=5, smoothSigma=0.05):
        self.group1, self.group2 = group1, group2
        self.nIterations, self.smoothSigma = nIterations, smoothSigma
        self.initGlobConfInterval = initGlobConfInterval

        self.checkGroups()

        # origGroupDiff is also known as D0 in the definition of permutation test.
        self.origGroupDiff = self.computeGroupDiff(group1, group2)

        # Generate surrogate groups, compute difference of mean for each group, and put in a matrix.
        self.diffSurGroups = self.setDiffSurrGroups()

        # Set statistics
        self.pVal = self.setPVal()
        self.highBand, self.lowBand = self.setBands()
        self.significantDiff = self.setSignificantGroup()

    def checkGroups(self):
        # input check
        if not isinstance(self.group1, np.ndarray) or not isinstance(self.group2, np.ndarray):
            raise ValueError("In permutation test, \"group1\" and \"group2\" should be numpy arrays.")

        if self.group1.ndim > 2 or self.group2.ndim > 2:
            raise ValueError('In permutation test, the groups must be either vectors or matrices.')

        elif self.group1.ndim == 1 or self.group2.ndim == 1:
            self.group1 = np.reshape(self.group1, (len(self.group1), 1))
            self.group2 = np.reshape(self.group2, (len(self.group2), 1))

    def computeGroupDiff(self, group1, group2):
        meanDiff = np.nanmean(group1, axis=0) - np.nanmean(group2, axis=0)
        
        if len(self.group1[0]) == 1 and len(self.group2[0]) == 1:
            return [meanDiff]
        
        return smooth(meanDiff, sigma=self.smoothSigma, order=0, 
                    mode='constant', cval=0, truncate=4.0)

    def setDiffSurrGroups(self):
        # shuffling the data
        self.concatenatedData = np.concatenate((self.group1,  self.group2), axis=0)
        
        diffSurrGroups = np.zeros((self.nIterations, self.group1.shape[1]))
        for iteration in range(self.nIterations):
             # Generate surrogate groups
            surrGroup1, surrGroup2 = self.generateSurrGroup()
            
            # Compute the difference between mean of surrogate groups
            surrGroupDiff = self.computeSurrGroupDiff(surrGroup1, surrGroup2) 
            
            # Store individual differences in a matrix.
            diffSurrGroups[iteration, :] = surrGroupDiff

        return diffSurrGroups

    def generateSurrGroup(self):
        # Shuffle every column.
        np.random.shuffle(self.concatenatedData)  

         # Return surrogate groups of same size.            
        return self.concatenatedData[: self.group1.shape[0], :], self.concatenatedData[self.group1.shape[0]:, :]

    def computeSurrGroupDiff(self, surrGroup1, surrGroup2):
        return self.computeGroupDiff(surrGroup1, surrGroup2)
  
    def setPVal(self):
        positivePVals = np.sum(1*(self.diffSurGroups > self.origGroupDiff), axis=0) / self.nIterations
        negativePVals = np.sum(1*(self.diffSurGroups < self.origGroupDiff), axis=0) / self.nIterations
        return np.array([np.min([1, 2*pPos, 2*pNeg]) for pPos, pNeg in zip(positivePVals, negativePVals)])

    def setBands(self):
        if len(self.origGroupDiff) < 2:  # single point comparison
            return None, None
        
        alpha = 100 # Global alpha value
        highGlobCI = self.initGlobConfInterval  # global confidance interval
        lowGlobCI = self.initGlobConfInterval  # global confidance interval
        while alpha >= 5:
            highBand = np.percentile(a=self.diffSurGroups, q=100-highGlobCI, axis=0)
            lowBand = np.percentile(a=self.diffSurGroups, q=lowGlobCI, axis=0)
            
            breaksPositive = np.sum(
                [np.sum(self.diffSurGroups[i, :] > highBand) > 1 for i in range(self.nIterations)]) 
            
            breaksNegative = np.sum(
                [np.sum(self.diffSurGroups[i, :] < lowBand) > 1 for i in range(self.nIterations)])
            
            alpha = ((breaksPositive + breaksNegative) / self.nIterations) * 100
            highGlobCI = 0.95 * highGlobCI
            lowGlobCI = 0.95 * lowGlobCI
        return highBand, lowBand           

    def setSignificantGroup(self):
        if len(self.origGroupDiff) < 2:  # single point comparison
            return self.pVal <= 0.05

        # finding significant bins
        globalSig = np.logical_or(self.origGroupDiff > self.highBand, self.origGroupDiff < self.lowBand)
        pairwiseSig = np.logical_or(self.origGroupDiff > self.setPairwiseHighBand(), self.origGroupDiff < self.setPairwiseLowBand())
        
        significantGroup = globalSig.copy()
        lastIndex = 0
        for currentIndex in range(len(pairwiseSig)):
            if (globalSig[currentIndex] == True):
                lastIndex = self.setNeighborsToTrue(significantGroup, pairwiseSig, currentIndex, lastIndex)

        return significantGroup
    
    def setPairwiseHighBand(self):        
        return np.percentile(a=self.diffSurGroups, q=100 - self.initGlobConfInterval, axis=0)

    def setPairwiseLowBand(self):        
        return np.percentile(a=self.diffSurGroups, q=self.initGlobConfInterval, axis=0)


    def setNeighborsToTrue(self, significantGroup, pairwiseSig, currentIndex, previousIndex):
        """
            While the neighbors of a global point pass the local band (consecutively), set the global band to true.
            Returns the last index which was set to True.
        """ 
        if (currentIndex < previousIndex):
            return previousIndex
        
        for index in range(currentIndex, previousIndex, -1):
            if (pairwiseSig[index] == True):
                significantGroup[index] = True
            else:
                break

        previousIndex = currentIndex
        for index in range(currentIndex + 1, len(significantGroup)):
            previousIndex = index
            if (pairwiseSig[index] == True):
                significantGroup[index] = True
            else:
                break
        
        return previousIndex
    
    def plotSignificant(self,ax: plt.Axes.axes,y: float,x=None,**kwargs):
        if x is None:
            x=np.arange(0,len(self.significantDiff))+1
        for x0,x1,p in zip(x[:-1],x[1:],self.significantDiff):
            if p:
                ax.plot([x0,x1],[y,y],zorder=-2,**kwargs)
                
    @staticmethod
    def plotSigPair(ax: plt.Axes.axes,y: float,x=None, s: str ='*',**kwargs):
        if x is None:
            x=(0,len(self.significantDiff))
        if 'color' not in kwargs:
            kwargs['color']='k'
        
        dy=.03*(ax.get_ylim()[1]-ax.get_ylim()[0])
        ax.plot(x,[y,y],**kwargs)
        ax.plot([x[0],x[0]],[y-dy,y],[x[1],x[1]],[y-dy,y],**kwargs)
        ax.text(np.mean(x),y,s=s,
                ha='center',va='center',color=kwargs['color'],
                size='xx-small',fontstyle='italic',backgroundcolor='w')

---

plot Imm -> Ctrl transition

In [None]:
def plot_goal_time_change(root, Profiles, N, GT, badAnimals, TaskParamToPlot, wspace, gs, fig):
    """
    gs could be either gridSpec or a list of axes with 
    """
    if isinstance(gs,matplotlib.gridspec.SubplotSpec):
        gssub = gs.subgridspec(1, len(Profiles)-1,wspace=wspace)
        axes=[]
    else:
        gssub=False
        axes=gs
    for eventN in range(len(N)-1):
        #getting the data
        _,SessionDict=event_detect(root, Profiles[eventN], Profiles[eventN+1], badAnimals=badAnimals)

        Results,nSessionPre,nSessionPost=event_statistic(root,
                                                         SessionDict,
                                                         parameter=param,
                                                         redo=False,
                                                         TaskParamToPlot=[TaskParamToPlot])

        
        assert N[eventN]<nSessionPre and N[eventN+1]<nSessionPost,f"fewer sessions available than requested:{Profiles[eventN]['Tag']}"

        data=np.array(list(Results[TaskParamToPlot].values()))
        data=data[:,
                  nSessionPre-N[eventN] : nSessionPre+N[eventN+1]
                 ]

        y=np.nanpercentile(data,50,axis=0)
        yerr=np.nanpercentile(data,(25,75),axis=0)
        
        #getting the axes
        if gssub:
            ax=fig.add_subplot(gssub[0, eventN])
        else:
            ax=axes[eventN]

        #plotting
        xLabel=list(range(-N[eventN],N[eventN+1]+1))
        xLabel.remove(0)
        ax.errorbar(xLabel,y,yerr=abs(yerr-y), ecolor='k', fmt='k-o', markersize=2, elinewidth=1, linewidth=1, markerfacecolor='w',zorder=1)
        ax.plot([xLabel[0],-1,1,xLabel[-1]],[GT[eventN],GT[eventN],GT[eventN+1],GT[eventN+1]],
                'm--',lw=1, zorder=-5)



        if xLabel[0]<=-10 or xLabel[-1]>=10:
            newLabel=[label for label in xLabel if label%10==0]
#             newLabel.extend([1,-1])
#             newLabel.extend([xLabel[0],xLabel[-1]])
            newLabel=list(set(newLabel))
            newLabel.sort()

        ax.set_xlim([xLabel[0]-1,xLabel[-1]+1])
        ax.set_xticks(newLabel)
#         ax.spines['bottom'].set_bounds(xLabel[0],xLabel[-1])
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        ax.set_ylabel(TaskParamToPlot)
#         ax.set_ylim([0,10])
        ax.set_yticks(list(set(GT))+[10,0])
        ax.vlines(x=0, ymin=ax.get_ylim()[0], ymax=ax.get_ylim()[1]*.9, color='k',linestyles='--',lw=.8)

        if gssub:
            axes.append(ax)

    axes[0].set_xlabel('Session relative to goal time change')
    axes[0].spines['left'].set_position(('axes',-.05))
    for ax in axes[1:]:
        ax.yaxis.set_visible(False)
        ax.spines['left'].set_visible(False)

    return tuple(axes)

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'0',
              'Speed':'0',
              'Tag':['ImmobileTreadmill-Late-cue',
                     'ImmobileTreadmill-ValuedReward',
                     'ImmobileTreadmill',
                     'ImmobileTreadmill-NormalReward',
                     'ImmobileTreadmill-BackToGT7']
             }
    
    profile2={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'0',
              'Speed':'10',
              'Tag':['ImmobileTreadmill-Control','Control-AfterBreak']
             }

    #number of sessions to plot
    N=[10,30]
    #goal times
    GT=[7,7]
    badAnimals=['']#['Rat121','Rat122','Rat123','Rat124','Rat132','Rat131']
    TaskParamToPlot="percentile entrance time"
    wspace=0.05
    
    
    Profiles=(profile1,profile2)
    plt.close('all')
    fig=plt.figure(figsize=(10,4))
    gs=fig.add_gridspec(1,1)[0]
    
    axes=plot_goal_time_change(root, Profiles, N, GT, badAnimals, TaskParamToPlot, wspace, gs, fig)
#     axes=plot_goal_time_change(root, Profiles, N, GT, badAnimals,
#                                TaskParamToPlot='standard deviation of entrance time', wspace=wspace, gs=axes, fig=fig)
    
    axes[0].set_ylabel('Entrance Times (s)')

**group ET learning curve**

In [None]:
def plot_learning_curve(ax, root, animalList, profile, TaskParamToPlot, 
                        stop_dayPlot,color='gray'):
    Results,_=get_rat_group_statistic(root,
                                      animalList,
                                      profile,
                                      parameter=param,
                                      redo=False,
                                      stop_dayPlot=stop_dayPlot,
                                      TaskParamToPlot=[TaskParamToPlot])
    
    goalTime=data_fetch(root, animal=animalList[0], profile=profile,
                        PerfParam= [lambda data:data.goalTime[-1]],
                        NbSession=0).values()
    goalTime=list(goalTime)[-1]
    
    x=np.arange(stop_dayPlot)+1.1
    data=np.array( list( Results[TaskParamToPlot].values() ) )
    y=np.nanpercentile(data,50,axis=0)
    yerr=np.nanpercentile(data,(25,75),axis=0)
    
    ax.errorbar(x,y,yerr=abs(yerr-y), ecolor=color, fmt='-o',color=color,
                elinewidth=1, markersize=3, markerfacecolor='w',zorder=1)

    return data
        

def plot_dotted_learning_curve(ax, root, animalList, profile, TaskParamToPlot, 
                               stop_dayPlot,colors, seed=3):
    Results,_=get_rat_group_statistic(root,
                                      animalList,
                                      profile,
                                      parameter=param,
                                      redo=False,
                                      stop_dayPlot=stop_dayPlot,
                                      TaskParamToPlot=[TaskParamToPlot])
    
    goalTime=data_fetch(root, animal=animalList[0], profile=profile,
                        PerfParam= [lambda data:data.goalTime[-1]],
                        NbSession=0).values()
    goalTime=list(goalTime)[-1]
    
    x=np.arange(stop_dayPlot)+1
    data=np.array( list( Results[TaskParamToPlot].values() ) )
    y=np.nanpercentile(data,50,axis=0)
    yerr=np.nanpercentile(data,(25,75),axis=0)
    np.random.seed(seed=seed)
    sigma=.3
    
    ax.errorbar(x,y,yerr=abs(yerr-y), ecolor='k', fmt='k-o',elinewidth=1, markersize=3, markerfacecolor='w',zorder=3)
    
    for pts,day in zip(data.T,x):
        jitter=np.random.uniform(low=day-sigma, high=day+sigma, size=len(pts))
        ax.scatter(jitter,pts,s=.2,c=colors, marker='o',zorder=2)

    ax.set_xlim([x[0]-1,x[-1]+1])
    xtick=[1]
    for i in range(1,stop_dayPlot+1):
        if i%5==0:
            xtick.append(i)
    ax.set_xticks(xtick)
    ax.spines['bottom'].set_bounds(x[0],x[-1])
    ax.yaxis.set_major_formatter(FormatStrFormatter('%g'))
    ax.set_ylim([0,10])
    ax.set_yticks([0,3.5,7])
    ax.set_yticklabels([0,'',7])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.hlines(y=goalTime, xmin=x[0], xmax=x[-1], linestyle='--', lw=1, color='m')
    ax.set_xlabel('Session#')
    ax.set_ylabel(TaskParamToPlot)
    
    
    return data

In [None]:
if "__file__" not in dir():

    profile={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'0',
              'Speed':'10',
              'Tag':['ImmobileTreadmill-Control','Control-AfterBreak']
             }
#     animalList=batch_get_animal_list(root,profile)
    animalList=['Rat121','Rat122','Rat123','Rat124','Rat132','Rat131']
    
    profile0={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control']
             }
    animalList0=['Rat103','Rat104','Rat110','Rat113','Rat120','Rat137','Rat138','Rat139','Rat140','Rat149',
                 'Rat150','Rat151','Rat152','Rat161','Rat162','Rat163','Rat164','Rat165','Rat166','Rat215',
                 'Rat216','Rat217','Rat218','Rat219','Rat220','Rat221','Rat222','Rat223','Rat224','Rat225',
                 'Rat226','Rat227','Rat228','Rat229','Rat230','Rat231','Rat232','Rat246','Rat247','Rat248',
                 'Rat249','Rat250','Rat251','Rat252','Rat253','Rat254','Rat255','Rat256','Rat257','Rat258',
                 'Rat259','Rat260','Rat261','Rat262','Rat263','Rat264','Rat265','Rat297','Rat298','Rat299',
                 'Rat300','Rat305','Rat306','Rat307','Rat308']

    TaskParamToPlot="percentile entrance time"
    stop_dayPlot =30
    colors=get_ordered_colors(colormap='plasma', n=len(animalList)+1)[:-1]
    colorSig='goldenrod'
    
    TaskParamToPlot2="standard deviation of entrance time"
    TaskLabel2='$SD_{ET}$ (s)'

    plt.close('all')
    fig=plt.figure(figsize=(15,12))
    ax=fig.add_subplot(211);
    ax2=fig.add_subplot(212);
    
    d1=plot_dotted_learning_curve(ax, root, animalList, profile, TaskParamToPlot, stop_dayPlot,colors)
    d2=plot_learning_curve(ax, root, animalList0, profile0, TaskParamToPlot, stop_dayPlot)
    ax.set_ylabel('Entrance Times (s)')
#     permTest=TwoTailPermTest(group1=d1, group2=d2, nIterations=10000)
#     permTest.plotSignificant(ax=ax,y=7.5,color=colorSig,lw=2)

    _d1=plot_dotted_learning_curve(ax2, root, animalList, profile, TaskParamToPlot2, stop_dayPlot,colors)
    _d2=plot_learning_curve(ax2, root, animalList0, profile0, TaskParamToPlot2, stop_dayPlot)
    ax2.set_ylabel(TaskLabel2)
#     permTest=TwoTailPermTest(group1=_d1, group2=_d2, nIterations=10000)
#     permTest.plotSignificant(ax=ax2,y=1,color=colorSig,lw=2)


    plt.show()
    plt.close('all')

---

Plot average population trajectory

In [None]:
def animal_position_aligned_on_entrance_time(root,animalList,profile, Win, cs, SessionRange):
    maxPosLen=int(cs*Win)+1
    pos=dict.fromkeys(animalList,None)
    for animal in animalList:
        sessions=batch_get_session_list(root,animalList=[animal],profile=profile)
        sessions=sessions['Sessions'][SessionRange[0]:SessionRange[1]]
        pos[animal]=[]
        for session in sessions:
            position, time, trialColor=position_aligned_on_entrance_time(root,session,Win)
            if position.shape[0] >= maxPosLen:
                position=position[:maxPosLen,:]
            else:
                print(f'bad session: {session}')
                continue

            goodPos=position.T[trialColor=='lime']
            pos[animal].extend(goodPos.tolist())
    
    return pos



def plot_animal_median_trajectory(root,ax,animalList,profile, Win, cs, SessionRange, colors):
    
    pos= animal_position_aligned_on_entrance_time(root,animalList,profile, Win, cs, SessionRange)
    time=np.linspace(-Win,0,np.array(pos[animalList[0]]).shape[1])

    for i,animal in enumerate(animalList):
        data=np.array(pos[animal]).T
        y   =np.nanpercentile(data,50,axis=1)
        sem =scipy.stats.sem(a=data, axis=1, nan_policy='omit')
#         std=np.std(data,axis=1)
        
        ax.plot(time,y, color=colors[i])
        ax.fill_between(x=time, y1=y-sem, y2=y+sem,color=colors[i], alpha=.2)
    #     ax.fill_between(x=time, y1=y-std, y2=y, alpha=.2)



    ax.set_ylim([param['treadmillRange'][0]+10,param['treadmillRange'][1]])
    ax.set_yticks([10,50,80])
    ax.set_xlim([-Win-.05,.05])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_bounds(-Win,0)
#         ax.spines['left'].set_bounds(0,1)
    ax.set_xlabel('Time to '+'$ET$'+' (s)')
    ax.set_ylabel('Position'+'$\pm$'+'sem (cm)')       
        
    return ax


In [None]:
if "__file__" not in dir():
    profile={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':['0'],
             'Speed':'0',
             'Tag':['ImmobileTreadmill']
             }
    animalList=['Rat121', 'Rat122', 'Rat123', 'Rat124', 'Rat131', 'Rat132']

    SessionRange=[20,30]
    Win=2
    cs=25
    colors=get_ordered_colors(colormap='plasma', n=len(animalList)+1)[:-1]
    
    plt.close('all')
    fig=plt.figure(figsize=(4,4))
    ax=fig.add_subplot(111)
    
    
    plot_animal_median_trajectory(root,ax,animalList,profile, Win, cs, SessionRange,colors)


------



------

# part 2:

# GENERATING THE FIGURE

Definition of Parameters

In [None]:
if "__file__" not in dir():
    # GRID 1 PARAMS
    profile1={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'0',
             'Speed':'0',
             'Tag':['ImmobileTreadmill']
             }
    animalList1=batch_get_animal_list(root,profile1)
    animalList1=['Rat121','Rat122','Rat123','Rat124','Rat132','Rat131']

    TaskParamToPlot1="percentile entrance time"
    stop_dayPlot1 =30
    markers1={'naive':u'\u25B6',
            'trained':u"\u25AA",
            'stupidNaive':u'\u25B0',
            'stupidTrained':u"\u2B22"}
    
    #===============================================
    
    # GRID 2 PARAMS
    profile21={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'0',
              'Speed':'0',
              'Tag':['ImmobileTreadmill']
              }
    profile22={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'0',
              'Speed':'0',
              'Tag':['ImmobileTreadmill-GT4']
              }
    profile23={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'0',
              'Speed':'0',
              'Tag':['ImmobileTreadmill-GT2']
              }
    profile24={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'0',
              'Speed':'0',
              'Tag':['ImmobileTreadmill-BackToGT7']
              }

    #number of sessions to plot
    N2=[20,10,10,20]
    #goal times
    GT2=[7,4,2,7]
    badAnimals2=['RatBAD']
    TaskParamToPlot2="percentile entrance time"
    wspace2=0.2
    Profiles2=(profile21,profile22,profile23,profile24)
    #================================================
    
    # GRID 3 PARAMS: entrance time examples
    Win3=2
    #plotting naive
    session3goodNaive    ='Rat123_2017_02_14_18_33'
    #plotting trained
    session3goodTrained  ='Rat123_2017_03_17_17_52'
    #plotting stupid naive animal
    session3badNiave     ='Rat132_2017_04_26_14_53'
    #plotting stupid trained animal
    session3badTrained   ='Rat132_2017_06_01_13_16'
    
    day123_0, day123_1, day132_0, day132_1=days=(2,25,3,28)
    #================================================
    
    # GRID 4: Trajectory examples
    
    #plotting naive
    trials4goodNaive=range(10,18)
    #plotting trained
    trials4goodTrained=range(50,59)
    #plotting stupid naive animal
    trials4badNiave=range(71,80)
    #plotting stupid trained animal
    trials4badTrained=range(115,125)
    #=================================================
    
    # GRID 5: AVERAGE TRAJECTORY
    SessionRange5=[20,30]
    Win5=2
    cs5=25
    
    
    #=================================================
    
    #GRID 6: Probablity of correct
    GT6=7
    distBins6=np.arange(0,130, 10)
    
    
    #=================================================
    
    #GRID 7: Percent Correct Trial
    PerfParam7= '% good trials'
        
    
    #=================================================
    
    #GRID 8: correlation of correct and distance   
    param8=[run_distance, "% good trials"]
            
    
    #=================================================
    # GENERAL
    colors=get_ordered_colors(colormap='plasma', n=len(animalList1)+1)[:-1]
    param={
        "goalTime":7,#needed for pavel data only
        "treadmillRange":[0,90],#pavel error conversion "treadmillRange":[0,80]
        "maxTrialDuration":15,
        "interTrialDuration":10,#None pavel
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "cameraSamplingRate":25, #needed for new setup    

        "sigmaSmoothPosition":0.1,#0.33, 0.18 pavel
        "sigmaSmoothSpeed":0.3,#0.3, 0.5 pavel
        "nbJumpMax":100,#200 pavel
        "binSize":0.25,
        #parameters used to preprocess (will override the default parameters)
        }  

Plotting the figure

In [None]:
if "__file__" not in dir():
    plt.close('all')
    set_rc_params()
    figsize=(7,7)
    fig=plt.figure(figsize=figsize,dpi=600)
    
    
    ##########################################
    # 1: learning curve
    gs1= fig.add_gridspec(nrows=1, ncols=1, left=0.02, bottom=0.82, right=0.46, top=.98)
    ax1= fig.add_subplot(gs1[0])
    plot_dotted_learning_curve(ax1, root, animalList1, profile1, TaskParamToPlot1, stop_dayPlot1,colors, legend=True)
    ax1.set_ylabel('$ET$'+' (s)')
    ######highlighting example sessions
    add_markers_to_example_sessions(root, ax1, profile1, stop_dayPlot1, markers1,days)
    
    
    ##########################################
    # 2: goal time change
    gs2= fig.add_gridspec(nrows=1, ncols=1, left=.55, bottom=0.82, right=0.98, top=.98)
    axes2=plot_goal_time_change(root, Profiles2, N2, GT2, badAnimals2, TaskParamToPlot2, wspace2, gs2[0], fig)
    axes2[0].set_ylabel('')
    for ax in axes2:
        ax.tick_params(axis='x', labelsize=8)
    axes2[0].text(x=-N2[0], y=1, s='$n=4$ rats',fontsize='xx-small')
    
    
    
    ###########################################
    # 3: Traj aligned on Entrance Time
    gs3= fig.add_gridspec(nrows=4, ncols=1, left=0.76, bottom=0.28, right=0.98, top=0.75)
    
    naiveAx3= fig.add_subplot(gs3[2])
    plot_trajectories_aligned_on_entrance_time(root, naiveAx3, session3goodNaive, Win3)
    
    trainedAx3= fig.add_subplot(gs3[3])
    plot_trajectories_aligned_on_entrance_time(root, trainedAx3, session3goodTrained, Win3)
    
    NaiveStupidAx3= fig.add_subplot(gs3[0])
    plot_trajectories_aligned_on_entrance_time(root, NaiveStupidAx3, session3badNiave, Win3)
    
    trainedStupidAx3= fig.add_subplot(gs3[1])
    plot_trajectories_aligned_on_entrance_time(root, trainedStupidAx3, session3badTrained, Win3)
    
    axes3=[naiveAx3, trainedAx3, NaiveStupidAx3, trainedStupidAx3]
    for ax in axes3:
        ax.set_ylabel('')
        ax.set_yticklabels('')
        ax.xaxis.set_visible(False)
        ax.spines['bottom'].set_visible(False)
    trainedAx3.xaxis.set_visible(True)
    trainedAx3.spines['bottom'].set_visible(True)
    trainedAx3.set_xlabel('         '+'Time to '+'$ET$'+' (s)')
    
    
    ###########################################
    # 4: consecutive trajectory examples
    gs4= fig.add_gridspec(nrows=4, ncols=1, left=0.02, bottom=0.28, right=0.75, top=0.75)
    
    naiveAx4= fig.add_subplot(gs4[2])
    plot_consecutive_trajectories(root, session3goodNaive, trials4goodNaive, naiveAx4)
    naiveAx4.text(x=0, y=0, s=f' {session3goodNaive[:6]} session {day123_0} ({markers1["naive"]})',fontsize=8)
    
    trainedAx4= fig.add_subplot(gs4[3])
    plot_consecutive_trajectories(root, session3goodTrained, trials4goodTrained, trainedAx4)
    trainedAx4.text(x=0, y=0, s=f' {session3goodTrained[:6]} session {day123_1} ({markers1["trained"]})',fontsize=8)
    
    NaiveStupidAx4= fig.add_subplot(gs4[0])
    plot_consecutive_trajectories(root, session3badNiave, trials4badNiave, NaiveStupidAx4)
    NaiveStupidAx4.text(x=0, y=80, s=f' {session3badNiave[:6]} session {day132_0} ({markers1["stupidNaive"]})',fontsize=8)
    
    trainedStupidAx4= fig.add_subplot(gs4[1])
    plot_consecutive_trajectories(root, session3badTrained, trials4badTrained, trainedStupidAx4)
    trainedStupidAx4.text(x=0, y=80, s=f' {session3badTrained[:6]} session {day132_1} ({markers1["stupidTrained"]})',fontsize=8)
    
    axes4=[NaiveStupidAx4, trainedStupidAx4,naiveAx4, trainedAx4]
    timeMax=max([ax.get_xlim()[1] for ax in axes4])
    for ax in axes4:
        ax.set_xlim([0,timeMax])
        ax.xaxis.set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_bounds(0,90)
        ax.set_ylabel('')
        ax.set_yticks([0,45,90])
        ax.set_yticklabels([' 0','','90'])
    
    trainedAx4.xaxis.set_visible(True)
    trainedAx4.spines['bottom'].set_visible(True)
    trainedAx4.set_xlabel('Time (s)')
    
    legend4=add_legend_for_consecutive_trajectories(NaiveStupidAx4)
    
    totAx4=fig.add_subplot(gs4[:],frameon=False)
    totAx4.set_ylabel('Position (cm)')
    totAx4.xaxis.set_visible(False)
    totAx4.set_ylim([0,10])
    totAx4.set_yticks([10])
    totAx4.tick_params(color=(0, 0, 0, 0),labelcolor=(0, 0, 0, 0),zorder=-10)
    
    
    
    ###########################################
    # 5: average trajectories
    gs5= fig.add_gridspec(nrows=1, ncols=1, left=0.02, bottom=0.02, right=0.23, top=0.20)
    ax5= fig.add_subplot(gs5[0])
    plot_animal_median_trajectory(root,ax5,animalList1,profile1, Win5, cs5, SessionRange5,colors)

    
    ###########################################
    # 6: probabblity of correct
    gs6= fig.add_gridspec(nrows=1, ncols=1, left=0.76, bottom=0.02, right=0.98, top=0.20)
    ax6= fig.add_subplot(gs6[0])
    plot_cond_prob_correct(ax6, animalList1, profile1, SessionRange5, GT6, distBins6,colors)


    ###########################################
    # 7: percent correct
    gs7= fig.add_gridspec(nrows=1, ncols=1, left=0.31, bottom=0.02, right=0.5, top=0.20)
    ax7= fig.add_subplot(gs7[0])    
    plot_percent_correct(root,ax7,animalList1,profile1, PerfParam7, SessionRange5,colors)
    ax7.set_xticklabels(ax7.get_xticklabels(),color=[0,0,0,0], rotation=0)
    ax7.set_xlabel('Rat#')

    

    ############################################
    # 8: corelation  correct and distance 
    gs8= fig.add_gridspec(nrows=1, ncols=1, left=0.5, bottom=0.02, right=.7, top=0.20)
    ax8= fig.add_subplot(gs8[0])    

    plot_correct_distance_correlation(root, ax8, animalList1, profile1, param8, SessionRange5, colors)
    ax8.set_ylabel('')
    ax8.set_yticklabels('')
    
    #############################################
    #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    AXES=(ax1,axes2[0],axes4[0],ax5,ax7,ax8,ax6)
    OFFX=(.05,)*len(AXES)
    OFFY=(.03,)*len(AXES)
    add_panel_caption(axes=AXES, offsetX=OFFX, offsetY=OFFY)
    
    fig.savefig(os.path.join(os.path.dirname(os.getcwd()),'BehavioralPaper','Figures','ImmTrd.pdf'),
                format='pdf', bbox_inches='tight')
    
    plt.show()
    plt.close('all')
    matplotlib.rcdefaults()