In [2]:
# optional
!pip3 install numpy matplotlib seaborn tqdm numba

Collecting numpy
  Downloading numpy-1.20.1-cp39-cp39-macosx_10_9_x86_64.whl (16.1 MB)
[K     |████████████████████████████████| 16.1 MB 8.7 MB/s 
[?25hCollecting matplotlib
  Downloading matplotlib-3.3.4-cp39-cp39-macosx_10_9_x86_64.whl (8.5 MB)
[K     |████████████████████████████████| 8.5 MB 28.0 MB/s 
[?25hCollecting seaborn
  Downloading seaborn-0.11.1-py3-none-any.whl (285 kB)
[K     |████████████████████████████████| 285 kB 24.9 MB/s 
[?25hCollecting tqdm
  Downloading tqdm-4.59.0-py2.py3-none-any.whl (74 kB)
[K     |████████████████████████████████| 74 kB 7.2 MB/s 
[?25hCollecting numba
  Downloading numba-0.53.0-cp39-cp39-macosx_10_14_x86_64.whl (2.2 MB)
[K     |████████████████████████████████| 2.2 MB 17.9 MB/s 
[?25hCollecting cycler>=0.10
  Using cached cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)
Collecting pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3
  Using cached pyparsing-2.4.7-py2.py3-none-any.whl (67 kB)
Collecting kiwisolver>=1.0.1
  Downloading kiwisolver-1.3.

In [1]:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from numba import jit

plt.style.use('seaborn-colorblind')

In [2]:
def get_recurrences(*params):
    # returns (x1, x2, x3, x4, x5, x6)
    # a tuple of functions for given parameter values

    if len(params) != 7:
        print('enter correct number of parameters')
    b3, b2, b1, b0, s, h, a = params

    @jit(nopython=True)
    def y(x4,x5,x6):
        return 1 - s*(x4 + x5 + x6)

    @jit(nopython=True)
    def z1(x4,x5,x6):
        return 1 + h*a*(1 - s)*(x4 + x5 + x6)/y(x4, x5, x6)

    @jit(nopython=True)
    def z2(x4,x5,x6):
        return 1 + a*(1 - s)*(x4 + x5 + x6)/y(x4, x5, x6)

    @jit(nopython=True)
    def x1(*vars):
        x1, x2, x3, x4, x5, x6 = vars

        m1 = ((x1*b3 + x4*b1)*(2*x1 + x2) + (1 - s)*(x1*b2 + x4*b0)*(2*x4 + x5))/(2*y(x4, x5, (1 - x1 - x2 - x3 - x4 - x5)))
        m2 = ((x2*b3 + x5*b1)*(2*x1 + x2) + (1 - s)*(1 + h*a)*(2*x4 + x5)*(x2*b2 + x5*b0))/(4*y(x4, x5, (1 - x1 - x2 - x3 - x4 - x5))*z1(x4, x5, (1 - x1 - x2 - x3 - x4 - x5)))
        return m1+m2

    @jit(nopython=True)
    def x2(*vars):
        x1, x2, x3, x4, x5, x6 = vars

        m1 = ((x1*b3 + x4*b1)*(x2 + 2*x3) + (1 - s)*(x1*b2 + x4*b0)*(x5 + 2*(1 - x1 - x2 - x3 - x4 - x5)))/(2*y(x4, x5, (1 - x1 - x2 - x3 - x4 - x5)))
        m2 = ((x2*b3 + x5*b1)*(x1 + x2 + x3) + (1 - s)*(1 + h*a)*(x2*b2 + x5*b0)*(x4 + x5 + (1 - x1 - x2 - x3 - x4 - x5)))/(2*y(x4, x5, (1 - x1 - x2 - x3 - x4 - x5))*z1(x4, x5, (1 - x1 - x2 - x3 - x4 - x5)))
        m3 = ((x3*b3 + (1 - x1 - x2 - x3 - x4 - x5)*b1)*(2*x1 + x2) + (1 - s)*(1 + a)*(x3*b2 + (1 - x1 - x2 - x3 - x4 - x5)*b0)*(2*x4 + x5))/(2*y(x4, x5, (1 - x1 - x2 - x3 - x4 - x5))*z2(x4, x5, 1 - x1 - x2 - x3 - x4 - x5))
        return m1 + m2 + m3

    @jit(nopython=True)
    def x3(*vars):
        x1, x2, x3, x4, x5, x6 = vars
        m1 = ((x2*b3 + x5*b1)*(x2 + 2*x3) + (1 - s)*(1 + h*a)*(x2*b2 + x5*b0)*(x5 + 2*(1 - x1 - x2 - x3 - x4 - x5)))/(4*y(x4, x5, (1 - x1 - x2 - x3 - x4 - x5))*z1(x4, x5, (1 - x1 - x2 - x3 - x4 - x5)))
        m2 = ((x3*b3 + (1 - x1 - x2 - x3 - x4 - x5)*b1)*(x2 + 2*x3) + (1 - s)*(1 + a)*(x3*b2 + (1 - x1 - x2 - x3 - x4 - x5)*b0)*(x5 + 2*(1 - x1 - x2 - x3 - x4 - x5)))/(2*y(x4, x5, (1 - x1 - x2 - x3 - x4 - x5))* z2(x4, x5, (1 - x1 - x2 - x3 - x4 - x5)))
        return m1 + m2

    @jit(nopython=True)
    def x4(*vars):
        x1, x2, x3, x4, x5, x6 = vars
        m1 = ((x1*(1-b3) + x4*(1-b1))*(2*x1 + x2) + (1 - s)*(x1*(1-b2) + x4*(1-b0))*(2*x4 + x5))/(2*y(x4, x5, (1 - x1 - x2 - x3 - x4 - x5)))
        m2 = ((x2*(1-b3) + x5*(1-b1))*(2*x1 + x2) + (1 - s)*(1 + h*a)*(2*x4 + x5)*(x2*(1-b2) + x5*(1-b0)))/(4*y(x4, x5, (1 - x1 - x2 - x3 - x4 - x5))*z1(x4, x5, (1 - x1 - x2 - x3 - x4 - x5)))
        return m1 + m2

    @jit(nopython=True)
    def x5(*vars):
        x1, x2, x3, x4, x5, x6 = vars
        m1 = ((x1*(1-b3) + x4*(1-b1))*(x2 + 2*x3) + (1 - s)*(x1*(1-b2) + x4*(1-b0))*(x5 + 2*(1 - x1 - x2 - x3 - x4 - x5)))/(2*y(x4, x5, (1 - x1 - x2 - x3 - x4 - x5)))
        m2 = ((x2*(1-b3) + x5*(1-b1))*(x1 + x2 + x3) + (1 - s)*(1 + h*a)*(x2*(1-b2) + x5*(1-b0))*(x4 + x5 + (1 - x1 - x2 - x3 - x4 - x5)))/(2*y(x4, x5, (1 - x1 - x2 - x3 - x4 - x5))*z1(x4, x5, (1 - x1 - x2 - x3 - x4 - x5)))
        m3 = ((x3*(1-b3) + (1 - x1 - x2 - x3 - x4 - x5)*(1-b1))*(2*x1 + x2) + (1 - s)*(1 + a)*(x3*(1-b2) + (1 - x1 - x2 - x3 - x4 - x5)*(1-b0))*(2*x4 + x5))/(2*y(x4, x5, (1 - x1 - x2 - x3 - x4 - x5))*z2(x4, x5, 1 - x1 - x2 - x3 - x4 - x5))
        return m1 + m2 + m3

    def x6(*vars):
        return 1 - x1(*vars) - x2(*vars) - x3(*vars) - x4(*vars) - x5(*vars)

    return x1, x2, x3, x4, x5, x6

In [3]:
@jit()
def nest_list(f, x0, n):
    ans =[x0]
    for i in range(n):
        ans.append(f(*ans[-1]))
    return ans

# frequency conversion assuming Hardy-Weinberg equiulibrium
@jit(nopython=True)
def get_phenotype_freqs_from_genotype_freqs(T2_freq, P2_freq):
    T, P = T2_freq, P2_freq
    return (1-T)*(1-P)*(1-P), (1-T)*2*P*(1-P), (1-T)*P*P, T*(1-P)*(1-P), T*2*P*(1-P), T*P*P

@jit(nopython=True)
def get_genotype_freqs_from_phenotype_freqs(*freqs):
    if len(freqs) == 1:
        x1e, x2e, x3e, x4e, x5e, x6e = freqs[0]
    else:
        x1e, x2e, x3e, x4e, x5e, x6e = freqs
    return x4e + x5e + x6e, x2e/2 + x3e + x5e/2 + x6e

# plot title generation
def get_param_str(parameters):
    return "$b_3 = " + str(parameters[0])+"$, $b_2 = " + str(parameters[1])+ \
            "$, $b_2 = " + str(parameters[2])+"$, $b_1 = " + str(parameters[3])+ \
            "$, $s = " + str(parameters[4])+"$, $h = " + str(parameters[5])+ \
            "$, $a = " + str(parameters[6])+"$."

# a function to plot trajectories for three initial frequencies across the range
def plot_dynamics(n, parameters):
    title = "Dynamics of the model with parameters" + "\n" + get_param_str(parameters)

    x1, x2, x3, x4, x5, x6 = get_recurrences(*parameters)
    x_inits = (get_phenotype_freqs_from_genotype_freqs(0.5,0.5), \
               get_phenotype_freqs_from_genotype_freqs(0.2,0.2), \
               get_phenotype_freqs_from_genotype_freqs(0.8,0.8))
    
    # one step of recurrence
    def step(*freqs): 
        return x1(*freqs), x2(*freqs), x3(*freqs), x4(*freqs), x5(*freqs), x6(*freqs)

    points = [[list(el) for el in np.array(nest_list(step, x_inits[i], n)).T] for i in range(len(x_inits))]

    def plotlines(ax, pts):
        ax.plot(list(range(0,n+1)), pts[0], label = "$P_1P_1T_1$");
        ax.plot(list(range(0,n+1)), pts[1], label = "$P_1P_2T_1$");
        ax.plot(list(range(0,n+1)), pts[2], label = "$P_2P_2T_1$");
        ax.plot(list(range(0,n+1)), pts[3], label = "$P_1P_1T_2$");
        ax.plot(list(range(0,n+1)), pts[4], label = "$P_1P_2T_2$");
        ax.plot(list(range(0,n+1)), pts[5], label = "$P_2P_2T_2$");


    fig, axs = plt.subplots(3, 1, sharex=True, figsize=(6,12))
    fig.suptitle(title)

    [plotlines(axs[i], points[i]) for i in range(len(x_inits))]

    axs[2].xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))
    axs[2].legend(bbox_to_anchor=(1.05, 1), loc='upper left')

    axs[0].set(xlim = (-1, n+ 0.2), ylabel = "Frequency", title = "Intermediate frequencies of $T_2$ and $P_2$")
    axs[1].set(xlim = (-1, n+ 0.2), ylabel = "Frequency", title = "Low frequencies of $T_2$ and $P_2$")
    axs[2].set(xlim = (-1, n+ 0.2), xlabel = "$n$",ylabel = "Frequency", title = "High frequencies of $T_2$ and $P_2$")

    return fig

We have defined the recurrences and some helper functions. Now we try several parameter values to study the behavior of the system.

# Unbiased transmission


In [4]:
b3, b2, b1, b0, s, h, a = 1.0, 0.5, 0.5, 0, 0.4, 1, 3
parameters = [b3, b2, b1, b0, s, h, a]
plot = plot_dynamics(64, parameters)
plot.savefig("plot1.pdf", bbox_inches='tight')
#plt.close(plot)

parameters[-2] = 0.5
plot = plot_dynamics(64, parameters)
plot.savefig("plot2.pdf", bbox_inches='tight')
#plt.close(plot)

parameters[-2] = 0.0
plot = plot_dynamics(64, parameters)
plot.savefig("plot3.pdf", bbox_inches='tight')
#plt.close(plot)


Try three initial conditions, as well as try to make the analogue of Laland plots. The idea of the plots below is simple:

1 - Discretize the parameter space $[0,1] \times [0,1]$, corresponding to frequencies of $T_2$ and $P_2$ using a grid. 

2 - For each point of the grid define a population with phenogenotype frequencies defined by the usual Hardy-Weinberg rules (cf. function `get_phenotype_freqs_from_genotype_freqs` above).

3 - Evolve the populations for a set number of generations, recording frequencies of $T_2$ and $P_2$ (denoted $t_2$ and $p_2$) -- each population therefore can be imagined as moving in the state space $[0,1] \times [0,1]$.

4 - Plot the locations of all populations simultaneously for some specific generations. It will look like a regular grid at generation $\tau = 0$, and all points will lie on equilibrium points as $\tau \to \infty$.

In [5]:
def get_grid_points(step):
    N = len(np.arange(0,1 + step,step))
    x, y = np.meshgrid(np.arange(0,1 + step,step), np.arange(0,1 + step,step))
    return [(x[i,j], y[i,j]) for i in range(N) for j in range(N)]

def get_equilibria(n, parameters, s):
    x1, x2, x3, x4, x5, x6 = get_recurrences(*parameters)
    def step(*freqs): 
            return x1(*freqs), x2(*freqs), x3(*freqs), x4(*freqs), x5(*freqs), x6(*freqs)

    def get_endpoint(g_init,n):
        return get_genotype_freqs_from_phenotype_freqs(*nest_list(step, get_phenotype_freqs_from_genotype_freqs(*g_init), n)[-1])

    grid = get_grid_points(s)
    x_end, y_end = [], []
    if n == 0:
        for pt in grid:
            x_end.append(pt[0])
            y_end.append(pt[1])
    else:
        for pt in grid:
            ept = get_endpoint(pt,n)
            if (0 <= ept[0] <= 1) and (0 <= ept[1] <= 1):
                x_end.append(ept[0])
                y_end.append(ept[1])
    return x_end, y_end
    
def plot_equilibria(n, parameter_list, s, multiple = False):
    N = len(parameter_list)
    fig, axs = plt.subplots(N, 1, figsize=(7,7*N))
    for i, ax in enumerate(axs):
        size = 30
        if multiple:
            n1 = int(n/8)
            ax.scatter(*get_equilibria(n1, parameter_list[i], s), alpha = 0.3, label="$t = "+ str(n1)+"$", s = size)
            n2 = int(n/4)
            ax.scatter(*get_equilibria(n2, parameter_list[i], s), alpha = 0.3, label="$t = "+ str(n2)+"$", s = size)
            
        ax.scatter(*get_equilibria(n, parameter_list[i], s), alpha = 0.8, label="$t = "+ str(n)+"$", s = size)
        ax.set(xlim = (-0.1,1.1), ylim = (-0.1,1.1), xlabel="$t_2$", ylabel="$p_2$", title = get_param_str(parameter_list[i]))
        ax.yaxis.label.set_size(18)
        ax.xaxis.label.set_size(18)
        ax.set_aspect(1)
        
    
    if multiple:
        axs[-1].legend(bbox_to_anchor=(1.05, 1), loc='upper left')

    return fig




In [6]:
b3, b2, b1, b0, s, h, a = 1.0, 0.5, 0.5, 0, 0.4, 1, 3
parameters = [[b3, b2, b1, b0, s, 1, a], [b3, b2, b1, b0, s, 0.5, a], [b3, b2, b1, b0, s, 0, a]]
s = 0.01
plot = plot_equilibria(200, parameters, s)
plot.savefig("plot4.pdf", bbox_inches='tight')
#plt.close(plot)

In [7]:
plot = plot_equilibria(200, parameters, s, multiple=True)
plot.savefig("plot4_multiple.pdf", bbox_inches='tight')
#plt.close(plot)

# Biased transmission

In [8]:
b3, b2, b1, b0, s, h, a = 1.0, 0.3, 0.3, 0, 0.4, 1, 3
parameters = [b3, b2, b1, b0, s, h, a]
plot = plot_dynamics(48, parameters)
plot.savefig("plot5.pdf", bbox_inches='tight')
#plt.close(plot)

parameters[-2] = 0.5
plot = plot_dynamics(48, parameters)
plot.savefig("plot6.pdf", bbox_inches='tight')
#plt.close(plot)

parameters[-2] = 0
plot = plot_dynamics(48, parameters)
plot.savefig("plot7.pdf", bbox_inches='tight')
#plt.close(plot)

In [9]:
b3, b2, b1, b0, s, h, a = 1.0, 0.3, 0.3, 0, 0.4, 1, 3
parameters = [[b3, b2, b1, b0, s, 1, a], [b3, b2, b1, b0, s, 0.5, a], [b3, b2, b1, b0, s, 0, a]]
s = 0.01
plot = plot_equilibria(20, parameters, s)
plot.savefig("plot8.pdf", bbox_inches='tight')
#plt.close(plot)
plot = plot_equilibria(20, parameters, s, multiple=True)
plot.savefig("plot8_multiple.pdf", bbox_inches='tight')
#plt.close(plot)

# Maternal Transmission

In [10]:
b3, b2, b1, b0, s, h, a = 0.8, 0.8, 0.2, 0.2, 0.4, 1, 3
parameters = [b3, b2, b1, b0, s, h, a]
plot = plot_dynamics(30, parameters)
plot.savefig("plot9.pdf", bbox_inches='tight')
#plt.close(plot)

parameters[-2] = 0.5
plot = plot_dynamics(30, parameters)
plot.savefig("plot10.pdf", bbox_inches='tight')
#plt.close(plot)

parameters[-2] = 0
plot = plot_dynamics(30, parameters)
plot.savefig("plot11.pdf", bbox_inches='tight')
#plt.close(plot)

In [11]:
parameters = [[b3, b2, b1, b0, s, 1, a], [b3, b2, b1, b0, s, 0.5, a], [b3, b2, b1, b0, s, 0, a]]
s = 0.01
plot = plot_equilibria(20, parameters, s)
plot.savefig("plot12.pdf", bbox_inches='tight')
#plt.close(plot)
plot = plot_equilibria(20, parameters, s, multiple=True)
plot.savefig("plot12_multiple.pdf", bbox_inches='tight')
#plt.close(plot)

# Laland Fixation table

In [12]:
def get_generations(dynamics, tresholds):
    N = len(tresholds)
    ans = []
    i = 0
    for k in range(N):
        if i >= len(dynamics):
            ans.append(">1000")
            break
        while dynamics[i] <= tresholds[k]:
            i += 1
            if i >= len(dynamics):
                break
        if i >= len(dynamics):
            ans.append(">1000")
        else:
            ans.append(i)
    return ans


In [13]:
bs = [0.5, 0.495, 0.475, 0.45, 0.4, 0.25]

s_as = [(0.4, 1.5), (0.4, 2), (0.4, 3), (0.2, 1.5), (0.2, 2), (0.2, 3), (0.1, 1.5), (0.1, 2), (0.1, 3), (0, 1.5), (0, 2), (0, 3), (-0.1, 1.5), (-0.1, 2), (-0.1, 3)]

freqs_init = (0.01, 0.1)
x_init = get_phenotype_freqs_from_genotype_freqs(*freqs_init)
tresholds = (0.01, 0.99, 0.999, 0.9999)

values = [[0 for j in range(len(bs))] for i in range(len(s_as))]

for i, (s, a) in tqdm(enumerate(s_as)):
    for j, b in enumerate(bs):
        parameters = [1, b, b, 0, s, 0.5, a]
        x1, x2, x3, x4, x5, x6 = get_recurrences(*parameters)
        # one step of recurrence
        def step(*freqs): 
            return x1(*freqs), x2(*freqs), x3(*freqs), x4(*freqs), x5(*freqs), x6(*freqs)
        dynamics = list(map(get_genotype_freqs_from_phenotype_freqs, nest_list(step, x_init, 1000)))
        if dynamics[-1][0] < tresholds[0]:
            values[i][j] = "lost"
        elif dynamics[-1][0] < tresholds[1]:
            values[i][j] = ">1000"
        else:
            values[i][j] = get_generations([t2[0] for t2 in dynamics], tresholds[1:])


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [14]:
def get_row_val(i, val):
    if i == 1:
        if val == "lost":
            return val
        elif val == ">1000":
            return "$>1000$"
        else:
            return val[0]
    elif i == 2:
        if val == "lost":
            return " "
        elif val == ">1000":
            return " "
        else:
            return val[1]
    elif i == 3:
        if val == "lost":
            return " "
        elif val == ">1000":
            return " "
        else:
            return val[2]

for i, row in enumerate(values):
    r1 = "$s = "+str(s_as[i][0])+"$ & "+" & ".join([str(get_row_val(1,val)) for val in row]) + "\\\\"
    r2 = "$a = "+str(s_as[i][1])+"$ & "+" & ".join([str(get_row_val(2,val)) for val in row]) + "\\\\"
    r3 = "        & "+" & ".join([str(get_row_val(3,val)) for val in row]) + "\\\\ [5pt]"
    print(r1)
    print(r2)
    print(r3)
    print("\n")

$s = 0.4$ & lost & lost & lost & lost & lost & 31\\
$a = 1.5$ &   &   &   &   &   & 36\\
        &   &   &   &   &   & 41\\ [5pt]


$s = 0.4$ & lost & lost & lost & lost & lost & 30\\
$a = 2$ &   &   &   &   &   & 35\\
        &   &   &   &   &   & 40\\ [5pt]


$s = 0.4$ & lost & lost & lost & lost & lost & 28\\
$a = 3$ &   &   &   &   &   & 33\\
        &   &   &   &   &   & 37\\ [5pt]


$s = 0.2$ & lost & lost & lost & 322 & 65 & 21\\
$a = 1.5$ &   &   &   & 471 & 81 & 25\\
        &   &   &   & 621 & 96 & 28\\ [5pt]


$s = 0.2$ & lost & lost & lost & 212 & 59 & 20\\
$a = 2$ &   &   &   & 298 & 74 & 24\\
        &   &   &   & 385 & 89 & 28\\ [5pt]


$s = 0.2$ & $>1000$ & lost & lost & 135 & 52 & 19\\
$a = 3$ &   &   &   & 187 & 66 & 23\\
        &   &   &   & 238 & 79 & 26\\ [5pt]


$s = 0.1$ & $>1000$ & $>1000$ & 177 & 90 & 46 & 19\\
$a = 1.5$ &   &   & 231 & 114 & 56 & 22\\
        &   &   & 286 & 138 & 67 & 25\\ [5pt]


$s = 0.1$ & 399 & 313 & 138 & 79 & 43 & 18\\
$a = 2$ & 548 & 