# Sickle-Cell / HbS Gene Monte-Carlo Simulation
## Interactive Explorer

This notebook provides an interactive interface for the Monte-Carlo simulations described in:

1. **Paper 1**: Habibzadeh F. *On the feasibility of malaria hypothesis.* Scientific Reports 2024; 14:5800.
2. **Paper 2**: Habibzadeh F. *The effect on the equilibrium sickle cell allele frequency of the probable protection conferred by malaria and sickle cell gene against other infectious diseases.* Scientific Reports 2024; 14:15399.

---

### Background

The **malaria hypothesis** proposes that hemoglobin S (HbS) gene carriers (sickle cell trait, AS genotype) have a survival advantage against fatal malaria. This creates a **balanced polymorphism**: the gene frequency increases until the benefit of heterozygote protection against malaria is balanced by the loss of homozygotes (SS) from sickle cell disease complications.

The simulation models a small Neolithic tribe (25 couples) transitioning from hunter-gatherer to agricultural lifestyle near malaria-endemic water. Key features include:

- **Discrete, individual-based** stochastic model (not continuous approximation)
- **Realistic demography**: variable family sizes, hunter→farmer transition, logistic population growth
- **Generational overlap**: some parents may mate with the next generation
- **Genetic drift**: natural fluctuations in small populations
- **Paper 2 extension**: additional mortality from other diseases, with protection from both malaria and HbS

### Fitness equations

For the malaria-only model (Paper 1):

$$W_{AA} = 1 - M_m, \quad W_{AS} = 1 - \frac{M_m}{P_{hetero}}, \quad W_{SS} = 1 - \left[M_{SS} + (1-M_{SS})\frac{M_m}{P_{homo}}\right]$$

The equilibrium gene frequency:

$$p_{eq} = \frac{W_{AA} - W_{AS}}{W_{AA} - 2W_{AS} + W_{SS}}$$

Use the sliders below to explore how different parameters affect the gene frequency dynamics.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import ipywidgets as widgets
from concurrent.futures import ProcessPoolExecutor
import time
import warnings
warnings.filterwarnings('ignore')

## Core Simulation Engine

The cell below contains the simulation functions. It uses the same logic as the C programs but leverages Python lists and a Park-Miller PRNG for reproducibility.

In [None]:
# ============================================================
#  Constants and genotype codes
# ============================================================
HLTHY = 0
MINOR = 1
MAJOR = 2
DEAD  = 4
MALARIA_FLAG = 8
NO_CHILD = -1
HUNTER = 0
FARMER = 1

CHILD_DIST_HUNTER = [(2, 10), (3, 25), (4, 75), (5, 90), (6, 100)]
CHILD_DIST_FARMER = [(3, 10), (4, 25), (5, 75), (6, 90), (7, 100)]

# ============================================================
#  Park-Miller PRNG
# ============================================================
class ParkMillerRNG:
    def __init__(self, seed):
        self.state = int(seed) & 0x7FFFFFFF
        if self.state == 0: self.state = 1
    def next(self):
        lo = 16807 * (self.state & 0xFFFF)
        hi = 16807 * (self.state >> 16)
        lo += (hi & 0x7FFF) << 16
        lo += hi >> 15
        if lo > 0x7FFFFFFF: lo -= 0x7FFFFFFF
        self.state = lo
        return self.state
    def randint(self, n):
        return self.next() % n

def growth(generation, couples, couples_max, growth_start, r):
    if generation < growth_start: generation = growth_start
    exp_val = np.exp(r * (generation - growth_start))
    coups = couples_max * couples * exp_val / (couples_max + couples * (exp_val - 1))
    return int(coups + 0.5) * 2

def shuffle_array(arr, size, rng):
    for _ in range(7):
        for i in range(size):
            j = rng.randint(size)
            arr[i], arr[j] = arr[j], arr[i]

def manage_population(parent, N, rng):
    i, j = 0, N - 1
    while i < j:
        while i < j and (parent[i] & DEAD) != DEAD: i += 1
        while j > i and (parent[j] & DEAD) == DEAD: j -= 1
        if i < j and (parent[j] & DEAD) != DEAD:
            parent[i], parent[j] = parent[j], parent[i]
    if (parent[i] & DEAD) == DEAD: i -= 1
    if i % 2 == 0:
        i += 1
        parent[i] = parent[rng.randint(i)]
    return i + 1

def manage_offspring(offspring, childsize):
    i, j = 0, childsize - 1
    while i < j:
        while i < j and offspring[i] != NO_CHILD: i += 1
        while j > i and offspring[j] == NO_CHILD: j -= 1
        if i < j and offspring[j] != NO_CHILD:
            offspring[i], offspring[j] = offspring[j], offspring[i]
    if offspring[i] == NO_CHILD: i -= 1
    return i + 1

def get_childnum(rng, lifestyle):
    prob = rng.randint(100)
    dist = CHILD_DIST_FARMER if lifestyle == FARMER else CHILD_DIST_HUNTER
    for num, cum_prob in dist:
        if prob < cum_prob: return num
    return dist[-1][0]

In [None]:
def run_single(seed, params):
    """Run one repetition with the given parameters dict."""
    rng = ParkMillerRNG(seed)
    mode = params['mode']
    PREC = 1000
    MAX_GEN = params['max_gen']
    _CHILD = 7
    _MORE_CHILD = params.get('more_child', 0)
    CHILD = _CHILD + _MORE_CHILD
    COUPLES_MAX = params['couples_max']
    childsize = CHILD * COUPLES_MAX

    M_m = params['M_malaria']
    p_minor = params['P_minor']
    p_major = params['P_major']
    m_ss = params['M_SS']
    more_fertile = params['more_fertile']
    overlap = params['overlap']
    overlap_start = params['overlap_start']
    growth_start = params['growth_start']
    life_style_gen = params['life_style_gen']
    couples = params['couples']
    r_growth = params['r_growth']

    # Paper 2 extras
    if mode == 'paper2':
        PR_m_val = params['PR_m']
        M_cond = M_m / PR_m_val
        M_other = params['M_other']
        P_m_other = params['P_m_other']
        P_minor_other = params['P_minor_other']
        P_major_other = params['P_major_other']

    N = growth(0, couples, COUPLES_MAX, growth_start, r_growth)
    parent = [HLTHY] * N
    parent[rng.randint(N)] = MINOR  # mutation

    gene_aborted = 0
    prevparent = list(parent)
    prevpopsize = N
    generations_data = np.zeros((MAX_GEN + 1, 4))

    for gen in range(MAX_GEN + 1):
        if mode == 'paper2':
            n_minor = sum(1 for x in parent[:N] if (x & 3) == MINOR)
            n_major = sum(1 for x in parent[:N] if (x & 3) == MAJOR)
        else:
            n_minor = sum(1 for x in parent[:N] if x == MINOR)
            n_major = sum(1 for x in parent[:N] if x == MAJOR)

        f_gene = (n_minor + 2*n_major) * 100.0 / (2*N)
        f_minor = n_minor * 100.0 / N
        f_major = n_major * 100.0 / N

        if mode == 'paper2':
            for i in range(N):
                if rng.randint(PREC) < PR_m_val * PREC:
                    parent[i] |= MALARIA_FLAG

        n_dead_hlthy = n_dead_minor = n_dead_major = 0
        ii = 0
        for i in range(N):
            if mode == 'paper1':
                if parent[i] == HLTHY and rng.randint(PREC) < M_m * PREC:
                    parent[i] |= DEAD; n_dead_hlthy += 1
                if parent[i] == MINOR:
                    if rng.randint(PREC) < (M_m/p_minor)*PREC:
                        parent[i] |= DEAD; n_dead_minor += 1
                if parent[i] == MAJOR:
                    if rng.randint(PREC) < m_ss*PREC or rng.randint(PREC) < (M_m/p_major)*PREC:
                        parent[i] |= DEAD; n_dead_major += 1
            else:
                geno = parent[i] & 3
                has_mal = (parent[i] & MALARIA_FLAG) != 0
                if has_mal:
                    if geno == HLTHY:
                        if rng.randint(PREC) < M_cond*PREC or rng.randint(PREC) < (M_other/P_m_other)*PREC:
                            parent[i] |= DEAD; n_dead_hlthy += 1
                    elif geno == MINOR:
                        if rng.randint(PREC) < (M_cond/p_minor)*PREC or rng.randint(PREC) < (M_other/(P_minor_other*P_m_other))*PREC:
                            parent[i] |= DEAD; n_dead_minor += 1
                    elif geno == MAJOR:
                        if rng.randint(PREC) < (M_cond/p_major)*PREC or rng.randint(PREC) < m_ss*PREC or rng.randint(PREC) < (M_other/(P_major_other*P_m_other))*PREC:
                            parent[i] |= DEAD; n_dead_major += 1
                else:
                    if geno == HLTHY:
                        if rng.randint(PREC) < M_other*PREC:
                            parent[i] |= DEAD; n_dead_hlthy += 1
                    elif geno == MINOR:
                        if rng.randint(PREC) < (M_other/P_minor_other)*PREC:
                            parent[i] |= DEAD; n_dead_minor += 1
                    elif geno == MAJOR:
                        if rng.randint(PREC) < m_ss*PREC or rng.randint(PREC) < (M_other/P_major_other)*PREC:
                            parent[i] |= DEAD; n_dead_major += 1

            if (gene_aborted == 0 and gen >= overlap_start
                    and rng.randint(PREC) < overlap*PREC and ii < prevpopsize):
                parent[i] = prevparent[ii]; ii += 1

        n_dead = sum(1 for x in parent[:N] if (x & DEAD) == DEAD)
        f_dead = n_dead * 100.0 / N
        generations_data[gen] = [f_gene, f_minor, f_major, f_dead]

        n_minor_alive = sum(1 for x in parent[:N] if (x & ~MALARIA_FLAG) == MINOR)
        n_major_alive = sum(1 for x in parent[:N] if (x & ~MALARIA_FLAG) == MAJOR)
        if gene_aborted == 0 and n_minor_alive == 0 and n_major_alive == 0:
            gene_aborted = gen

        popsize = N - n_dead
        if popsize <= 1:
            if gene_aborted == 0: gene_aborted = gen
            if N - n_dead > 0:
                nm = int((n_minor - n_dead_minor) * N / (N - n_dead) + 0.5)
                nj = int((n_major - n_dead_major) * N / (N - n_dead) + 0.5)
            else: nm = nj = 0
            parent = [HLTHY]*N
            for idx in range(min(nj, N)): parent[idx] = MAJOR
            for idx in range(nj, min(nj+nm, N)): parent[idx] = MINOR
            shuffle_array(parent, N, rng)
            popsize = N
        else:
            popsize = manage_population(parent, N, rng)

        shuffle_array(parent, popsize, rng)
        offspring = [NO_CHILD] * childsize
        lifestyle = FARMER if gen >= life_style_gen else HUNTER

        for i in range(0, popsize-1, 2):
            j_idx = i + 1
            childnum = get_childnum(rng, lifestyle)
            mc = _MORE_CHILD
            pi = parent[i] & 3 if mode == 'paper2' else parent[i]
            pj = parent[j_idx] & 3 if mode == 'paper2' else parent[j_idx]
            if (pi == MINOR or pj == MINOR) and rng.randint(PREC) < more_fertile*PREC:
                mc = 0
            chbase = i * CHILD // 2
            if pi == HLTHY:
                if pj == HLTHY:
                    for k in range(_MORE_CHILD, childnum+_MORE_CHILD): offspring[chbase+k] = HLTHY
                elif pj == MINOR:
                    for k in range(mc, childnum+_MORE_CHILD): offspring[chbase+k] = rng.randint(2)
                elif pj == MAJOR:
                    for k in range(_MORE_CHILD, childnum+_MORE_CHILD): offspring[chbase+k] = MINOR
            elif pi == MINOR:
                if pj == HLTHY:
                    for k in range(mc, childnum+_MORE_CHILD): offspring[chbase+k] = rng.randint(2)
                elif pj == MINOR:
                    for k in range(mc, childnum+_MORE_CHILD):
                        rv = rng.randint(4)
                        offspring[chbase+k] = HLTHY if rv==0 else (MINOR if rv<=2 else MAJOR)
                if pj == MAJOR:
                    for k in range(mc, childnum+_MORE_CHILD): offspring[chbase+k] = 1+rng.randint(2)
            elif pi == MAJOR:
                if pj == HLTHY:
                    for k in range(_MORE_CHILD, childnum+_MORE_CHILD): offspring[chbase+k] = MINOR
                elif pj == MINOR:
                    for k in range(mc, childnum+_MORE_CHILD): offspring[chbase+k] = 1+rng.randint(2)
                elif pj == MAJOR:
                    for k in range(_MORE_CHILD, childnum+_MORE_CHILD): offspring[chbase+k] = MAJOR

        prevparent = list(parent[:popsize])
        prevpopsize = popsize
        shuffle_array(prevparent, prevpopsize, rng)
        total_offspring = manage_offspring(offspring, childsize)
        shuffle_array(offspring, total_offspring, rng)
        shuffle_array(offspring, total_offspring, rng)
        N = growth(gen+1, couples, COUPLES_MAX, growth_start, r_growth)
        parent = [HLTHY] * max(N, len(parent))
        for i in range(N):
            if i < total_offspring and offspring[i] != NO_CHILD:
                parent[i] = offspring[i]
            else:
                parent[i] = offspring[rng.randint(total_offspring)]

    return generations_data, gene_aborted

In [None]:
def run_batch(params, repeats=200, seed_base=42):
    """Run batch of simulations and return statistics."""
    master_rng = np.random.RandomState(seed_base)
    seeds = master_rng.randint(1, 2**31-1, size=repeats)
    
    all_results = []
    all_aborted = []
    for s in seeds:
        res, ab = run_single(s, params)
        all_results.append(res)
        all_aborted.append(ab)
    
    n_gen = params['max_gen'] + 1
    f_abort = sum(1 for a in all_aborted if a != 0)
    n = repeats - f_abort
    if n == 0:
        return None
    
    Sx = np.zeros((n_gen, 4))
    S2x = np.zeros((n_gen, 4))
    for res, ab in zip(all_results, all_aborted):
        if ab == 0:
            Sx += res
            S2x += res**2
    
    means = Sx / n
    variances = np.maximum((S2x - Sx**2/n)/(n-1) + 1e-12, 0)
    sds = np.sqrt(variances)
    
    p_abort = f_abort / repeats
    return {'means': means, 'sds': sds, 'f_abort': f_abort,
            'n_valid': n, 'p_abort': p_abort, 'repeats': repeats}

## Interactive Explorer

Adjust the sliders below and click **Run Simulation** to see how parameter changes affect the equilibrium gene frequency and temporal dynamics.

> **Note**: For speed in the notebook, the default is 200 repeats. Increase for smoother curves (at the cost of longer computation time).

In [None]:
# ============================================================
#  Interactive Widget UI
# ============================================================

style = {'description_width': '180px'}
layout = widgets.Layout(width='500px')

mode_w = widgets.Dropdown(options=['paper1', 'paper2'], value='paper1',
                          description='Mode:', style=style, layout=layout)
repeats_w = widgets.IntSlider(value=200, min=50, max=2000, step=50,
                              description='MC Repeats:', style=style, layout=layout)
couples_w = widgets.IntSlider(value=25, min=5, max=500, step=5,
                              description='Initial Couples:', style=style, layout=layout)
max_gen_w = widgets.IntSlider(value=100, min=20, max=300, step=10,
                              description='Generations:', style=style, layout=layout)
M_malaria_w = widgets.FloatSlider(value=0.15, min=0.01, max=0.50, step=0.01,
                                  description='Malaria Mortality (M):', style=style, layout=layout)
P_minor_w = widgets.FloatSlider(value=10.0, min=1.0, max=30.0, step=0.5,
                                description='Protection AS (P_minor):', style=style, layout=layout)
P_major_w = widgets.FloatSlider(value=10.0, min=1.0, max=30.0, step=0.5,
                                description='Protection SS (P_major):', style=style, layout=layout)
M_SS_w = widgets.FloatSlider(value=0.85, min=0.10, max=1.0, step=0.05,
                             description='SS Mortality (M_SS):', style=style, layout=layout)
overlap_w = widgets.FloatSlider(value=0.05, min=0.0, max=0.20, step=0.01,
                                description='Overlap Fraction:', style=style, layout=layout)
more_fertile_w = widgets.FloatSlider(value=0.10, min=0.0, max=0.50, step=0.05,
                                     description='More Fertile (AS):', style=style, layout=layout)

# Paper 2 specific
PR_m_w = widgets.FloatSlider(value=0.4, min=0.05, max=0.90, step=0.05,
                             description='Malaria Prevalence:', style=style, layout=layout)
M_other_w = widgets.FloatSlider(value=0.25, min=0.0, max=0.50, step=0.05,
                                description='Other Mortality (M_O):', style=style, layout=layout)
P_m_other_w = widgets.FloatSlider(value=1.5, min=1.0, max=5.0, step=0.1,
                                  description='Malaria prot. other:', style=style, layout=layout)
P_minor_other_w = widgets.FloatSlider(value=3.0, min=1.0, max=10.0, step=0.5,
                                      description='AS prot. other:', style=style, layout=layout)
P_major_other_w = widgets.FloatSlider(value=3.0, min=1.0, max=10.0, step=0.5,
                                      description='SS prot. other:', style=style, layout=layout)

paper2_box = widgets.VBox([PR_m_w, M_other_w, P_m_other_w, P_minor_other_w, P_major_other_w],
                          layout=widgets.Layout(border='1px solid #ccc', padding='10px',
                                               margin='5px 0', display='none'))

def toggle_paper2(change):
    paper2_box.layout.display = '' if change['new'] == 'paper2' else 'none'
    if change['new'] == 'paper2':
        more_fertile_w.value = 0.0
    else:
        more_fertile_w.value = 0.10

mode_w.observe(toggle_paper2, names='value')

run_btn = widgets.Button(description='Run Simulation', button_style='success',
                         layout=widgets.Layout(width='200px', height='40px'))
output_area = widgets.Output()

def on_run(b):
    with output_area:
        clear_output(wait=True)
        params = {
            'mode': mode_w.value,
            'max_gen': max_gen_w.value,
            'couples': couples_w.value,
            'couples_max': 1000,
            'M_malaria': M_malaria_w.value,
            'P_minor': P_minor_w.value,
            'P_major': P_major_w.value,
            'M_SS': M_SS_w.value,
            'overlap': overlap_w.value,
            'overlap_start': 1,
            'growth_start': 5,
            'life_style_gen': 5,
            'r_growth': 0.15,
            'more_child': 0,
            'more_fertile': more_fertile_w.value,
        }
        if mode_w.value == 'paper2':
            params.update({
                'PR_m': PR_m_w.value,
                'M_other': M_other_w.value,
                'P_m_other': P_m_other_w.value,
                'P_minor_other': P_minor_other_w.value,
                'P_major_other': P_major_other_w.value,
            })

        # Theoretical equilibrium (Paper 1 formula)
        M_m = params['M_malaria']
        W_AA = 1 - M_m
        W_AS = 1 - M_m / params['P_minor']
        W_SS = 1 - (params['M_SS'] + (1-params['M_SS'])*M_m/params['P_major'])
        denom = W_AA - 2*W_AS + W_SS
        p_eq = (W_AA - W_AS) / denom * 100 if abs(denom) > 1e-12 else 0

        print(f'Running {repeats_w.value} reps ({mode_w.value})...', flush=True)
        t0 = time.time()
        stats = run_batch(params, repeats=repeats_w.value)
        elapsed = time.time() - t0

        if stats is None:
            print('All repetitions aborted!')
            return

        gens = np.arange(params['max_gen']+1)
        m = stats['means']
        s = stats['sds']

        fig, axes = plt.subplots(1, 3, figsize=(16, 5))
        
        # Gene frequency
        axes[0].plot(gens, m[:,0], 'g-', lw=2)
        axes[0].fill_between(gens, m[:,0]-1.96*s[:,0], m[:,0]+1.96*s[:,0], color='g', alpha=0.15)
        axes[0].axhline(y=p_eq, color='gray', ls='--', alpha=0.7, label=f'Theor. Eq.={p_eq:.1f}%')
        if mode_w.value == 'paper1':
            axes[0].set_ylim(0, max(p_eq*1.8, 5))
        axes[0].set_xlabel('Generation'); axes[0].set_ylabel('Gene Freq. (%)')
        axes[0].set_title('Gene Frequency'); axes[0].legend(); axes[0].grid(alpha=0.3)

        # Hetero + Homo
        axes[1].plot(gens, m[:,1], 'b-', lw=2, label='Heterozygote (AS)')
        axes[1].plot(gens, m[:,2], 'r-', lw=2, label='Homozygote (SS)')
        axes[1].fill_between(gens, m[:,1]-1.96*s[:,1], m[:,1]+1.96*s[:,1], color='b', alpha=0.1)
        axes[1].fill_between(gens, m[:,2]-1.96*s[:,2], m[:,2]+1.96*s[:,2], color='r', alpha=0.1)
        axes[1].set_xlabel('Generation'); axes[1].set_ylabel('Frequency (%)')
        axes[1].set_title('Genotype Frequencies'); axes[1].legend(); axes[1].grid(alpha=0.3)

        # Mortality
        axes[2].plot(gens, m[:,3], 'k-', lw=2)
        axes[2].fill_between(gens, m[:,3]-1.96*s[:,3], m[:,3]+1.96*s[:,3], color='k', alpha=0.1)
        axes[2].set_xlabel('Generation'); axes[2].set_ylabel('Mortality (%)')
        axes[2].set_title('Overall Mortality'); axes[2].grid(alpha=0.3)

        plt.suptitle(f"{'Paper 1 (malaria only)' if mode_w.value=='paper1' else 'Paper 2 (multi-disease)'}"
                     f" — {stats['n_valid']}/{stats['repeats']} survived"
                     f" ({stats['p_abort']*100:.1f}% aborted) — {elapsed:.1f}s",
                     fontsize=13, fontweight='bold', y=1.02)
        plt.tight_layout()
        plt.show()

        eq_val = m[-1, 0]
        print(f'\nSimulated equilibrium f_gene (last gen): {eq_val:.2f}%')
        print(f'Theoretical equilibrium (Paper 1 formula): {p_eq:.2f}%')
        print(f'Gene survival probability: {(1-stats["p_abort"])*100:.1f}%')

run_btn.on_click(on_run)

ui = widgets.VBox([
    widgets.HTML('<h3>Simulation Parameters</h3>'),
    mode_w, repeats_w, couples_w, max_gen_w,
    widgets.HTML('<b>Mortality & Protection</b>'),
    M_malaria_w, P_minor_w, P_major_w, M_SS_w,
    widgets.HTML('<b>Demographic</b>'),
    overlap_w, more_fertile_w,
    widgets.HTML('<b>Paper 2: Multi-disease parameters</b>'),
    paper2_box,
    run_btn,
    output_area
])

display(ui)

## Theoretical vs Simulated Equilibrium

The cell below computes and plots the theoretical equilibrium gene frequency as a function of malaria mortality rate and protection factor, overlaid with simulation results.

In [None]:
# Theoretical equilibrium surface
M_vals = np.linspace(0.05, 0.40, 50)
P_vals = np.linspace(2, 20, 50)
M_grid, P_grid = np.meshgrid(M_vals, P_vals)

M_SS_val = 0.85
W_AA = 1 - M_grid
W_AS = 1 - M_grid / P_grid
W_SS = 1 - (M_SS_val + (1 - M_SS_val) * M_grid / P_grid)
P_eq = (W_AA - W_AS) / (W_AA - 2 * W_AS + W_SS) * 100

fig, ax = plt.subplots(figsize=(10, 7))
cs = ax.contourf(M_grid * 100, P_grid, P_eq, levels=20, cmap='viridis')
plt.colorbar(cs, label='Equilibrium Gene Frequency (%)')
ax.set_xlabel('Malaria Mortality (%)')
ax.set_ylabel('Protection Factor')
ax.set_title('Theoretical Equilibrium Gene Frequency\n(malaria-only model, M_SS = 85%)')
ax.plot(15, 10, 'r*', markersize=15, label='Default (M=15%, P=10)')
ax.legend()
plt.tight_layout()
plt.show()