# Part 0:
## import everything
Run the cell below

In [None]:
import os
import glob
import numpy as np
from platform import system as OS
import pandas as pd
import scipy.stats
import math
import datetime
from copy import deepcopy
import matplotlib.cm as cm
import warnings
warnings.filterwarnings("ignore")
import sys
import pickle
import csv
import string
import matplotlib as mpl
import matplotlib.pyplot as plt
import PIL
import PyPDF2 as ppdf
from scipy import stats
from scipy.ndimage.filters import gaussian_filter as smooth
import matplotlib.animation as animation
import matplotlib.backends.backend_pdf
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.ticker import FormatStrFormatter, MaxNLocator, ScalarFormatter, FuncFormatter
from matplotlib.patches import ConnectionPatch, FancyArrowPatch
from set_rc_params import set_rc_params
import ROOT

try:
    import nest_asyncio
    nest_asyncio.apply()
except:
    pass

if "__file__" not in dir():
    %matplotlib inline
    %config InlineBackend.close_figures = False

    root=ROOT.root
    
    ThisNoteBookPath=os.path.dirname(os.path.realpath("__file__"))
    CommonNoteBookesPath=os.path.join(os.path.split(ThisNoteBookPath)[0],"load_preprocess_rat")
    CWD=os.getcwd()
    os.chdir(CommonNoteBookesPath)
    %run UtilityTools.ipynb
    %run Animal_Tags.ipynb
    %run loadRat_documentation.ipynb
    %run Lesion_Size.ipynb
    %run plotRat_documentation_1_GeneralBehavior.ipynb
    %run plotRat_documentation_3_KinematicsInvestigation.ipynb
    %run RunBatchRat_3_CompareGroups.ipynb
    %run BatchRatBehavior.ipynb
    currentNbPath=os.path.join(os.path.split(ThisNoteBookPath)[0],'LesionPaper','MaxPosAnalysis.ipynb')
    %run $currentNbPath

    os.chdir(CWD)

    logging.getLogger().setLevel(logging.ERROR)
    
    param={
        "goalTime":7,#needed for pavel data only
        "treadmillRange":[0,90],#pavel error conversion "treadmillRange":[0,80]
        "maxTrialDuration":15,
        "interTrialDuration":10,#None pavel
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "cameraSamplingRate":25, #needed for new setup    

        "sigmaSmoothPosition":0.1,#0.33, 0.18 pavel
        "sigmaSmoothSpeed":0.3,#0.3, 0.5 pavel
        "nbJumpMax":100,#200 pavel
        "binSize":0.25,
        #parameters used to preprocess (will override the default parameters)
    }
    Y1,Y2=param['treadmillRange']

    print('os:',OS(),'\nroot:',root,'\nImport successful!')

---
---


# part 1:

# DEFINITIONS

### If you don't know what to do, move to part 2

In [None]:
def add_panel_caption(axes: tuple, offsetX: tuple, offsetY: tuple, **kwargs):
    """
    This function adds letter captions (a,b,c,d) to Axes in axes
    at top left, with the specified offset, in RELATIVE figure coordinates
    """
    assert len(axes)==len(offsetX)==len(offsetY), 'Bad input!'
    
    fig=axes[0].get_figure()
    fbox=fig.bbox
    for ax,dx,dy,s in zip(axes,offsetX,offsetY,string.ascii_uppercase):
        axbox=ax.get_window_extent()
    
        ax.text(x=(axbox.x0/fbox.xmax)-abs(dx), y=(axbox.y1/fbox.ymax)+abs(dy),
                s=s,fontweight='extra bold', fontsize=10, ha='left', va='center',
               transform=fig.transFigure,**kwargs)

---

String Format for Scientific Notation

In [None]:
def SciNote(string):
    """
    Format numbers with Real scientific notation
    Ex: 'p-val={}'.format(SciNote(p))
    """
    f = ScalarFormatter(useOffset=False, useMathText=True)
    g = lambda x,pos : "${}$".format(f._formatSciNotation('%1.2e' % x))
    fmt = FuncFormatter(g)
    return fmt(string)

---

Plot the definition of before after ...

In [None]:
def plot_session_def(ax, sessionSlices= tuple()):
    eps=0
    xpos=ax.get_xticks()
    
    for sli in sessionSlices:
        assert isinstance(sli, slice), 'Bad Boy!'        
        stop=sli.stop
        if stop is None:
            stop=-1
        start=sli.start
        if start>=0:
            start+=1
#             stop+=1
        
        ax.axvspan(xmin=start+eps, xmax=stop-eps, zorder=-1, alpha=.8, ec= None,
                  color='xkcd:ivory')
        
    if len(sessionSlices)==3:
        for sli,String in zip(sessionSlices,['Before','Acute','Stable']):
            stop=sli.stop
            if stop is None:
                stop=-1
            start=sli.start
            if start>=0:
                start+=1
            
            mid=(start+stop)/2
            
            ax.text(mid,ax.get_ylim()[1],s=String,
                   fontsize='xx-small', ha='center',va='top')

---

plotting group double errorbars for speed 

In [None]:
def plot_event_1on1(root, ax, Profiles, colorCode, badAnimals=None, TaskParamToPlot='% good trials',
                   x_pos=None, nPre=slice(-3,None), nPost=slice(0,3),nFin=slice(3,6),
                    seed=1, animal_plot=True):
    if badAnimals is None:
        badAnimals=[]
    
    if x_pos is None:
        diff=.35
        x_c=ax.get_xlim()[1]
        x_pos=(x_c-diff,x_c+diff,x_c+3*diff)
    diff=x_pos[1]-x_pos[0]
    
    #getting the data
    animalList,sessionDict=event_detect(root, Profiles[0], Profiles[1], badAnimals=badAnimals)
    
    data=np.empty((len(animalList),3))
    for i,animal in enumerate(animalList):
        preSession,postSession=sessionDict[animal][0], sessionDict[animal][1]
        out=animal_learning_stats(root, animal, PerfParam=TaskParamToPlot,
                                  goodSessions=[*preSession,*postSession],redo=False)
        
        data[i,0]=np.nanmedian(out[:len(preSession)][nPre])
        data[i,1]=np.nanmedian(out[len(preSession):][nPost])
        data[i,2]=np.nanmedian(out[len(preSession):][nFin])
        
            
    
    y=np.nanpercentile(data,50,axis=0)
    yerr=np.nanpercentile(data,(25,75),axis=0)
    
    #plotting the errorbar
    ax.errorbar(x_pos, y, abs(yerr-y),fmt='o', zorder=5, ms=2, elinewidth=1, color='k')
    
    #plotting individual animals
    if animal_plot:
        np.random.seed(seed=seed)
        _coeff=50000
        for i in range(data.shape[0]):
            if np.any(np.isnan(data[i,:])):
                continue
            jitPre =np.random.uniform(low=x_pos[0]-diff/_coeff, high=x_pos[0]+diff/_coeff, size=1)
            jitPost=np.random.uniform(low=x_pos[1]-diff/_coeff, high=x_pos[1]+diff/_coeff, size=1)
            jitFin =np.random.uniform(low=x_pos[2]-diff/_coeff, high=x_pos[2]+diff/_coeff, size=1)
            
            _,tag=lesion_type(root,animalList[i])
            c=colorCode[tag] if tag in colorCode else 'gray'
            ax.scatter([jitPre,jitFin],data[i,[0,-1]], s=1, c=c, marker='o', edgecolors='none', alpha=.8)
#             ax.plot([jitPre,jitPost,jitFin],data[i,:], c=c, lw=.2, alpha=.4)
            ax.plot([jitPre,jitFin],data[i,[0,-1]], c=c, lw=.2, alpha=.4)

    ax.set_xlim([x_pos[0]-diff/2,x_pos[-1]+diff/2])
    ax.set_xticks(x_pos)
    ax.set_xticklabels(['Before','Acute','Stable'])
    ax.xaxis.set_tick_params(rotation=0)
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylabel(TaskParamToPlot)
    ax.set_xlabel('Sessions relative to lesion')

    return data,animalList

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':['0','10'],
             'Speed':'10',
             'Tag':['Control', 'Control-AfterBreak', 'Control-Late-NoTimeout-BackToTimeout',
              'Control-NoTimeout-Control','Control-Sharp','IncReward-Late-Sharp',
              'Control-Sharp-AfterBreak','ImmobileTreadmill-Control']
             }
    profile2={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':['0','10'],
             'Speed':'10',
             'Tag':['Late-Lesion_DLS','Late-Lesion_DMS','Late-Lesion_DS',
                    'Late-Lesion_DS-Sharp','Late-Lesion_DMS-Sharp']
             }

    #number of sessions to plot
    badAnimals=['RatBAD']
    TaskParamToPlot="% good trials"
    wspace=0.05
    
    
    Profiles=(profile1,profile2,)
    plt.close('all')
    fig=plt.figure(figsize=(5,4))
    ax=fig.add_subplot(111)

    ax.set_xticks([])
    plot_event_1on1(root, ax, Profiles, colorCode={'Control':'r'}, badAnimals=None, TaskParamToPlot=TaskParamToPlot,
                       x_pos=None, nPre=slice(-3,None), nPost=slice(0,3),nFin=slice(3,6),
                        seed=1, animal_plot=True)

---

plot the speed correlation plot

In [None]:
def _late_lesion_effect(root, Profiles, badAnimals, TaskParamToPlot:str,
                        preSlice=slice(-5,None), postSlice=slice(0,5)):
    
    _,sessionDict=event_detect(root, Profiles[0], Profiles[1], badAnimals=badAnimals)
    [sessionDict.pop(key, None) for key in badAnimals]
    
    Results,nSessionPre,nSessionPost=event_statistic2(root,
                                                 sessionDict,
                                                 parameter=param,
                                                 redo=False,
                                                 TaskParamToPlot=[TaskParamToPlot])

    data=np.array(list(Results[TaskParamToPlot].values()))
    slicedPreData  = np.ones((data.shape))*np.nan
    slicedPostData = np.ones((data.shape))*np.nan
    
    for row,Dnan in enumerate(data):
        Dpre= Dnan[:nSessionPre]#[~ np.isnan(Dnan[:nSessionPre])]
        Dpost=Dnan[nSessionPre:]#[~ np.isnan(Dnan[nSessionPre:])]
        preData =Dpre[preSlice]
        postData=Dpost[postSlice]

        slicedPreData[row,:len(preData)]=preData
        slicedPostData[row,:len(postData)]=postData
    
    yPre =np.nanmean(slicedPreData ,axis=1)
    yPost=np.nanmean(slicedPostData,axis=1)

    behav=[]
    animals=[]
    for i,animal in enumerate(Results[TaskParamToPlot].keys()):
        if np.isnan(yPre[i]) or np.isnan(yPost[i]):
            logging.error(f'{animal}: {TaskParamToPlot} not defined!')
            continue

        behav.append(yPost[i]-yPre[i])
        animals.append(animal)
    
    return behav, animals

def late_lesion_correlation_with_size(root, ax, Profiles, Animals, color, TaskParamToPlot:str,
                                      preSlice=slice(-5,None), postSlice=slice(0,5), Excluded='RatBAD'):
    
    _,sessionDict=event_detect(root, Profiles[0], Profiles[1])
    
    sessionDict={animal:sessionDict[animal] for animal in Animals}
    
    Results,nSessionPre,nSessionPost=event_statistic2(root,
                                                 sessionDict,
                                                 parameter=param,
                                                 redo=False,
                                                 TaskParamToPlot=[TaskParamToPlot])

    data=np.array(list(Results[TaskParamToPlot].values()))
    slicedPreData  = np.ones((data.shape))*np.nan
    slicedPostData = np.ones((data.shape))*np.nan
    
    for row,Dnan in enumerate(data):
        Dpre= Dnan[:nSessionPre]#[~ np.isnan(Dnan[:nSessionPre])]
        Dpost=Dnan[nSessionPre:]#[~ np.isnan(Dnan[nSessionPre:])]
        preData =Dpre[preSlice]
        postData=Dpost[postSlice]

        slicedPreData[row,:len(preData)]=preData
        slicedPostData[row,:len(postData)]=postData
    
    yPre =np.nanmean(slicedPreData ,axis=1)
    yPost=np.nanmean(slicedPostData,axis=1)

    size=[]
    behav=[]
    animals=[]
    for i,animal in enumerate(Results[TaskParamToPlot].keys()):
        if np.isnan(yPre[i]) or np.isnan(yPost[i]):
            logging.error(f'{animal}: {TaskParamToPlot} not defined!')
            continue
        try:
            Hist=HistologyExcel('/NAS02',animal)
        except Exception as e:
            logging.error(f'{animal}: {repr(e)}')
            continue
        
        behav.append(yPost[i]-yPre[i])
        size.append(Hist.lesion_size())
        animals.append(animal)
        
    
    #plotting
    if ax is not None:
        for i,animal in enumerate(animals):
            if animal == Excluded: continue
            _,tag=lesion_type(root,animal)
            ax.scatter(size[i], behav[i], s=5, c=color[tag]);
#             ax.annotate(animal[-3:], (size[i],behav[i]), fontsize='xx-small')

        ax.set_xticks(np.arange(0,1.01,.1))
        ax.set_xticklabels(['0','','','','','0.5','','','','','1'])
        ax.spines['bottom'].set_bounds(0,1)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_xlim([-.02,1.02])
        ax.set_ylabel(f'$\Delta$ {TaskParamToPlot}')
        ax.set_xlabel('Lesion size')

    
    return behav, size, animals

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':['0','10'],
             'Speed':'10',
             'Tag':['Control', 'Control-AfterBreak', 'Control-Late-NoTimeout-BackToTimeout',
              'Control-NoTimeout-Control','Control-Sharp','IncReward-Late-Sharp',
              'Control-Sharp-AfterBreak','ImmobileTreadmill-Control']
             }
    profile2={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':['0','10'],
             'Speed':'10',
             'Tag':['Late-Lesion_DLS','Late-Lesion_DMS','Late-Lesion_DS',
                    'Late-Lesion_DS-Sharp','Late-Lesion_DMS-Sharp']
             }

    Profiles=(profile1,profile2)
    TaskParamToPlot="Forward Running Speed"
    preSlice=slice(-5,None)
    postSlice=slice(3,8)
    
    color={'DS':'r','DMS':'g', 'DLS':'b', 'Control':'k'}
    
    plt.close('all')
    ax=plt.subplot(111)

    
#     a=late_lesion_correlation_with_size(root, ax=ax, Profiles=Profiles,color=color, TaskParamToPlot=TaskParamToPlot,
#                                       preSlice=preSlice, postSlice=postSlice)

---

Max Pos Comparison for Animals with Speed Effect

In [None]:
def plot_event_with_animalList(root, ax, Profiles, colorCode, animalList, TaskParamToPlot="Maximum Position",
                   x_pos=None, nPre=slice(-3,None), nPost=slice(0,3),nFin=slice(3,6),
                    seed=1, animal_plot=True):
    
    if x_pos is None:
        diff=.35
        x_c=ax.get_xlim()[1]
        x_pos=(x_c-diff,x_c+diff,x_c+3*diff)
    diff=x_pos[1]-x_pos[0]
    
    #getting the data
    _,sessionDict=event_detect(root, Profiles[0], Profiles[1])
    
    data=np.ones((len(animalList),3))*np.nan
    for i,animal in enumerate(animalList):
        if animal not in sessionDict:
            logging.warning ("Animals don't match the Profiles!")
            continue
        preSession,postSession=sessionDict[animal][0], sessionDict[animal][1]
        out=animal_learning_stats(root, animal, PerfParam=TaskParamToPlot,
                                  goodSessions=[*preSession,*postSession],redo=False)
        
        data[i,0]=np.nanmedian(out[:len(preSession)][nPre])
        data[i,1]=np.nanmedian(out[len(preSession):][nPost])
        data[i,2]=np.nanmedian(out[len(preSession):][nFin])
            
    
    y=np.nanpercentile(data,50,axis=0)
    yerr=np.nanpercentile(data,(25,75),axis=0)
    
    #plotting the errorbar
    ax.errorbar(x_pos, y, abs(yerr-y),fmt='o', zorder=5, ms=2, elinewidth=1, color='k')
    
    #plotting individual animals
    if animal_plot:
        np.random.seed(seed=seed)
        _coeff=50000
        for i in range(data.shape[0]):
            if np.any(np.isnan(data[i,[0,-1]])):
                continue
            jitPre =np.random.uniform(low=x_pos[0]-diff/_coeff, high=x_pos[0]+diff/_coeff, size=1)
            jitPost=np.random.uniform(low=x_pos[1]-diff/_coeff, high=x_pos[1]+diff/_coeff, size=1)
            jitFin =np.random.uniform(low=x_pos[2]-diff/_coeff, high=x_pos[2]+diff/_coeff, size=1)
            
            _,tag=lesion_type(root,animalList[i])
            c=colorCode[tag] if tag in colorCode else 'gray'
            ax.scatter([jitPre,jitFin],data[i,[0,-1]], s=1, c=c, marker='o', edgecolors='none', alpha=.8)
#             ax.plot([jitPre,jitPost,jitFin],data[i,:], c=c, lw=.2, alpha=.4)
            ax.plot([jitPre,jitFin],data[i,[0,-1]], c=c, lw=.2, alpha=.4)

    ax.set_xlim([x_pos[0]-diff/2,x_pos[-1]+diff/2])
    ax.set_xticks(x_pos)
    ax.set_xticklabels(['Before','Acute','Stable'])
    ax.xaxis.set_tick_params(rotation=0)
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylabel(TaskParamToPlot)
    ax.set_xlabel('Sessions relative to lesion')

    return data

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':['0','10'],
             'Speed':'10',
             'Tag':['Control', 'Control-AfterBreak', 'Control-Late-NoTimeout-BackToTimeout',
              'Control-NoTimeout-Control','Control-Sharp','IncReward-Late-Sharp',
              'Control-Sharp-AfterBreak','ImmobileTreadmill-Control']
             }
    profile2={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':['0','10'],
             'Speed':'10',
             'Tag':['Late-Lesion_DLS','Late-Lesion_DMS','Late-Lesion_DS',
                    'Late-Lesion_DS-Sharp','Late-Lesion_DMS-Sharp']
             }

    #number of sessions to plot   
    
    Profiles=(profile1,profile2,)
    plt.close('all')
    fig=plt.figure(figsize=(5,4))
    ax=fig.add_subplot(111)

    behav, size, animals=late_lesion_correlation_with_size(root, ax=ax, Profiles=Profiles, badAnimals=[],
                                                              color=ColorCode,
                                                              TaskParamToPlot="Forward Running Speed")

    plt.close('all')
    fig=plt.figure(figsize=(5,4))
    ax=fig.add_subplot(111)
    
    animalList=[animal for i,animal in enumerate(animals) if behav[i]<0]
    d,a=plot_event_with_animalList(root, ax, Profiles, colorCode={'Control':'r'}, animalList=animalList,
                                  nPre=slice(-3,None), nPost=slice(0,3),nFin=slice(3,6))


---

Plot Traj of EX sessions

In [None]:
def plot_session_median_trajectory(data,ax,trialToPlot):
    posDict=data.position
    maxL=np.nanmax(list(data.stopFrame.values()))
    maxL=int(maxL)
    position=np.ones((maxL,len(trialToPlot)))*np.nan
    time=np.arange(-data.cameraToTreadmillDelay,
                   (maxL-data.cameraSamplingRate)/data.cameraSamplingRate,
                   1/data.cameraSamplingRate)
    
    i=0
    for trial in posDict:
        if trial not in trialToPlot:
            continue
        pos=posDict[trial][:data.stopFrame[trial]]
        position[:len(pos),i]=pos
        i+=1
    
    #keeping data where 70% of points exist
    nanSum=np.sum(np.isnan(position),axis=1)
    try:
        maxTraj=np.where(nanSum>.3*position.shape[1])[0][0]
    except IndexError:
        maxTraj=position.shape[1]
    
    
    ax.plot(time[:maxTraj], np.nanmean(position,axis=1)[:maxTraj], color='midnightblue', lw=2, zorder=10)
    
    return np.nanmean(position,axis=1)[:maxTraj]

def plot_trajectories(data,ax):
    posDict=data.position
    time=data.timeTreadmill #align on camera
    
    _,hardR,peaks=sequentialTrials(data)._compute_max_pos()
    
    Colors=[]
    for trial in posDict:
        color="steelblue"
        zorder=1
        alpha=1
        if trial not in hardR:
            color="salmon"
            zorder=0
            alpha=.2
        Colors.append(color)
        _End=int(data.cameraSamplingRate*data.entranceTime[trial])+ \
        int(data.cameraSamplingRate*data.cameraToTreadmillDelay)+1

        ax.plot(time[trial][:_End], posDict[trial][:_End],
               color=color, lw=.5, zorder=zorder, alpha=alpha)
    
    midTrial=np.argsort(peaks)[len(peaks)//2]
    midTrial=hardR[midTrial]
    _End=int(data.cameraSamplingRate*data.entranceTime[midTrial])+ \
    int(data.cameraSamplingRate*data.cameraToTreadmillDelay)+1
    ax.plot(time[midTrial][:_End], posDict[midTrial][:_End], color='midnightblue', lw=1, zorder=10)
#     ax.fill_betweenx(y=(0,90),x1=0,x2=7, facecolor='gray', edgecolor=None, alpha=.4)
    
    return np.array(Colors),posDict[midTrial][:_End]



def plot_trajectories_and_distributions(root, ax, session):
    data=Data(root,session[:6],session,redoPreprocess=False)
    
    color,med_traj=plot_trajectories(data,ax=ax)
    
    
    
    ax.set_xlim([0,10.5])
    ax.set_ylim([0,90])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_bounds(0,10)
    
    ax.tick_params(bottom=False, top=False, left=False, right=False,
                  labelbottom=False, labeltop=False, labelleft=False, labelright=False)
    
    return med_traj

In [None]:
if "__file__" not in dir():
    session='Rat302_2018_12_12_14_27'
#     session ='Rat302_2019_01_14_14_27'

    plt.close('all')
    ax=plt.figure(figsize=(3,3)).add_subplot(111);

    plot_trajectories_and_distributions(root, ax, session)

---

plotting the predefined image

In [None]:
def plot_animal_image(ax, animal):
    PATHS=('/NAS02/Rat302/Histology/DS_45.jpg',)
    
    try:
        filePath=[path for path in PATHS if animal in path][0]
    except IndexError:
        logging.error(f'Bad Animal name ({animal}), path not defined!')
        return
    
    f=PIL.Image.open(filePath)
    f.thumbnail((500,500),PIL.Image.ANTIALIAS)
    ax.imshow(f)

---

plotting group time course

In [None]:
def normalize_performance(a,nPre):
    """implement a normalizing function to operate on 'a' and nPre first point are baseline
       a: np.array: animal x session
    """
    meanBase= np.nanmedian(a[:,:nPre], axis=1, keepdims=True)
    
    out=a-meanBase
    return out

def plot_normalized_time_course(root, ax, Profiles, Animals=None, TaskParamToPlot='% good trials', shift=0,
                   nPre=5, nPost=15, **kwarg):

    #getting the data
    animalList,SessionDict=event_detect(root, Profiles[0], Profiles[1])
    
    if Animals is None:
        Animals=animalList
    
    SessionDict={animal:SessionDict[animal] for animal in Animals}
    Results,nSessionPre,nSessionPost=event_statistic2(root,
                                                      SessionDict,
                                                      parameter=param,
                                                      redo=False,
                                                      TaskParamToPlot=[TaskParamToPlot])
    
    assert nPre<=nSessionPre and nPost<=nSessionPost,"fewer sessions available than requested:"
    
    data=np.array(list(Results[TaskParamToPlot].values()))
    
    xData=np.append(np.arange(-nPre,0),np.arange(1,nPost+1))
    data=data[:,nSessionPre-nPre:nSessionPre+nPost]
    
    data=normalize_performance(data,nPre)
    
    reliableAnimals=np.sum(np.isnan(data),axis=1) < data.shape[1] / 2
    
    data=data[reliableAnimals,:]
    animalList=np.array(list(Results[TaskParamToPlot].keys()))[reliableAnimals]
    
    groupData=np.nanmedian(data,axis=0)
    
    groupErr=np.nanpercentile(data,(25,75),axis=0)
    
    #plotting the errorbar
    ax.errorbar(xData + shift, groupData,
                abs(groupErr-groupData),
                fmt='-o', zorder=5, ms=2, elinewidth=1,alpha=.8, **kwarg)
    
    ax.axhline(0, ls=':',c='gray',lw=1)
    
    def _tik(x,pos):
        if x in [-nPre,-1,1,nPost]:
            return ('$%+g$' if x>0 else '$%g$')%x
        else:
            return ''
    ax.xaxis.set_major_formatter(mpl.ticker.FuncFormatter(_tik))
    
    ax.set_xlim([-nPre-.3,nPost+.3])
    ax.spines['bottom'].set_bounds(-nPre,nPost)
    ax.set_xticks(xData)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
#     ax.set_ylim([0,100])
    ax.set_ylabel('Norm. '+TaskParamToPlot)
    ax.set_xlabel('Session # relative to lesion')

    return data,animalList

def add_legend_for_rat_groups(ax,N):
    g_marker = matplotlib.lines.Line2D([0,0], [1,1], color='gray',lw=1,
                                       markeredgecolor='gray', marker='o',
                                       markerfacecolor='gray', markersize=2, label=f'$\Delta$Speed$>0$\n($n={N[0]}$)')
    b_marker = matplotlib.lines.Line2D([0,0], [1,1], color='k',
                                       markeredgecolor='k', marker='o',lw=1,
                                       markerfacecolor='k', markersize=2, label=f'$\Delta$Speed$<0$\n($n={N[1]}$)')


    leg=ax.legend(handles=[b_marker,g_marker],loc=(0,.9),ncol=2,
                  facecolor=None,edgecolor=None, fontsize=5,frameon=False)
    return leg

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':['10','0'],
              'Speed':'10',
              'Tag':['Control', 'Control-AfterBreak', 'Control-Late-NoTimeout-BackToTimeout', 'Control-NoTimeout-Control',
                     'Control-Sharp','IncReward-Late-Sharp','Control-Sharp-AfterBreak','ImmobileTreadmill-Control']
              }
    profile2={'Type': ['Good'],
             'rewardType': ['Progressive'],
             'option': ['not used', 'AsymmetricLesion'],
             'initialSpeed': ['0', '10'],
             'Speed': ['10'],
             'Tag': ['Late-Lesion_DLS',
              'Late-Lesion_DMS-Sharp',
              'Late-Lesion_DMS',
              'Late-Lesion_DS',
              'Late-Lesion_DS-Sharp']}

    #number of sessions to plot
    animalList,_=event_detect(root,profile1,profile2)
    TaskParamToPlot="Forward Running Speed"
    wspace=0.05
    
    
    Profiles=(profile1,profile2,)
    plt.close('all')
    fig=plt.figure(figsize=(5,4))
    ax=fig.add_subplot(111)

#     a=plot_normalized_time_course(root, ax, Profiles,Animals=animalList,TaskParamToPlot=TaskParamToPlot,color='k',shift=-.1*.25)
    add_legend_for_rat_groups(ax,(5,1))

---

plotting group lesion size

In [None]:
def plot_group_lesion_bar(animalList, ax=None, x=0, color='k'):
    _W=.8   #bar width
    lesionsize=[]
    for animal in animalList:
        try:
            tmp=HistologyExcel('/NAS02',animal).lesion_size()
        except:
            tmp=np.nan
        finally:
            lesionsize.append(tmp)
    
    if ax is not None:
        ax.bar(x,np.nanmedian(lesionsize), ec=None, fc=color, zorder=1)
        
        _coeff=2
        x_vals=np.random.uniform(low=x-_W/_coeff, high=x+_W/_coeff, size=len(animalList))
        ax.scatter(x_vals, lesionsize, s=3, c='xkcd:red', marker='o', edgecolors='none', alpha=.8, zorder=2)
        np.random.seed(seed=3)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.set_xlim([x-_W,x+_W])
        ax.set_xticks([x])
        
    return np.array(lesionsize)

In [None]:
if "__file__" not in dir():
    plt.close('all')
    fig=plt.figure(figsize=(5,4))
    ax=fig.add_subplot(111)

    plot_group_lesion_bar(['Rat304','Rat302','Rat079'],ax=ax)

---

Stef: Model

In [None]:
def penalty(x,k,mu,amp):
    '''This function generates a Heaviside step function of steepness k, centered around mu and of amplitude amp'''
    return(amp/(1+np.exp(2*k*(x-mu))))


def compute_cost_trajectories(trajectory, speed, force, a, b, cost_type, 
                              vtapis, Ltread, xb, 
                              kxpenalty, ampxpenalty):
    
    n_trajectoires = trajectory.shape[1]
    inst_cost = np.zeros(trajectory.shape)
    
    if cost_type == 'speed_quadratic':
        for i_t in range(n_trajectoires):
            inst_cost[:,i_t] = a*(speed[i_t]-vtapis)**2 + b*(trajectory[i_t] - Ltread)**2
    
    elif cost_type == 'force_quadratic':   
        for i_t in range(n_trajectoires):
            inst_cost[:,i_t] = a*(force[i_t]-vtapis)**2 + b*(trajectory[i_t] - Ltread)**2
        
    elif cost_type == 'speed_heaviside':
        for i_t in range(n_trajectoires):
            inst_cost[:,i_t] = a*(speed[:,i_t]-vtapis)**2 
            + b*penalty(x=trajectory[:,i_t], k=kxpenalty, mu=xb, amp=ampxpenalty)       
    
    elif cost_type == 'force_heaviside':   
        for i_t in range(n_trajectoires):
            inst_cost[:,i_t] = a*(force[:,i_t]-vtapis)**2 
            + b*penalty(x=trajectory[i_t], k=kxpenalty, mu=xb, amp=ampxpenalty)
    return inst_cost

------



------

# part 2:

# GENERATING THE FIGURE

Definition of Parameters

In [None]:
if "__file__" not in dir():
    # GENERAL PARAMS
    
    CtrlColor='gray'
    DLSColor='xkcd:orange'
    DMSColor='purple'
    DSColor='xkcd:green'
    
    ColorCode={'DS':DSColor,
               'DMS':DMSColor,
               'DLS':DLSColor,
               'Control':CtrlColor
              }
    
    BadLateRats=('Rat223','Rat231')
       

    
    #===============================================
    
    # GRID 1 PARAMS
       
    
    TaskParamToPlot1="Forward Running Speed"
    
    profilePreLesion1={'Type':'Good',
              'rewardType':'Progressive',
              'initialSpeed':['10','0'],
              'Speed':'10',
              'Tag':['Control', 'Control-AfterBreak', 'Control-Late-NoTimeout-BackToTimeout', 'Control-NoTimeout-Control',
                     'Control-Sharp','IncReward-Late-Sharp','Control-Sharp-AfterBreak','ImmobileTreadmill-Control']
              }
    
    profileDLS1={'Type':'Good',
              'rewardType':'Progressive',
              'option':['not used', 'AsymmetricLesion'],
              'initialSpeed':['10','0'],
              'Speed':'10',
              'Tag':['Late-Lesion_DLS']
                 }
    profileDMS1={'Type':'Good',
              'rewardType':'Progressive',
              'option':['not used', 'AsymmetricLesion'],
              'initialSpeed':['10','0'],
              'Speed':'10',
              'Tag':['Late-Lesion_DMS','Late-Lesion_DMS-Sharp'],
                 }  
    profileDS1={'Type':'Good',
              'rewardType':'Progressive',
              'option':['not used', 'AsymmetricLesion'],
              'initialSpeed':['10','0'],
              'Speed':'10',
              'Tag':['Late-Lesion_DS','Late-Lesion_DS-Sharp']
                 }  

    ProfilesDLS1=(profilePreLesion1,profileDLS1,)
    ProfilesDMS1=(profilePreLesion1,profileDMS1,)
    ProfilesDS1 =(profilePreLesion1,profileDS1,)
    
    
    
    #================================================
    
    # GRID 2 PARAMS
    
    sessionBefore2='Rat302_2018_12_12_14_27'
    sessionAfter2 ='Rat302_2019_01_14_14_27'
    
    
    #================================================
    
    # GRID 3 PARAMS

    profileLesions3={'Type':'Good',
                     'rewardType':'Progressive',
                     'option':['not used', 'AsymmetricLesion'],
                     'initialSpeed':['0','10'],
                     'Speed':'10',
                     'Tag': list(set((*profileDLS1['Tag'], *profileDMS1['Tag'], *profileDS1['Tag'])))
                     }

    Profiles3=(profilePreLesion1,profileLesions3)
    
    preSlice3=slice(-5,None)
    postSlice3=slice(0,3)
    finSlice3=slice(8,13)
    
    minSpdReduction3=0
    
    TaskParamToPlot3="Maximum Position Relxed"
    
    #================================================
    
    # GRID 4 PARAMS
    
    



Plotting the figure

In [None]:
if "__file__" not in dir():
    plt.close('all')
    set_rc_params({'axes.labelsize':'x-small'})
    figsize=(5,4)
    fig=plt.figure(figsize=figsize,dpi=600)
    
        
##=====================================================================
##===================STEF's CODE BELOW=================================
##=====================================================================
    baseFolder=os.path.dirname(currentNbPath)
    # Panel A and B
    simulations_results_folder = baseFolder+'/PickleResults/Simulations'
    diffuse_heaviside_effort_file = simulations_results_folder + '/DiffuseHeavisideSpatialCostEffort.pickle'
    file_to_open = diffuse_heaviside_effort_file
    with open(file_to_open, 'rb') as handle:
        results_simulations = pickle.load(handle)  

    xk_vect_speed_heav = results_simulations['xk_vect_speed_heav']
    tk = results_simulations['tk_vect_speed_heav'][:,0]
    xdotk_vect_speed_heav = results_simulations['xdotk_vect_speed_heav']
    acck_vect_speed_heav = results_simulations['acck_vect_speed_heav']
    # Panel C and D
    data_folder = baseFolder+'/Data_for_fit/Before/'
    file_to_open = data_folder + 'clean_max_pos_traj.pickle'
    with open(file_to_open, 'rb') as handle:
        trajectories_before = pickle.load(handle)    

    data_folder = baseFolder+'/Data_for_fit/Final/'
    file_to_open = data_folder + 'clean_max_pos_traj.pickle'
    with open(file_to_open, 'rb') as handle:
        trajectories_final = pickle.load(handle)
    file_to_open = data_folder + '_lesion_size_.p'   
    with open(file_to_open, 'rb') as handle:
        lesion_size_final = pickle.load(handle)    

    fit_results_folder = baseFolder+'/PickleResults/Fit/'
    # Before Lesion: effort parameter fitted trajectory
    with open(fit_results_folder + 'a_fit_bounded_beforeXX.pickle', 'rb') as handle:
        results_a_bounded_before = pickle.load(handle) 
    with open(fit_results_folder + 'xk_fit_bounded_beforeXX.pickle', 'rb') as handle:
        results_xk_bounded_before = pickle.load(handle)     
    # After Lesion: effort parameter fitted trajectory
    with open(fit_results_folder + 'a_fit_bounded_finalXX.pickle', 'rb') as handle:
        results_a_bounded_final = pickle.load(handle) 
    with open(fit_results_folder + 'xk_fit_bounded_finalXX.pickle', 'rb') as handle:
        results_xk_bounded_final = pickle.load(handle)  
    with open(fit_results_folder + 'delta_a_boundedXX.pickle', 'rb') as handle:
        delta_a_median_bounded_dict = pickle.load(handle)         



    a = 10
    b = 1
    kxpenalty = 1 
    vtapis = 0.1
    Ltread = 0.9
    xb = 0.1
    kxpenalty = 1
    ampxpenalty = 10
    aux_a = [.1,1,2,5,10,100] 
    NTRAJECTORIES = len(aux_a) # I get rid of the alpha=100

    inst_cost_speed_heav=compute_cost_trajectories(trajectory=xk_vect_speed_heav[:,0:NTRAJECTORIES-1], 
                            speed=xdotk_vect_speed_heav[:,0:NTRAJECTORIES-1], 
                            force=acck_vect_speed_heav[:,0:NTRAJECTORIES-1], 
                            a=a, b=b, cost_type='speed_heaviside', 
                            vtapis=vtapis, Ltread=Ltread, xb=xb,
                            kxpenalty=kxpenalty, ampxpenalty=ampxpenalty)

    cumulative_cost_speed_heav = np.cumsum(inst_cost_speed_heav, axis=0)



    example_rat = 'Rat302'
    session_before = -1
    session_final = 1
    shift_time = 1
    dt = 0.04

    a_opt_er_before = results_a_bounded_before[example_rat][session_before]
    xk_opt_er_before = results_xk_bounded_before[example_rat][session_before]
    a_opt_er_final = results_a_bounded_final[example_rat][session_final]
    xk_opt_er_final = results_xk_bounded_final[example_rat][session_final]

    trajectory_er_before = trajectories_before[example_rat][session_before][int(shift_time/dt):]*0.01
    trajectory_er_final = trajectories_final[example_rat][session_final][int(shift_time/dt):]*0.01



    data_folder = baseFolder+'/Data_for_fit/'
    file_to_open = data_folder + 'AnimalList.txt'

    selected_rats = []   
    crimefile = open(file_to_open, 'r')
    reader = csv.reader(crimefile)
    rats_list = [row for row in reader]
    for i_rat, aux in enumerate(rats_list):
        selected_rats.append(rats_list[i_rat][0][:])

    selected_rats_lesion_size = []
    for rat in lesion_size_final.keys():
        if not(np.isnan(lesion_size_final[rat][0])):
            selected_rats_lesion_size.append(rat)

    selected_rats_lesion_size = list(set(selected_rats_lesion_size).intersection(selected_rats))



    file_to_open = data_folder + 'LesionType.txt'
    with open(file_to_open, 'r') as document:
        lestion_type_dict = {}
        for line in document:
            line = line.split()
            if not line:  # empty line?
                continue
            lestion_type_dict[line[0]] = line[1]    


    axes1 = []
    linewidth = 1.5
    ##################################################################################################
    # PANEL A   
    gs = fig.add_gridspec(nrows=1, ncols=1, left=0, bottom=0.8, right=.149, top=1)
    ax=fig.add_subplot(gs[0])
    axes1.append(ax)

    cmap = matplotlib.cm.get_cmap('jet', 2*NTRAJECTORIES)
    for it in range(NTRAJECTORIES):
        ax.plot(tk,xk_vect_speed_heav[:,it],linewidth=linewidth,alpha=1.0,c=cmap(2*it))
    ax.text(0.3,.78,'Spatial cost $=$ Diffuse \nEffort' + r'$\approx$' + 'Kinetic energy',fontsize=4)
    #################################################
    # COLORMAP
    gsc = fig.add_gridspec(nrows=1,ncols=1,left=.172,bottom=0.8,top=1,right=.244)

    axc = fig.add_subplot(gsc[0])
    axc.xaxis.set_visible(False)
    axc.spines['left'].set_visible(False)
    axc.spines['bottom'].set_visible(False)
    axc.spines['top'].set_visible(False)
    axc.spines['right'].set_visible(False)
    axc.tick_params(color=(0, 0, 0, 0),labelcolor=(0, 0, 0, 0),zorder=-10)

    c = np.arange(1, 50*NTRAJECTORIES + 1)
    cmap_ = matplotlib.cm.get_cmap('jet', 50*NTRAJECTORIES)
    dummie_cax = axc.scatter(c, c, c=c, cmap=cmap_)
    # Clear axis
    axc.cla()

    cbaxes = inset_axes(axc, width="20%", height="60%",loc=3) 
    shift = 50
    cb=fig.colorbar(dummie_cax,cax=cbaxes,ticks=[1+shift,50*NTRAJECTORIES],aspect=10)
    cb.outline.set_edgecolor(None)
    cb.set_label('Effort\nSensitivity', labelpad=-13,y=1.5, rotation=0,fontsize='xx-small')

    cb.ax.set_yticklabels(['Low', 'High'],rotation=45,fontsize=4)
    cb.ax.yaxis.set_tick_params(size=0)
    #################################################
    for ind_ax,ax in enumerate(axes1):

        ax.plot(tk,xk_vect_speed_heav[0,0]+[0.1*t for t in tk],'--',color='Black',alpha=.8,zorder=-1,linewidth=1)
        ax.plot(tk,[0.1 for t in tk],':',color='Gray',zorder=-1,linewidth=1)

        if ind_ax==0:
            ax.set_ylabel('Position (cm)')    
            ax.set_yticklabels(['0'] + ['' for i in range(8)] + ['90'])
            ax.set_xlabel('Trial time (s)',labelpad=-2)
        else:
            ax.set_yticks([0.1*i for i in range(10)])
            ax.set_yticklabels(['' for i in range(10)])

        ax.set_yticks([0.1*i for i in range(10)])
        ax.set_xticks([i for i in range(8)])
        ax.set_xticklabels(['0'] + ['' for i in range(6)] + ['7'])


        ax.set_xlim(0,7.5)
        ax.set_ylim(0,0.9)
        ax.spines['bottom'].set_bounds(0,7)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

    ##################################################################################################    
    # PANEL B
    gs = fig.add_gridspec(nrows=1, ncols=3, left=.325, bottom=0.8, right=0.905, top=1,wspace=0.45,hspace=0.3)
    cmap = matplotlib.cm.get_cmap('gray', 3*(NTRAJECTORIES-1))
    #################################################
    # Left
    ax=fig.add_subplot(gs[0])
    axes1.append(ax)

    for it in range(NTRAJECTORIES-1):
        ax.plot(tk,xk_vect_speed_heav[:,it],
                linewidth=linewidth,alpha=1.0,c=cmap(int(2.5*it)))

    ax.plot(tk,xk_vect_speed_heav[0,0]+[0.1*t for t in tk],'--',color='Black',alpha=.8,zorder=-1,linewidth=1)
    ax.plot(tk,[0.1 for t in tk],':',color='Gray',zorder=-1,linewidth=1)

    ax.set_yticks([0.1*i for i in range(10)])
    ax.set_yticklabels(['0'] + ['' for i in range(8)] + ['90'])
    ax.set_ylim(0,0.9)
    ax.set_ylabel('Position (cm)')
    ax.set_xlabel('Trial time (s)',labelpad=-2)    
    #################################################
    # Center
    ax=fig.add_subplot(gs[1])
    axes1.append(ax)

    for it in range(NTRAJECTORIES-1):
        ax.plot(tk,-xdotk_vect_speed_heav[:,it],
                linewidth=linewidth,alpha=1.0,c=cmap(int(2.5*it)))   
    ax.plot(tk,[-0.1 for t in tk],'--',color='Black',alpha=.8,zorder=-1,linewidth=1)
    ax.set_ylabel('Speed (cm/s)')
    ax.set_ylim(-0.12,0.62)
    ax.set_yticks([-0.1+0.1*i for i in range(8)])
    ax.set_yticklabels(['0'] + ['' for i in range(6)] + ['60'])  
    s='Fixed effort sensitivity'
    ax.text(-2.5,0.65,s, ha='left',va='bottom', fontsize='x-small')  
    #################################################
    # Right
    ax=fig.add_subplot(gs[2])
    axes1.append(ax)

    for it in range(NTRAJECTORIES-1):
        ax.plot(tk,cumulative_cost_speed_heav[:,it],
                linewidth=linewidth,alpha=1.0,c=cmap(int(2.5*it)))
    ax.set_ylabel('Total Cost')
    ax.set_ylim(-0.1,150)
    ax.set_yticks([i*25 for i in range(7)])
    ax.set_yticklabels(['0'] + ['' for i in range(5)] + ['15'])

    for ind_ax,ax in enumerate(axes1):
        ax.set_xticks([i for i in range(8)])
        ax.set_xticklabels(['0'] + ['' for i in range(6)] + ['7'])         
        ax.set_xlim(0,7.5)
        ax.spines['bottom'].set_bounds(0,7)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    #################################################
    gsc = fig.add_gridspec(nrows=1,ncols=1,left=0.928,bottom=0.8,top=1,right=1)

    axc = fig.add_subplot(gsc[0])
    axc.xaxis.set_visible(False)
    axc.spines['left'].set_visible(False)
    axc.spines['bottom'].set_visible(False)
    axc.spines['top'].set_visible(False)
    axc.spines['right'].set_visible(False)
    axc.tick_params(color=(0, 0, 0, 0),labelcolor=(0, 0, 0, 0),zorder=-10)


    c = np.arange(1, 50*NTRAJECTORIES + 1)
    cmap_ = matplotlib.cm.get_cmap('gray', 50*NTRAJECTORIES).reversed()
    dummie_cax = axc.scatter(c, c, c=c, cmap=cmap_)
    # Clear axis
    axc.cla()

    cbaxes = inset_axes(axc, width="20%", height="60%",loc=3) 
    shift = 50
    cb=fig.colorbar(dummie_cax,cax=cbaxes,ticks=[1+shift,50*NTRAJECTORIES],aspect=10)
    cb.outline.set_edgecolor(None)
    cb.set_label('Total\nCost', labelpad=-13,y=1.5, rotation=0,fontsize='xx-small')

    cb.ax.set_yticklabels(['Low', 'High'],rotation=45,fontsize=4)
    cb.ax.yaxis.set_tick_params(size=0)
    ####################################################################################################
    # PANEL C
    gsC = fig.add_gridspec(nrows=2, ncols=2, left=.11, bottom=.42, right=.51, top=.67,wspace=0.4,hspace=0.1)
    #################################################
    # Left
    ax=fig.add_subplot(gsC[2])
    axes1.append(ax)
    ax.plot(np.arange(0,dt*len(trajectory_er_before),dt), trajectory_er_before,
            color='midnightblue', lw=1)

    ax.plot(np.arange(0,dt*len(xk_opt_er_before),dt), xk_opt_er_before, '--', c='maroon',lw=1)

    s = r'$\alpha=$'+'$%+.2f$'%(a_opt_er_before)   
    #s = 'Effort Sens.='+'$%+.2f$'%(a_opt)   
    ax.text(5,0,s, ha='left',va='bottom', fontsize='xx-small')
#     ax.set_title('Before (#-1)',fontsize='xx-small')
    ax.set_yticks([0,.45,.9])
    ax.set_yticklabels([0,'',90])
    ax.set_ylim(0,0.9)
    ax.set_ylabel('Position (cm)')
    ax.set_xticks([i for i in range(11)])
    ax.set_xticklabels(['0'] + ['' for i in range(6)] + ['7'] + ['' for i in range(3)])
    ax.set_xlabel('Trial time (s)',labelpad=-1) 
    ax.set_xlim(0,10.5)
    ax.spines['bottom'].set_bounds(0,10)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    #################################################
    # Right
    ax=fig.add_subplot(gsC[3])
    axes1.append(ax)
    ax.plot(np.arange(0,dt*len(trajectory_er_final),dt),trajectory_er_final, 
            color='midnightblue', lw=1)
    ax.plot(np.arange(0,dt*len(xk_opt_er_final),dt),xk_opt_er_final, '--', c='maroon', lw=1)

    s = r'$\alpha=$'+'$%+.2f$'%(a_opt_er_final)   
    #s = 'Effort Sens.='+'$%+.2f$'%(a_opt)   
    ax.text(5,0,s, ha='left',va='bottom', fontsize='xx-small')
    ax.set_ylim(0,0.9)
    ax.set_xlim(0,10.5)
    ax.spines['bottom'].set_bounds(0,10)
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(bottom=False, top=False, left=False, right=False,
                    labelbottom=False, labeltop=False, labelleft=False, labelright=False)
    #Adding the legend text
    ax.text(-1,-.12,'Fit',
         va='top',ha='left',fontsize=5, color='maroon')
    ax.text(-1,0,'Avg. Traj.',
         va='top',ha='left',fontsize=5, color='midnightblue')

    ##################################################################################################
    # PANEL D
    gs = fig.add_gridspec(nrows=1, ncols=1, left=.64, bottom=.42, right=.89, top=.67)
    ax=fig.add_subplot(gs[0])
    axes1.append(ax)

    size = []
    effort = []
    example_rat = 'Rat302'
    for rat in selected_rats_lesion_size:
        if rat!=example_rat:
            ax.scatter(lesion_size_final[rat], delta_a_median_bounded_dict[rat], s=5,
                    c=ColorCode[lestion_type_dict[rat]])#, c=color[tag]);
        size.append(lesion_size_final[rat][0])
        effort.append(delta_a_median_bounded_dict[rat])

    r2,p2=stats.pearsonr(size, effort)
    n_rats = len(size)
    s=f'$r=$'+'$%+.2f$'%(r2)+'\n'+'$p=$'+'{}'.format(SciNote(p2) + '\n' + '$n=$' + str(n_rats)+' rats')
    #s=f'$r=$'+'$%+.2f$'%(r2)+'\n'+'$p=$'+'{}'.format(SciNote(p2) + '\n' + '$n=$41')
    ax.text(1,-5,s, ha='right',va='bottom', fontsize='xx-small')  
    
    y_exr = delta_a_median_bounded_dict[example_rat]
    x_exr = lesion_size_final[example_rat][0]
    ax.scatter(x_exr,y_exr,
            c=DSColor,s=5,marker='v',zorder=10)
    ax.annotate(s='',xy=(x_exr,y_exr),xytext=(x_exr-.2,y_exr),
                arrowprops=dict(facecolor='k',arrowstyle='->',shrinkB=2))    
    ax.set_xticks(np.arange(0,1.01,.1))
    ax.set_xticklabels(['0','','','','','0.5','','','','','1'])
    ax.spines['bottom'].set_bounds(0,1)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim([-.02,1.02])
    ax.set_ylim([-5,12])
    ax.spines['left'].set_bounds(-5,10)
    ax.set_yticks(np.arange(-5,11,5))
    ax.set_xlabel('Lesion size')
    ax.set_ylabel(r'$\Delta \alpha$')
    
        
##=====================================================================
##===================STEF's CODE ABOVE===================================
##=====================================================================
      
    
    
    ##########################################
    # 2: Max Pos EXAMPLE
    ax2L = fig.add_subplot(gsC[0])
    ax2R = fig.add_subplot(gsC[1])

    MeanTrajBefore2=plot_trajectories_and_distributions(root, ax2L, sessionBefore2)
    MeanTrajAfter2 =plot_trajectories_and_distributions(root, ax2R, sessionAfter2)
    

#     ax2R.plot(np.linspace(-1,len(MeanTrajBefore2)/25-1,len(MeanTrajBefore2)),MeanTrajBefore2,
#               color='midnightblue', lw=.8, alpha=.7)
    
    ax2L.set_title('Before (#$-1$)',fontsize='x-small')
    ax2R.set_title('After (#$+10$)',fontsize='x-small')
    


    
    
    
    ##########################################
    # 5: Normalized Conditional Max Pos time course
    gs5= fig.add_gridspec(nrows=1, ncols=1, left=.11, bottom=.05, right=.51, top=.3)
    ax5= fig.add_subplot(gs5[0])
    
    behav3, AllAnimals3=_late_lesion_effect(root, Profiles=Profiles3,badAnimals=[],TaskParamToPlot=TaskParamToPlot1,
                                            preSlice=preSlice3, postSlice=finSlice3)

    goodAnimals3=[animal for i,animal in enumerate(AllAnimals3) if behav3[i]<minSpdReduction3]
    badAnimals3=list(set(AllAnimals3) - set(goodAnimals3))

    
    data5,goodAnimals5=plot_normalized_time_course(root, ax5, Profiles3,Animals=goodAnimals3,
                                                   TaskParamToPlot=TaskParamToPlot3,color='k')
    
    ax5.scatter([-1,10],data5[np.argwhere(goodAnimals5==sessionBefore2[:6]),[4,14]][0],
                c=DSColor,s=5,marker='v', zorder=10)

    ax5.text(-5,-21,f'$n={data5.shape[0]}$ rats',
          ha='left',va='bottom',fontsize='xx-small',color='k')
    
    plot_session_def(ax5, sessionSlices= (slice(-2,None),slice(0,2)))
    plot_session_def(ax5, sessionSlices= (slice(13,15),))

    #STATS
    nPre=5
    a=np.nanmean(data5.T[nPre-2:nPre,:],axis=0)
    b=np.nanmean(data5.T[nPre:nPre+2,:],axis=0)
    _,s=bootstrapTest(a-b,10000)
    ax5.text(1.5,5,s,fontsize=5,c='k',ha='center')
    
    a=np.nanmean(data5.T[nPre-2:nPre,:],axis=0)
    b=np.nanmean(data5.T[nPre+13:,:],axis=0)
    _,s=bootstrapTest(a-b,10000)
    ax5.text(14.5,5,s,fontsize=5,c='k',ha='center')
    

    
    ax5.set_ylim([-21,5])
    ax5.set_yticks(range(-20,6,5))
    ax5.set_ylabel('Norm. Max. Pos. (cm)',labelpad=0)
    ax5.spines['left'].set_bounds(-20,5)
    
    
    
    ##########################################
    # 4: % Max Pos correlation
    gs4= fig.add_gridspec(nrows=1, ncols=1, left=0.64, bottom=.05, right=.89, top=.3)
    ax4= fig.add_subplot(gs4[0])
  
    behav4, size4, animals4=late_lesion_correlation_with_size(root, ax=ax4, Profiles=Profiles3, Animals=goodAnimals5,
                                                              color=ColorCode, TaskParamToPlot=TaskParamToPlot3,
                                                              preSlice=preSlice3, postSlice=slice(5,None),
                                                              Excluded=sessionBefore2[:6])
    
    _y_=behav4[animals4.index(sessionBefore2[:6])]
    L6=HistologyExcel('/NAS02',sessionBefore2[:6]).lesion_size()
    
    ax4.scatter(L6,_y_,
            c=DSColor,s=5,marker='v',zorder=10)
    ax4.annotate(s='',xy=(L6,_y_),xytext=(L6-.2,_y_),
                 arrowprops=dict(facecolor='k',arrowstyle='->',shrinkB=2))

    
    r4,p4=stats.pearsonr(size4, behav4)
    s=f'$r=$'+'$%+.2f$'%(r4)+'\n'+'$p=$'+'{}'.format(SciNote(p4))+'\n'+f'$n={len(animals4)}$ rats'
    ax4.text(1,21,s, ha='right',va='top', fontsize='xx-small')

    ax4.set_ylim([-42,20])
    ax4.spines['left'].set_bounds(-40,20)
    ax4.set_yticks([-40,-20,0,20])
    ax4.set_ylabel('$\Delta$Max. Pos. (cm)')
    
    
   
    
    
#     fig.align_ylabels([ax3,ax5])
    ############################################
    #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    AXES=(axes1[0],axes1[1],ax2L,axes1[-1],ax5,ax4)
    OFFX=np.array([.07]*len(AXES))
    OFFY=np.array([.03]*len(AXES))
    OFFX[[2]]=0.05
#     OFFX[[0,1,2,4,6]]=0.05

    add_panel_caption(axes=AXES, offsetX=OFFX, offsetY=OFFY)
    
    fig.savefig(os.path.join(os.path.dirname(os.getcwd()),'LesionPaper','Figures','MaxPosAnalysis.pdf'),
                format='pdf', bbox_inches='tight')
    

    plt.show()
    plt.close('all')
    matplotlib.rcdefaults()