import everything

In [None]:
import os
import glob
import numpy as np
from platform import system as OS
import pandas as pd
import scipy.stats
import math
import datetime
from copy import deepcopy
import matplotlib.cm as cm
import warnings
import types
warnings.filterwarnings("ignore")
import sys, time
import string
import pickle
import matplotlib.pyplot as plt
import mpl_toolkits
from scipy import stats
from matplotlib import mlab
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import matplotlib.backends.backend_pdf
from sklearn.decomposition import KernelPCA
import mpl_toolkits.axes_grid1.inset_locator as inset
from matplotlib.ticker import FormatStrFormatter
import imageio



if "__file__" not in dir():
    %matplotlib inline
    %config InlineBackend.close_figures = False
    
    ThisNoteBookPath=os.path.dirname(os.path.realpath("__file__"))
    CommonNoteBookesPath=os.path.join(os.path.split(ThisNoteBookPath)[0],"load_preprocess_rat")
    CWD=os.getcwd()
    os.chdir(CommonNoteBookesPath)
    %run UtilityTools.ipynb
    %run Animal_Tags.ipynb
    %run loadRat_documentation.ipynb
    %run plotRat_documentation_1_GeneralBehavior.ipynb
    %run plotRat_documentation_3_KinematicsInvestigation.ipynb
    %run RunBatchRat_3_CompareGroups.ipynb
    %run BatchRatBehavior.ipynb
    %run ../BehavioralPaper/TaskRules.ipynb
    os.chdir(CWD)
    # PARAMETERS (used if the pickles don't exist)
    param={
        "goalTime":7,#needed for pavel data only
        "treadmillRange":[0,90],#pavel error conversion "treadmillRange":[0,80]
        "maxTrialDuration":15,
        "interTrialDuration":10,#None pavel
        "endTrial_frontPos":30,
        "endTrial_backPos":55, 
        "endTrial_minTimeSec":4,
        "cameraSamplingRate":25, #needed for new setup    

        "sigmaSmoothPosition":0.1,#0.33, 0.18 pavel
        "sigmaSmoothSpeed":0.3,#0.3, 0.5 pavel
        "nbJumpMax":100,#200 pavel
        "binSize":0.25,
        #parameters used to preprocess (will override the default parameters)
    }  

    if OS()=='Linux':
        root="/data"
    elif OS()=='Windows':
        root="C:\\DATA\\"
    else:
        root="/Users/davidrobbe/Documents/Data/"
        
    logging.getLogger().setLevel(logging.ERROR)

    print('os:',OS(),'\nroot:',root,'\nImport successful!')

# part 1:

# DEFINITIONS

### If you don't know, move to part 2

In [None]:
def add_panel_caption(axes: tuple, offsetX: tuple, offsetY: tuple, **kwargs):
    """
    This function adds letter captions (a,b,c,d) to Axes in axes
    at top left, with the specified offset, in RELATIVE figure coordinates
    """
    assert len(axes)==len(offsetX)==len(offsetY), 'Bad input!'
    
    fig=axes[0].get_figure()
    fbox=fig.bbox
    for ax,dx,dy,s in zip(axes,offsetX,offsetY,string.ascii_lowercase):
        axbox=ax.get_window_extent()
    
        ax.text(x=(axbox.x0/fbox.xmax)-abs(dx), y=(axbox.y1/fbox.ymax)+abs(dy),
                s=s,fontweight='extra bold', fontsize=10, ha='left', va='center',
               transform=fig.transFigure,**kwargs)


Drawing the reward magnitude

In [None]:
def plot_reward_progression_schematic(ax, minReward, gt, maxTrial, maxReward, step, color,showYAxis,fontsize=3):
    points=[(i,(i-minReward)/(gt-minReward)*.6) for i in np.arange(minReward,gt,step)]
    rewardRatio = lambda x: (-x/(maxReward-gt)+ (maxReward/(maxReward-gt)))

    for x,alpha in points:
        ax.plot( (x,gt), (0,1), alpha=alpha, color=color, lw=.4)

    ax.plot( (gt,gt), (1,0), alpha=1, color=color, lw=.4 )
    ax.plot( (gt,maxTrial), (1,rewardRatio(maxTrial)), alpha=1, color=color, lw=.4 )
    ax.plot( (maxTrial,maxReward), (rewardRatio(maxTrial),0), ':', alpha=1, color=color, lw=.4)

    ax.yaxis.set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.set_xticks([minReward, gt, maxTrial, maxReward])
    ax.tick_params(top='off', bottom='off', left='off', right='off',
                   labelleft='on', labelbottom='on',labelsize=fontsize)
    ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
    ax.set_xlim([minReward-.5,maxReward+.5])
    if showYAxis:
        ax.yaxis.set_visible(True)
        ax.spines['left'].set_visible(True)
        ax.set_ylim([-.05,1])
        ax.set_yticks([0,1])
        ax.set_yticklabels(['zero', 'max'])

    #adding the text
    ax.set_title('Reward\nMagnitude',fontsize=fontsize, pad=4)
    ax.set_xlabel('$ET$',fontsize=fontsize, labelpad=0)
    
    return points

In [None]:
if "__file__" not in dir():
    minReward=1.5
    gt=7
    maxTrial=15
    maxReward=20
    step=.8
    color='r'
    showYAxis=False
    ax=plt.figure(figsize=(2,2)).add_subplot(111);

    plot_reward_progression_schematic(ax, minReward, gt, maxTrial, maxReward, step, color,showYAxis)
    
    plt.show()
    plt.close('all')

Drawing the Punishment duration

In [None]:
def plot_punishment_schematic(ax, minReward, gt, minPunish, maxPunish, maxTrial, color,fontsize=10):

#     ax.plot( (0,minReward), (0,0)          ,':'   , alpha=1, color=color, lw=.4)
    ax.plot( (minReward,minReward), (minPunish,maxPunish) , alpha=1, color=color, lw=.4)
    ax.plot( (minReward,gt), (maxPunish,minPunish), alpha=1, color=color, lw=.4)
#     ax.plot( (gt,gt), (minPunish,0)               , alpha=1, color=color, lw=.4)
#     ax.plot( (gt,gt+2), (0,0)              ,':'   , alpha=1, color=color, lw=.4)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xticks([minReward,gt])
    ax.tick_params(top='off', bottom='off', left='off', right='off',
                   labelleft='on', labelbottom='on',labelsize=fontsize)
    ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
    ax.set_xlim([.5,8.5])
    ax.set_ylim([0,maxPunish+1])
    ax.set_yticks([minPunish,maxPunish])
    ax.spines['left'].set_bounds(minPunish,maxPunish)
    ax.spines['bottom'].set_bounds(minReward,gt)

    #adding the text
    ax.set_title('Punishment\nDuration (s)',fontsize=fontsize, pad=0)
    ax.set_xlabel('$ET$',fontsize=fontsize, labelpad=0)
    

In [None]:
if "__file__" not in dir():
    minReward=1.5
    gt=7
    maxPunish=10
    minPunish=1
    maxTrial=15
    color='r'
    ax=plt.figure(figsize=(2,2)).add_subplot(111);

    plot_punishment_schematic(ax, minReward, gt, minPunish, maxPunish,maxTrial, color)
    
    plt.show()
    plt.close('all')

Drawing task rules 

In [None]:
def plot_task_rules(fig, gs, hspace, alpha, xmin, xmax, et, GT, maxP, fontsize):
    gssub = gs.subgridspec(4, 1,hspace=hspace)
    goodAx=fig.add_subplot(gssub[1, 0])
    earlyAx =fig.add_subplot(gssub[2, 0])
    omitAx =fig.add_subplot(gssub[3, 0])
    treadmillAx =fig.add_subplot(gssub[0, 0])
    axes=(goodAx,earlyAx, omitAx)
    ET,ET0,ET1=et
    
    #good trials
    ######################################
    
    r=np.array([0,1])

    #light on
    goodAx.fill_betweenx(y=r,x1=-1,x2=ET, facecolor='yellow', edgecolor='', alpha=alpha)
    goodAx.fill_betweenx(y=r,x1=-10,x2=-1, facecolor='gray', edgecolor='', alpha=alpha)
    goodAx.fill_betweenx(y=r,x1=ET,x2=300, facecolor='gray', edgecolor='', alpha=alpha)
    goodAx.text(x=-1, y=r.mean(), s='Light ON', verticalalignment='center', fontsize=fontsize)
    goodAx.text(x=ET+.1, y=r.mean(), s='Intertrial', verticalalignment='center', fontsize=fontsize)

    r+=1
    #motor running
    goodAx.fill_betweenx(y=r,x1=0,x2=ET, facecolor='xkcd:muted blue', edgecolor='', alpha=alpha)
    goodAx.text(x=0, y=r.mean(), s='Motor ON', verticalalignment='center', fontsize=fontsize)

    r+=1
    #beam on
    goodAx.fill_betweenx(y=r,x1=1.5,x2=ET, facecolor='pink', edgecolor='', alpha=alpha)
    goodAx.text(x=1.5, y=r.mean(), s='Beam ON', verticalalignment='center', fontsize=fontsize)

#     # r+=1
#     #entrance times
#     goodAx.fill_betweenx(y=r,x1=GT,x2=15, facecolor='lime', edgecolor='', alpha=alpha)
#     goodAx.text(x=GT, y=r.mean(), s='Entrance time range', verticalalignment='center', fontsize=fontsize)

    #Correct trials
    goodAx.vlines(x=ET, ymin=-10, ymax=3, color='k', linestyles='-', lw=1, zorder=5)
    goodAx.text(x=xmin, y=4, s='Correct trial '+'$(7\leq ET<15 s)$', verticalalignment='center', fontsize=fontsize+2)
    goodAx.text(ET+.5,1.5,u"\U0001F322",fontname='Symbola',fontsize=10,color='xkcd:blue',ha='center',va='center')



    #early trials
    ######################################
    r=np.array([0,1])

    #light on
    earlyAx.fill_betweenx(y=r,x1=-1,x2=maxP, facecolor='yellow', edgecolor='', alpha=alpha)
    earlyAx.fill_betweenx(y=r,x1=-10,x2=-1, facecolor='gray', edgecolor='', alpha=alpha)
    earlyAx.fill_betweenx(y=r,x1=maxP,x2=300, facecolor='gray', edgecolor='', alpha=alpha)
    # earlyAx.text(x=-1, y=r.mean(), s='Light ON', verticalalignment='center', fontsize=fontsize)
#     earlyAx.text(x=maxP, y=r.mean(), s='Intertrial', verticalalignment='center', fontsize=fontsize)

    r+=1
    #motor running
    earlyAx.fill_betweenx(y=r,x1=0,x2=maxP, facecolor='xkcd:muted blue', edgecolor='', alpha=alpha)
    # earlyAx.text(x=0, y=r.mean(), s='Motor ON', verticalalignment='center', fontsize=fontsize)

    r+=1
    #beam on
    earlyAx.fill_betweenx(y=r,x1=1.5,x2=ET0, facecolor='pink', edgecolor='', alpha=alpha)
    # earlyAx.text(x=0, y=r.mean(), s='Beam OFF', verticalalignment='center', fontsize=fontsize)

    # r+=1
    # punishment
    earlyAx.fill_betweenx(y=r,x1=ET0,x2=maxP, facecolor='r', edgecolor='', alpha=alpha,zorder=1)
    earlyAx.text(x=ET0+.1, y=r.mean(), s='Penalty', verticalalignment='center', fontsize=fontsize,zorder=3)

#     r+=1
#     #entrance times
#     earlyAx.fill_betweenx(y=r,x1=1.5,x2=GT, facecolor='lime', edgecolor='', alpha=alpha)
#     # earlyAx.text(x=1.5, y=r.mean(), s='Entrance time range', verticalalignment='center', fontsize=fontsize)

    #INCorrect trials
    earlyAx.text(x=xmin, y=4, s='Error trial '+'$(1.5\leq ET<7 s)$', verticalalignment='center', fontsize=fontsize+2)
    earlyAx.vlines(x=ET0, ymin=-10, ymax=3, color='k', linestyles='-', lw=1, zorder=5)
    earlyAx.text(maxP+.5,1.5,u"\U0001F322",fontname='Symbola',fontsize=10,color='xkcd:blue',ha='center',va='center',zorder=6)
    earlyAx.text(maxP+.5,1.5,u"\u2715",fontname='Symbola',fontsize=10,color='r',ha='center',va='center',zorder=6)
    
    
    
    #omit trials
    ######################################
    r=np.array([0,1])

    #light on
    omitAx.fill_betweenx(y=r,x1=-1,x2=15, facecolor='yellow', edgecolor='', alpha=alpha)
    omitAx.fill_betweenx(y=r,x1=-10,x2=-1, facecolor='gray', edgecolor='', alpha=alpha)
    omitAx.fill_betweenx(y=r,x1=15,x2=300, facecolor='gray', edgecolor='', alpha=alpha)

    r+=1
    #motor running
    omitAx.fill_betweenx(y=r,x1=0,x2=15, facecolor='xkcd:muted blue', edgecolor='', alpha=alpha)
    # omitAx.text(x=0, y=r.mean(), s='Motor ON', verticalalignment='center', fontsize=fontsize)

    r+=1
    #beam on
    omitAx.fill_betweenx(y=r,x1=1.5,x2=15, facecolor='pink', edgecolor='', alpha=alpha)
    # omitAx.text(x=0, y=r.mean(), s='Beam OFF', verticalalignment='center', fontsize=fontsize)


    omitAx.text(x=xmin, y=4, s='Omitted trial '+'$(No\  ET)$', verticalalignment='center', fontsize=fontsize+2)
    omitAx.text(15.5,1.5,u"\U0001F322",fontname='Symbola',fontsize=10,color='xkcd:blue',ha='center',va='center',zorder=6)
    omitAx.text(15.5,1.5,u"\u2715",fontname='Symbola',fontsize=10,color='r',ha='center',va='center',zorder=6)


    #common axes
    for i,ax in enumerate(axes):
        ax.vlines(x=GT, ymin=0, ymax=2.8, color='m', linestyles='--', linewidth=1.5, zorder=5)
        ax.set_xlim([xmin,xmax])
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.yaxis.set_visible(False)
        ax.set_xticks([-1,0,1.5,7,15])
        ax.tick_params('x',labelsize=8)
        ax.xaxis.set_major_formatter(FormatStrFormatter('%g'))
        ax.set_ylim([-.1,5])

#     goodAx.text(x=GT, y=2.7, s='Goal time', ha='center', fontsize=fontsize, color='m', backgroundcolor='w')
    
    goodAx.set_xticklabels([])

    #adding ET
    goodAx.set_xticks(list(goodAx.get_xticks())+[ET])
    l=[]
    for tick in goodAx.get_xticks():
        if tick==ET:
            l.append('$ET$')
        else:
            l.append('')
    goodAx.set_xticklabels(l)

    earlyAx.set_xticks(list(earlyAx.get_xticks())+[ET0])
    l=[]
    for tick in earlyAx.get_xticks():
        if tick==ET0:
            l.append('$ET$')
        else:
            l.append('')
    earlyAx.set_xticklabels(l)
    
    fig.canvas.draw()
    axes[-1].set_xlabel('Time (s)', fontsize=8)
    l=list(axes[-1].get_xticklabels())
    for i,tick in enumerate(axes[-1].get_xticks()):
        if tick==GT:
            axes[-1].get_xticklabels()[i].set_color('m')
            l[i]=f'Goal time={GT}'
#             axes[-1].get_xticklabels()[i].set_text(f'${GT}=Goal time$')
    axes[-1].set_xticklabels(l)
    
    axes=treadmillAx,*axes
    return axes

In [None]:
if "__file__" not in dir():

    alpha=1
    hspace=.4
    xmin,xmax=(-2,16)
    et=(9,5,15)
    GT=7
    maxP=10 #punishment
    fontsize=6

    plt.close('all')
    fig=plt.figure(figsize=(3,5),dpi=200)
    gs=fig.add_gridspec(1,1)[0]
    axes=plot_task_rules(fig, gs, hspace, alpha, xmin, xmax, et, GT, maxP,fontsize)
    axes[0].xaxis.set_visible(False)
    axes[0].yaxis.set_visible(False)
    
    plt.show()

# part 2:

# GENERATING THE FIGURE

Parameters

In [None]:
if "__file__" not in dir():
    # GRID 1 PARAMS
    
    alpha=1
    hspace=.3
    xmin,xmax=(-2,16)
    et=(9,4,15)
    GT=7
    maxP=10 #punishment
    fontsize=6
    
    #GRID 2 reward magnitude
    minReward2=1.5
    gt2=7
    maxTrial2=15
    maxReward2=20
    step2=.8
    color2='r'
    showYAxis2=True
    fontsize2=4
    
    #GRID 3 Punishment duration
    minReward3=1.5
    gt3=7
    maxPunish3=10
    minPunish3=1
    maxTrial3=15
    color3='r'

Plotting

In [None]:
if "__file__" not in dir():
    plt.close('all')
    figsize=(3.1,5)
    fig=plt.figure(figsize=figsize,dpi=100)
    
    
    ##########################################
    # 1: task rules
    gs1=fig.add_gridspec(1,1)[0]
    axes1=plot_task_rules(fig, gs1, hspace, alpha, xmin, xmax, et, GT, maxP,fontsize)
    axes1[0].xaxis.set_visible(False)
    axes1[0].yaxis.set_visible(False)
    axes1[0].spines['bottom'].set_visible(False)
    axes1[0].spines['left'].set_visible(False)
    axes1[0].spines['right'].set_visible(False)
    
    ##########################################
    # 2 reward magnitude
    ax2=fig.add_axes([0.69, 0.6, 0.16, 0.06])
    pts=plot_reward_progression_schematic(ax2, minReward2, gt2, maxTrial2, maxReward2,
                                          step2, color2,showYAxis2,fontsize=fontsize2)
    ax2.spines['bottom'].set_linewidth(.5)
    ax2.spines['left'].set_linewidth(.5)
    ax2.tick_params('both',pad=-3)
    ax2.xaxis.set_major_formatter(FormatStrFormatter('%g'))
    
    cmap=matplotlib.colors.LinearSegmentedColormap.from_list(name='test',N=10, colors=[[1,0,0,.1],[1,0,0,1]])
    norm = matplotlib.colors.Normalize(vmin=0, vmax=1)
    sm   = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
    sm.set_array([])
    cbar=fig.colorbar(mappable=sm, ax=ax2 ,
                      orientation='vertical', shrink=1, aspect=9,
                      anchor=(1,1),ticks=[0,1])
    
    cbar.ax.xaxis.set_label_position('top') 
    cbar.ax.set_xlabel('Training\nPhase', fontsize=fontsize2, ha='left')
    cbar.outline.set_visible(False)
    cbar.set_ticklabels(['Early','Late'])
    cbar.ax.tick_params(labelsize=fontsize2,right=False, pad=-3)
    cbar.ax.set_ylim([0,1])
    
    
    ##########################################
    # 3 Punishment Duration
    ax3=fig.add_axes([0.69, 0.40, 0.16, 0.06])
    plot_punishment_schematic(ax3, minReward3, gt3, minPunish3, maxPunish3,maxTrial3, color3,fontsize2)
    ax3.spines['bottom'].set_linewidth(.5)
    ax3.spines['left'].set_linewidth(.5)
    ax3.tick_params('both',pad=-3)

    
    
    #############################################
    #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    AXES=axes1
    OFFX=(.05,)*len(AXES)
    OFFY=(.0,)*len(AXES)
    add_panel_caption(axes=AXES, offsetX=OFFX, offsetY=OFFY)
    
    
    fig.savefig(os.path.join(os.path.dirname(os.getcwd()),'BehavioralPaper','Figures','TaskRules.pdf'),
                format='pdf', bbox_inches='tight')
    
    plt.show()
    plt.close('all')