# <span style="color:blue">This notebook runs multiple animals/sessions of RAT Behavioral experiments and compares groups of animals </span> 


### This notebook is at the core of the pipeline of data processing. Do not play with it lightly inside the master folder (load_preprocess_rat)

#### 1. Only modifiy if you are sure of what you are doing and that you are solving a bug
#### 2. If you do modify you MUST commit this modification using bitbucket
#### 3. If you want to play whis notebook (to understand it better) copy it on a toy folder distinct from the master folder
#### 4. If you want to modify this code (fix bug, improve, add attributes ...) it is recommanded  to first duplicate in a draft folder. Try to keep track of your change.
#### 5. When you are ready to commit : # clear all output, clean everything between hashtag 



## 0. Load packages and define functions

In [None]:
#modules to find path of all sessions
import glob
import os, logging
import numpy as np
from IPython.display import clear_output, display, HTML
import matplotlib.cm as cm
import warnings
from platform import system as OS
import time
import pickle
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")
#run other notebooks
if "__file__" not in dir():
    
    ThisNoteBookPath=os.path.dirname(os.path.realpath("__file__"))
    CommunNoteBookesPath=os.path.join(os.path.split(ThisNoteBookPath)[0],"load_preprocess_rat")
    CWD=os.getcwd()
    os.chdir(CommunNoteBookesPath)
    %run Animal_Tags.ipynb
    %run UtilityTools.ipynb
    %run BatchRatBehavior.ipynb
    %run plotRat_documentation_3_KinematicsInvestigation.ipynb
    %run plotRat_documentation_1_GeneralBehavior.ipynb
    %run loadRat_documentation.ipynb
    os.chdir(CWD)
    
    if platform.system()=='Linux':
        root="/data"
    elif platform.system()=='Windows':
        root="C:\\DATA\\"
    else:
        root="/Users/davidrobbe/Documents/Data/"

    # PARAMETERS (will be used if 1. no pickles and 2. no param files (old data))
    param={
        "goalTime":7,#needed for pavel data only
        "treadmillRange":[0,90],#pavel error conversion "treadmillRange":[0,80]
        "maxTrialDuration":20,
        "interTrialDuration":10,#None pavel
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "cameraSamplingRate":25, #needed for new setup    

        "sigmaSmoothPosition":0.1,#0.33, 0.18 pavel
        "sigmaSmoothSpeed":0.3,#0.3, 0.5 pavel
         "nbJumpMax":100,#200 pavel
        "binSize":0.25
    }
    print('os:',OS(),'\nroot:',root,'\nImport successful!')

## 1. Compare two profiles

In [None]:
def get_rat_group_statistic(root,animalList,profile,parameter={},redo=False,stop_dayPlot=0,TaskParamToPlot=[]):

    profile=   profile.copy()  #I'm not sure but still!
    allResults={key:{animal:list() for animal in animalList} for key in TaskParamToPlot}
    nSessionMax=0
    # Load all statistics
    for animal in animalList:
        badID=[]
        
        pathPickle=os.path.join(root,animal,"Analysis","learningStats.p")
        if os.path.exists(pathPickle) and (not redo):
            try:
                with open(pathPickle,"rb") as f:
                    results=pickle.load(f)
                
                sessionProfile=results.pop('sessionProfile',{})
                if sessionProfile=={}: 
                    #this means that the pickle file is the old kind!
                    raise NameError(animal+' :old pickle, computing again...')
                if len (set(TaskParamToPlot) - results.keys()) !=0:
                    #this means not all the TaskParamToPlot keys are available in the pickle
                    raise NameError(animal+' :pickle not complete, computing again...')
            except Exception as e:
                print(repr(e))
                results={}
            
            if results=={}:
                results=plot_learningCurves(root,animal,PerfParamToPlot=TaskParamToPlot,profile=profile,parameter=parameter,redoStat=redo,plot=False)
                plt.close()
                sessionProfile=results.pop('sessionProfile',{})

            for key in profile:
                if isinstance(profile[key],str):
                    profile[key]=profile[key].split()
                for ID,_ in enumerate(sessionProfile['Speed']):
                    try:
                        if str(sessionProfile[key][ID]) not in profile[key]:
                            badID.append(ID)
                    except:
                        pass

        else:
            results=plot_learningCurves(root,animal,PerfParamToPlot=TaskParamToPlot,profile=profile,parameter=parameter,redoStat=redo,plot=False)
            plt.close()
            sessionProfile=results.pop('sessionProfile',{})
        if isinstance(results,bool):
            print('False')
            return {},None
        
        results.pop('days',{})
        
        for key in TaskParamToPlot:
            if key=="percentile entrance time":
                results[key]=[array[2] for array in results[key]]
            
            wantedResults=[i for ID,i in enumerate(results[key]) if ID not in badID]
                          
            if stop_dayPlot!=0:
                del wantedResults[stop_dayPlot:]
            
            allResults[key][animal]=wantedResults
            nSessionMax=max(nSessionMax,len(wantedResults))

    #pad with nan when fewer session than maximum
    for key in allResults:
        for animal in allResults[key]:
            nSession=len(allResults[key][animal])
            allResults[key][animal].extend([np.nan]*(nSessionMax-nSession))
    
    return allResults,nSessionMax

In [None]:
if "__file__" not in dir():
    profile1={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':'Control'
             }
    profile2={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'0',
             'Speed':'0',
             'Tag':'ImmobileTreadmill'
             }
    animalList1=batch_get_animal_list(root,profile1)
    animalList2=batch_get_animal_list(root,profile2)

    groups={
        "group1":(cm.Greys,"black",animalList1,profile1),
        "group2":(cm.Reds ,"red"  ,animalList2,profile2)    #http://matplotlib.org/examples/color/colormaps_reference.html
        }

    TaskParamToPlot=["Mean Pairwise RMSE"]
    
    res,N= get_rat_group_statistic(root,animalList1,profile1,parameter=param,
                                                           redo=False,stop_dayPlot=0,TaskParamToPlot=TaskParamToPlot)

In [None]:
def plot_mean_subgroup_animal(root,groups,axes=None,parameter={},redo=False,stop_dayPlot=0,TaskParamToPlot=[],fullLegend=False):            
    #divide input dictionary
    colorMapGroup={key: groups[key][0] for key in groups}
    colorGroup={key: groups[key][1] for key in groups}
    animalGroup={key: groups[key][2] for key in groups}
    experimentalGroup={key: groups[key][3] for key in groups}
    
    allResults={key:{} for key in TaskParamToPlot}
    nSessionMax={}

    
    # get the data for all the groups
    for group in animalGroup:
        results,nSessionMax[group]=get_rat_group_statistic(root,animalList=animalGroup[group],profile=experimentalGroup[group],parameter=parameter,
                                                           redo=redo,stop_dayPlot=stop_dayPlot,TaskParamToPlot=TaskParamToPlot)
        for key in TaskParamToPlot:
            allResults[key][group]=results[key]

    #Compute the number of animals per day
    nbAnimal={}
    allResults['Number of animals']={}
    for group in animalGroup:
        allResults['Number of animals'][group]={}
        nbAnimal[group]=[]
        for n in range(nSessionMax[group]):
            nbAnimal[group].append(False)
            for animal in allResults[TaskParamToPlot[0]][group]:
                nbAnimal[group][-1]+=not np.isnan(allResults[TaskParamToPlot[0]][group][animal][n])
        allResults['Number of animals'][group]['RatXXX']=nbAnimal[group]

        
    #PLOTTING
    def plot_one_key(xaxis,keyResult,animalList,experimentalGroup,colorList,color="blue",axes=None,fullLegend=False):
        Label=None
        if axes is None:
            axes=plt.gca()
        #compute mean
        keyRes=np.asarray(list(keyResult.values()))
        meanKeyRes=np.nanmedian(keyRes,axis=0)
        stdKeyRes=np.nanstd(keyRes,axis=0)
        #plot mean
        axes.plot(xaxis,meanKeyRes,color=color,lw=3,marker='.',markersize=20,label=experimentalGroup)
        if not fullLegend:
            axes.fill_between(xaxis,(meanKeyRes - stdKeyRes),(meanKeyRes + stdKeyRes),color=color,alpha=0.25)
        for color,animal in zip(colorList,keyResult):
            if fullLegend:
                Label=animal
            axes.plot(xaxis,keyResult[animal],ls=':',color=color, label=Label)

    
    if fullLegend:
        ncols=3
    else: ncols=1    
    
    nbCol=1
    nbLine=len(allResults)//nbCol+len(allResults)%nbCol
    nbLine+=1    
    
    TaskParams=[]
    if axes is None:
        axes=[]
        fig=plt.figure(figsize=(7, nbLine*4))
        for index,_ in enumerate(TaskParamToPlot+['Number of animals']):
            axes.append(fig.add_subplot(nbLine,nbCol,index+1))
        TaskParams=TaskParamToPlot+['Number of animals']
    elif len(axes)==len(TaskParamToPlot):
        TaskParams=TaskParamToPlot
    elif len(axes)==len(TaskParamToPlot)+1:
        TaskParams=TaskParamToPlot+['Number of animals']
    
    assert len(axes)==len(TaskParams), "Bad number of axes/paramsToPlot"

    
    colors={}
    for group in animalGroup:
        colors[group]  =colorMapGroup[group](np.linspace(0, 1, len(animalGroup[group])+3))
        colors[group]  =colors[group][3:]
        
    for index,key in enumerate(TaskParams):
        if key=="treadmillSpeed":continue
        ax=axes[index]
        for group in allResults[key]:
            xaxis=np.arange(1,nSessionMax[group]+1,1) 
            plot_one_key(xaxis,allResults[key][group],animalGroup[group],experimentalGroup[group]['Tag'][0],
                         colors[group],color=colorGroup[group],axes=ax,fullLegend=fullLegend)
        
        ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.,ncol=ncols)
        ax.set_xlabel("Session",fontsize=14)
        ax.set_ylabel(key,fontsize=14)
        ax.set_xlim([0,max(list(nSessionMax.values()))])
        ax.tick_params(axis='both', which='major', labelsize=12)
        if key=="median entrance time (sec)":
            ax.set_ylim([0,15])
            ax.axhline(7,color="b",ls="--")
        elif key=="mean entrance time (sec)":
            ax.set_ylim([0,10])
            ax.axhline(7,color="b",ls="--")
        elif key=="% good trials":
            ax.set_ylim([0,100])
        elif key=="standard deviation of entrance time":
            ax.set_ylim([0,11])
        elif key=="Tortuosity":pass
            #ax.set_ylim([2,10])
        elif key=="percentile entrance time":
            ax.set_ylim([0,15])
        elif key=="% good trials on last 40":
            ax.axhline(72.5,color="b",ls="--")
            ax.set_ylim([0,100])
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')

    plt.tight_layout()            
    plt.subplots_adjust(top=0.90)            
            
    return allResults,colors,colorGroup

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control']
             }
    profile2={'Type':'Good',
             'rewardType':'Progressive',
#              'initialSpeed':'10',
#              'Speed':'10',
             'Tag':['ImmobileTreadmill']
             }
    animalList1=batch_get_animal_list(root,profile1)
    animalList2=batch_get_animal_list(root,profile2)
    #different conrtol groups
    

    groups={
        "group1":(cm.Greys,"black",animalList1,profile1),
        "group2":(cm.Reds ,"red"  ,animalList2,profile2),
        }

    print("animal lists:\n",animalList1,'\n',animalList2)

In [None]:
if "__file__" not in dir():

    TaskParamToPlot=["% good trials on last 40","percentile entrance time",
                "Forward Speed Vs TreadmillSpeed","Tortuosity","standard deviation of entrance time","Trajectory Correlation"]
    TaskParamToPlot=["percentile entrance time","% good trials on last 40"]

    stop_dayPlot =15
    fullLegend=False

    allResults,colors,colorGroup=plot_mean_subgroup_animal(root,groups,parameter=param,redo=False,stop_dayPlot=stop_dayPlot,TaskParamToPlot=TaskParamToPlot,fullLegend=fullLegend)

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
             'option':['not used'],
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control']
             }
    profile2={'Type':'Good',
             'option':['not used'],
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Early-Lesion_DMS']
             }
    animalList1=batch_get_animal_list(root,profile1)
    animalList2=batch_get_animal_list(root,profile2)
#     animalList1=[]
    ## ok DLS Histo
#     animalList2= [ 'Rat097', 'Rat099', 'Rat100', 'Rat114', 'Rat116']
    groups={
        "group1":(cm.Greys,"black",animalList1,profile1),
        "group2":(cm.Reds ,"red"  ,animalList2,profile2),
        }

    print("animal lists:\n",animalList1,'\n',animalList2)

In [None]:
if "__file__" not in dir():

    TaskParamToPlot=["% good trials on last 40","percentile entrance time",
                "Forward Speed Vs TreadmillSpeed","Tortuosity","standard deviation of entrance time","Trajectory Correlation"]
    TaskParamToPlot=["Successive Entrance Score"]
    

    stop_dayPlot =30
    fullLegend=False

    allResults,colors,colorGroup=plot_mean_subgroup_animal(root,groups,parameter=param,redo=False,stop_dayPlot=stop_dayPlot,TaskParamToPlot=TaskParamToPlot,fullLegend=fullLegend)

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control']
             }
    profile2={'Type':'Good',
             'rewardType':'Progressive',
#              'initialSpeed':'10',
#              'Speed':'10',
             'Tag':['ImmobileTreadmill']
             }
    animalList1=batch_get_animal_list(root,profile1)
    animalList2=batch_get_animal_list(root,profile2)
    groups={
        "group1":(cm.Greys,"black",animalList1,profile1),
        "group2":(cm.Reds ,"red"  ,animalList2,profile2),
        }

    print("animal lists:\n",animalList1,'\n',animalList2)

In [None]:
if "__file__" not in dir():

    TaskParamToPlot=["% good trials on last 40","% good trials","percentile entrance time",
                "Forward Speed Vs TreadmillSpeed","Tortuosity","standard deviation of entrance time","Trajectory Correlation"]
#     TaskParamToPlot=["Entropy","Mean Pairwise RMSE"]

    stop_dayPlot =30
    fullLegend=False

    allResults,colors,colorGroup=plot_mean_subgroup_animal(root,groups,parameter=param,redo=False,stop_dayPlot=stop_dayPlot,TaskParamToPlot=TaskParamToPlot,fullLegend=fullLegend)

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control']
             }
    profile2={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Early-Lesion_DMS']
             }
    animalList1=batch_get_animal_list(root,profile1)
    animalList2=batch_get_animal_list(root,profile2)
    #animalList1=[ 'Rat085', 'Rat095']
    ## ok DMS Histo
    animalList2= ['Rat118', 'Rat119', 'Rat133', 'Rat134']
    groups={
        "group1":(cm.Greys,"black",animalList1,profile1),
        "group2":(cm.Reds ,"red"  ,animalList2,profile2),
        }

    print("animal lists:\n",animalList1,'\n',animalList2)

In [None]:
if "__file__" not in dir():

    TaskParamToPlot=["% good trials on last 40","percentile entrance time",
                "Forward Speed Vs TreadmillSpeed","Tortuosity","standard deviation of entrance time","Trajectory Correlation"]
    TaskParamToPlot=["Entropy","Mean Pairwise RMSE"]

    stop_dayPlot =30
    fullLegend=False

    allResults,colors,colorGroup=plot_mean_subgroup_animal(root,groups,parameter=param,redo=False,stop_dayPlot=stop_dayPlot,TaskParamToPlot=TaskParamToPlot,fullLegend=fullLegend)

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control']
             }
    profile2={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Early-Lesion_DLS','Early-Lesion_DMS']
             }
    animalList1=batch_get_animal_list(root,profile1)
    animalList2=batch_get_animal_list(root,profile2)
    #animalList1=[ 'Rat085', 'Rat095']
    ## ok full DS Histo
    animalList2= [ 'Rat117', 'Rat082', 'Rat115']
    groups={
        "group1":(cm.Greys,"black",animalList1,profile1),
        "group2":(cm.Reds ,"red"  ,animalList2,profile2),
        }

    print("animal lists:\n",animalList1,'\n',animalList2)

In [None]:
if "__file__" not in dir():

    TaskParamToPlot=["% good trials on last 40","percentile entrance time",
                "Forward Speed Vs TreadmillSpeed","Tortuosity","standard deviation of entrance time","Trajectory Correlation"]
    TaskParamToPlot=["Entropy","Mean Pairwise RMSE"]

    stop_dayPlot =30
    fullLegend=False

    allResults,colors,colorGroup=plot_mean_subgroup_animal(root,groups,parameter=param,redo=False,stop_dayPlot=stop_dayPlot,TaskParamToPlot=TaskParamToPlot,fullLegend=fullLegend)

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'30',
             'Tag':['Control']
             }
    profile2={'Type':'Good',
             'rewardType':'Progressive',
             #'initialSpeed':'10',
             'Speed':'30',
             'Tag':['Early-Lesion_DLS','Early-Lesion_DLS-Early-var-FastSpeed']
             }
    animalList1=batch_get_animal_list(root,profile1)
    animalList2=batch_get_animal_list(root,profile2)
    #animalList1=[ 'Rat085', 'Rat095']
    ## ok late DLS Histo
    #animalList2= ['Rat085', 'Rat106','Rat113', 'Rat141']
    animalList2.remove('Rat148')
    groups={
        "group1":(cm.Greys,"black",animalList1,profile1),
        "group2":(cm.Reds ,"red"  ,animalList2,profile2),
        }

    print("animal lists:\n",animalList1,'\n',animalList2)

In [None]:
if "__file__" not in dir():

    TaskParamToPlot=["% good trials on last 40","percentile entrance time",
                "Forward Speed Vs TreadmillSpeed","Tortuosity","standard deviation of entrance time","Trajectory Correlation"]
    TaskParamToPlot=["Mean Pairwise RMSE","Trajectory Correlation"]

    stop_dayPlot =7
    fullLegend=False

    allResults,colors,colorGroup=plot_mean_subgroup_animal(root,groups,parameter=param,redo=False,stop_dayPlot=stop_dayPlot,TaskParamToPlot=TaskParamToPlot,fullLegend=fullLegend)

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control-AfterBreak']
             }
    profile2={'Type':'Good',
             'rewardType':'Progressive',
             #'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Late-Lesion_DMS']
             }
    animalList1=batch_get_animal_list(root,profile1)
    animalList2=batch_get_animal_list(root,profile2)
    #animalList1=[ 'Rat085', 'Rat095']
    ## ok late DLS Histo
    animalList2= ['Rat120', 'Rat131', 'Rat132']

    groups={
        "group1":(cm.Greys,"black",animalList1,profile1),
        "group2":(cm.Reds ,"red"  ,animalList2,profile2),
        }

    print("animal lists:\n",animalList1,'\n',animalList2)

In [None]:
if "__file__" not in dir():

    TaskParamToPlot=["% good trials on last 40","percentile entrance time",
                "Forward Speed Vs TreadmillSpeed","Tortuosity","standard deviation of entrance time","Trajectory Correlation"]
    TaskParamToPlot=["Entropy","Mean Pairwise RMSE","Run Distance","Forward Speed Vs TreadmillSpeed"]

    stop_dayPlot =8
    fullLegend=False

    allResults,colors,colorGroup=plot_mean_subgroup_animal(root,groups,parameter=param,redo=False,stop_dayPlot=stop_dayPlot,TaskParamToPlot=TaskParamToPlot,fullLegend=fullLegend)

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
             'rewardType':'Progressive',
             #'initialSpeed':'10',
             #'Speed':'10',
             'Tag':['Control-Early-var-AfterBreak']
             }
    profile2={'Type':'Good',
             'rewardType':'Progressive',
             #'initialSpeed':'10',
             #'Speed':'10',
             'Tag':['Early-var-Late-Lesion_DLS']
             }
    animalList1=batch_get_animal_list(root,profile1)
    animalList2=batch_get_animal_list(root,profile2)
    #animalList1=[ 'Rat085', 'Rat095']
    ## ok late DLS Histo
    #animalList2= ['Rat120', 'Rat131', 'Rat132']

    groups={
        "group1":(cm.Greys,"black",animalList1,profile1),
        "group2":(cm.Reds ,"red"  ,animalList2,profile2),
        }

    print("animal lists:\n",animalList1,'\n',animalList2)

In [None]:
if "__file__" not in dir():

    TaskParamToPlot=["% good trials on last 40","percentile entrance time",
                "Forward Speed Vs TreadmillSpeed","Tortuosity","standard deviation of entrance time","Trajectory Correlation"]
    #TaskParamToPlot=["% good trials","percentile entrance time","standard deviation of entrance time","Trajectory Correlation"]

    stop_dayPlot =30
    fullLegend=False

    allResults,colors,colorGroup=plot_mean_subgroup_animal(root,groups,parameter=param,redo=False,stop_dayPlot=stop_dayPlot,TaskParamToPlot=TaskParamToPlot,fullLegend=fullLegend)

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
             'rewardType':'Progressive',
             #'initialSpeed':'10',
             #'Speed':'10',
             'Tag':['Control']
             }
    profile2={'Type':'Good',
             'rewardType':'Progressive',
             #'initialSpeed':'10',
             #'Speed':'10',
             'Tag':['Early-Lesion_GPi']
             }
    animalList1=batch_get_animal_list(root,profile1)
    animalList2=batch_get_animal_list(root,profile2)
    #animalList1=[ 'Rat085', 'Rat095']
    ## ok late DLS Histo
    #animalList2= ['Rat120', 'Rat131', 'Rat132']

    groups={
        "group1":(cm.Greys,"black",animalList1,profile1),
        "group2":(cm.Reds ,"red"  ,animalList2,profile2),
        }

    print("animal lists:\n",animalList1,'\n',animalList2)

In [None]:
if "__file__" not in dir():

    TaskParamToPlot=["% good trials on last 40","percentile entrance time",
                "Forward Speed Vs TreadmillSpeed","Tortuosity","standard deviation of entrance time","Trajectory Correlation"]
    #TaskParamToPlot=["% good trials","percentile entrance time","standard deviation of entrance time","Trajectory Correlation"]

    stop_dayPlot =30
    fullLegend=False

    allResults,colors,colorGroup=plot_mean_subgroup_animal(root,groups,parameter=param,redo=False,stop_dayPlot=stop_dayPlot,TaskParamToPlot=TaskParamToPlot,fullLegend=fullLegend)

## 2. Compare two profiles of different size

In [None]:
def plot_mean_animals(root,groups,axes=None,n_iteration=1e3,parameter={},redo=False,stop_dayPlot=0,TaskParamToPlot=[],fullLegend=False,**kargs):
    #divide input dictionary
    colorMapGroup={key: groups[key][0] for key in groups}
    colorGroup={key: groups[key][1] for key in groups}
    animalGroup={key: groups[key][2] for key in groups}
    experimentalGroup={key: groups[key][3] for key in groups}
    NbAnimal=len(groups["group2"][2]) #number of animals to be chosen from the first group
    
    allResults={key:{} for key in TaskParamToPlot}
    nSessionMax={}

    if fullLegend:
        ncols=3
    else: ncols=1
    
    # get the data for all the groups
    
    ALLobj=sample_size_control(get_rat_group_statistic,animalList=animalGroup['group1'],NbAnimal=NbAnimal,n=n_iteration,
                               root=root,profile=experimentalGroup['group1'],parameter=parameter,redo=redo,
                               stop_dayPlot=stop_dayPlot,TaskParamToPlot=TaskParamToPlot);
    ALL=ALLobj.Results;

    nIterMax=[ALL[i][1] for i,_ in enumerate(ALL)]
    nSessionMax['group1']=max(nIterMax)
    for key in TaskParamToPlot:
        results={i:np.nanmean(np.asarray(list(ALL[i][0][key].values())),axis=0) for i,_ in enumerate(ALL) 
                 if nIterMax[i]>=nSessionMax['group1']}
        allResults[key]['group1']=results
        ALLobj.iterN=len(results.keys())
    

    results,nSessionMax['group2']=get_rat_group_statistic(root,animalList=animalGroup['group2'],profile=experimentalGroup['group2'],parameter=parameter,
                                                        redo=redo,stop_dayPlot=stop_dayPlot,TaskParamToPlot=TaskParamToPlot)
    for key in TaskParamToPlot:
        allResults[key]['group2']=results[key]
    
    #PLOTTING
    def plot_one_key(xaxis,keyResult,animalList,experimentalGroup,colorList,color="blue",axes=None,fullLegend=False):
        Label=None
        #compute mean
    #         keyRes=np.asarray(list(keyResult.values()))
        meanKeyRes=np.nanmedian(keyResult,axis=0)
        stdKeyRes=np.nanstd(keyResult,axis=0)

        #plot mean
        axes.plot(xaxis,meanKeyRes,color=color,lw=3,marker='.',markersize=20,label=experimentalGroup)
        if not fullLegend:
            axes.fill_between(xaxis,(meanKeyRes - stdKeyRes),(meanKeyRes + stdKeyRes),color=color,alpha=0.25)    
    
    if axes is None:
        nbCol=2
        nbLine=len(allResults)//nbCol+len(allResults)%nbCol
        nbLine+=1
        axes=[]
        fig=plt.figure(figsize=(12, nbLine*4))
        for index in range(len(TaskParamToPlot)):
            axes.append(fig.add_subplot(nbLine,nbCol,index+1))

    colors={}
    for group in animalGroup:
        colors[group]  =colorMapGroup[group](np.linspace(0, 1, len(animalGroup[group])+3))
        colors[group]  =colors[group][3:]
        
    for index,key in enumerate(TaskParamToPlot):
        if key=="treadmillSpeed":continue
        ax=axes[index]
        for group in allResults[key]:
            xaxis=np.arange(1,nSessionMax[group]+1,1)
            yaxis=np.array(list(allResults[key][group].values()),ndmin=2)
            plot_one_key(xaxis,yaxis,animalGroup[group],experimentalGroup[group]['Tag'][0],
                         colors[group],color=colorGroup[group],axes=ax,fullLegend=fullLegend)
        
        
        ax.set_xlabel("Session",fontsize=14)
        ax.set_ylabel(key,fontsize=14)
        ax.set_title(key,fontsize=14)
        ax.set_xlim([0,max(list(nSessionMax.values()))])
        ax.tick_params(axis='both', which='major', labelsize=12)
        if key=="median entrance time (sec)":
            ax.set_ylim([0,15])
            ax.axhline(7,color="b",ls="--")
        elif key=="mean entrance time (sec)":
            ax.set_ylim([0,10])
            ax.axhline(7,color="b",ls="--")
        elif key=="% good trials":
            ax.set_ylim([0,100])
        elif key=="standard deviation of entrance time":
            ax.set_ylim([0,11])
        elif key=="Tortuosity":pass
            #ax.set_ylim([2,10])
        elif key=="percentile entrance time":
            ax.set_ylim([0,15])
        elif key=="% good trials on last 40":
            ax.axhline(72.5,color="b",ls="--")
            ax.set_ylim([0,100])
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')
        try:
            ax.set(**kargs)
        except:pass

    ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.,ncol=ncols)
    plt.tight_layout()            
    plt.subplots_adjust(top=0.90)

    return allResults,colors,colorGroup,ALLobj


In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
             'option':['not used'],
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control']
             }
    profile2={'Type':'Good',
             'option':['not used'],
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Early-Lesion_DMS']
             }
    animalList1=batch_get_animal_list(root,profile1)
    animalList2=batch_get_animal_list(root,profile2)
    #animalList1=[ 'Rat085', 'Rat095']
    ## ok late DLS Histo
    #animalList2= ['Rat118', 'Rat119', 'Rat133', 'Rat134']

    groups={
        "group1":(cm.Greys,"black",animalList1,profile1),
        "group2":(cm.Reds ,"red"  ,animalList2,profile2),
        }

    print("animal lists:\n",animalList1,'\n',animalList2)

In [None]:
if "__file__" not in dir():

#     TaskParamToPlot=["% good trials on last 40","percentile entrance time",
#                 "Forward Speed Vs TreadmillSpeed","Tortuosity","standard deviation of entrance time","Trajectory Correlation"]
    TaskParamToPlot=["Successive Entrance Score"]

    stop_dayPlot =30
    fullLegend=False
    n_iteration=1e3
    
    a=plot_mean_animals(root,groups,n_iteration=n_iteration,
                        parameter={},redo=False,stop_dayPlot=stop_dayPlot,TaskParamToPlot=TaskParamToPlot,fullLegend=fullLegend,title='')
    print("number of removed subsets: ",n_iteration-a[3].iterN)
#     plt.gcf().savefig("/home/david/Downloads/b.svg",format='svg')

## 3. Compare event impact

In [None]:
def event_statistic(root,sessionDic,parameter={},redo=False,TaskParamToPlot=[]):
    animalList=list(sessionDic.keys())
    allResultsPre={key:{animal:list() for animal in animalList} for key in TaskParamToPlot}
    allResultsPost={key:{animal:list() for animal in animalList} for key in TaskParamToPlot}
    OutResults={key:{animal:list() for animal in animalList} for key in TaskParamToPlot}
    nSessionMaxPre=0
    nSessionMaxPost=0
    # Load all statistics
    for animal in animalList:
        sessionList=sessionDic[animal][0].copy()
        sessionList.extend(sessionDic[animal][1])
        
        pathPickle=os.path.join(root,animal,"Analysis","learningStats.p")
        if os.path.exists(pathPickle) and (not redo):
            try:
                with open(pathPickle,"rb") as f:
                    results=pickle.load(f)
                
                sessionProfile=results.pop('sessionProfile',{})
                if sessionProfile=={}: 
                    #this means that the pickle file is the old kind!
                    raise NameError(animal+' :old pickle, computing again...')
                if len (set(TaskParamToPlot) - results.keys()) !=0:
                    #this means not all the TaskParamToPlot keys are available in the pickle
                    raise NameError(animal+' :pickle not complete, computing again...')
                if not set(sessionProfile['Sessions']).issuperset(set(sessionList)):
                    raise NameError(animal+' :pickle not up-to-date, computing again...')
                    
            except Exception as e:
                print(repr(e))
                results={}
            
            if results=={}:
                results=plot_learningCurves(root,animal,PerfParamToPlot=TaskParamToPlot,profile={},parameter=parameter,redoStat=redo,plot=False)
                plt.close()
                sessionProfile=results.pop('sessionProfile',{})

        else:
            results=plot_learningCurves(root,animal,PerfParamToPlot=TaskParamToPlot,profile={},parameter=parameter,redoStat=redo,plot=False)
            plt.close()
            sessionProfile=results.pop('sessionProfile',{})
        if isinstance(results,bool):
            print('False')
            return {},None
        
        results.pop('days',{})
        
        for key in TaskParamToPlot:
            if key=="percentile entrance time":
                results[key]=[array[2] for array in results[key]]
            
            wantedResultsPre =[i for ID,i in enumerate(results[key]) 
                               if sessionProfile['Sessions'][ID] in sessionDic[animal][0]]
            wantedResultsPost=[i for ID,i in enumerate(results[key]) 
                               if sessionProfile['Sessions'][ID] in sessionDic[animal][1]]

            allResultsPre[key][animal] =wantedResultsPre
            allResultsPost[key][animal]=wantedResultsPost
            nSessionMaxPre=max(nSessionMaxPre,len(wantedResultsPre))
            nSessionMaxPost=max(nSessionMaxPost,len(wantedResultsPost))

    #pad with nan when fewer session than maximum
    for key in allResultsPost:
        for animal in allResultsPost[key]:
            nSessionPre=len(allResultsPre[key][animal])
            nSessionPost=len(allResultsPost[key][animal])
            OutResults[key][animal]=[np.nan]*(nSessionMaxPre-nSessionPre)+allResultsPre[key][animal]
            OutResults[key][animal].extend(allResultsPost[key][animal]+[np.nan]*(nSessionMaxPost-nSessionPost))
    
    return OutResults,nSessionMaxPre,nSessionMaxPost

In [None]:
def event_statistic2(root,sessionDic,TaskParamToPlot: list,param=None):
    if param is None:
        parameter={}
        
    animalList=list(sessionDic.keys())
    allResultsPre={key:{animal:list() for animal in animalList} for key in TaskParamToPlot}
    allResultsPost={key:{animal:list() for animal in animalList} for key in TaskParamToPlot}
    OutResults={key:{animal:list() for animal in animalList} for key in TaskParamToPlot}
    nSessionMaxPre=0
    nSessionMaxPost=0
    
    for animal in animalList:
        sessionListPre=sessionDic[animal][0]
        sessionListPost=sessionDic[animal][1]
        
        for key in TaskParamToPlot:
            if key=="percentile entrance time":
                continue
                        
            wantedResultsPre=animal_learning_stats(root,animal,key,goodSessions=sessionListPre, parameters=parameter)
            wantedResultsPost=animal_learning_stats(root,animal,key,goodSessions=sessionListPost,parameters=parameter)
            allResultsPre[key][animal] =wantedResultsPre
            allResultsPost[key][animal]=wantedResultsPost
            nSessionMaxPre=max(nSessionMaxPre,len(wantedResultsPre))
            nSessionMaxPost=max(nSessionMaxPost,len(wantedResultsPost))

    #pad with nan when fewer session than maximum
    for key in allResultsPost:
        for animal in allResultsPost[key]:
            nSessionPre=len(allResultsPre[key][animal])
            nSessionPost=len(allResultsPost[key][animal])
            OutResults[key][animal]=[np.nan]*(nSessionMaxPre-nSessionPre)+allResultsPre[key][animal]
            OutResults[key][animal].extend(allResultsPost[key][animal]+[np.nan]*(nSessionMaxPost-nSessionPost))
    
    return OutResults,nSessionMaxPre,nSessionMaxPost

In [None]:
if "__file__" not in dir():
    profile1={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Control','Control-AfterBreak']
             }
    profile2={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':['Late-Lesion_DS','Late-Lesion_DMS','Late-Lesion_DLS']
             }
    
    _,sessionDic=event_detect(root,profile1,profile2)

    TaskParamToPlot=["Intertrial Displacement"]
    res2,N,M= event_statistic2(root,sessionDic=sessionDic,TaskParamToPlot=TaskParamToPlot)
#     res,N,M= event_statistic(root,sessionDic=sessionDic,parameter=param,redo=False,TaskParamToPlot=TaskParamToPlot)


In [None]:
def plot_groups(root,groups,axes=None,parameter={},redo=False,session_range=(-1,5),TaskParamToPlot=[],fullLegend=False):
    #divide input dictionary
    colorMapGroup={key: groups[key][0] for key in groups}
    colorGroup={key: groups[key][1] for key in groups}
    animalGroup={key:list() for key in groups}
    sessionGroups={key: list() for key in groups}

    for key in groups:
        animalGroup[key],sessionGroups[key]=event_detect(root,groups[key][2],groups[key][3],badAnimals=groups[key][4])
    
    allResults={key:{} for key in TaskParamToPlot}
    nSessionMaxPre={}
    nSessionMaxPost={}
    
    ncols=3 if fullLegend else 1
    
    # get the data for all the groups
    for group in animalGroup:
        results,nSessionMaxPre[group],nSessionMaxPost[group]=event_statistic(root,sessionDic=sessionGroups[group],parameter=parameter,redo=redo,TaskParamToPlot=TaskParamToPlot)
        for key in TaskParamToPlot:
            allResults[key][group]=results[key]

    #Compute the number of animals per day
    nbAnimal={}
    allResults['Number of animals']={}
    for group in animalGroup:
        allResults['Number of animals'][group]={}
        nbAnimal[group]=[]
        for n in range(nSessionMaxPre[group]+nSessionMaxPost[group]):
            nbAnimal[group].append(False)
            for animal in allResults[TaskParamToPlot[0]][group]:
                nbAnimal[group][-1]+=not np.isnan(allResults[TaskParamToPlot[0]][group][animal][n])
        allResults['Number of animals'][group]['RatXXX']=nbAnimal[group]

    
    #PLOTTING
    nbCol=1
    nbLine=len(allResults)//nbCol+len(allResults)%nbCol
    nbLine+=1
    
    TaskParams=[]
    if axes is None:
        axes=[]
        fig=plt.figure(figsize=(7, nbLine*4))
        for index,_ in enumerate(TaskParamToPlot+['Number of animals']):
            axes.append(fig.add_subplot(nbLine,nbCol,index+1))
        TaskParams=TaskParamToPlot+['Number of animals']
    elif len(axes)==len(TaskParamToPlot):
        TaskParams=TaskParamToPlot
    elif len(axes)==len(TaskParamToPlot)+1:
        TaskParams=TaskParamToPlot+['Number of animals']
    
    assert len(axes)==len(TaskParams), "Bad number of axes/paramsToPlot"
        

    colors={}
    for group in animalGroup:
        colors[group]  =colorMapGroup[group](np.linspace(0, 1, len(animalGroup[group])+3))
        colors[group]  =colors[group][3:]
        
    for index,key in enumerate(TaskParams):
        if key=="treadmillSpeed":continue
        ax=axes[index]
        for group in allResults[key]:
            xaxis=np.arange(-abs(session_range[0]),session_range[1],1) 
            plot_single_key(xaxis,allResults[key][group],nSessionMaxPre[group],
                            animalGroup[group],groups[group][2]['Tag']+groups[group][3]['Tag'],
                            colors[group],color=colorGroup[group],axes=ax,fullLegend=fullLegend)
        
        ax.set_xlabel("Session",fontsize=14)
        ax.set_ylabel(key,fontsize=14)
        ax.set_xlim([-abs(session_range[0])-1,session_range[1]])
        ax.tick_params(axis='both', which='major', labelsize=12)
        if key=="median entrance time (sec)":
            ax.set_ylim([0,15])
            ax.axhline(7,color="b",ls="--")
        elif key=="mean entrance time (sec)":
            ax.set_ylim([0,10])
            ax.axhline(7,color="b",ls="--")
        elif key=="% good trials":
            ax.set_ylim([0,100])
        elif key=="standard deviation of entrance time":
            ax.set_ylim([0,11])
        elif key=="Tortuosity":pass
            #ax.set_ylim([2,10])
        elif key=="percentile entrance time":
            ax.set_ylim([0,15])
        elif key=="% good trials on last 40":
            ax.axhline(72.5,color="b",ls="--")
            ax.set_ylim([0,100])
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')
        ax.axvline(x=-.5, linewidth=4, color='m')
        ax.set_xticks(range(-abs(session_range[0]),session_range[1]+1))
        tickList=[str(i) for i in ax.get_xticks()]
        del tickList[tickList.index('0')]
        ax.set_xticklabels(tickList)
    
    ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.,ncol=ncols)
    plt.tight_layout()
    plt.subplots_adjust(top=0.90)
    return allResults,colors,colorGroup,nSessionMaxPre


def plot_single_key(xaxis,keyResult,zero_point,animalList,experimentalGroup,colorList,color="blue",axes=None,fullLegend=False):
    Label=None
    if axes is None:
        axes=plt.gca()
    if animalList==[]:
        return
    #compute mean
    keyRes=np.asarray(list(keyResult.values()))[:,zero_point+xaxis[0]:zero_point+xaxis[-1]+1]
    meanKeyRes=np.nanmedian(keyRes,axis=0)
    stdKeyRes=np.nanstd(keyRes,axis=0)
    #plot mean
    axes.plot(xaxis,meanKeyRes,color=color,lw=3,marker='.',markersize=20,label=experimentalGroup)
    if not fullLegend:
        axes.fill_between(xaxis,(meanKeyRes - stdKeyRes),(meanKeyRes + stdKeyRes),color=color,alpha=0.25)
    for color,animal in zip(colorList,keyResult):
        if fullLegend:
            Label=animal
        axes.plot(xaxis,keyResult[animal][zero_point+xaxis[0]:zero_point+xaxis[-1]+1],ls=':',color=color, label=Label)

In [None]:
if "__file__" not in dir():

    profile1pre={'Type':'Good',
                 'rewardType':'Progressive',
                 'initialSpeed':'10',
                 'Speed':['10'],
                 'Tag':['Control']
                }
    profile1post={'Type':'Good',
                  'rewardType':'Progressive',
                  'initialSpeed':'10',
                  'Speed':['10'],
                  'Tag':['Control-AfterBreak']
                  }
    badAnimals1=[]
    
    profile2pre={'Type':'Good',
                 'rewardType':'Progressive',
                 'initialSpeed':'10',
                 'Speed':['10'],
                 'Tag':['Control','Control-AfterBreak']
                }
    profile2post={'Type':'Good',
                  'rewardType':'Progressive',
                  'initialSpeed':'10',
                  'Speed':['10'],
                  'Tag':['Late-Lesion_DS']
                  }
    badAnimals2=[]

    groups={
        "group1":(cm.Greys,"black",profile1pre,profile1post,[]),
        "group2":(cm.Reds ,"red"  ,profile2pre,profile2post,badAnimals2),
        }
    
    list1,_=event_detect(root,profile1pre,profile1post,badAnimals1)
    list2,_=event_detect(root,profile2pre,profile2post,badAnimals2)
    print('animal lists:\n',list1,'\n',list2)

In [None]:
if "__file__" not in dir():

#     TaskParamToPlot=["% good trials on last 40","percentile entrance time",
#                 "Forward Speed Vs TreadmillSpeed","Tortuosity","standard deviation of entrance time","Trajectory Correlation"]
    TaskParamToPlot=["Forward Speed Vs TreadmillSpeed","Lick Onset Delay"]

    pre_sessions =2
    post_sessions=6
    fullLegend=True

    allResults,colors,colorGroup,_=plot_groups(root,groups,parameter=param,redo=False,session_range=(pre_sessions,post_sessions),TaskParamToPlot=TaskParamToPlot,fullLegend=fullLegend)


In [None]:
if "__file__" not in dir():

    profile1pre={'Type':'Good',
                 'rewardType':'Progressive',
                 #'initialSpeed':'10',
                 #'Speed':['10'],
                 'Tag':['Control-Early-var-AfterBreak']
                }
    profile1post={'Type':'Good',
                  'rewardType':'Progressive',
                  #'initialSpeed':'10',
                  #'Speed':['10'],
                  'Tag':['Early-var-Late-Lesion_DLS']
                  }
    
    
    profile2pre={'Type':'Good',
                 'rewardType':'Progressive',
                 #'initialSpeed':'10',
                 #'Speed':['10'],
                 'Tag':['Control-Early-var']
                }
    profile2post={'Type':'Good',
                  'rewardType':'Progressive',
                  #'initialSpeed':'10',
                  #'Speed':['10'],
                  'Tag':['Control-Early-var-AfterBreak']
                  }

    groups={
        "group1":(cm.Greys,"black",profile1pre,profile1post),
        "group2":(cm.Reds ,"red"  ,profile2pre,profile2post),
        }
    
    list1,_=event_detect(root,profile1pre,profile1post)
    list2,_=event_detect(root,profile2pre,profile2post)
    print('animal lists:\n',list1,'\n',list2)

In [None]:
if "__file__" not in dir():

    TaskParamToPlot=["% good trials on last 40","percentile entrance time",
                "Forward Speed Vs TreadmillSpeed","Tortuosity","standard deviation of entrance time","Trajectory Correlation"]
    #TaskParamToPlot=["Trajectory Correlation"]

    pre_sessions =2
    post_sessions=6
    fullLegend=False

    allResults,colors,colorGroup=plot_groups(root,groups,parameter=param,redo=False,session_range=(pre_sessions,post_sessions),TaskParamToPlot=TaskParamToPlot,fullLegend=fullLegend)