# Part 0:
## import everything
Run the cell below

In [None]:
# -*- coding: utf-8 -*-
import os
import glob
import numpy as np
from platform import system as OS
import pandas as pd
import scipy.stats
import math
import datetime
from copy import deepcopy
import matplotlib.cm as cm
import warnings
import types
warnings.filterwarnings("ignore")
import sys, time
import pickle
import matplotlib.pyplot as plt
from scipy import stats
from matplotlib import mlab
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import matplotlib.backends.backend_pdf
from sklearn.decomposition import KernelPCA
import mpl_toolkits.axes_grid1.inset_locator as inset
from matplotlib.ticker import FormatStrFormatter
import imageio



if "__file__" not in dir():
    %matplotlib inline
    %config InlineBackend.close_figures = False
    if OS()=='Linux':
        root="/data"
    elif OS()=='Windows':
        root="C:\\DATA\\"
    else:
        root="/Users/davidrobbe/Documents/Data/"
            
    ThisNoteBookPath=os.path.dirname(os.path.realpath("__file__"))
    CommonNoteBookesPath=os.path.join(os.path.split(ThisNoteBookPath)[0],"load_preprocess_rat")
    CWD=os.getcwd()
    os.chdir(CommonNoteBookesPath)
    %run UtilityTools.ipynb
    %run Animal_Tags.ipynb
    %run loadRat_documentation.ipynb
    %run plotRat_documentation_1_GeneralBehavior.ipynb
    %run plotRat_documentation_3_KinematicsInvestigation.ipynb
    %run RunBatchRat_3_CompareGroups.ipynb
    %run BatchRatBehavior.ipynb
    %run ../BehavioralPaper/CtrlTrd.ipynb
    os.chdir(CWD)
    # PARAMETERS (used if the pickles don't exist)
    logging.getLogger().setLevel(logging.ERROR)
    
    param={
        "goalTime":7,#needed for pavel data only
        "treadmillRange":[0,90],#pavel error conversion "treadmillRange":[0,80]
        "maxTrialDuration":15,
        "interTrialDuration":10,#None pavel
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "cameraSamplingRate":25, #needed for new setup    

        "sigmaSmoothPosition":0.1,#0.33, 0.18 pavel
        "sigmaSmoothSpeed":0.3,#0.3, 0.5 pavel
        "nbJumpMax":100,#200 pavel
        "binSize":0.25,
        #parameters used to preprocess (will override the default parameters)
    }  

    print('os:',OS(),'\nroot:',root,'\nImport successful!')

---
---


# part 1:

# DEFINITIONS

### If you don't know what to do, move to part 2

**Utility Functions**

In [None]:
def get_ordered_colors(colormap, n):
    colors = []
    cmap = plt.cm.get_cmap(colormap)
    for colorVal in np.linspace(0, 1, n):
        colors.append(cmap(colorVal))
    return colors

**group ET learning curve**

In [None]:
def plot_dotted_learning_curve(ax, root, animalList, profile, TaskParamToPlot, 
                               stop_dayPlot,colors,legend=True, seed=3):
    Results,_=get_rat_group_statistic(root,
                                      animalList,
                                      profile,
                                      parameter=param,
                                      redo=False,
                                      stop_dayPlot=stop_dayPlot,
                                      TaskParamToPlot=[TaskParamToPlot])
    
    goalTime=data_fetch(root, animal=animalList[0], profile=profile,
                        PerfParam= [lambda data:data.goalTime[-1]],
                        NbSession=0).values()
    goalTime=list(goalTime)[-1]
    
    x=np.arange(stop_dayPlot)+1
    data=np.array( list( Results[TaskParamToPlot].values() ) )
    y=np.nanpercentile(data,50,axis=0)
    yerr=np.nanpercentile(data,(25,75),axis=0)
    np.random.seed(seed=seed)
    sigma=.3
    
    ax.errorbar(x,y,yerr=abs(yerr-y), ecolor='k', fmt='k-o',elinewidth=1, markersize=4, markerfacecolor='w',zorder=2)
    
    for pts,day in zip(data.T,x):
        jitter=np.random.uniform(low=day-sigma, high=day+sigma, size=len(pts))
        ax.scatter(jitter,pts,s=2,c=colors, marker='o',zorder=1)

    ax.set_xlim([x[0]-1,x[-1]+1])
    xtick=[1]
    for i in range(1,stop_dayPlot+1):
        if i%5==0:
            xtick.append(i)
    ax.set_xticks(xtick)
    ax.spines['bottom'].set_bounds(x[0],x[-1])
    ax.set_ylim([0,10])
    ax.set_yticks([0,7,10])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
#     ax.grid(b=True, which='major', axis='x')
    ax.hlines(y=goalTime, xmin=x[0], xmax=x[-1], linestyle='--', lw=1, color='m')
    ax.text(x=x[0], y=goalTime[-1], s='Goal time', verticalalignment='center', color='m',backgroundcolor='w',fontsize=8)
#     ax.spines['left'].set_position(('data',-1))
    ax.set_xlabel('Session#')
    ax.set_ylabel(TaskParamToPlot)
    
    if legend:
        x_marker=[]
        for i,animal in enumerate(animalList):
            l=matplotlib.lines.Line2D([], [], color=[0,0,0,0],
                                      marker='o', markerfacecolor=colors[i], markeredgecolor='None',
                                      markersize=2, label=animal[3:])
            x_marker.append(l)

        leg=ax.legend(handles=x_marker, title="Rat#",title_fontsize=6, handletextpad=.6,
                      bbox_to_anchor=(.99, .5),loc=6, ncol=1, fontsize=4)
    
    return ax

def add_markers_to_example_sessions(root, ax, profile1, stop_dayPlot1, markers1, days):
    day123_0, day123_1, day132_0, day132_1=days
    rat123=data_fetch(root, animal='Rat123', profile=profile1,
                        PerfParam= ["median entrance time (sec)"],
                        NbSession=stop_dayPlot1)["median entrance time (sec)"]
    ax.text(x=day123_0, y=rat123[day123_0-1], s=markers1['naive'],
            va='center',ha='center', color='k',fontsize=7)
    ax.text(x=day123_1, y=rat123[day123_1-1], s=markers1['trained'],
            va='center',ha='center', color='k',fontsize=7)

    rat132=data_fetch(root, animal='Rat132', profile=profile1,
                        PerfParam= ["median entrance time (sec)"],
                        NbSession=stop_dayPlot1)["median entrance time (sec)"]
    ax.text(x=day132_0, y=rat132[day132_0-1], s=markers1['stupidNaive'],
          va='center',ha='center', color='k',fontsize=9)
    ax.text(x=day132_1, y=rat132[day132_1-1], s=markers1['stupidTrained'],
            va='center',ha='center', color='k',fontsize=7)


In [None]:
if "__file__" not in dir():

    profile={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control']
             }
    animalList=batch_get_animal_list(root,profile)
    animalList=['Rat103','Rat104','Rat110','Rat113','Rat120','Rat137','Rat138','Rat139','Rat140','Rat149',
                'Rat150','Rat151','Rat152','Rat161','Rat162','Rat163','Rat164','Rat165','Rat166','Rat215',
                'Rat216','Rat217','Rat218','Rat219','Rat220','Rat221','Rat222','Rat223','Rat224','Rat225',
                'Rat226','Rat227','Rat228','Rat229','Rat230','Rat231','Rat232','Rat246','Rat247','Rat248',
                'Rat249','Rat250','Rat251','Rat252','Rat253','Rat254','Rat255','Rat256','Rat257','Rat258',
                'Rat259','Rat260','Rat261','Rat262','Rat263','Rat264','Rat265','Rat297','Rat298','Rat299',
                'Rat300','Rat305','Rat306','Rat307','Rat308']

    TaskParamToPlot="percentile entrance time"
    stop_dayPlot =30
    colors=get_ordered_colors(colormap='plasma', n=len(animalList)+1)[:-1]
    
    plt.close('all')
    fig=plt.figure(figsize=(10,5))
    ax=fig.add_subplot(111);
    
    plot_dotted_learning_curve(ax, root, animalList, profile, TaskParamToPlot, stop_dayPlot,colors)
    ax.set_ylabel('Entrance Times (s)')