<a href="https://colab.research.google.com/github/aeyuan/moran_lesson/blob/master/moran.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ❗Important: Run the code at the bottom first❗

Otherwise, nothing else will work and you will just get a bunch of errors. You can run a code block by hilighting it and pressing `Shift+Enter`.

# 1: Moran process with two alleles from a single mutantion event

The code below simulates the moran process that we just discussed. Specifically, we have a population with 100 cells and there are two alleles, A and B. Initially, 99 of the cells are type A and 1 cell is type B (e.g. B is a newly arising mutation). Additionally, type A cells always have a growth rate of 1 and type B cells have a growth rate of $r$.

Play around with the parameter $r$.

For the other parameters, set:
* `n_trials = 500`
* `population_size = 100`
* `init_n_mutants = 1`
* `max_iter = 10**5` # (in python, $10^5$ is written `10**5`)

Try  setting $r=1.0, 1.03, 1.06, 1.09, \mathrm{and} \ 1.12$

and record the number of trials where the new allele B takes over the population in each of these cases. Later, we will combine all of our results together. To run the code, hilight the code block and press `Shift+Enter`.

In [None]:
n_trials = 500
population_size = 100
initial_n_mutants = 1
r = 1.03
max_iter = 10 ** 5

plot_selection(n_trials, population_size, initial_n_mutants, r, max_iter, seed)

## Discussion

* What do you notice about the relationship between $r$ and the fraction of mutations that take over the culture? Do most of the beneficial mutations in this simulation end up taking over the population? Discuss with your partners.

* The simulation runs until the mutant allele either takes over (i.e. achieves a frequency of 100%) or goes extinct. Once either of these two occurs, the simulation stops. Compare the stopping times of mutations that end up taking over versus mutations that go extinct. What do you see?

* If time allows: Try increasing $r$ above 1.12. How high does $r$ need to be in order to have the mutation take over half of the time?

# 2: Moran process with constant mutations

In this simulation, we will follow a population with two alleles: A, and B. The population will start out with 100% type A, but cells will at some frequecy mutate to type B. This is different from part (1) because in part (1), we initially had a frequency of 99% A and 1% B, but no mutations occurred during the simulation.

At each time step, each cell has a probability set by the parameter `mu` of suffering a mutation. Type B mutants grow 10% faster than nonmutant cells because `r=1.1`.

Run the code below. This section takes a couple minutes to run. Record the fraction of replicates in which the mutation roughly takes over (>80% frequency by the end).

In [None]:
n_replicates = 10      # The number of independent replicates
n_genes = 1            # The number of genes undergoing mutation
mu = 10**-6            # The mutation rate (probability of muation per cell per allele per timestep)
pop_size = 500         # The population size
r = 1.1                # The fitness of mutants (nonmutants have fitness=1)
n_steps = 150000       # The number of timesteps to run the simulation

color=None
plot_mutation_selection(seed, color, n_replicates, n_genes, mu, pop_size, r, n_steps)

## Discussion

How is this process different from the process in part (1)?

Write the fraction of times you saw the mutation take over on the board.

# 3: Moran process with mutations at multiple genes

This cell has the same settings as above, except now there are 5 genes, and we're only running a single replicate. That means that the different colors will represent mutants of different genes.

Set the random number seed to 1, so that we can all see the same dynamics. Later, you can play around with different random number seeds and different parameter values to see what happens.

Run the cell.


In [None]:
seed = 1               # The random number seed
n_replicates = 1       # The number of independent replicates
n_genes = 5            # The number of genes undergoing mutation
mu = 10**-6            # The mutation rate (probability of muation per cell per allele per timestep)
pop_size = 500         # The population size
r = 1.1                # The fitness of mutants (nonmutants have fitness=1)
n_steps = 150000       # The number of timesteps to run the simulation

color = None
plot_mutation_selection(seed, color, n_replicates, n_genes, mu, pop_size, r, n_steps)

## Discussion

* If you used `seed=1` above, you should see mutants at the blue gene rise to around 50% in frequency, and then fall back down to near-extinction. This behavior (almost) never happens when we only have one gene at a time (as in part 2), but in this case it is not unusual\*. What might be happening here\*\* ?

* If you used `seed=1` above, you should see the mutant frequencies at the blue and orange genes tracing each other. Why do you think this is?


```


```

\* I tried setting seed to 1 through 10 and saw it with seed = 1, 3, 4, 6, 7.

\** Hint: If you run the simulation with seed=7, then between 40000 and 70000 time steps the green and purple alleles appear to mirror each other. What does this suggest?


# Important: Run this cell first.

Otherwise nothing else in this notebook will work.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import time

def moran(N, num_A_init, r=1, max_iter=np.inf):
    """
    Args:
        N (int): population size (i.e. total number of individuals)
        num_A_init (int): initial number of individuals of type "A"
        r (numeric): fitness of individuals of type "A"
            (fitness of type "B" individuals is assumed to be 1)
        max_iter (numeric): maximum number of iterations
    """
    num_A = [num_A_init]
    n_iter = 0
    while (num_A[-1] > 0) and (num_A[-1] < N) and n_iter < max_iter:
        n_iter += 1
        i = num_A[-1]
        prob_inc_A = (r * i) * (N - i) / ((r * i + N - i) * (N))
        prob_dec_A = (N - i) * (i) / ((r * i + N - i) * (N))
        weights = [prob_inc_A, prob_dec_A, 1-(prob_inc_A + prob_dec_A)]
        decision = np.random.choice(['inc_A','dec_A','same'], size=1, p=weights)
        if decision == 'inc_A':
            num_A.append(i+1)
        if decision == 'dec_A':
            num_A.append(i-1)
        if decision == 'same':
            num_A.append(i)
    return num_A

def plot_selection(n_trials, population_size, init_num_A, r, max_iter, seed):
    np.random.seed(seed)
    trajectories = []
    total_time = []
    A_fixes = []
    for i in range(n_trials):
        traj = moran(population_size, init_num_A, r, max_iter)
        trajectories.append(traj)
        total_time.append(len(traj))
        A_fixes.append(traj[-1] == population_size)
    total_time = np.array(total_time)
    A_fixes = np.array(A_fixes)

    plt.figure(figsize=[10,6])
    for t_idx, traj in enumerate(trajectories):
        plt.semilogx(np.array(traj)/population_size)
        if traj[-1] == population_size:
            plt.scatter([len(traj)-1], [1.05], s=50, marker="|",
                        color='blue')
        if traj[-1] == 0:
            plt.scatter([len(traj)-1], [-0.05], s=30, marker="|",
                        color='red')
    plt.xlabel("time steps", fontsize=16)
    plt.xticks(size=14)
    plt.ylabel("mutant allele frequency", fontsize=16)
    plt.yticks(size=14)
    plt.title(f'Trajectories (mutation took over in {sum(A_fixes)} of {n_trials} trials)',
              fontsize=16)
    plt.show()

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        plt.figure(figsize=[10,4.5])
        sns.distplot(total_time[A_fixes], kde=False, color='blue')
        sns.distplot(total_time[np.logical_not(A_fixes)], kde=False, color='red')
        plt.legend(['mutation took over', 'mutation went extinct'], fontsize=14)
        plt.xlabel("stopping time", fontsize=16)
        plt.xticks(size=14)
        plt.ylabel("trials", fontsize=16)
        plt.yticks(size=14)
        plt.title('Stopping time distributions', fontsize=16)
        plt.show()

### moran process with clonal interference

def recursive_fxn(a, list_, n):
    if len(list_) == n: # base case
        a.append(tuple(list_))
    else: # recursive case
        recursive_fxn(a, list_ + [0], n)
        recursive_fxn(a, list_ + [1], n)
def get_bin_perms(n):
    # get a list of all binary vectors of length n
    a = []
    recursive_fxn(a, [], n)
    return a

def get_gtype_mut_to(gtype_mut_from):
    mut_allele = np.random.choice(len(gtype_mut_from))
    gtype_mut_to = list(gtype_mut_from)
    gtype_mut_to[mut_allele] = np.logical_not(gtype_mut_to[mut_allele]).astype('int')
    return tuple(gtype_mut_to)

def get_dim_slices(n_dims):
    # get a bunch of slicing objects for indexing "sides" of many-dimensional
    # tensors
    a = []
    start = tuple(slice(2) for _ in range(n_dims))
    for i in range(n_dims):
        a_ = list(start)
        a_[i] = 1
        a.append(a_)
    return a

def moran_clonal_int(n_loci, mu, pop_size, r, n_iter):
    """
    Args:
        n_loci (int): number of loci under mutation and selection
        mu (numeric): for each cell and each locus at each generation, that cell
                        will suffer a mutation at that locus with probability
                        mu
        pop_size (int): the number of cells in the population.
        r (numeric): relative fitness of mutant over wild-type; fitness effects
                        in this simulation are additive: for a two-locus
                        population where r=1.15, wild-type cells have fitness=1,
                        mutants at one of the two loci have fitness=1.15, and
                        cells with mutations at both loci have fitness=1.3.
        n_iter (int): number of birth/replacement iterations to run
    
    Returns:
        a matrix where rows are iterations and columns are number of mutant
        individuals at each locus
    """
    # initialization
    mu = mu * n_loci # mutation rate is per-locus
    state = np.zeros([2 for _ in range(n_loci)]) # current allele distribution
    state[tuple(0 for _ in range(n_loci))] = pop_size # initial allele distribution
    gtypes = get_bin_perms(n_loci) # list of possible genotypes.
    n_gtypes = len(gtypes) # number of possible genotypes. n_gtypes = 2 ** n_loci
    advantage = np.sum(np.array(gtypes), axis=1) # fitness advantage of muts over wt
    weights = (r - 1) * advantage + 1 # probabilty weights for cell division
    dim_slices = get_dim_slices(n_loci) # list of dimension slices for the ledger
    ledger = np.zeros([n_iter, n_loci]) # ledger to record allele frequencies
    for iteration in range(n_iter):
        # record in ledger
        for i in range(n_loci):
            ledger[iteration, i] = np.sum(state[dim_slices[i]])
        # a single moran step
        p_divide = np.ravel(state) * weights
        p_divide = p_divide / np.sum(p_divide)
        idx_divide = np.random.choice(n_gtypes, p=p_divide)
        p_replaced = np.ravel(state) / pop_size
        idx_replaced = np.random.choice(n_gtypes, p=p_replaced)
        state[gtypes[idx_divide]] += 1
        state[gtypes[idx_replaced]] -= 1
        # a single mutation step
        n_muts = np.random.binomial(pop_size, mu)
        for _ in range(n_muts):
            p_mut = np.ravel(state) / pop_size
            gtype_mut_from = gtypes[np.random.choice(n_gtypes, p=p_mut)]
            gtype_mut_to = get_gtype_mut_to(gtype_mut_from)
            state[gtype_mut_from] -= 1
            state[gtype_mut_to] += 1
    return ledger

# make plots for moran process with clonal interference

def plot_mutation_selection(seed, color, n_replicates, n_loci, mu, pop_size,
               r, n_steps):
  np.random.seed(seed)
  plt.figure(figsize=[10,6])
  for i in range(n_replicates):
    ledger = moran_clonal_int(n_loci, mu, pop_size, r, n_steps)
    plt.plot(ledger / pop_size, color=color)
    plt.ylim([0, 1])
  plt.xlabel("time steps", fontsize=16)
  plt.xticks(size=14)
  plt.ylabel("mutant allele frequency", fontsize=16)
  plt.yticks(size=14)
  title = f' {n_replicates} replicate of mutation-selection process with {n_loci} genes'
  if n_replicates > 1:
    title = title.replace('replicate', 'replicates')
  if n_loci > 1:
    title = title.replace('gene', 'genes')
  plt.title(title, fontsize=16)
  plt.show()

seed = int(str(time.time()).replace('.','')[-6:])