In [1]:
# exact same setup as before, except we don't need the mutations, just the tree itself
import msprime as msp
import demes
import demesdraw

Ne = 1e4
q_anc = 2.0
q = [1.0, 1.5]
m = [0.1]
tau = 4.0
demo = msp.Demography()
demo.add_population(initial_size= (Ne) / (q_anc), name = "anc")
demo.add_population(initial_size = (Ne) / (q[0]), name = "P0")
demo.add_population(initial_size = (Ne) / (q[1]), name = "P1")
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate= m[0] / (2*Ne))
tmp = [f"P{i}" for i in range(2)]
demo.add_population_split(time = tau * 2 * Ne, derived=tmp, ancestral="anc")
g = demo.to_demes()
# demesdraw.tubes(g)
# print(g)
sample_size = 10
samples = {f"P{i}": sample_size for i in range(2)}
anc = msp.sim_ancestry(samples=samples, demography=demo, recombination_rate=1e-8, sequence_length=1e8)
ts = msp.sim_mutations(anc, rate=1e-8)

In [6]:
import jax.numpy as jnp
import jax
import jax.experimental.sparse as jesp
import scipy.sparse as sp
from jax.experimental import sparse
from jax.experimental.sparse import BCOO
from jax import lax
import scipy
from jax import vmap
import phlash
from phlash.likelihood.arg import log_density


def _matvec(y, Q, Q1, q):
    n = q.shape[0]
    yn = y[:-1].reshape(n, n)
    yd = jnp.diag(yn)
    ret = Q @ yn + yn @ Q
    ret -= Q1 * yn
    ret -= jnp.diag(q * yd)
    return jnp.append(ret.reshape(n ** 2), q.dot(yd))

def matvec(y, Q, Q1, q):
    return _matvec(jnp.array(y), Q, Q1, q)

def P(v, Q, Q1, q, Lambda):
    return v + matvec(v, Q, Q1, q)/Lambda

def expm_mult_unif(v0, Q, Q1, q, t):
    # Lambda = 2*(Q.sum(1).todense().max() + q.max()) 
    # two people can migrate with 4 edges with max migration 0.1, hence first term and 2 * 1 for max coalescent
    # rate of two people on a graph.
    Lambda = 2*(0.1 * 6) + 2
    # don't need jax.scipy because none of the arguments are related to jax
    N = scipy.stats.poisson.isf(1e-5, 30 * Lambda)
    # N = 200
    w = v0
    ret = 0 * v0
    def f(accum, i):
        w, ret = accum
        ret += jax.scipy.stats.poisson.pmf(i, t * Lambda) * w
        w = P(w, Q, Q1, q, Lambda)
        # This needs to return two arguments so we make the second one to be None
        return (w, ret), None

    (w, ret), _ = jax.lax.scan(f, (w, ret), jnp.arange(1+N))

    return ret[-1]

def unit_vector(size, index):
    # Vector of zeros
    tmp = jnp.zeros(size)
    # ith element will be 1
    # JAX arrays are immutable so we cannot do tmp[index] = 1
    tmp = tmp.at[index].set(1)
    return tmp

# Instead using a coalesce rate c like migration_process jupyter notebook we use q
def solve_ode(time_discretization, init_vertices, q, m, BCOO_indices, index, tau):
    # edges is 0 ONLY IF you have a graph with only a SINGLE node
    num_nodes = len(q)
    shape = (num_nodes, num_nodes)
    if len(m) == 0:
        # If the graph has absolutely no edges then the graph has 0 for every edge weight
        # in the way we code Q.
        copy = jesp.BCOO(((jnp.array([], dtype=jnp.float32), jnp.empty((0, len(shape)), dtype=jnp.int32))), shape=shape)
    else:
        copy = jesp.BCOO((m.astype(float), BCOO_indices.astype(jnp.int64)), shape=shape)
        
    Q = copy + copy.transpose()
    Q1 = Q.sum(0).todense()[:, None] + Q.sum(1).todense()[None, :]    
    e_i = unit_vector(num_nodes, init_vertices[0])
    e_j = unit_vector(num_nodes, init_vertices[1])
    tmp = (e_i.reshape(-1, 1)) @ jnp.array([e_j])
    y0 = jnp.append(jnp.ravel(tmp), 0)
    sol = vmap(expm_mult_unif, in_axes = (None, None, None, None, 0))(y0, Q, Q1, q, time_discretization)
    sol = jnp.insert(sol, 0, 0.0)
    probabilities = jnp.diff(sol)
    prob_not_coal = 1 - sol[index]
    return probabilities, sol, prob_not_coal

def from_pmf(t, p):
    """Initialize a size history from a distribution function of coalescent times.

    Args:
        t: time points
        p: p[i] = probability of coalescing in [t[i], t[i + 1])
    """
    sum_initial = 0.0
    c = []
    p = jnp.clip(p, 0, jnp.inf)
    tol = 1e-4
    # difference in times
    dts = jnp.diff(t)
    p_truncated = p[:-1]

    # scan function
    def scan_fn(carry, inputs):
        sum_i = carry
        dt, p_i = inputs
        x = p_i / (1 - sum_i)
        bad = (x >= 1) | (x <= 0) | jnp.isnan(x) | jnp.isinf(x)
        x_safe = jnp.where(bad, 0., x)
        c_safe = jnp.where(bad, tol, -jnp.log1p(-x_safe) / dt)
        c_final = jnp.where(c_safe > tol, c_safe, tol)
        sum_i = sum_i + p_i
        return sum_i, c_final
    
    inputs = (dts, p_truncated)
    sum_final, c = jax.lax.scan(scan_fn, sum_initial, inputs)
    # Append the last coalescent rate (not identifiable from data)
    c = jnp.append(c, c[-1])
    return jnp.array(t), jnp.array(c)

def density(m, q, tau, q_anc, BCOO_indices, init_vertices, tmrca_span): 
    t = jnp.geomspace(1e-4, jnp.max(tmrca_span[:, 0]), 1000)
    # t = jnp.insert(jnp.geomspace(1e-4, 15.0, 1000), 0, 0.0)
    index = jnp.searchsorted(t, tau, side = "right")
    t_aug = jnp.insert(t, index, tau)
    index = index + 1
    probabilities, sol, p_not_coal = solve_ode(t_aug, init_vertices=init_vertices, q=q, m=m, BCOO_indices=BCOO_indices, index=index, tau=tau)
    # print(probabilities)
    t_aug = jnp.insert(t_aug, 0, 0.0) 
    rates = jnp.diff(t_aug) * (t_aug > tau)[1:] * (q_anc)
    # rates = jnp.diff(jnp.maximum(t_aug, tau)) * q_anc
    cum_rates = jnp.cumsum(rates)
    # cum_rates = (t_aug - tau).clip(0)[1:] * q_anc
    after_tau_surv_probs = jnp.exp(-cum_rates) * p_not_coal
    after_tau_epoch_probs = jnp.abs(jnp.diff(after_tau_surv_probs))

    indices = jnp.arange(len(t)+1)
    after_tau_epoch_probs = jnp.append(after_tau_epoch_probs, 0)
    # Define conditions
    conditions = [
        (indices == index-1),  
        (indices >= index)
    ]

    # Define choices corresponding to each condition
    choices = [
        probabilities[index-1]+after_tau_epoch_probs[index-1],
        after_tau_epoch_probs
    ]

    # Default case (when none of the above conditions are met, use 'a')
    result = jnp.select(conditions, choices, default=probabilities)
    final_probabilities = result.at[-1].set(1-jnp.sum(result))
    # print(final_probabilities)
    # vmap_surv_prob = 1 - np.insert(np.cumsum(final_probabilities), 0, 0)[:-1]
    # vmap_surv_prob

    t = jnp.insert(t, 0, 0.0)
    t, c = from_pmf(t, final_probabilities)
    # c = c / (2*Ne)
    # t = t * 2 * Ne

    # c = c / 2 # ASK JT!!!

    import phlash
    from phlash.likelihood.arg import log_density

    eta = phlash.size_history.SizeHistory(t=t, c=c)
    dm = phlash.size_history.DemographicModel(
        eta=eta, theta=None, rho=1e-8 * 1e4 * 4
    )
    # return -(jnp.sum(jnp.log(c_values) + log_p_values))  # Negative for minimization
    
    return log_density(dm, tmrca_span[None])# , vmap_surv_prob, c

In [7]:
import jax.numpy as jnp
import jax
def sample_tmrca_spanss(ts, subkey=jax.random.PRNGKey(1), num_pop=2):
    samples = jax.random.choice(subkey, ts.num_samples, shape=(2,), replace=False)
    sample1, sample2 = samples[0], samples[1]

    pop1 = ts.node(sample1.item(0)).population - 1
    pop2 = ts.node(sample2.item(0)).population - 1
    sample_config = (pop1, pop2)

    # Precompute all TMRCAs and spans into arrays
    tmrcas = []
    spans = []
    for tree in ts.trees():
        spans.append(tree.interval.right - tree.interval.left)
        tmrcas.append(tree.tmrca(sample1, sample2) / (2*1e4))
    
    # Convert to JAX arrays
    tmrcas = jnp.array(tmrcas)  # Shape: (num_trees,)
    spans = jnp.array(spans)    # Shape: (num_trees,)
    tmrcas_spans = jnp.stack([tmrcas, spans], axis=1)  # Shape: (num_trees, 2)

    # Merge consecutive spans with same TMRCA
    def merge_spans(carry, x):
        current_tmrca, current_span, idx, output = carry
        tmrca, span = x
        
        # Update each component individually
        new_tmrca = jnp.where(tmrca == current_tmrca, current_tmrca, tmrca)
        new_span = jnp.where(tmrca == current_tmrca, current_span + span, span)
        new_idx = jnp.where(tmrca == current_tmrca, idx, idx + 1)
        new_output = jnp.where(
            tmrca == current_tmrca, 
            output, 
            output.at[idx].set(jnp.array([current_tmrca, current_span]))
        )
        
        return (new_tmrca, new_span, new_idx, new_output), None

    init_carry = (tmrcas_spans[0, 0], 0.0, 0, jnp.full((ts.num_trees, 2), jnp.array([1.0, 0.0])))
    final_carry, _ = jax.lax.scan(merge_spans, init_carry, tmrcas_spans)
    final_tmrca, final_span, _, final_output = final_carry
    final_output = final_output.at[-1].set(jnp.array([final_tmrca, final_span]))
    is_ones = jnp.all(final_output == jnp.array([1.0, 0.0]), axis=1)
    reordered_arr = jnp.concatenate([final_output[~is_ones], final_output[is_ones]])

    return reordered_arr, sample_config


In [8]:
import networkx as nx
import numpy as np
import jax.experimental.sparse as jesp
import scipy.sparse as sp
graph = nx.Graph()
graph.add_edge(0, 1)
tmp = jesp.BCOO.from_scipy_sparse(sp.triu(nx.adjacency_matrix(graph, np.arange(2)).astype(float)))

In [9]:
import jax
import jax.numpy as jnp
import optax
from functools import partial

def transform_params(params):
    m_unconstrained, q_unconstrained, tau_unconstrained, q_anc_unconstrained = params
    # Sigmoid maps to (0, 1), then scale to (0, 2)
    m = jax.nn.softplus(m_unconstrained)
    q = jax.nn.softplus(q_unconstrained)
    tau = jax.nn.softplus(tau_unconstrained)
    q_anc = jax.nn.softplus(q_anc_unconstrained)
    return (m, q, tau, q_anc)

def inverse_transform_params(params):
    """Convert constrained params back to unconstrained via inverse softplus."""
    m, q, tau, q_anc = params
    m_unconstrained = jnp.log(jnp.exp(m) - 1)  # Inverse softplus
    q_unconstrained = jnp.log(jnp.exp(q) - 1)
    tau_unconstrained = jnp.log(jnp.exp(tau) - 1)
    q_anc_unconstrained = jnp.log(jnp.exp(q_anc) - 1)
    return (m_unconstrained, q_unconstrained, tau_unconstrained, q_anc_unconstrained)

def loss_fn(params, tmp_indices, sample_config, tmrca_span):
    m, q, tau, q_anc = params
    return -density(m, q, tau, q_anc, tmp_indices, sample_config, tmrca_span)

# 2. Initialize parameters
def init_params(m_shape, q_shape):
    # Initialize with reasonable values for your problem
    m = jnp.ones(m_shape) * 0.001  # Example initialization
    q = jnp.ones(q_shape) * 0.5
    tau = 1.0  # Example initial value
    q_anc = 0.5 # Example initial value

    # Convert to unconstrained space for optimization
    m_unconstrained = jnp.log(jnp.exp(m) - 1)  # Inverse softplus
    q_unconstrained = jnp.log(jnp.exp(q) - 1)
    tau_unconstrained = jnp.log(jnp.exp(tau) - 1)
    q_anc_unconstrained = jnp.log(jnp.exp(q_anc) - 1)
    
    return (m_unconstrained, q_unconstrained, tau_unconstrained, q_anc_unconstrained)
    # return (m, q, tau, q_anc)

# 4. Set up the optimization
def optimize(params, optimizer, tmp_indices, ts, num_steps=1000):
    # Initialize optimizer state
    opt_state = optimizer.init(params)
    
    # Initialize random key
    key = jax.random.PRNGKey(0)
    
    # Define the update step (now without ts)
    @jax.jit
    def step(params, opt_state, tmp_indices, sample_config, tmrca_span):
        # Compute loss and gradients
        params = transform_params(params)
        jax.debug.print("params:{}", params, ordered=True)
        loss, grads = jax.value_and_grad(loss_fn)(params, tmp_indices, sample_config, tmrca_span)

        # Update parameters
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        
        return params, opt_state, loss, grads
    
    # Training loop
    losses = []
    for step_num in range(num_steps):
        # Generate new samples outside the JIT-compiled step
        key, subkey = jax.random.split(key)
        tmrca_span, sample_config = sample_tmrca_spanss(ts, subkey)
        
        # Perform optimization step
        params, opt_state, loss, grads = step(params, opt_state, tmp_indices, sample_config, tmrca_span)
        print(inverse_transform_params(params))
        # print(loss)
        # print(grads)
        losses.append(loss)
        
        if step_num % 100 == 0:  # Print progress every 100 steps
            print(f"Step {step_num}, Loss: {loss}")
    
    return params, losses

# Example usage
# Initialize parameters
params = init_params(m_shape=jnp.array(m).shape, q_shape=jnp.array(q).shape)  # Adjust shapes as needed

optimizer = optax.chain(optax.clip_by_global_norm(1.0),
                        optax.adam(learning_rate=0.1)
)

with jax.debug_nans(False), jax.disable_jit(False):
    optimized_params_unconstrained, losses = optimize(
        params, optimizer, tmp.indices, ts, num_steps=1000
    )

# Convert back to constrained space for final results
optimized_params = transform_params(optimized_params_unconstrained)

print("\nOptimized parameters:")
print(f"m: {optimized_params[0]}")
print(f"q: {optimized_params[1]}")
print(f"tau: {optimized_params[2]}")
print(f"q_anc: {optimized_params[3]}")

params:(Array([0.001], dtype=float64), Array([0.5, 0.5], dtype=float64), Array(1., dtype=float64), Array(0.5, dtype=float64))
(Array([-2.24170977], dtype=float64), Array([-0.19587302, -0.19587302], dtype=float64), Array(0.37816477, dtype=float64), Array(-0.19587078, dtype=float64))
Step 0, Loss: 463768.7141677801
params:(Array([0.74492176], dtype=float64), Array([1.03748718, 1.03748718], dtype=float64), Array(1.24115397, dtype=float64), Array(1.03748783, dtype=float64))
(Array([0.20889823], dtype=float64), Array([0.71323129, 0.71369017], dtype=float64), Array(0.79255207, dtype=float64), Array(0.48321442, dtype=float64))
params:(Array([1.17320002], dtype=float64), Array([1.39638684, 1.39661861], dtype=float64), Array(1.43723147, dtype=float64), Array(1.28682687, dtype=float64))
(Array([0.86696662], dtype=float64), Array([1.21860492, 1.22507474], dtype=float64), Array(1.09970292, dtype=float64), Array(0.84488416, dtype=float64))
params:(Array([1.47697599], dtype=float64), Array([1.683146

KeyboardInterrupt: 