# STEM Cell Population Wright-Fisher Algorithm with Constant Environment

In this example notebook we present an example of how to forward simulate a population of STEM cells in the context of a fixed size tumor. We assume a constant number of cells in the population at all times -- only the counts of the different species of cells change in time:

- wild type (WT)
- with cell intrinsic mutations that increase fitness (A)
- with mutations that give evolutionary advantage based on environmental factors such as level of cytokines (B).

Fot the purposes of this notebook we consider the environmental conditions such that the B cells always have a selective advantage over their wilde type counterpart.

In [1]:
# Load necessary libraries
import os
import numpy as np
import pandas as pd
from scipy.stats import gamma
import cmmlinflam as ci
import matplotlib
import plotly.graph_objects as go
from matplotlib import pyplot as plt

## Define STEM cell population

In [2]:
# Set initial population state WT - A - B
initial_population = [99, 1, 0]

# Set baseline growth rate
alpha = 0.5

# Set selective advantages for mutated cells
s = 0.1
r = 0.01

# Set mutation rates
mu_A = 0.002
mu_B = 0.003

# Coalesce into paramater vector
parameters = initial_population
parameters.extend([alpha, s, r, mu_A, mu_B])

In [3]:
# Instantiate algorithm
algorithm = ci.StemWF()

# Select start and end times
start_time = 1
end_time = 100

times = list(range(start_time, end_time+1))

output_algorithm = algorithm.simulate_fixed_times(parameters, start_time, end_time)

## Plot output of Wright-Fisher for the different species of cells

In [4]:
from plotly.subplots import make_subplots

colours = ['blue', 'red', 'green', 'purple', 'orange', 'black', 'gray', 'pink']
species = ['WT', 'A', 'B']

In [5]:
# Trace names - represent the type of cells for the simulation
trace_name = ['{} cell counts'.format(s) for s in species]

# Names of panels
panels = ['{} only'.format(s) for s in species]
panels.append('Combined')

fig = go.Figure()
fig = make_subplots(rows=int(np.ceil(len(panels)/2)), cols=2, subplot_titles=tuple('{}'.format(p) for p in panels))

# Add traces to the separate counts panels
for s, spec in enumerate(species):
    fig.add_trace(
        go.Scatter(
            y=output_algorithm[:, s],
            x=times,
            mode='lines',
            name=trace_name[s],
            line_color=colours[s]
        ),
        row= int(np.floor(s / 2)) + 1,
        col= s % 2 + 1
    )

# Add traces to last total panel
for s, spec in enumerate(species):
    fig.add_trace(
        go.Scatter(
            y=output_algorithm[:, s],
            x=times,
            mode='lines',
            name=trace_name[s],
            line_color=colours[s],
            showlegend=False
        ),
        row=int(np.ceil(len(panels)/2)),
        col=2
    )

for p, _ in enumerate(panels):
    fig.add_hline(
        y=sum(initial_population),
        line_dash='dot',
        annotation_text='Total population', fillcolor='black',
        annotation_position='top right',
        row= int(np.floor(p / 2)) + 1,
        col= p % 2 + 1)
    
    fig.update_yaxes(ticks='outside', tickcolor='black', ticklen=7.5, title_text='Percentage (%) of population', row=int(np.floor(p / 2)) + 1, col=p % 2 + 1)
    fig.update_xaxes(ticks='outside', tickcolor='black', ticklen=7.5, title_text='Number of Generations', row=int(np.floor(p / 2)) + 1, col=p % 2 + 1)


# Add axis labels
fig.update_layout(
    title='Counts of different cell types over time: IC = {}, α = {}, s = {}, r = {}, μA = {}, μB = {}'.format(parameters[0:3], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7]),
    width=1100, 
    height=600,
    plot_bgcolor='white',
    xaxis=dict(
        linecolor='black'
        ),
    yaxis=dict(
        linecolor='black',
        range = [0, sum(initial_population)+10],
        tickvals=np.arange(0, sum(initial_population)+1, 25).tolist(),
        ticktext=['0', '25', '50', '75', '100']),
    xaxis2=dict(
        linecolor='black'
        ),
    yaxis2=dict(
        linecolor='black',
        range = [0, sum(initial_population)+10],
        tickvals=np.arange(0, sum(initial_population)+1, 25).tolist(),
        ticktext=['0', '25', '50', '75', '100']),
    xaxis3=dict(
        linecolor='black'
        ),
    yaxis3=dict(
        linecolor='black',
        range = [0, sum(initial_population)+10],
        tickvals=np.arange(0, sum(initial_population)+1, 25).tolist(),
        ticktext=['0', '25', '50', '75', '100']),
    xaxis4=dict(
        linecolor='black'
        ),
    yaxis4=dict(
        linecolor='black',
        range = [0, sum(initial_population)+10],
        tickvals=np.arange(0, sum(initial_population)+1, 25).tolist(),
        ticktext=['0', '25', '50', '75', '100']),
    #legend=dict(
    #    orientation="h",
    #    yanchor="bottom",
    #    y=1.02,
    #    xanchor="right",
    #    x=1
    #)
    )

fig.write_image('images/Stem-counts-wf.pdf')
fig.show()

## Compute mean time to illness and mean environment state at debut

In [6]:
# Select stopping criterion
criterion = [[0, None, None], ['more', None, None]]

# Select number of simulations
num_simulations =1000

computation_time = np.empty(num_simulations, dtype=np.int)
final_state = np.empty((num_simulations, 3), dtype=np.int)

for s in range(num_simulations):
    computation_time[s], final_state[s, :] = algorithm.simulate_criterion(parameters, criterion)

mean_computation_time = np.mean(computation_time)
mean_final_state = np.mean(final_state, axis=0)

print('Average time to illness: ', mean_computation_time)
print('Average system state right before illness: ', mean_final_state)

Average time to illness:  53.327
Average system state right before illness:  [ 0.    96.741  3.259]


In [7]:
# Plot transition probabilities
sep_algo = ci.StemWF()

sep_algo.N = int(np.sum(np.asarray(initial_population)))

sep_algo.alpha_A = alpha + s
sep_algo.alpha_B = alpha + r
sep_algo.alpha_WT = alpha

sep_algo.mu_A = mu_A
sep_algo.mu_B = mu_B

# Assuming no Bs in the population
trans_prob = np.empty((sep_algo.N+1, 6))

for i in range(sep_algo.N+1):
    trans_prob[i, 0] = sep_algo._prob_A_to_B(i, sep_algo.N - i, 0)
    trans_prob[i, 1] = sep_algo._prob_A_to_WT(i, sep_algo.N - i, 0)
    trans_prob[i, 2] = sep_algo._prob_B_to_A(i, sep_algo.N - i, 0)
    trans_prob[i, 3] = sep_algo._prob_B_to_WT(i, sep_algo.N - i, 0)
    trans_prob[i, 4] = sep_algo._prob_WT_to_A(i, sep_algo.N - i, 0)
    trans_prob[i, 5] = sep_algo._prob_WT_to_B(i, sep_algo.N - i, 0)

In [8]:
# Trace names - represent the transition probabilities used for the simulation
trace_name = ['A->B', 'A->WT', 'B->A', 'B->WT', 'WT->A', 'WT->B']

fig = go.Figure()

# Add traces of the transition probabilities
for c in range(trans_prob.shape[1]):
    fig.add_trace(
        go.Scatter(
            y=trans_prob[:, c],
            x=list(range(sep_algo.N+1)),
            mode='lines',
            name=trace_name[c],
            line_color=colours[c]
        )
    )

fig.update_layout(
    title='Transition probabilities for edge case with no B cells',
    width=1000, 
    height=600,
    plot_bgcolor='white',
    xaxis=dict(linecolor='black'),
    yaxis=dict(linecolor='black'),
    )