# Part 0:
## import everything
Run the cell below

In [None]:
# -*- coding: utf-8 -*-
import os
import glob
import numpy as np
from platform import system as OS
import pandas as pd
import scipy.stats
from scipy.ndimage.filters import gaussian_filter1d as smooth
import math
import datetime
from copy import deepcopy
import matplotlib.cm as cm
import warnings
warnings.filterwarnings("ignore")
import types
import inspect
import string
import sys, time
import pickle
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import mlab
from scipy import stats
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import matplotlib.backends.backend_pdf
from sklearn.decomposition import KernelPCA
import mpl_toolkits.axes_grid1.inset_locator as inset
from matplotlib.ticker import FormatStrFormatter, MaxNLocator
import imageio
from set_rc_params import set_rc_params


if "__file__" not in dir():
    %matplotlib inline
    %config InlineBackend.close_figures = False
    matplotlib.rcdefaults()
    
    if OS()=='Linux':
        root="/data"
    elif OS()=='Windows':
        root="C:\\DATA\\"
    else:
        root="/Users/davidrobbe/Documents/Data/"
            
    ThisNoteBookPath=os.path.dirname(os.path.realpath("__file__"))
    CommonNoteBookesPath=os.path.join(os.path.split(ThisNoteBookPath)[0],"load_preprocess_rat")
    CWD=os.getcwd()
    os.chdir(CommonNoteBookesPath)
    %run UtilityTools.ipynb
    %run Animal_Tags.ipynb
    %run loadRat_documentation.ipynb
    %run plotRat_documentation_1_GeneralBehavior.ipynb
    %run plotRat_documentation_3_KinematicsInvestigation.ipynb
    %run RunBatchRat_3_CompareGroups.ipynb
    %run BatchRatBehavior.ipynb
    currentNbPath=os.path.join(os.path.split(ThisNoteBookPath)[0],'BehavioralPaper','SUPP_Ctrl2Var.ipynb')
    %run $currentNbPath
    os.chdir(CWD)

    logging.getLogger().setLevel(logging.ERROR)
    
    param={
        "goalTime":7,#needed for pavel data only
        "treadmillRange":[0,90],#pavel error conversion "treadmillRange":[0,80]
        "maxTrialDuration":15,
        "interTrialDuration":10,#None pavel
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "cameraSamplingRate":25, #needed for new setup    

        "sigmaSmoothPosition":0.1,#0.33, 0.18 pavel
        "sigmaSmoothSpeed":0.3,#0.3, 0.5 pavel
        "nbJumpMax":100,#200 pavel
        "binSize":0.25,
        #parameters used to preprocess (will override the default parameters)
           }  
    Y1,Y2=param['treadmillRange']

    print('os:',OS(),'\nroot:',root,'\nImport successful!')

---
---


# part 1:

# DEFINITIONS

### If you don't know what to do, move to part 2

In [None]:
def get_ordered_colors(colormap, n):
    colors = []
    cmap = plt.cm.get_cmap(colormap)
    for colorVal in np.linspace(0, 1, n):
        colors.append(cmap(colorVal))
    return colors

In [None]:
def add_panel_caption(axes: tuple, offsetX: tuple, offsetY: tuple, **kwargs):
    """
    This function adds letter captions (a,b,c,d) to Axes in axes
    at top left, with the specified offset, in RELATIVE figure coordinates
    """
    assert len(axes)==len(offsetX)==len(offsetY), 'Bad input!'
    
    fig=axes[0].get_figure()
    fbox=fig.bbox
    for ax,dx,dy,s in zip(axes,offsetX,offsetY,string.ascii_lowercase):
        axbox=ax.get_window_extent()
    
        ax.text(x=(axbox.x0/fbox.xmax)-abs(dx), y=(axbox.y1/fbox.ymax)+abs(dy),
                s=s,fontweight='extra bold', fontsize=10, ha='left', va='center',
               transform=fig.transFigure,**kwargs)


In [None]:
class TwoTailPermTest:
    """
    Permutation test as to whether there is significant difference between group one and two.
    
    group1, group2: Represent the data. they could be either one dimentional (several realizations)
        or 2-D (several realizaions through out the time/space/... course)
        EX: x.shape==(15,500) means 15 trials/samples over 500 time bins

    nIterations: Number of iterations used to shuffle. max(iterN)=(len(x)+len(y))!/len(x)!len(y)!

    initGlobConfInterval:
        Initial value for the global confidence band.

    sigma: the standard deviation of the gaussian kernel used for smoothing when there are multiple data points,
        based on the Fujisawa 2008 paper, default value: 0.05

    Outputs:
        pVal: P-values
        highBand, lowBand: AKA boundary. Represents global bands.
        significantDiff: An array of True or False, indicating whether there is a difference.
    
    """  
    def __init__(self, group1, group2, nIterations=1000, initGlobConfInterval=5, smoothSigma=0.05):
        self.group1, self.group2 = group1, group2
        self.nIterations, self.smoothSigma = nIterations, smoothSigma
        self.initGlobConfInterval = initGlobConfInterval

        self.checkGroups()

        # origGroupDiff is also known as D0 in the definition of permutation test.
        self.origGroupDiff = self.computeGroupDiff(group1, group2)

        # Generate surrogate groups, compute difference of mean for each group, and put in a matrix.
        self.diffSurGroups = self.setDiffSurrGroups()

        # Set statistics
        self.pVal = self.setPVal()
        self.highBand, self.lowBand = self.setBands()
        self.significantDiff = self.setSignificantGroup()

    def checkGroups(self):
        # input check
        if not isinstance(self.group1, np.ndarray) or not isinstance(self.group2, np.ndarray):
            raise ValueError("In permutation test, \"group1\" and \"group2\" should be numpy arrays.")

        if self.group1.ndim > 2 or self.group2.ndim > 2:
            raise ValueError('In permutation test, the groups must be either vectors or matrices.')

        elif self.group1.ndim == 1 or self.group2.ndim == 1:
            self.group1 = np.reshape(self.group1, (len(self.group1), 1))
            self.group2 = np.reshape(self.group2, (len(self.group2), 1))

    def computeGroupDiff(self, group1, group2):
        meanDiff = np.nanmean(group1, axis=0) - np.nanmean(group2, axis=0)
        
        if len(self.group1[0]) == 1 and len(self.group2[0]) == 1:
            return [meanDiff]
        
        return smooth(meanDiff, sigma=self.smoothSigma, order=0, 
                    mode='constant', cval=0, truncate=4.0)

    def setDiffSurrGroups(self):
        # shuffling the data
        self.concatenatedData = np.concatenate((self.group1,  self.group2), axis=0)
        
        diffSurrGroups = np.zeros((self.nIterations, self.group1.shape[1]))
        for iteration in range(self.nIterations):
             # Generate surrogate groups
            surrGroup1, surrGroup2 = self.generateSurrGroup()
            
            # Compute the difference between mean of surrogate groups
            surrGroupDiff = self.computeSurrGroupDiff(surrGroup1, surrGroup2) 
            
            # Store individual differences in a matrix.
            diffSurrGroups[iteration, :] = surrGroupDiff

        return diffSurrGroups

    def generateSurrGroup(self):
        # Shuffle every column.
        np.random.shuffle(self.concatenatedData)  

         # Return surrogate groups of same size.            
        return self.concatenatedData[: self.group1.shape[0], :], self.concatenatedData[self.group1.shape[0]:, :]

    def computeSurrGroupDiff(self, surrGroup1, surrGroup2):
        return self.computeGroupDiff(surrGroup1, surrGroup2)
  
    def setPVal(self):
        positivePVals = np.sum(1*(self.diffSurGroups > self.origGroupDiff), axis=0) / self.nIterations
        negativePVals = np.sum(1*(self.diffSurGroups < self.origGroupDiff), axis=0) / self.nIterations
        return np.array([np.min([1, 2*pPos, 2*pNeg]) for pPos, pNeg in zip(positivePVals, negativePVals)])

    def setBands(self):
        if len(self.origGroupDiff) < 2:  # single point comparison
            return None, None
        
        alpha = 100 # Global alpha value
        highGlobCI = self.initGlobConfInterval  # global confidance interval
        lowGlobCI = self.initGlobConfInterval  # global confidance interval
        while alpha >= 5:
            highBand = np.percentile(a=self.diffSurGroups, q=100-highGlobCI, axis=0)
            lowBand = np.percentile(a=self.diffSurGroups, q=lowGlobCI, axis=0)
            
            breaksPositive = np.sum(
                [np.sum(self.diffSurGroups[i, :] > highBand) > 1 for i in range(self.nIterations)]) 
            
            breaksNegative = np.sum(
                [np.sum(self.diffSurGroups[i, :] < lowBand) > 1 for i in range(self.nIterations)])
            
            alpha = ((breaksPositive + breaksNegative) / self.nIterations) * 100
            highGlobCI = 0.95 * highGlobCI
            lowGlobCI = 0.95 * lowGlobCI
        return highBand, lowBand           

    def setSignificantGroup(self):
        if len(self.origGroupDiff) < 2:  # single point comparison
            return self.pVal <= 0.05

        # finding significant bins
        globalSig = np.logical_or(self.origGroupDiff > self.highBand, self.origGroupDiff < self.lowBand)
        pairwiseSig = np.logical_or(self.origGroupDiff > self.setPairwiseHighBand(), self.origGroupDiff < self.setPairwiseLowBand())
        
        significantGroup = globalSig.copy()
        lastIndex = 0
        for currentIndex in range(len(pairwiseSig)):
            if (globalSig[currentIndex] == True):
                lastIndex = self.setNeighborsToTrue(significantGroup, pairwiseSig, currentIndex, lastIndex)

        return significantGroup
    
    def setPairwiseHighBand(self):        
        return np.percentile(a=self.diffSurGroups, q=100 - self.initGlobConfInterval, axis=0)

    def setPairwiseLowBand(self):        
        return np.percentile(a=self.diffSurGroups, q=self.initGlobConfInterval, axis=0)


    def setNeighborsToTrue(self, significantGroup, pairwiseSig, currentIndex, previousIndex):
        """
            While the neighbors of a global point pass the local band (consecutively), set the global band to true.
            Returns the last index which was set to True.
        """ 
        if (currentIndex < previousIndex):
            return previousIndex
        
        for index in range(currentIndex, previousIndex, -1):
            if (pairwiseSig[index] == True):
                significantGroup[index] = True
            else:
                break

        previousIndex = currentIndex
        for index in range(currentIndex + 1, len(significantGroup)):
            previousIndex = index
            if (pairwiseSig[index] == True):
                significantGroup[index] = True
            else:
                break
        
        return previousIndex
    
    def plotSignificant(self,ax: plt.Axes.axes,y: float,x=None,**kwargs):
        if x is None:
            x=np.arange(0,len(self.significantDiff))+1
        for x0,x1,p in zip(x[:-1],x[1:],self.significantDiff):
            if p:
                ax.plot([x0,x1],[y,y],zorder=-2,**kwargs)
        

---

**plotting the trajectories of above example sessions**

In [None]:
def plot_trajectories(data,ax):
    posDict=data.position
    time=data.timeTreadmill #align on camera
    Colors=[]
    for trial in posDict:
        color="lime"
        if trial not in data.goodTrials:
            color="tomato"
        Colors.append(color)
        ax.plot(time[trial][:data.stopFrame[trial]], posDict[trial][:data.stopFrame[trial]],
               color=color, lw=.5, )
        

    ax.set_ylabel("X Position (cm)")
    ax.set_xlabel("Time (s) relative to camera start")
    
    ax.vlines(x=np.nanmedian(data.goalTime), 
              ymin=param['treadmillRange'][0], ymax=param['treadmillRange'][1], 
              colors='k',linestyle='--')
    
    return np.array(Colors)



def plot_trajectories_and_distributions(root, ax, session):
    data=Data(root,session[:6],session,redoPreprocess=False)
    y1,y2=treadmillRange=param['treadmillRange']
    
    color=plot_trajectories(data,ax=ax)
    
    position=get_positions_array_beginning(data,onlyGood=False,raw=False)
    position=position.T
    
    histT0,bins0=np.histogram(position[0,color=='lime'],30,range=(y1,y2), density=False)
    histT1,bins1=np.histogram(position[0,color=='tomato'],30,range=(y1,y2), density=False)
    
    maxBin=max([histT0.max(),histT1.max()])/3
    ax.barh(bins0[:-1],-histT0/maxBin,height=np.diff(bins0)[0],align='edge',left=-1.02*data.cameraToTreadmillDelay, color='lime', alpha=.6)
    ax.barh(bins1[:-1],-histT1/maxBin,height=np.diff(bins1)[0],align='edge',left=-1.02*data.cameraToTreadmillDelay, color='tomato', alpha=.6)

    
    ax.set_xlim([-4.2,data.maxTrialDuration[0]])
    ax.set_xticks((0,np.nanmedian(data.goalTime),np.nanmedian(data.maxTrialDuration)))
    ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
    ax.set_ylim([y1,y2])
    ax.set_yticks([y1,y2/2,y2])
    ax.set_yticklabels([y1,'',y2])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_bounds(ax.get_xticks()[0],ax.get_xticks()[-1])
    ax.spines['left'].set_position(('data',-data.cameraToTreadmillDelay))

    ax.set_xlabel('Trial time (s)')
    ax.set_ylabel('Position (cm)')

    return ax

In [None]:
if "__file__" not in dir():
    #the inputs
    
    fig=plt.figure(figsize=(3,6),dpi=100)
    naiveAx=fig.add_subplot(211)
    trainedAx=fig.add_subplot(212)
    #plotting VAR
    session='Rat128_2017_06_02_12_26'
    plot_trajectories_and_distributions(root, naiveAx, session)
    
    #plotting NTO
    session='Rat156_2017_09_08_14_10'
    plot_trajectories_and_distributions(root, trainedAx, session)
    
    
#     fig.savefig('/home/david/Pictures/tst.pdf', format='pdf', bbox_inches='tight')
    plt.show()
    plt.close('all')

---

**plotting Effect of VAriable speed**

In [None]:
def plot_event(root, Profiles, N, badAnimals, TaskParamToPlot, ax):
    """
    gs could be either gridSpec or a list of axes with 
    """
    #getting the data
    _,SessionDict=event_detect(root, Profiles[0], Profiles[1], badAnimals=badAnimals)

    Results,nSessionPre,nSessionPost=event_statistic(root,
                                                     SessionDict,
                                                     parameter=param,
                                                     redo=False,
                                                     TaskParamToPlot=[TaskParamToPlot])

    assert N[0]<nSessionPre and N[1]<nSessionPost,"fewer sessions available than requested:"

    data=np.array(list(Results[TaskParamToPlot].values()))
    data=data[:,
              nSessionPre-N[0] : nSessionPre+N[1]
             ]

    y=np.nanpercentile(data,50,axis=0)
    yerr=np.nanpercentile(data,(25,75),axis=0)

    #plotting
    xLabel=list(range(-N[0],N[1]+1))
    xLabel.remove(0)
    ax.errorbar(xLabel,y,yerr=abs(yerr-y), ecolor='k', fmt='k-o',
                markersize=2, elinewidth=1, linewidth=1, markerfacecolor='w',zorder=1)



    if xLabel[0]<=-10 or xLabel[-1]>=10:
        newLabel=[label for label in xLabel if label%5==0]
#         newLabel.extend([1,-1])
        newLabel.extend([xLabel[0],xLabel[-1]])
        newLabel=list(set(newLabel))
        newLabel.sort()

    ax.set_xlim([xLabel[0]-1,xLabel[-1]+1])
    ax.set_xticks(newLabel)
    ax.spines['bottom'].set_bounds(xLabel[0],xLabel[-1])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
#     ax.yaxis.set_major_formatter(FormatStrFormatter('%g'))
    ax.set_ylabel(TaskParamToPlot)
#         ax.set_ylim([0,10])
#         ax.set_yticks(list(set(GT)))
    ax.axvline(x=0, color='k',linestyle='--',lw=.8)

    return tuple(axes)

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'10',
              'Speed':'10',
              'Tag':['Control-BackTo10']
              }
    profile2={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'10',
              'Speed':'var',
              'Tag':['Control-Late-var']
              }

    #number of sessions to plot
    N=[4,14]
    #goal times
    GT=[7,7]
    badAnimals=['RatBAD']
    TaskParamToPlot="standard deviation of entrance time"
    TaskParamToPlot="percentile entrance time"
    wspace=0.05
    
    
    Profiles=(profile1,profile2,)
    plt.close('all')
    fig=plt.figure(figsize=(5,4))
    ax=fig.add_subplot(111)
    
    plot_event(root, Profiles, N, badAnimals, TaskParamToPlot, ax)

---

**Plot probability of correct trial given Treadmill Speeds**

In [None]:
def prob_correct_given_distance(animalList, profile, SessionRange, GT, spdBins):
    def entrance_time(data):
        data.entranceTime[data.entranceTime==data.maxTrialDuration]=0
        return data.entranceTime
    def treadmill_speed(data):
        return data.treadmillSpeed

    param=[treadmill_speed,entrance_time]

    et=[]
    spd=[]
    for animal in animalList:
        data=data_fetch(root,animal,profile, param, NbSession=100)
        spd.append(data[param[0].__name__][SessionRange[0]:SessionRange[-1]])
        et.append(data[param[1].__name__][SessionRange[0]:SessionRange[-1]])
        

    corrData=dict.fromkeys(animalList,None)
    spdData=dict.fromkeys(animalList,None)
    for i,animal in enumerate(animalList):
        corrData[animal]=[]
        spdData[animal] =[]
        for j in range(len(et[i])):
            if len(et[i][j])==len(spd[i][j]):
                corrData[animal].extend(np.logical_and(et[i][j]>min(GT),et[i][j]<max(GT)))
                spdData[animal].extend(spd[i][j])

    Panimal=np.ones((len(spdBins)-1,len(animalList)))*np.nan
    for col,animal in enumerate(animalList):
        correct =np.array(corrData[animal])
        speed=np.array(spdData[animal])
        for i,(loBin,hiBin) in enumerate(zip(spdBins[:-1],spdBins[1:])):
            signal=correct[np.logical_and(speed>=loBin,speed<hiBin)]
            if len(signal)> 10:
                Panimal[i,col]= sum(signal)/len(signal)

    Ptotal=np.ones((len(spdBins)-1,))*np.nan
    correct=[]
    speed=[]
    for col,animal in enumerate(animalList):
        correct.extend(corrData[animal])
        speed.extend(spdData[animal])
    correct=np.array(correct)
    speed=np.array(speed)
    for i,(loBin,hiBin) in enumerate(zip(spdBins[:-1],spdBins[1:])):
        signal=correct[np.logical_and(speed>=loBin,speed<hiBin)]
        if len(signal)> 10:
            Ptotal[i]= sum(signal)/len(signal)
    
    return Ptotal, Panimal


def plot_cond_prob_correct(ax, animalList, profile, SessionRange, GT, spdBins, **kwargs):
    Ptotal, _=prob_correct_given_distance(animalList,
                                                profile,
                                                SessionRange,
                                                GT,
                                                spdBins)
    
    binSize=spdBins[1]-spdBins[0]    
    
    ax.plot(np.array(spdBins[:-1])+binSize/2, Ptotal, **kwargs)

    
    ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
    ax.yaxis.set_major_formatter(FormatStrFormatter('%g'))
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_bounds(spdBins[0],spdBins[-1])
    ax.spines['left'].set_bounds(0,1)
    ax.set_ylim([-.02,1])
    ax.set_yticks([0,.5,1])
    ax.set_yticklabels([0,'',1])
    ax.set_xlim([spdBins[0]-1,spdBins[-1]])
    ax.set_xticks([spdBins[0],10,20,spdBins[-1]])
    ax.set_xlabel('Speed ($cm.s^{-1}$)')
    ax.set_ylabel(f"Probability")
    ax.yaxis.set_label_coords(-0.04,0.5)
    
    return ax

def add_legend_to_cond_prob_plot(ax, colors, labels):
    lines=[]
    for color, label in zip(colors,labels):
        lines.append(matplotlib.lines.Line2D([], [], color=color,label=label))
    
    leg=ax.legend(handles=lines, title="",title_fontsize=6, handletextpad=.6,
                  bbox_to_anchor=(0, 1),loc=2, ncol=1, fontsize='xx-small', frameon=False)
    

In [None]:
if "__file__" not in dir():
    profile={'Type':'Good',
         'rewardType':'Progressive',
         'initialSpeed':['10','var'],
         'Speed':['var'],
         'Tag':'Control-Late-var'
                  }
    animalList=['Rat077', 'Rat078', 'Rat084', 'Rat085', 'Rat088',
                'Rat091', 'Rat095', 'Rat096', 'Rat098', 'Rat103',
                'Rat104', 'Rat110', 'Rat113', 'Rat120']

    Y1,Y2=param['treadmillRange']
    SessionRange=[0,2]
    spdBins=range(5,31,1)
    GTrange=(6,8)
    colors=get_ordered_colors(colormap='RdBu', n=1)
    labels=[f'$ P ({GTrange[0]} < ET < {GTrange[1]})$']
    plt.close('all')
    fig=plt.figure(figsize=(2,2),dpi=200)
    ax=fig.add_subplot(111)
    
    plot_cond_prob_correct(ax, animalList, profile, SessionRange, GTrange, spdBins, color=colors[0],label=labels[0])
    add_legend_to_cond_prob_plot(ax, colors, labels)
    plt.show()


------



------

# part 2:

# **GENERATING THE FIGURE**

**Definition of Parameters**

In [None]:
if "__file__" not in dir():
    Y1,Y2=param['treadmillRange']
    #================================================
    # GRID 1 PARAMS
    
    profile1={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'10',
              'Speed':'10',
              'Tag':['Control-BackTo10']
              }
    profile1_={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'10',
              'Speed':'var',
              'Tag':['Control-Late-var']
              }

    #number of sessions to plot
    N1=[4,14]
    #goal times
    TaskParamToPlot1="percentile entrance time"
    TaskParamToPlot1_ ="standard deviation of entrance time"
    TaskParamToPlot1__ ="% good trials"  
    
    Profiles1=(profile1,profile1_)       
    wspace1=.5
    
    #================================================
    # GRID 2: Speed
    profile2={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':'10',
             'Speed':'var',
              'Tag':['Control-Late-var']
             }

    animalList2=['Rat077', 'Rat078', 'Rat084', 'Rat085', 'Rat088',
                 'Rat091', 'Rat095', 'Rat096', 'Rat098', 'Rat103',
                 'Rat104', 'Rat110', 'Rat113', 'Rat120']

    SessionRange2=[0,1]
    spdBins2=range(5,31,1)
    GTrange2Correct=(7,15)
    GTrange2Perfect=(6,8)
    colors2=get_ordered_colors(colormap='RdBu', n=2)
    labels2=[
              f'$ P ({GTrange2Correct[0]} \leq ET < {GTrange2Correct[1]})$',
              f'$ P ({GTrange2Perfect[0]} <    ET < {GTrange2Perfect[1]})$'
             ]
    
        
    
    
    
    #=================================================
    # GENERAL
    param={
        "goalTime":7,#needed for pavel data only
        "treadmillRange":[0,90],#pavel error conversion "treadmillRange":[0,80]
        "maxTrialDuration":15,
        "interTrialDuration":10,#None pavel
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "cameraSamplingRate":25, #needed for new setup    

        "sigmaSmoothPosition":0.1,#0.33, 0.18 pavel
        "sigmaSmoothSpeed":0.3,#0.3, 0.5 pavel
        "nbJumpMax":100,#200 pavel
        "binSize":0.25,
        #parameters used to preprocess (will override the default parameters)
        }
    Y1,Y2=param['treadmillRange']

**Plotting the figure**

In [None]:
if "__file__" not in dir():
    plt.close('all')
    set_rc_params()
    figsize=(7,4)
    fig=plt.figure(figsize=figsize,dpi=600)
    
    
    ###########################################
    # 1: variable speed effect
    gs1  = fig.add_gridspec(nrows=1, ncols=3, left=0.02, bottom=0.62, right=0.98, top=0.98,wspace=wspace1)
    ax1  = fig.add_subplot(gs1[0])
    ax1_ = fig.add_subplot(gs1[1])
    ax1__= fig.add_subplot(gs1[2])

    plot_event(root, Profiles1, N1, [], TaskParamToPlot1,   ax1  )
    plot_event(root, Profiles1, N1, [], TaskParamToPlot1_,  ax1_ )
    plot_event(root, Profiles1, N1, [], TaskParamToPlot1__, ax1__)
    
    ax1.axhline(GT1[0] ,linestyle='--',color='m',lw=1, zorder=-5)
    ax1.set_ylabel('$ET$ (s)')
    ax1.yaxis.set_major_locator(MaxNLocator(integer=True))
    ax1.set_yticks([6,7,8])
    ax1.spines['left'].set_bounds(6,8)
    ax1.text(x=N1[1], y=6, s=f'$n={len(animalList2)}$ rats',fontsize='xx-small',ha='right',va='bottom')
    
    ax1_.set_ylabel('$SD_{ET}$ (s)')
    ax1_.set_yticks([1.5,2,2.5])
    ax1_.spines['left'].set_bounds(1.5,2.5)
    
    ax1__.set_ylabel('% Correct trials')
    ax1__.set_yticks([30,40,50,60,70])
    ax1__.spines['left'].set_bounds(30,70)
    ax1__.set_ylim([28,72])

    
    totAx1=fig.add_subplot(gs1[:],frameon=False)
    totAx1.set_xlabel('Relative session#')
    totAx1.yaxis.set_visible(False)
    totAx1.set_xlim([0,10])
    totAx1.set_xticks([10])
    totAx1.tick_params(color=(0, 0, 0, 0),labelcolor=(0, 0, 0, 0),zorder=-10)

    

    ##########################################
    #2: probabblity of correct for different speeds
    gs2= fig.add_gridspec(nrows=1, ncols=1, left=0.3, bottom=0.02, right=0.7, top=0.5)
    ax2= fig.add_subplot(gs2[0])
    
    plot_cond_prob_correct(ax2, animalList2, profile2, SessionRange2,
                           GTrange2Correct, spdBins2, color=colors2[0],label=labels2[0])
    plot_cond_prob_correct(ax2, animalList2, profile2, SessionRange2,
                       GTrange2Perfect, spdBins2, color=colors2[1],label=labels2[1])

    add_legend_to_cond_prob_plot(ax2, colors2, labels2)
    ax2.set_ylabel('')
    spd,data=ax2.lines[1].get_data()
    ax2.vlines(spd[data.argmax()], -10,data.max(),alpha=.25,linewidth=.5,color='k')
    ax2.set_xticks([spdBins2[0],spd[data.argmax()],spdBins2[-1]])
    

    
        
    

    #############################################
    #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
#     AXES=(axes11[0],ax3,ax5,ax10,ax12,ax4,ax8,ax9)
#     OFFX=(.05,)*len(AXES)
#     OFFY=(.03,)*len(AXES)
#     OFFY=(.075,*OFFY[1:])
#     add_panel_caption(axes=AXES, offsetX=OFFX, offsetY=OFFY)


    fig.savefig(os.path.join(os.path.dirname(os.getcwd()),'BehavioralPaper','Figures','SUPP_Ctrl2Var.pdf'),
                format='pdf', bbox_inches='tight')
    
    plt.show()
    plt.close('all')
    matplotlib.rcdefaults()