## This notebook may be used to reproduce Figures 2A-2F, and Figures 4B, 4C, 4E, 4F in the main paper.

### Parameter values that need to be changed to derive the different figures are mentioned in comments as appropriate. Note that nonlinear and linear versions of the interneuronal inhibition model can be chosen.

In [1]:
import ipyparallel as ipp
clients = ipp.Client()

In [2]:
clients.ids

[0, 1, 2, 3]

In [3]:
dview = clients.direct_view()

In [4]:
with dview.sync_imports():
    import numpy as np
    from math import exp,sqrt,log
    import numpy.random
%px np = numpy
%px rand = numpy.random

importing numpy on engine(s)
importing exp,sqrt,log from math on engine(s)
importing numpy.random on engine(s)


In [84]:
def Sim_Anim_Beh(trials,randSeed,rin):
    
    rand.seed(randSeed)
    
    ##############################################################################################
    # Runge_Kutta 4th-order
    ##############################################################################################
    def rk4(F,t,y,ht):
        K0 = ht*F(t,y,dfct)
        K1 = ht*F(t + ht/2.0, y + K0/2.0,dfct)
        K2 = ht*F(t + ht/2.0, y + K1/2.0,dfct)
        K3 = ht*F(t + ht, y + K2,dfct)
        return (K0 + 2.0*K1 + 2.0*K2 + K3)/6.0

    ####################################################
    # Predictor-corrector integration routine: Heun for stochastic differential equations
    ####################################################
    # time t
    # time step ht
    # number of equations (=dimensionality) n
    # random numbers xi
    # deterministic and stochastic contributons Fdet and Frand
    def Heun(Fdet, Frand,t,y,ht):
        fd1 = np.zeros(3)
        fd2 = np.zeros(3)
        fr = np.zeros(3)
        yt = np.zeros(3)
        fd1 = Fdet(t,y,dfct)
        fr = Frand()
        yt = y + ht*fd1 + fr*sqrt(ht)
        fd2 = Fdet(t+ht,yt,dfct)
        ht2 = ht/2.0
        return ht2*(fd1+fd2) + fr*sqrt(ht)
    
    #activation function
    def actfct(z,g,b):
        return 1/(1+exp(-g*(z-b)))

    #differential equations motivations
    #determnistic part
    
    #nonlinear model
    #use this model for Figs 2A-F and Figs 4B and 4C (and comment linear model below)
    def Fdet(t,y,dfct):
        Fdet = np.zeros(3)
        Fdet[0] = -k*y[0] + beta*r*actfct(y[0],g1,b1) - beta*actfct(y[2],g2,b2) + q*dfct[0]
        Fdet[1] = -k*y[1] + beta*r*actfct(y[1],g1,b1) - beta*actfct(y[2],g2,b2) + q*dfct[1]
        Fdet[2] = -kinh*y[2] + w_exc*(actfct(y[0],g1,b1) + actfct(y[1],g1,b1))
        return Fdet
    
    '''  
    #linear model
    # use this model for Figs 4E and F (and comment nonlinear model above)
    def Fdet(t,y,dfct):
        Fdet = np.zeros(3)
        Fdet[0] = (r*beta*g1/4-k)*y[0] - beta*g2/4*y[2] + q*dfct[0] + phi
        Fdet[1] = (r*beta*g1/4-k)*y[1] - beta*g2/4*y[2] + q*dfct[1] + phi
        Fdet[2] = -kinh*y[2] + w_exc*(1 - g1*b1/2 + g1/4*(y[0] + y[1]))
        return Fdet
    '''
    
    def FrandDet():
        FrandDet = np.zeros(3)
        return FrandDet
    
    
    
    # stochastic part of RHS   
    def Frand():
        Frand = np.zeros(3)
        xRand0 = rand.normal(0,1)
        xRand1 = rand.normal(0,1)
        Frand[0] = sigma*xRand0
        Frand[1] = sigma*xRand1
        Frand[2] = 0
        return Frand
    
    # model parameters
    k = 0.8 # leak excitatory unit
    kinh = 0.8 # leak inhibitory unit
    w_exc = 3 # excitation strength inhibitory unit
    q = 0.1 # frequency of integration
    g1 = 10 # gain excitation function
    g2 = 10 # gain inhibition function
    b1 = 0.5 # midpoint excitation function
    b2 = 0.5 # midpoint inhibition function
    beta = 3 #3 # inhibition strength
    r = rin
    
    #phi for linear model
    phi = 0.5*beta*(r*(1 - g1*b1/2) - (1 - g2*b2/2))
    
    # geometric distribution for bout times
    # probability of interruption
    # calculation of maximum terminal bout time 'bout_max_99' including 99 percent of all values
    lam_interrupt = 0.05
    bout_max_99 = int(log(0.01)/log(1-lam_interrupt)) + 1
    
    eps_offset = 1e-9
    decay = 0.15
    
    ## for Fig2 (nonlinear model) choose:
    d01 = 7.5
    ## for Fig 4B (nonlinear model) choose:
    #d01 = 8.0
    ## for Fig 4C (nonlinear model), 4E (linear model) choose:
    #d01 = 8.1
    ## for Fig 4F (linear model) choose:
    #d01 = 8.35
    
    d02 = np.sqrt(2*7.5*7.5-d01**2) 
    #dfct_start = 7.5
    t = 0.0  #start time
    tend = 200  #terminal time 
    y = np.array([0.0, 0.0, 0.0])  #initial conditions
    dfct = np.array([d01, d02])
    ht = 0.005  #time step
    sigma = 0
    time1 = []  
    Ysol1 = []
    time1.append(t)
    Ysol1.append(y)

    while t <= tend:
        #ht = min(ht,tend-t)
        #y = y + Heun(Fdet,FrandDet,t,y,%%html)
        y = y + rk4(Fdet,t,y,ht)
        for ii in range(len(y)):
            y[ii] = max(0, y[ii])
        t = t + ht
        time1.append(t)
        Ysol1.append(y)
    
    y_init = y
    dfct_init = dfct
        
    Ep_sum_av = 0
    time2_list = []
    Ysol2_list = []
    Deficit_list = []
    
    for nn in range(trials):
        tau_dist = 4.0 #time to overcome distance between food and water sources
        t = 0.0
        tau = tau_dist/2.0 #initially animal is in between the two sources
        tchange = 0.0
        sigma = 0.01
        nfood = 2
        nwater = 2

        Tep = []
        Ep = []
        Ep_sum = 0
        time2 = []  
        Ysol2 = []
        Deficit = []
        time2.append(t)
        Ysol2.append(y)
        Deficit.append(dfct)
        Tbout_max = bout_max_99
        
        tend = Tbout_max + 1
        while t <= tend:
            for kk in range(1,Tbout_max+1):
                if t > kk-ht/2 and t < kk+ht/2:
                    Tep.append(int(t+ht))
                    penalty = dfct[0]*dfct[0] + dfct[1]*dfct[1]
                    Ep.append(penalty*lam_interrupt*pow(1-lam_interrupt,int(t+ht)-1))
            
            
            if y[0] == y[1]:
                randNr = rand.uniform(0,1)
                if randNr < 0.5:
                    y[0] = y[0] + eps_offset
                else:
                    y[1] = y[1] + eps_offset
                    
            dely1 = y[0] - y[1]
            
            if y[0] > y[1]:
                if t > tchange + tau: 
                    dfct[0] = dfct[0] - decay*ht  
                    nfood = 1
                    nwater = 0
            elif y[1] > y[0]:
                if t > tchange + tau: 
                    dfct[1] = dfct[1] - decay*ht 
                    nwater = 1
                    nfood = 0   
                
            
            for jj in range(len(dfct)):
                dfct[jj] = max(0, dfct[jj])
            
            y = y + Heun(Fdet,Frand,t,y,ht)
            for ii in range(len(y)):
                y[ii] = max(0, y[ii])
                
            if y[0] == y[1]:
                randNr = rand.uniform(0,1)
                if randNr < 0.5:
                    y[0] = y[0] + eps_offset
                else:
                    y[1] = y[1] + eps_offset
                    
            dely2 = y[0] - y[1]
            
            if np.sign(dely1)*np.sign(dely2) < 0:
                tchange_old = tchange
                tchange = t
                if nfood == 1 or nwater == 1:
                    tau = tau_dist
                    nfood = 0
                    nwater = 0
                else:
                    tau = tau_dist-tau + tchange-tchange_old
            
            t = t + ht
            dfct = np.array([dfct[0], dfct[1]])
            Deficit.append(dfct)
            time2.append(t)
            Ysol2.append(y)
        
        time2_list = time2_list + time2
        Ysol2_list = Ysol2_list + Ysol2
        Deficit_list = Deficit_list + Deficit
        Ep_sum = np.sum(Ep)
        Ep_sum_av += Ep_sum/trials
        y = y_init
        dfct = dfct_init
    
    time2_av = []
    Ysol2_av = []
    Deficit_av = []
    for kk in range(len(time2)):
        time2_av.append(np.sum(time2_list[kk::len(time2)])/trials)
        Ysol2_av.append(sum(Ysol2_list[kk::len(time2)])/trials)
        Deficit_av.append(sum(Deficit_list[kk::len(time2)])/trials)
    return Ep_sum_av,time1,Ysol1,time2_av,Ysol2_av,Deficit_av,tau_dist 

In [85]:
rlist = [0.5,1,1.5,2]

import timeit
time_start = timeit.default_timer()
nr_engines = len(clients.ids)
randSeedList = [(jj+1)*1928374 for jj in range(nr_engines)]
nr_trials_tot = 4
nr_trials_per_engine = int(nr_trials_tot/nr_engines)
results = dview.map_sync(Sim_Anim_Beh, [nr_trials_per_engine]*nr_engines, randSeedList, rlist)
time_stop = timeit.default_timer()
print('elapsed time =',time_stop - time_start,'seconds')
#print(randSeedList)

elapsed time = 8.348811512987595 seconds


In [86]:
time1 = results[0][1]  

## for r=1  (Fig2A) choose:
Ysol1 = results[1][2]
## for r=2  (Fig2D) choose:
#Ysol1 = results[3][2]

time2 = results[0][3]
Ep_sum_av = 0
Ysol2 = [0]*len(time2)
Deficit = [0]*len(time2)
for kk in range(nr_engines):
    Ep_sum_av += results[kk][0]/nr_engines
for jj in range(len(time2)):
    for kk in range(nr_engines):
        Ysol2[jj] += results[kk][4][jj]/nr_engines
        Deficit[jj] += results[kk][5][jj]/nr_engines
        

In [87]:
%matplotlib notebook
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

In [88]:
#The following lines in this cell may be used to reproduce Figs 2A and D

fig=plt.figure(figsize=(8,6))
plt.plot(np.asarray(time1)-200,np.asarray(Ysol1)[:,0],c = 'g',lw = 3,label = r'$x_1$ (food)')
plt.plot(np.asarray(time1)-200,np.asarray(Ysol1)[:,1],c = 'b', ls = '--', lw = 3,label = r'$x_2$ (water)')
plt.legend(loc='lower right', frameon=True, handlelength=4, borderpad=0.4, labelspacing=0.2, fontsize=20)
ax = plt.axes()
plt.xlabel(r'time t ', fontsize=26)
plt.ylabel(r'motivations', fontsize=26)
for tick in ax.xaxis.get_major_ticks():
    tick.label.set_fontsize(22) 
for tick in ax.yaxis.get_major_ticks():
    tick.label.set_fontsize(22)
#plt.tick_params(labelsize=14)
plt.xlim(-200,0)
#plt.ylim(0,0.6)
ax.xaxis.set_major_locator(ticker.MultipleLocator(50))
ax.xaxis.set_minor_locator(ticker.MultipleLocator(10))
#ax.yaxis.set_major_locator(ticker.MultipleLocator(0.1))
#ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.05))
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(2)
ax.tick_params('both', length=8, width=2, which='major')
ax.tick_params('both', length=6, width=1, which='minor')

#ax.annotate(r'initial motivations', xy=(0, np.asarray(Ysol1)[-1,0]), xytext=(-80, 2.4), fontsize=24,
#            arrowprops=dict(facecolor='k', edgecolor='k', shrink=0.03),
#            horizontalalignment='center', verticalalignment='bottom')

plt.tight_layout();

<IPython.core.display.Javascript object>

In [89]:
##The following is relevant for Figs 2 and 4

fig = plt.figure()
freq = 5

## to reproduce Fig2 choose:
## for r=1 (Fig2B):
plt.plot(time2[0:-1:freq],np.asarray(results[1][4])[0:-1:freq,0]-np.asarray(results[1][4])[0:-1:freq,1],c = 'k',lw = 3.5, ls='-')
## for r=2 (Fig2E):
#plt.plot(time2[0:-1:freq],np.asarray(results[3][4])[0:-1:freq,0]-np.asarray(results[3][4])[0:-1:freq,1],c = 'k',lw = 3.5, ls='-')

## to reproduce Fig4 choose:
'''
plt.plot(time2[0:-1:freq],np.asarray(results[0][4])[0:-1:freq,0]-np.asarray(results[0][4])[0:-1:freq,1],c = 'g',lw = 3.5, ls='-', label=r'$r=0.5$')
plt.plot(time2[0:-1:freq],np.asarray(results[1][4])[0:-1:freq,0]-np.asarray(results[1][4])[0:-1:freq,1],c = 'orange',lw = 3.5, ls='--', label=r'$r=1.0$')
plt.plot(time2[0:-1:freq],np.asarray(results[2][4])[0:-1:freq,0]-np.asarray(results[2][4])[0:-1:freq,1],c = 'b',lw = 3.5, ls=':', label=r'$r=1.5$')
plt.plot(time2[0:-1:freq],np.asarray(results[3][4])[0:-1:freq,0]-np.asarray(results[3][4])[0:-1:freq,1],c = 'r',lw = 3.5, ls='-.', label=r'$r=2.0$')
'''

ax = plt.axes()
plt.xlabel(r'time t', fontsize=26)
plt.ylabel(r'motivation difference $\Delta x$', fontsize=26)
for tick in ax.xaxis.get_major_ticks():
    tick.label.set_fontsize(22) 
for tick in ax.yaxis.get_major_ticks():
    tick.label.set_fontsize(22)
plt.xlim(0,100)

##for Fig2 choose:
plt.ylim(-0.5,0.5)

#plt.text(40, 0.4, r'water', fontsize=28, color = 'blue')
#plt.text(40, -0.4, r'food', fontsize=28, color = 'green')
#plt.text(61, 0.77, r'max. terminal', fontsize=24)
#plt.text(82, 0.64, r'time', fontsize=24)
plt.plot((0, 100), (0, 0), c = 'grey', ls = 'dashed', lw = 3)
#ax.fill_between([0, 100], 0, 0.5, facecolor='g', alpha=0.1)
#ax.fill_between([0, 100], 0, -0.5, facecolor='b', alpha=0.1)

#plt.plot((91, 91), (-0.5, 0.5), c = 'k', ls = '-', lw = 3)

#ax.arrow(91, 0.6, 0, -1.5, head_width=2, head_length=0.1, fc='k', ec='k',lw=2.5)
ax.xaxis.set_major_locator(ticker.MultipleLocator(20))
ax.xaxis.set_minor_locator(ticker.MultipleLocator(5))
ax.yaxis.set_major_locator(ticker.MultipleLocator(0.5))
ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.1))
for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(2)
ax.tick_params('both', length=8, width=2, which='major')
ax.tick_params('both', length=6, width=1, which='minor')

#ax.annotate(r'max. terminal time', xy=(91, -0.5), xytext=(50, -0.35), fontsize=24,
#            arrowprops=dict(facecolor='k', edgecolor='k', shrink=0.02),
#            horizontalalignment='center', verticalalignment='bottom')

for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(2)

## Uncomment legend for Fig 4
#plt.legend(bbox_to_anchor=(0.4, 1), loc='upper center', ncol=1, fontsize=18, handlelength=3, fancybox=True, borderpad=0.4, 
#           labelspacing=0.2)

plt.tight_layout();

<IPython.core.display.Javascript object>

In [90]:
# The following is only relevant for Fig 2

fig=plt.figure()

## for Fig2 choose:
## for r=1 (Fig2C):
plt.plot(np.asarray(time2),np.asarray(results[1][5])[:,0],c = 'g',lw = 4, label=r'$d_1$ (food)')
plt.plot(np.asarray(time2),np.asarray(results[1][5])[:,1],c = 'b',lw = 4, ls='--', label=r'$d_2$ (water)')
## for r=2 (Fig2F):
#plt.plot(np.asarray(time2),np.asarray(results[3][5])[:,0],c = 'g',lw = 4, label=r'$d_1$ (food)')
#plt.plot(np.asarray(time2),np.asarray(results[3][5])[:,1],c = 'b',lw = 4, ls='--', label=r'$d_2$ (water)')

#plt.plot(np.asarray(time2),np.asarray(Deficit)[:,0],c = 'b',lw = 3.5)
#plt.plot(np.asarray(time2),np.asarray(Deficit)[:,1],c = 'r',lw = 3.5)

ax = plt.axes()
plt.xlabel(r'time t', fontsize=26)
plt.ylabel(r'deficits', fontsize=26)
for tick in ax.xaxis.get_major_ticks():
    tick.label.set_fontsize(22) 
for tick in ax.yaxis.get_major_ticks():
    tick.label.set_fontsize(22)
#plt.text(61, 4.62, r'max. terminal', fontsize=24)
#plt.text(82, 4.4, r'time', fontsize=24)
#ax.arrow(91, 4.3, 0, -2.22, head_width=2, head_length=0.07, fc='k', ec='k',lw=2.5)
#plt.tick_params(labelsize=14)

plt.plot((91, 91), (0.5, 7.5), c = 'k', ls = '-', lw = 3)

plt.xlim(0,100)

##for Fig2 choose:
plt.ylim(2.5,7.5)

ax.xaxis.set_major_locator(ticker.MultipleLocator(10))
ax.xaxis.set_minor_locator(ticker.MultipleLocator(5))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.1))

for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(2)
ax.tick_params('both', length=8, width=2, which='major')
ax.tick_params('both', length=6, width=1, which='minor')

#ax.annotate(r'max. terminal time', xy=(91, 2.5), xytext=(30, 4), fontsize=24,
#            arrowprops=dict(facecolor='k', edgecolor='k', shrink=0.01),
#            horizontalalignment='center', verticalalignment='bottom')

for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(2)

plt.legend(bbox_to_anchor=(0.52, 0.27), frameon=True, handlelength=4, borderpad=0.4, labelspacing=0.2, fontsize=20)
plt.tight_layout();

<IPython.core.display.Javascript object>