In [4]:
import numpy as np
from scipy.sparse import dok_matrix
from math import log
from scipy.linalg import expm
import mdptoolbox.mdp as mdp
import os

In [5]:
N_pop=5000
mutation_rate=0.001
num_genotypes = 3

def setup_parameters(L=100, N_pop=5000, mutation_rate=0.001, tau=1000, dt=1):
    # Derived parameters
    a = 1.0 / L
    N_states = L * L
    # Generate state grid: list of (u1, u2)
    states = [(a * (0.5 + i), a * (0.5 + j)) for j in range(L) for i in range(L)]
    t_step = int(tau / dt)
    t = np.linspace(0, tau, t_step + 1)
    return a, states, N_states, t_step, t

# Drug fitness profiles
actions = [0, 1, 2, 3]  # Possible drug actions
drug_lists = {
    0: [0.994, 0.997, 1.009],  # Drug A
    3: [0.993, 1.005, 1.002],  # Drug D .. [1.005, 0.988, 0.999]
    2: [0.999, 1.003, 0.994],  # Drug C .. [0.997, 1.001, 0.989]
    1: [1.000, 0.991, 1.008]   # Drug 3 
}

def get_f(action):
    return drug_lists[action]

def get_s(action):
    f = get_f(action)
    s1 = f[1] / f[0] - 1
    s2 = f[2] / f[0] - 1
    return s1, s2

def unpack(u):
    x1 = u[0]
    x2 = u[1] * (1 - x1)
    x0 = 1 - x1 - x2
    return x0, x1, x2

# Fokker-Planck
def A1(u, action):
    s1, s2 = get_s(action)
    return mutation_rate - 3*mutation_rate*u[0] + u[0]*(1-u[0])*(s1 - u[1]*s2)

def A2(u, action):
    s1, s2 = get_s(action)
    return (mutation_rate - 2*mutation_rate*u[1])/(1-u[0]) + u[1]*(1-u[1])*s2

def D1(u):
    return u[0] * (1 - u[0]) / (2 * N_pop)

def D2(u):
    return u[1] * (1 - u[1]) / (2 * N_pop * (1 - u[0]))

def get_s(action):
    f = get_f(action)
    s1 = f[1] / f[0] - 1
    s2 = f[2] / f[0] - 1
    return s1, s2

def compute_reward(state, action):
    x0, x1, x2 = unpack(state)
    f0, f1, f2 = get_f(action)  
    fitness = x0*f0 + x1*f1 + x2*f2  
    return -fitness  # Negative reward to minimize fitness

# Discrete Markov Model
def run_for_dt(dt, L=100, N_pop=5000, mutation_rate=0.001, tau=1000, n_reps=10000, discount=0.99):
    a, states, N_states, t_step, t = setup_parameters(L, N_pop, mutation_rate, tau, dt)
 
    # Precompute transition matrices and rewards
    P = []
    R = np.zeros((N_states, len(actions)))
    for action in actions:
        # Build rate matrix
        Omega = dok_matrix((N_states, N_states), dtype=np.float64)

        # Loop over all states to define the off-diagonal rates.
        for n in range(N_states):
            u = states[n]  # current state (u1, u2)
            
            # RIGHT neighbor: m = n+1 if not on the right boundary.
            # In a 0-indexed array, if n % L != L - 1 then we can move right.
            if (n % L) != (L - 1):
                m = n + 1
                rate = D1(u)/a**2 + A1(u, action)/(2*a)
                Omega[n, m] = rate
                
            # LEFT neighbor: m = n-1 if not on the left boundary.
            if (n % L) != 0:
                m = n - 1
                rate = D1(u)/a**2 - A1(u, action)/(2*a)
                Omega[n, m] = rate
                
            # UP neighbor: m = n+L if not on the top boundary.
            if n < L * (L - 1):
                m = n + L
                rate = D2(u)/a**2 + A2(u, action)/(2*a)
                Omega[n, m] = rate
                
            # DOWN neighbor: m = n-L if not on the bottom boundary.
            if n >= L:
                m = n - L
                rate = D2(u)/a**2 - A2(u, action)/(2*a)
                Omega[n, m] = rate

        # Set diagonal elements such that row sums are zero.
        for n in range(N_states):
            row_sum = Omega[n, :].sum()
            Omega[n, n] = -row_sum

        W = expm(Omega.toarray()*dt)
        W[W < 0] = 0.0                         # clip tiny negatives
        row_sums = W.sum(axis=1, keepdims=True)
        W = W / row_sums                       # renormalize rows

        P.append(W)
        for i, s in enumerate(states):
            R[i, action] = compute_reward(s, action)

    # Solve MDP
    vi = mdp.ValueIteration(transitions=P, reward=R, discount=discount, epsilon=1e-4, max_iter=1000)
    vi.run()

    # Helper to map frequency to state idx
    def freq_to_state_idx(freq):
        u1, u2 = freq[1], freq[2]
        u1 = np.clip(u1, 0, 1)
        den = 1 - u1
        u2 = np.clip(u2/den if den > 0 else 0, 0, 1)
        i = min(max(int(np.floor(u1 * L)), 0), L-1)
        j = min(max(int(np.floor(u2 * L)), 0), L-1)
        return j * L + i

    # Picker functions
    def mdp_picker(freq):
        return vi.policy[freq_to_state_idx(freq)]

    def run_wf(picker, dt):
        """
        Simulate Wright–Fisher dynamics where the drug can only change every `dt` generations.
        Each outer step corresponds to one policy epoch of length `dt` generations.
        We log fitness at the end of each epoch (including epoch 0 at t=0).
        """
        dt_int = int(round(dt))
        if dt_int < 1:
            raise ValueError("dt must be >= 1 when used as generations per policy step.")

        # symmetric mutation matrix Q (j→i)
        Q = np.full((num_genotypes, num_genotypes), mutation_rate)
        np.fill_diagonal(Q, 1 - mutation_rate*(num_genotypes - 1))

        # start fully at genotype 0 (adjust if you prefer another init)
        counts = np.array([N_pop] + [0]*(num_genotypes-1), dtype=int)
        freq   = counts / N_pop

        fit_traj = np.zeros(t_step + 1)

        # Choose initial drug based on initial freq and log baseline fitness
        a_curr = picker(freq)
        f_vec  = np.array(get_f(a_curr), float)
        fit_traj[0] = freq.dot(f_vec)

        # For each policy epoch, hold drug fixed for dt_int biological generations
        for step in range(1, t_step + 1):
            # run dt generations under the same drug a_curr
            f_vec = np.array(get_f(a_curr), float)
            for _ in range(dt_int):
                # selection
                w_bar    = freq.dot(f_vec)
                freq_sel = (freq * f_vec) / max(w_bar, 1e-300)
                # mutation
                freq_mut = Q.T @ freq_sel
                # drift
                counts = np.random.multinomial(N_pop, freq_mut)
                freq   = counts / N_pop

            # record fitness at the end of this epoch under current drug
            fit_traj[step] = freq.dot(f_vec)

            # pick next drug for the *next* epoch
            a_curr = picker(freq)

        return fit_traj

    # Run replicates
    mdp_fit = np.zeros((n_reps, t_step + 1))
    for r in range(n_reps):
        mdp_fit[r] = run_wf(mdp_picker, dt)

    # Compute statistics across replicates
    mean_fit = mdp_fit.mean(axis=0)  # (t_step+1,)
    std_fit  = mdp_fit.std(axis=0)   # (t_step+1,)
    return t, mean_fit, std_fit

In [6]:
if __name__ == "__main__":
    # Define output directory and dt values
    output_dir = "3g-results-dt"
    os.makedirs(output_dir, exist_ok=True)
    dt_values = [1, 2, 5] + list(range(10, 101, 10))

    for dt in dt_values:
        print(f"Running for dt = {dt}...")
        t, mean_fit, std_fit = run_for_dt(dt)
        out_file = os.path.join(output_dir, f"3g_mean_std_dt{dt}.npz")
        np.savez_compressed(out_file, t=t, mean_fit=mean_fit, std_fit=std_fit)
        print(f"Saved results to {out_file}")

Running for dt = 1...
Saved results to 3g-results-dt/3g_mean_std_dt1.npz
Running for dt = 2...
Saved results to 3g-results-dt/3g_mean_std_dt2.npz
Running for dt = 5...
Saved results to 3g-results-dt/3g_mean_std_dt5.npz
Running for dt = 10...
Saved results to 3g-results-dt/3g_mean_std_dt10.npz
Running for dt = 20...
Saved results to 3g-results-dt/3g_mean_std_dt20.npz
Running for dt = 30...
Saved results to 3g-results-dt/3g_mean_std_dt30.npz
Running for dt = 40...
Saved results to 3g-results-dt/3g_mean_std_dt40.npz
Running for dt = 50...
Saved results to 3g-results-dt/3g_mean_std_dt50.npz
Running for dt = 60...
Saved results to 3g-results-dt/3g_mean_std_dt60.npz
Running for dt = 70...
Saved results to 3g-results-dt/3g_mean_std_dt70.npz
Running for dt = 80...
Saved results to 3g-results-dt/3g_mean_std_dt80.npz
Running for dt = 90...
Saved results to 3g-results-dt/3g_mean_std_dt90.npz
Running for dt = 100...
Saved results to 3g-results-dt/3g_mean_std_dt100.npz
