# Population Genetics Simulation

Interactive exploration of allele frequencies, genotype frequencies, and linkage disequilibrium.

In [None]:
from popgen_sim import *
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False

print('Module loaded successfully!')

---
## Interactive Simulation

Set parameters, run simulation, then use the slider/play button to watch dynamics unfold.

In [None]:
# === PARAMETERS ===

params = SimParams(
    n_loci=2,              # 1 or 2
    pop_size=500,
    n_generations=100,
    n_replicates=1,        # Use 1 for interactive viewing
    
    freq_A=0.5,
    freq_B=0.5,
    initial_D=0.2,         # Only used if n_loci=2
    
    mating_system='random',  # 'random', 'assortative', 'disassortative'
    assortment_strength=1.0,
    assortment_trait='additive',  # 'additive', 'locus_a', 'locus_b'
    
    recomb_rate=0.1,       # Only used if n_loci=2
)

# Validate
if params.n_loci == 2:
    D_min, D_max = params.get_D_bounds()
    print(f"Valid D range: [{D_min:.3f}, {D_max:.3f}]")
    print(f"Using D = {params.initial_D}")
print(f"Mating: {params.mating_system.value}")

In [None]:
# Run simulation
result = simulate(params, seed=42)
print(f"Simulated {params.n_generations} generations")

In [None]:
# === INTERACTIVE PLAYER ===

def make_player(result):
    """Create interactive player for simulation results."""
    n_gen = len(result.generations) - 1
    n_loci = result.params.n_loci
    
    # Widgets
    play = widgets.Play(
        value=0, min=0, max=n_gen,
        step=1, interval=100,  # ms between frames
        description="Play"
    )
    slider = widgets.IntSlider(
        value=0, min=0, max=n_gen,
        description='Generation',
        continuous_update=True,
        layout=widgets.Layout(width='600px')
    )
    speed = widgets.IntSlider(
        value=100, min=20, max=500, step=20,
        description='Speed (ms)',
        layout=widgets.Layout(width='200px')
    )
    
    # Link play and slider
    widgets.jslink((play, 'value'), (slider, 'value'))
    
    # Update speed
    def update_speed(change):
        play.interval = change['new']
    speed.observe(update_speed, names='value')
    
    # Output area
    out = widgets.Output()
    
    def draw(gen):
        with out:
            clear_output(wait=True)
            
            if n_loci == 1:
                fig, axes = plt.subplots(1, 3, figsize=(14, 4))
                
                # Allele frequency trajectory
                axes[0].plot(result.generations[:gen+1], result.p_A[:gen+1], 'b-', lw=2)
                axes[0].axhline(result.p_A[0], color='gray', ls='--', lw=0.5, alpha=0.5)
                axes[0].set_xlim(0, n_gen)
                axes[0].set_ylim(0, 1)
                axes[0].set_xlabel('Generation')
                axes[0].set_ylabel('p(A)')
                axes[0].set_title('Allele Frequency')
                axes[0].axvline(gen, color='red', ls='-', lw=1, alpha=0.5)
                
                # Current genotype frequencies (bar)
                G = result.G_A[gen]
                genos = ['aa', 'Aa', 'AA']
                colors = ['#d62728', '#ff7f0e', '#2ca02c']
                axes[1].bar(genos, [G[g] for g in genos], color=colors)
                axes[1].set_ylim(0, 1)
                axes[1].set_ylabel('Frequency')
                axes[1].set_title(f'Genotype Frequencies (Gen {gen})')
                
                # HW expectation line
                p = result.p_A[gen]
                hw_exp = [f'{(1-p)**2:.3f}', f'{2*p*(1-p):.3f}', f'{p**2:.3f}']
                for i, (g, exp) in enumerate(zip(genos, hw_exp)):
                    axes[1].annotate(f'HW: {exp}', (i, G[g] + 0.02), ha='center', fontsize=8)
                
                # Genotype trajectory
                for geno, color in zip(genos, colors):
                    freqs = [g[geno] for g in result.G_A[:gen+1]]
                    axes[2].plot(result.generations[:gen+1], freqs, color=color, lw=2, label=geno)
                axes[2].set_xlim(0, n_gen)
                axes[2].set_ylim(0, 1)
                axes[2].set_xlabel('Generation')
                axes[2].set_ylabel('Frequency')
                axes[2].set_title('Genotype Trajectories')
                axes[2].legend(loc='upper right')
                axes[2].axvline(gen, color='red', ls='-', lw=1, alpha=0.5)
                
            else:  # 2 loci
                fig, axes = plt.subplots(2, 3, figsize=(14, 8))
                
                # Row 1: Allele freqs, Genotypes A, Genotypes B
                
                # Allele frequencies
                axes[0,0].plot(result.generations[:gen+1], result.p_A[:gen+1], 'b-', lw=2, label='p(A)')
                axes[0,0].plot(result.generations[:gen+1], result.p_B[:gen+1], 'r-', lw=2, label='p(B)')
                axes[0,0].axhline(result.p_A[0], color='blue', ls='--', lw=0.5, alpha=0.3)
                axes[0,0].axhline(result.p_B[0], color='red', ls='--', lw=0.5, alpha=0.3)
                axes[0,0].set_xlim(0, n_gen)
                axes[0,0].set_ylim(0, 1)
                axes[0,0].set_xlabel('Generation')
                axes[0,0].set_ylabel('Frequency')
                axes[0,0].set_title('Allele Frequencies')
                axes[0,0].legend()
                axes[0,0].axvline(gen, color='gray', ls='-', lw=1, alpha=0.5)
                
                # Genotypes A (bar)
                G_A = result.G_A[gen]
                genos_A = ['aa', 'Aa', 'AA']
                colors_A = ['#d62728', '#ff7f0e', '#2ca02c']
                axes[0,1].bar(genos_A, [G_A[g] for g in genos_A], color=colors_A)
                axes[0,1].set_ylim(0, 1)
                axes[0,1].set_ylabel('Frequency')
                axes[0,1].set_title(f'Genotypes A (Gen {gen})')
                
                # Genotypes B (bar)
                G_B = result.G_B[gen]
                genos_B = ['bb', 'Bb', 'BB']
                colors_B = ['#9467bd', '#8c564b', '#e377c2']
                axes[0,2].bar(genos_B, [G_B[g] for g in genos_B], color=colors_B)
                axes[0,2].set_ylim(0, 1)
                axes[0,2].set_ylabel('Frequency')
                axes[0,2].set_title(f'Genotypes B (Gen {gen})')
                
                # Row 2: Gamete freqs (bar), D trajectory, r² trajectory
                
                # Gamete frequencies (bar)
                g_curr = result.g[gen]
                gametes = ['AB', 'Ab', 'aB', 'ab']
                colors_g = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
                axes[1,0].bar(gametes, [g_curr[gam] for gam in gametes], color=colors_g)
                axes[1,0].set_ylim(0, 1)
                axes[1,0].set_ylabel('Frequency')
                axes[1,0].set_title(f'Gamete Frequencies (Gen {gen})')
                
                # LE expectation
                p_A, p_B = result.p_A[gen], result.p_B[gen]
                le_exp = [p_A*p_B, p_A*(1-p_B), (1-p_A)*p_B, (1-p_A)*(1-p_B)]
                for i, (gam, exp) in enumerate(zip(gametes, le_exp)):
                    axes[1,0].plot(i, exp, 'k_', markersize=15, mew=2)
                axes[1,0].annotate('— LE expectation', (2.5, 0.9), fontsize=8)
                
                # D trajectory
                axes[1,1].plot(result.generations[:gen+1], result.D[:gen+1], 'g-', lw=2)
                axes[1,1].axhline(0, color='gray', ls='--', lw=0.5)
                axes[1,1].axhline(result.D[0], color='green', ls='--', lw=0.5, alpha=0.3)
                axes[1,1].set_xlim(0, n_gen)
                D_bound = max(abs(min(result.D)), abs(max(result.D)), 0.05)
                axes[1,1].set_ylim(-D_bound*1.1, D_bound*1.1)
                axes[1,1].set_xlabel('Generation')
                axes[1,1].set_ylabel('D')
                axes[1,1].set_title(f'Linkage Disequilibrium (D = {result.D[gen]:.4f})')
                axes[1,1].axvline(gen, color='gray', ls='-', lw=1, alpha=0.5)
                
                # r² trajectory
                axes[1,2].plot(result.generations[:gen+1], result.r_squared[:gen+1], 'm-', lw=2)
                axes[1,2].set_xlim(0, n_gen)
                axes[1,2].set_ylim(0, 1)
                axes[1,2].set_xlabel('Generation')
                axes[1,2].set_ylabel('r²')
                axes[1,2].set_title(f'LD r² = {result.r_squared[gen]:.4f}')
                axes[1,2].axvline(gen, color='gray', ls='-', lw=1, alpha=0.5)
            
            # Suptitle with params
            p = result.params
            title = f"N={p.pop_size}, {p.mating_system.value} mating"
            if n_loci == 2:
                title += f", r={p.recomb_rate}"
            fig.suptitle(title, fontsize=12, y=1.02)
            
            plt.tight_layout()
            plt.show()
    
    # Connect slider to draw
    def on_change(change):
        draw(change['new'])
    slider.observe(on_change, names='value')
    
    # Initial draw
    draw(0)
    
    # Layout
    controls = widgets.HBox([play, slider, speed])
    return widgets.VBox([controls, out])

# Display player
make_player(result)

---
## Static Plots (Multiple Replicates)

For visualizing drift across many replicates.

In [None]:
# Run multiple replicates
params_multi = SimParams(
    n_loci=2,
    pop_size=200,
    n_generations=100,
    n_replicates=20,
    freq_A=0.5,
    freq_B=0.5,
    initial_D=0.2,
    mating_system='random',
    recomb_rate=0.1,
)

results_multi = simulate_replicates(params_multi)
print(f"Ran {len(results_multi)} replicates")

In [None]:
plot_replicates(results_multi)
plt.show()

---
## Compare Scenarios

In [None]:
# Compare mating systems
base = dict(n_loci=2, pop_size=500, n_generations=50, n_replicates=5,
            freq_A=0.5, freq_B=0.5, initial_D=0.0, recomb_rate=0.1)

scenarios = [
    ('Random', SimParams(**base, mating_system='random')),
    ('Assortative', SimParams(**base, mating_system='assortative')),
    ('Disassortative', SimParams(**base, mating_system='disassortative')),
]

compare_scenarios(scenarios, metric='D')
plt.title('LD buildup under different mating systems (starting D=0)')
plt.show()

In [None]:
# Compare population sizes
scenarios = [
    (f'N={n}', SimParams(n_loci=2, pop_size=n, n_generations=100, n_replicates=10,
                         freq_A=0.5, freq_B=0.5, initial_D=0.2, recomb_rate=0.1))
    for n in [50, 200, 1000]
]

compare_scenarios(scenarios, metric='D')
plt.title('LD decay: effect of population size')
plt.show()

In [None]:
# Compare recombination rates
scenarios = [
    (f'r={r}', SimParams(n_loci=2, pop_size=10000, n_generations=50, n_replicates=3,
                         freq_A=0.5, freq_B=0.5, initial_D=0.2, recomb_rate=r))
    for r in [0.01, 0.05, 0.1, 0.5]
]

compare_scenarios(scenarios, metric='D')
plt.title('LD decay: effect of recombination rate (large N to minimize drift)')
plt.show()

---
## One-Locus Model

In [None]:
# One-locus example: drift
params_1L = SimParams(
    n_loci=1,
    pop_size=100,
    n_generations=100,
    n_replicates=1,
    freq_A=0.5,
    mating_system='random',
)

result_1L = simulate(params_1L, seed=123)
make_player(result_1L)

In [None]:
# One-locus: multiple replicates to see drift
params_1L_multi = SimParams(
    n_loci=1,
    pop_size=50,
    n_generations=100,
    n_replicates=30,
    freq_A=0.3,
    mating_system='random',
)

results_1L = simulate_replicates(params_1L_multi)
plot_replicates(results_1L)
plt.show()

---
## Quick Reference

**SimParams fields:**
- `n_loci`: 1 or 2
- `pop_size`: Population size (N)
- `n_generations`: Generations to simulate
- `n_replicates`: Number of replicate runs
- `freq_A`, `freq_B`: Initial allele frequencies
- `initial_D`: Initial LD coefficient (2-locus only)
- `mating_system`: `'random'`, `'assortative'`, `'disassortative'`
- `assortment_strength`: 0-1
- `assortment_trait`: `'additive'`, `'locus_a'`, `'locus_b'`
- `recomb_rate`: 0-0.5 (2-locus only)

**Nomenclature:**
- `p_A`, `p_B`: Allele frequencies
- `G_AA`, `G_Aa`, `G_aa`: Genotype frequencies (capital G)
- `g_AB`, `g_Ab`, etc.: Gamete frequencies (lowercase g)
- `D`, `r²`: LD measures