In [1]:
import numpy as np
from scipy.sparse import dok_matrix
from math import log2
from scipy.linalg import expm
import mdptoolbox.mdp as mdp
import os

In [2]:
# ------------------ Global simulation params ------------------
N_pop         = 10000
mutation_rate = 0.001     # per Hamming-1 edge (discrete-time mutation prob)
tau           = 1000
dt            = 1
t_step        = int(tau/dt)
t             = np.linspace(0, tau, t_step+1)

num_genotypes = 8
assert (log2(num_genotypes)).is_integer()
B = int(log2(num_genotypes))    # number of bits
d = num_genotypes - 1           # dimension of u-space

# ------------------ Actions / fitness (8 genotypes) ------------------
actions    = [0, 1, 2, 3]
drug_names = ['A', 'B', 'C', 'D']
drug_lists = {
    0: [0.993, 0.998, 1.009, 1.003, 1.007, 1.001, 0.992, 0.997],  # Drug A
    1: [0.995, 1.005, 1.002, 0.999, 1.005, 0.994, 0.999, 1.001],  # Drug B
    2: [0.997, 1.001, 0.989, 1.003, 1.003, 0.998, 1.010, 0.997],  # Drug C
    3: [1.005, 0.988, 0.999, 1.001, 0.995, 1.011, 1.000, 0.999],  # Drug D
}

def get_f(action):
    f = np.asarray(drug_lists[action], float)
    assert len(f) == num_genotypes
    return f

def get_s_vec(action):
    """Relative selection s_i = f_i/f_0 - 1 for i=1..M-1."""
    f = get_f(action)
    return f[1:] / f[0] - 1.0

def genotype_bitstrings(M):
    B = int(log2(M))
    return [format(i, f'0{B}b') for i in range(M)]

genotypes = genotype_bitstrings(num_genotypes)

# ------------- Discrete-time Hamming-1 mutation matrix (row-stochastic) -------------
def build_hamming_U(M, mu):
    bits = genotype_bitstrings(M)
    U = np.zeros((M, M), float)
    for i in range(M):
        for j in range(M):
            if i != j and sum(a != b for a,b in zip(bits[i], bits[j])) == 1:
                U[i, j] = mu
    np.fill_diagonal(U, 1.0 - U.sum(axis=1))
    return U

U_ham = build_hamming_U(num_genotypes, mutation_rate)
Mmat  = U_ham - np.eye(num_genotypes)   # per-generation expected drift in x

# ------------------ Stick-breaking helpers (general M) ------------------
def x_from_u(u):
    """u (len M-1) -> x (len M)."""
    u = np.asarray(u, float)
    M = u.size + 1
    x = np.zeros(M, float)
    S = 1.0
    for k in range(1, M):
        uk  = u[k-1]
        x[k] = uk * S
        S   *= (1.0 - uk)
    x[0] = S
    return x

def S_prefix(u):
    """S_k = prod_{j<k} (1-u_j), returned as array over k=1..M-1 (0-based)."""
    u = np.asarray(u, float)
    d = u.size
    S = np.ones(d, float)
    prod = 1.0
    for i in range(d):
        S[i] = prod
        prod *= (1.0 - u[i])
    return S

def pack_from_x(x):
    """x (len M) -> u (len M-1)."""
    x = np.asarray(x, float)
    M = x.size
    u = np.zeros(M-1, float)
    cum = 0.0
    for i in range(1, M):
        denom = 1.0 - cum
        u[i-1] = 0.0 if denom <= 0 else x[i] / denom
        cum += x[i]
    return u

# ------------------ Drift and diffusion in u-space ------------------
def drift_u_mut(u):
    """Mutation drift A^(mut)(u) from Mmat @ x(u)."""
    u = np.asarray(u, float)
    x = x_from_u(u)
    Ax = Mmat @ x  # length M
    pref = np.concatenate([[0.0], np.cumsum(Ax[1:])])  # length d+1
    S = S_prefix(u)
    A = np.zeros(d, float)
    for k in range(d):  # u-index k ↔ x-index k+1
        Akx = Ax[k+1]
        sum_prev = pref[k]  # sum_{i=1}^{k} Ax[i]
        A[k] = (Akx + u[k]*sum_prev) / S[k]
    return A

def drift_u_sel(u, action):
    """Selection drift A^(sel)(u)."""
    u = np.asarray(u, float)
    s = get_s_vec(action)   # length d
    x = x_from_u(u)
    S = S_prefix(u)
    # tail sums T_k = sum_{j=k}^{d-1} x_{j+1} s_j
    tail = np.zeros(d+1, float)
    for k in range(d-1, -1, -1):
        tail[k] = tail[k+1] + x[k+1]*s[k]
    A = np.zeros(d, float)
    for k in range(d):
        A[k] = (u[k]/S[k]) * (S[k]*s[k] - tail[k])
    return A

def drift_u(u, action):
    return drift_u_mut(u) + drift_u_sel(u, action)

def diffusion_u(u):
    """Diagonal diffusion D_k(u) = u_k(1-u_k)/(2N * S_k)."""
    u = np.asarray(u, float)
    S = S_prefix(u)
    return u * (1.0 - u) / (2.0 * N_pop * S)

# ------------------ L^d grid helpers ------------------
def index_to_multi(n, L, d):
    idx = np.empty(d, int)
    for r in range(d):
        idx[r] = (n // (L**r)) % L
    return idx

def multi_to_index(idx, L):
    n = 0
    for r, ir in enumerate(idx):
        n += (L**r)*ir
    return n

def build_states(L):
    a = 1.0 / L
    centers = a * (0.5 + np.arange(L))
    N_states = L**d
    states = []
    for n in range(N_states):
        idx = index_to_multi(n, L, d)
        u = centers[idx]
        states.append(u)
    return np.asarray(states), a

# ------------------ CTMC Ω and transition W for one action ------------------
def build_transition_rate_matrix(action, states, L, a):
    N_states = len(states)
    Omega = dok_matrix((N_states, N_states), dtype=np.float64)
    strides = [L**r for r in range(d)]

    for n in range(N_states):
        u = states[n]                  # vector length d
        A = drift_u(u, action)         # length d
        D = diffusion_u(u)             # length d
        idx = index_to_multi(n, L, d)

        for r in range(d):
            if idx[r] < L-1:  # forward neighbor
                m = n + strides[r]
                rate = D[r]/a**2 + A[r]/(2*a)
                Omega[n, m] = rate
            if idx[r] > 0:    # backward neighbor
                m = n - strides[r]
                rate = D[r]/a**2 - A[r]/(2*a)
                Omega[n, m] = rate

    # diagonals: rows sum to zero
    for n in range(N_states):
        row_sum = Omega[n, :].sum()
        Omega[n, n] = -row_sum
    return Omega.tocsr()

def build_transition_matrix(Omega, dt):
    W = expm(Omega.toarray() * dt)   # N x N
    W[W < 0] = 0.0                   # clip tiny negatives
    W = W / W.sum(axis=1, keepdims=True)
    return W

# ------------------ Reward (-mean fitness) ------------------
def compute_reward(u_state, action):
    x = x_from_u(u_state)
    return -x.dot(get_f(action))

def build_W_and_R(actions, states, L, a, dt):
    N_states = len(states)
    P = []
    R = np.zeros((N_states, len(actions)), float)
    for act in actions:
        Omega = build_transition_rate_matrix(act, states, L, a)
        W = build_transition_matrix(Omega, dt)
        P.append(W)
        for i, us in enumerate(states):
            R[i, act] = compute_reward(us, act)
    return P, R

# ------------------ Policy mapping and WF simulator ------------------
def freq_to_state_idx(freq, L):
    u = pack_from_x(freq)                  # length d
    idx = np.minimum((u * L).astype(int), L-1)
    return multi_to_index(idx, L)

def make_mdp_picker(policy_flat, L):
    def picker(freq):
        return int(policy_flat[freq_to_state_idx(freq, L)])
    return picker

def run_wf(picker):
    counts = np.zeros(num_genotypes, int); counts[0] = N_pop
    freq   = counts / N_pop

    fit_traj = np.zeros(t_step+1)

    # gen 0
    a0 = picker(freq)
    f0 = get_f(a0)
    fit_traj[0] = freq.dot(f0)

    # gens 1..t_step
    for gen in range(1, t_step+1):
        a     = picker(freq)
        f_vec = get_f(a)
        # selection
        w_bar    = freq.dot(f_vec)
        freq_sel = (freq * f_vec) / w_bar
        # mutation (discrete-time)
        freq_mut = U_ham.T @ freq_sel
        # drift
        counts = np.random.multinomial(N_pop, freq_mut)
        freq   = counts / N_pop

        fit_traj[gen] = freq.dot(f_vec)

    return fit_traj

# ------------------ Run for a single L (MDP + WF reps) ------------------
def run_for_L(L, discount=0.99, epsilon=1e-4, max_iter=1000, n_reps=10000):
    states, a = build_states(L)
    P, R = build_W_and_R(actions, states, L, a, dt)

    # Solve MDP
    vi = mdp.ValueIteration(transitions=P, reward=R,
                                       discount=discount, epsilon=epsilon, max_iter=max_iter)
    vi.run()
    policy_flat = np.array(vi.policy, int)
    mdp_picker  = make_mdp_picker(policy_flat, L)

    # Run replicates under MDP policy
    mdp_fit = np.zeros((n_reps, t_step+1))
    for r in range(n_reps):
        mdp_fit[r] = run_wf(mdp_picker)

    mean_fit = mdp_fit.mean(axis=0)
    std_fit  = mdp_fit.std(axis=0)
    return t, mean_fit, std_fit




In [None]:
# ------------------ Main: loop over multiple L and save ------------------
if __name__ == "__main__":
    output_dir = "8g-results-L"
    os.makedirs(output_dir, exist_ok=True)

    L_values = [2, 3]

    for L in L_values:
        print(f"Running for L = {L} (states = {L**d}) ...")
        t_arr, mean_fit, std_fit = run_for_L(L)
        out_file = os.path.join(output_dir, f"8g_mean_std_L{L}.npz")
        np.savez_compressed(out_file, t=t_arr, mean_fit=mean_fit, std_fit=std_fit)
        print(f"Saved results to {out_file}")

Running for L = 2 (states = 128) ...
Saved results to 8g-results/8g_mean_std_L2.npz
Running for L = 3 (states = 2187) ...
Saved results to 8g-results/8g_mean_std_L3.npz
