# Part 0:
## import everything
Run the cell below

In [None]:
import os
import glob
import numpy as np
from platform import system as OS
import pandas as pd
import scipy.stats
import math
import datetime
from copy import deepcopy
import matplotlib.cm as cm
import warnings
warnings.filterwarnings("ignore")
import sys
import pickle
import matplotlib as mpl
import matplotlib.pyplot as plt
import PIL
from scipy import stats
import matplotlib.animation as animation
import matplotlib.backends.backend_pdf
import mpl_toolkits.axes_grid1.inset_locator as inset
from matplotlib.ticker import FormatStrFormatter
from matplotlib.patches import ConnectionPatch
from set_rc_params import set_rc_params
import ROOT


if "__file__" not in dir():
    %matplotlib inline
    %config InlineBackend.close_figures = False

    root=ROOT.root
    
    ThisNoteBookPath=os.path.dirname(os.path.realpath("__file__"))
    CommonNoteBookesPath=os.path.join(os.path.split(ThisNoteBookPath)[0],"load_preprocess_rat")
    CWD=os.getcwd()
    os.chdir(CommonNoteBookesPath)
    %run UtilityTools.ipynb
    %run Animal_Tags.ipynb
    %run loadRat_documentation.ipynb
    %run plotRat_documentation_1_GeneralBehavior.ipynb
    %run plotRat_documentation_3_KinematicsInvestigation.ipynb
    %run RunBatchRat_3_CompareGroups.ipynb
    %run BatchRatBehavior.ipynb
    currentNbPath=os.path.join(os.path.split(ThisNoteBookPath)[0],'LesionPaper','ExampleRats.ipynb')
    %run $currentNbPath

    os.chdir(CWD)

    logging.getLogger().setLevel(logging.ERROR)
    
    param={
        "goalTime":7,#needed for pavel data only
        "treadmillRange":[0,90],#pavel error conversion "treadmillRange":[0,80]
        "maxTrialDuration":15,
        "interTrialDuration":10,#None pavel
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "cameraSamplingRate":25, #needed for new setup    

        "sigmaSmoothPosition":0.1,#0.33, 0.18 pavel
        "sigmaSmoothSpeed":0.3,#0.3, 0.5 pavel
        "nbJumpMax":100,#200 pavel
        "binSize":0.25,
        #parameters used to preprocess (will override the default parameters)
    }
    Y1,Y2=param['treadmillRange']

    print('os:',OS(),'\nroot:',root,'\nImport successful!')

---
---


# part 1:

# DEFINITIONS

### If you don't know what to do, move to part 2

---

plotting the trajectories of example sessions

In [None]:
def plot_session_median_trajectory(data,ax):
    posDict=data.position
    maxL=np.nanmax(list(data.stopFrame.values()))
    maxL=int(maxL)
    position=np.ones((maxL,len(posDict.keys())))*np.nan
    time=np.arange(-data.cameraToTreadmillDelay,
                   (maxL-data.cameraSamplingRate)/data.cameraSamplingRate,
                   1/data.cameraSamplingRate)
    
    
    for i,trial in enumerate(posDict):
        pos=posDict[trial][:data.stopFrame[trial]]
        position[:len(pos),i]=pos
    
    #keeping data where 70% of points exist
    nanSum=np.sum(np.isnan(position),axis=1)
    try:
        maxTraj=np.where(nanSum>.3*position.shape[1])[0][0]
    except IndexError:
        maxTraj=position.shape[1]
    
    
    ax.plot(time[:maxTraj], np.nanmedian(position,axis=1)[:maxTraj], color='navy', lw=2)    

def plot_trajectories(data,ax):
    posDict=data.position
    time=data.timeTreadmill #align on camera
    Colors=[]
    for trial in posDict:
        color="xkcd:green"
        if trial not in data.goodTrials:
            color="salmon"
        Colors.append(color)
        ax.plot(time[trial][:data.stopFrame[trial]], posDict[trial][:data.stopFrame[trial]],
               color=color, lw=.5, )
            
    ax.fill_betweenx(y=(0,90),x1=0,x2=7, facecolor='gray', edgecolor=None, alpha=.4)
    
    return np.array(Colors)



def plot_trajectories_and_distributions(root, ax, session, showText=True):
    data=Data(root,session[:6],session,redoPreprocess=False)
    
    color=plot_trajectories(data,ax=ax)
    
    position=get_positions_array_beginning(data,onlyGood=False,raw=False)
    position=position.T
    
    plot_session_median_trajectory(data,ax)
    
    props={'color':'k', 'linewidth':1}
    ax.boxplot(x=data.entranceTime,whis=[5,95],vert=False,
               positions=[5], widths=5,
               showcaps=False, showfliers=False,
               medianprops=props, boxprops=props, whiskerprops=props, zorder=5
              )
    
    if showText:
        ax.set_xlim([-1,15.2])
        ax.set_xticks([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15])
        ax.set_xticklabels([0,'','','','','','',7,'','','','','','','',15])
        ax.set_ylim([0,90])
        ax.set_yticks([0,10,20,30,40,50,60,70,80,90])
        ax.set_yticklabels([0,'','','','','','','','',90])
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_bounds(0,15)
        ax.set_xlabel('Trial time (s)',labelpad=0)
        ax.set_ylabel('Position (cm)',labelpad=-10)
    else:
        ax.set_xlim([-1,15.2])
        ax.set_ylim([0,90])
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.tick_params(bottom=False, top=False, left=False, right=False,
                      labelbottom=False, labeltop=False, labelleft=False, labelright=False)
    
#======================================
def plot_pre_post_traj(root, gs, animal, preProfile, postProfile, preSession, postSession, showText=False):
    
    assert len(preSession)+len(postSession) == gs.get_geometry()[1]
    axes=[]
    
    sessionList=batch_get_session_list(root, animalList=[animal], profile=preProfile)['Sessions']
    for i,sessionId in enumerate(preSession):
        session=sessionList[sessionId]
        ax= gs.figure.add_subplot(gs[i])
        plot_trajectories_and_distributions(root, ax, session, showText=i==0 and showText)
        axes.append(ax)
    
    sessionList=batch_get_session_list(root, animalList=[animal], profile=postProfile)['Sessions']
    for j,sessionId in enumerate(postSession):
        session=sessionList[sessionId]
        ax= gs.figure.add_subplot(gs[i+j+1])
        plot_trajectories_and_distributions(root, ax, session,showText=False)
        axes.append(ax)
    
    return axes

In [None]:
if "__file__" not in dir():
    #the inputs
    fig=plt.figure(figsize=(8,1.5))
    gs= fig.add_gridspec(nrows=1, ncols=4, left=0.02, bottom=0.02, right=0.98, top=.98, wspace=.1)
    

    
    profile1pre={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control']
             }
    
    profile1post={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control-AfterBreak']
             }
    animalList1Ctrl='Rat250'
    
    sessionIndex1pre =(0,-1)
    sessionIndex1post=(0,-1)
    
    plot_pre_post_traj(root, gs, animalList1Ctrl, profile1pre, profile1post, sessionIndex1pre, sessionIndex1post)
    gs.figure.add_artist(ConnectionPatch(xyA=(.5,0), xyB=(.5,1), coordsA='figure fraction', coordsB='figure fraction',
                                     ls='--',lw=2))

    
    plt.show()
    plt.close('all')

------



------

# part 2:

# GENERATING THE FIGURE

Definition of Parameters

In [None]:
if "__file__" not in dir():
    # GRID 1 PARAMS
    
    profile1pre={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control']
             }
    
    profile1post={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control-AfterBreak']
             }
    
    animalList1Ctrl='Rat250'
    
    sessionIndex1pre =(0,-1)
    sessionIndex1post=(0,-1)
    

    
    #===============================================
    
    # GRID 2 PARAMS
    
    profile2pre={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control-AfterBreak']
             }
    
    profile2post={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Late-Lesion_DLS']
             }
    
    animalList2Ctrl='Rat250'
    
    
    
    #================================================
    
    # GRID 3 PARAMS
    
    profile3pre={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control-AfterBreak']
             }
    
    profile3post={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Late-Lesion_DMS']
             }
    
    animalList3Ctrl='Rat217'
    

    
    #================================================
    
    # GRID 4 PARAMS
    
    profile4pre={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control-Sharp']
             }
    
    profile4post={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Late-Lesion_DS-Sharp']
             }
    
    animalList4Ctrl='Rat304'


Plotting the figure

In [None]:
ax.

In [None]:
if "__file__" not in dir():
    plt.close('all')
    set_rc_params()
    figsize=(4,5)
    fig=plt.figure(figsize=figsize,dpi=600)
    
    
    ##########################################
    # 1: Control
    gs1= fig.add_gridspec(nrows=1, ncols=4, left=0.02, bottom=0.6, right=0.99, top=.8)
    
    axes1=plot_pre_post_traj(root, gs1, animalList1Ctrl, profile1pre, profile1post,
                             sessionIndex1pre, sessionIndex1post,showText=True)
    gs1.figure.add_artist(ConnectionPatch(xyA=(.5,0), xyB=(.5,.8), coordsA='figure fraction', coordsB='figure fraction',
                                     ls='--',lw=2))

    
    
    ##########################################
    # 2: DLS
    gs2= fig.add_gridspec(nrows=1, ncols=4, left=0.02, bottom=0.4, right=0.99, top=.6)
    
    axes2=plot_pre_post_traj(root, gs2, animalList2Ctrl, profile2pre, profile2post, sessionIndex1pre, sessionIndex1post)
#     axes2[0].clear()


    ##########################################
    # 3: DMS
    gs3= fig.add_gridspec(nrows=1, ncols=4, left=0.02, bottom=0.2, right=0.99, top=.4)
    
    axes3=plot_pre_post_traj(root, gs3, animalList3Ctrl, profile3pre, profile3post, sessionIndex1pre, sessionIndex1post)
#     axes3[0].clear()


    ##########################################
    # 4: DS
    gs4= fig.add_gridspec(nrows=1, ncols=4, left=0.02, bottom=0.0, right=0.99, top=.2)
    
    axes4=plot_pre_post_traj(root, gs4, animalList4Ctrl, profile4pre, profile4post, sessionIndex1pre, sessionIndex1post)
#     axes4[0].clear()


    
    #############################################
    #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
#     AXES=(axes4[0],ax1,ax2,ax5,ax7,ax8,ax6)
#     OFFX=np.array([.07]*len(AXES))
#     OFFY=np.array([.01]*len(AXES))
#     OFFX[5]=0.03
#     OFFX[[0,1,2,4,6]]=0.05
    
#     add_panel_caption(axes=AXES, offsetX=OFFX, offsetY=OFFY)
    
    fig.savefig(os.path.join(os.path.dirname(os.getcwd()),'LesionPaper','Figures','ExampleRats.pdf'),
                format='pdf', bbox_inches='tight')
    
    plt.show()
    plt.close('all')
    matplotlib.rcdefaults()