In [1]:
import numpy as np
from scipy.sparse import dok_matrix
from scipy.linalg import expm
import mdptoolbox.mdp as mdp
import os

In [2]:
N_pop=10000
mutation_rate=0.001
num_genotypes = 4
genotypes = [format(i, f'0{int(np.log2(num_genotypes))}b') for i in range(num_genotypes)]

def setup_parameters(L=20, N_pop=10000, mutation_rate=0.001, tau=1000, dt=1):
    # Derived parameters
    a = 1.0 / L
    N_states = L * L * L
    # Generate state grid: list of (u1, u2)
    states = [(a * (0.5 + i), a * (0.5 + j), a * (0.5 + k)) for k in range(L) for j in range(L) for i in range(L)]
    t_step = int(tau / dt)
    t = np.linspace(0, tau, t_step + 1)
    return a, states, N_states, t_step, t

# Drug fitness profiles
actions = [0, 1, 2, 3]  # Possible drug actions
drug_lists = {
    0: [0.993, 0.998, 1.009, 1.003],  # Drug A
    3: [0.995, 1.005, 1.002, 0.999],  # Drug D
    2: [0.997, 1.001, 0.989, 1.003],  # Drug C
    1: [1.005, 0.988, 0.999, 1.001]   # Drug B
}

def get_f(action):
    return drug_lists[action]

def get_s(action):
    f = get_f(action)
    s1 = f[1] / f[0] - 1
    s2 = f[2] / f[0] - 1
    s3 = f[3] / f[0] - 1
    return s1, s2, s3

def unpack(u):
    x1 = u[0]
    x2 = u[1] * (1 - x1)
    x3 = u[2] * (1 - u[0]) * (1 - u[1])
    x0 = 1 - x1 - x2 - x3
    return x0, x1, x2, x3

# Fokker-Planck
def A1(u, action):
    s1, s2, s3 = get_s(action)
    return mutation_rate*((1-u[0])*(1-u[1])-2*u[0]) + u[0]*(1-u[0])*(s1 - u[1]*s2 - (1-u[1])*u[2]*s3)

def A2(u, action):
    s1, s2, s3 = get_s(action)
    return mutation_rate/(1-u[0])*((1-u[0])*(1-u[1])*(1+u[1])-2*(1-u[0])*u[1]-2*u[0]*u[1]) + u[1]*(1-u[1])*(s2 - u[2]*s3)

def A3(u, action):
    s1, s2, s3 = get_s(action)
    return mutation_rate/((1-u[0])*(1-u[1]))*(u[0]+(1-u[0])*u[1])*(1-2*u[2]) + u[2]*(1-u[2])*s3

def D1(u):
    return u[0] * (1 - u[0]) / (2 * N_pop)

def D2(u):
    return u[1] * (1 - u[1]) / (2 * N_pop * (1 - u[0]))

def D3(u):
    return u[2] * (1 - u[2]) / (2 * N_pop * (1 - u[0])*(1 - u[1]))

def compute_reward(state, action):
    x0, x1, x2, x3 = unpack(state)
    f0, f1, f2, f3 = get_f(action)  
    fitness = x0*f0 + x1*f1 + x2*f2 + x3*f3
    return -fitness  # Negative reward to minimize fitness

# Discrete Markov Model
def run_for_dt(dt, L=20, N_pop=10000, mutation_rate=0.001, tau=1000, n_reps=5000, discount=0.99):
    a, states, N_states, t_step, t = setup_parameters(L, N_pop, mutation_rate, tau, dt)

    # Precompute transition matrices and rewards
    P = []
    R = np.zeros((N_states, len(actions)))
    for action in actions:
        # Build rate matrix
        Omega = dok_matrix((N_states, N_states), dtype=np.float64)

        # Loop over all states to define the off-diagonal rates.
        for n in range(N_states):
            u = states[n]  # current state (u1, u2)
            
            # i, j, k coordinates of state n
            i = n % L
            j = (n // L) % L
            k = n // (L * L)

            # Right (i+1)
            if i < L - 1:
                m = n + 1
                rate = D1(u)/a**2 + A1(u, action)/(2*a)
                Omega[n, m] = rate

            # Left (i-1)
            if i > 0:
                m = n - 1
                rate = D1(u)/a**2 - A1(u, action)/(2*a)
                Omega[n, m] = rate

            # Up (j+1)
            if j < L - 1:
                m = n + L
                rate = D2(u)/a**2 + A2(u, action)/(2*a)
                Omega[n, m] = rate

            # Down (j-1)
            if j > 0:
                m = n - L
                rate = D2(u)/a**2 - A2(u, action)/(2*a)
                Omega[n, m] = rate

            # Front (k+1)
            if k < L - 1:
                m = n + L * L
                rate = D3(u)/a**2 + A3(u, action)/(2*a)
                Omega[n, m] = rate
                
            # Back neighbor:
            if k > 0:
                m = n - L * L
                rate = D3(u)/a**2 - A3(u, action)/(2*a)
                Omega[n, m] = rate

        # Set diagonal elements such that row sums are zero.
        for n in range(N_states):
            row_sum = Omega[n, :].sum()
            Omega[n, n] = -row_sum

        W = expm(Omega.toarray()*dt)
        W[W < 0] = 0.0                         # clip tiny negatives
        row_sums = W.sum(axis=1, keepdims=True)
        W = W / row_sums                       # renormalize rows

        P.append(W)
        for i, s in enumerate(states):
            R[i, action] = compute_reward(s, action)

    # Solve MDP
    vi = mdp.ValueIteration(transitions=P, reward=R, discount=discount, epsilon=1e-4, max_iter=1000)
    vi.run()

    def pack(freq):
        x0, x1, x2, x3 = freq
        u1 = x1
        den1 = 1 - u1
        u2 = 0 if den1 == 0 else x2 / den1
        den2 = den1 * (1 - u2)
        u3 = 0 if den2 == 0 else x3 / den2
        return u1, u2, u3

    # Helper to map frequency to state idx
    def freq_to_state_idx(freq, L):
        u1, u2, u3 = pack(freq)
        # clip just in case numerical drift
        i = min(max(int(np.floor(u1 * L)), 0), L-1)
        j = min(max(int(np.floor(u2 * L)), 0), L-1)
        k = min(max(int(np.floor(u3 * L)), 0), L-1)
        return i + L*j + (L**2)*k

    # Picker functions
    def mdp_picker(freq):
        return vi.policy[freq_to_state_idx(freq, L)]

    # WF simulation with dt-block drug updates
    def run_wf(picker, dt):
        dt_int=int(round(dt)); 
        if dt_int<1: raise ValueError("dt must be >=1 generations.")
        Q=np.zeros((num_genotypes,num_genotypes))
        for i in range(num_genotypes):
            for j in range(num_genotypes):
                if i!=j and sum(a!=b for a,b in zip(genotypes[i],genotypes[j]))==1:
                    Q[i,j]=mutation_rate
        for i in range(num_genotypes): Q[i,i]=1-Q[i].sum()

        counts=np.array([N_pop]+[0]*(num_genotypes-1),int)
        freq=counts/N_pop
        fit_traj=np.zeros(t_step+1)

        a_curr=picker(freq)
        f_vec=np.array(get_f(a_curr),float)
        fit_traj[0]=freq.dot(f_vec)

        for step in range(1,t_step+1):
            f_vec=np.array(get_f(a_curr),float)
            for _ in range(dt_int):
                w_bar=freq.dot(f_vec)
                freq_sel=(freq*f_vec)/max(w_bar,1e-300)
                freq_mut=Q.T@freq_sel
                counts=np.random.multinomial(N_pop,freq_mut)
                freq=counts/N_pop
            fit_traj[step]=freq.dot(f_vec)
            a_curr=picker(freq)

        return fit_traj
    
    # Run replicates
    mdp_fit = np.zeros((n_reps, t_step + 1))
    for r in range(n_reps):
        mdp_fit[r] = run_wf(mdp_picker, dt)

    # Compute statistics across replicates
    mean_fit = mdp_fit.mean(axis=0)  # (t_step+1,)
    std_fit  = mdp_fit.std(axis=0)   # (t_step+1,)
    return t, mean_fit, std_fit

In [3]:
if __name__=="__main__":
    output_dir="4g-results-dt"; os.makedirs(output_dir,exist_ok=True)
    dt_values=[1,2,5]+list(range(10,101,10))
    for dt in dt_values:
        print(f"Running for dt={dt}...")
        t,mean_fit,std_fit=run_for_dt(dt,L=20)
        out_file=os.path.join(output_dir,f"4g_mean_std_dt{dt}.npz")
        np.savez_compressed(out_file,t=t,mean_fit=mean_fit,std_fit=std_fit)
        print(f"Saved {out_file}")

Running for dt=1...
Saved 4g-results-dt/4g_mean_std_dt1.npz
Running for dt=2...
Saved 4g-results-dt/4g_mean_std_dt2.npz
Running for dt=5...
Saved 4g-results-dt/4g_mean_std_dt5.npz
Running for dt=10...
Saved 4g-results-dt/4g_mean_std_dt10.npz
Running for dt=20...
Saved 4g-results-dt/4g_mean_std_dt20.npz
Running for dt=30...
Saved 4g-results-dt/4g_mean_std_dt30.npz
Running for dt=40...
Saved 4g-results-dt/4g_mean_std_dt40.npz
Running for dt=50...
Saved 4g-results-dt/4g_mean_std_dt50.npz
Running for dt=60...
Saved 4g-results-dt/4g_mean_std_dt60.npz
Running for dt=70...
Saved 4g-results-dt/4g_mean_std_dt70.npz
Running for dt=80...
Saved 4g-results-dt/4g_mean_std_dt80.npz
Running for dt=90...
Saved 4g-results-dt/4g_mean_std_dt90.npz
Running for dt=100...
Saved 4g-results-dt/4g_mean_std_dt100.npz
