In [1]:
# exact same setup as before, except we don't need the mutations, just the tree itself
import msprime as msp
import demes
import demesdraw

Ne_anc = 1e4
Ne = [1e4, 1e4]
m = [0.0001]
tau = 1000.0
demo = msp.Demography()
demo.add_population(initial_size= Ne_anc, name = "anc")
demo.add_population(initial_size = Ne[0], name = "P0")
demo.add_population(initial_size = Ne[1], name = "P1")
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate= m[0])
tmp = [f"P{i}" for i in range(2)]
demo.add_population_split(time = tau, derived=tmp, ancestral="anc")
g = demo.to_demes()
# demesdraw.tubes(g)
# print(g)
sample_size = 10
samples = {f"P{i}": sample_size for i in range(2)}
anc = msp.sim_ancestry(samples=samples, demography=demo, recombination_rate=1e-8, sequence_length=1e7)
ts = msp.sim_mutations(anc, rate=1e-8)

In [2]:
"Matrix exponential using uniformization method"

from beartype import beartype
import jax.numpy as jnp
from jax.scipy.stats import poisson
import equinox as eqx
from beartype.typing import Callable
from jaxtyping import Float, Array, ArrayLike, ScalarLike


@beartype
def expm_unif(
    t: ScalarLike,
    A: Callable[[Float[ArrayLike, "*shape"]], Float[ArrayLike, "*shape"]],
    v: Float[ArrayLike, "*shape"],
    alpha: ScalarLike,
    eps: float = 1e-6,
) -> Float[Array, "*shape"]:
    """Compute matrix exponential via uniformization.

    Args:
        t: Time scalar.
        A: Function that computes action A @ v, where A is an intensity matrix.
        alpha: max(abs(diag(A)))
        eps: Tolerance for the uniformization method.

    Returns:
        expm(t*A) @ v: Matrix exponential applied to vector v.

    Notes:
        This implements Algorithm 1 of "INEXACT UNIFORMIZATION METHOD FOR COMPUTING
        TRANSIENT DISTRIBUTIONS OF MARKOV CHAINS", Sidje et al., SIAM J. Sci. Comput.,
        2007.
    """

    def P(v):
        return v + A(v) / alpha

    THETA = 100.0
    m = jnp.ceil(alpha * t / THETA).astype(int)
    t = t / m
    s = alpha * t
    r = jnp.exp(-s)

    ell = eqx.internal.while_loop(
        lambda ell: poisson.cdf(ell, alpha * t) < 1 - eps,
        lambda ell: ell + 1,
        0,
        kind="checkpointed",
        checkpoints=10,
    )

    def body2(tup):
        k, w, f = tup
        f = (s / k) * P(f)
        w = w + f
        return (k + 1, w, f)

    def body1(tup):
        i, w = tup
        _, w, _ = eqx.internal.while_loop(
            lambda tup: tup[0] <= ell,
            body2,
            (1, w, w),
            kind="checkpointed",
            checkpoints=10,
        )
        return (i + 1, w * r)

    _, w = eqx.internal.while_loop(
        lambda tup: tup[0] <= m,
        body1,
        (1, v),
        kind="checkpointed",
        checkpoints=10,
    )

    return w

from abc import ABC, abstractmethod, abstractproperty
from dataclasses import dataclass

import interpax
import jax.numpy as jnp
from jaxtyping import Array, ArrayLike, Float, Scalar, ScalarLike


@dataclass(kw_only=True)
class CoalRate(ABC):
    @abstractproperty
    def jumps(self) -> Float[Array, "_"]:
        """Return the times at which the coalescent rate changes discontinuously."""
        pass

    @abstractmethod
    def __call__(self, t: ScalarLike) -> Scalar:
        """Evaluate the coalescent rate at time t."""
        pass

    def R(self, a: ScalarLike, b: ScalarLike) -> Scalar:
        r"""Integrated coalescent rate,

        int_a^b self(t) dt
        """
        pass


@dataclass(kw_only=True)
class PiecewiseConstant(CoalRate):
    """Piecewise constant coalescent rate."""

    c: Float[ArrayLike, "T"]
    t: Float[ArrayLike, "T"]

    @property
    def jumps(self) -> Float[Array, "T"]:
        """Return the times at which the coalescent rate changes discontinuously."""
        return jnp.array(self.t)

    @property
    def _ppoly(self) -> interpax.PPoly:
        return interpax.PPoly(self.c[None], jnp.append(self.t, jnp.inf), check=False)

    def __call__(self, t: ScalarLike) -> Scalar:
        return self._ppoly(t)

    def R(self, a: ScalarLike, b: ScalarLike) -> Scalar:
        return self._ppoly.integrate(a, b)

"Likelihood of an ARG"

import diffrax as dfx
import jax
import jax.numpy as jnp
from jax import vmap
from jax.scipy.special import xlog1py, xlogy
from jaxtyping import Array, Float, Scalar, ScalarLike

def loglik(eta: CoalRate, r: ScalarLike, data: Float[Array, "intervals 2"]) -> Scalar:
    """Compute the log-likelihood of the data given the demographic model.

    Args:
        eta: Coalescent rate at time t.
        r: float, the recombination rate.
        data: the data to compute the likelihood for. The first column is the TMRCA, and
              the second column is the span.

    Notes:
        - Successive spans that have the same TMRCA should be merged into one span:
          <tmrca, span1> + <tmrca, span1> = <tmrca, span + span>.
        - Missing data/padding indicated by span<=0.
    """
    times, spans = data.T
    i = times.argsort()
    sorted_times = times[i]

    def f(t, y, _):
        c = eta(t)
        A = jnp.array([[-r, r, 0.0], [c, -2 * c, c], [0.0, 0.0, 0.0]])
        return A.T @ y

    y0 = jnp.array([1.0, 0.0, 0.0])
    solver = dfx.Tsit5()
    term = dfx.ODETerm(f)
    ssc = dfx.PIDController(rtol=1e-6, atol=1e-6, jump_ts=eta.jumps)
    T = times.max()
    sol = dfx.diffeqsolve(
        term,
        solver,
        0.0,
        T,
        dt0=0.001,
        y0=y0,
        stepsize_controller=ssc,
        saveat=dfx.SaveAt(ts=sorted_times),
    )

    # invert the sorting so that cscs matches times
    i_inv = i.argsort()
    cscs = sol.ys[i_inv]

    @vmap
    def p(t0, csc0, t1, csc1, span):
        p_nr_t0, p_float_t0, p_coal_t0 = csc0
        p_nr_t1, p_float_t1, p_coal_t1 = csc1
        # no recomb for first span - 1 positions
        r1 = xlogy(span - 1, p_nr_t0)
        # coalescence at t1
        r2 = jnp.log(eta(t1))
        # back-coalescence process up to t1, depends to t0 >< t1
        r3 = jnp.where(
            t0 < t1, jnp.log(p_float_t0) - eta.R(t0, t1), jnp.log(p_float_t1)
        )
        return r1 + r2 + r3

    ll = p(times[:-1], cscs[:-1], times[1:], cscs[1:], spans[:-1]).sum()
    # for the last position, we only know span was at least as long
    ll += xlogy(spans[-1], cscs[-1, 0])
    return ll

In [3]:
import jax.numpy as jnp
import jax
import jax.experimental.sparse as jesp
import scipy.sparse as sp
from jax.experimental import sparse
from jax.experimental.sparse import BCOO
from jax import lax
import scipy
from jax import vmap
import phlash
from phlash.likelihood.arg import log_density


def _matvec(params, y):
    Q, Q1, q = params
    n = q.shape[0]
    yn = y[:-1].reshape(n, n)
    yd = jnp.diag(yn)
    ret = Q @ yn + yn @ Q
    ret -= Q1 * yn
    ret -= jnp.diag(q * yd)
    return jnp.append(ret.reshape(n ** 2), q.dot(yd))

def unit_vector(size, index):
    # Vector of zeros
    tmp = jnp.zeros(size)
    # ith element will be 1
    # JAX arrays are immutable so we cannot do tmp[index] = 1
    tmp = tmp.at[index].set(1)
    return tmp

# Instead using a coalesce rate c like migration_process jupyter notebook we use q
def solve_ode(time_discretization, init_vertices, q, m, BCOO_indices, index, tau):
    # edges is 0 ONLY IF you have a graph with only a SINGLE node
    num_nodes = len(q)
    shape = (num_nodes, num_nodes)
    if len(m) == 0:
        # If the graph has absolutely no edges then the graph has 0 for every edge weight
        # in the way we code Q.
        copy = jesp.BCOO(((jnp.array([], dtype=jnp.float32), jnp.empty((0, len(shape)), dtype=jnp.int32))), shape=shape)
    else:
        copy = jesp.BCOO((m.astype(float), BCOO_indices.astype(jnp.int64)), shape=shape)
        
    Q = copy + copy.transpose()
    Q1 = Q.sum(0).todense()[:, None] + Q.sum(1).todense()[None, :]
    e_i = unit_vector(num_nodes, init_vertices[0])
    e_j = unit_vector(num_nodes, init_vertices[1])
    tmp = (e_i.reshape(-1, 1)) @ jnp.array([e_j])
    y0 = jnp.append(jnp.ravel(tmp), 0)
    from functools import partial
    A = partial(_matvec, (Q, Q1, q))
    sol = vmap(expm_unif, in_axes = (0, None, None, None))(time_discretization, A, y0, 2*jnp.max(m) + jnp.max(q))
    sol = sol[:, -1]
    print(sol)
    sol = jnp.insert(sol, 0, 0.0)
    probabilities = jnp.diff(sol)
    prob_not_coal = 1 - sol[index]
    return probabilities, sol, prob_not_coal

def from_pmf(t, p):
    """Initialize a size history from a distribution function of coalescent times.

    Args:
        t: time points
        p: p[i] = probability of coalescing in [t[i], t[i + 1])
    """
    sum_initial = 0.0
    c = []
    p = jnp.clip(p, 0, jnp.inf)
    tol = 1e-4
    # difference in times
    dts = jnp.diff(t)
    p_truncated = p[:-1]

    # scan function
    def scan_fn(carry, inputs):
        sum_i = carry
        dt, p_i = inputs
        x = p_i / (1 - sum_i)
        bad = (x >= 1) | (x <= 0) | jnp.isnan(x) | jnp.isinf(x)
        x_safe = jnp.where(bad, 0., x)
        c_safe = jnp.where(bad, tol, -jnp.log1p(-x_safe) / dt)
        c_final = jnp.where(c_safe > tol, c_safe, tol)
        sum_i = sum_i + p_i
        return sum_i, c_final
    
    inputs = (dts, p_truncated)
    sum_final, c = jax.lax.scan(scan_fn, sum_initial, inputs)
    # Append the last coalescent rate (not identifiable from data)
    c = jnp.append(c, c[-1])
    return jnp.array(t), jnp.array(c)

def density(m, Ne, tau, Ne_anc, BCOO_indices, init_vertices, tmrca_span): 
    t = jnp.geomspace(1e-4, jnp.max(tmrca_span[:, 0]), 1000)
    # t = jnp.insert(jnp.geomspace(1e-4, 15.0, 1000), 0, 0.0)
    index = jnp.searchsorted(t, tau, side = "right")
    t_aug = jnp.insert(t, index, tau)
    index = index + 1
    probabilities, sol, p_not_coal = solve_ode(t_aug, init_vertices=init_vertices, q=1/(2*Ne), m=m, BCOO_indices=BCOO_indices, index=index, tau=tau)
    # print(probabilities)
    t_aug = jnp.insert(t_aug, 0, 0.0) 
    rates = jnp.diff(t_aug) * (t_aug > tau)[1:] * (1/(2*Ne_anc))
    # rates = jnp.diff(jnp.maximum(t_aug, tau)) * q_anc
    cum_rates = jnp.cumsum(rates)
    # cum_rates = (t_aug - tau).clip(0)[1:] * q_anc
    after_tau_surv_probs = jnp.exp(-cum_rates) * p_not_coal
    after_tau_epoch_probs = jnp.abs(jnp.diff(after_tau_surv_probs))

    indices = jnp.arange(len(t)+1)
    after_tau_epoch_probs = jnp.append(after_tau_epoch_probs, 0)
    # Define conditions
    conditions = [
        (indices == index-1),  
        (indices >= index)
    ]

    # Define choices corresponding to each condition
    choices = [
        probabilities[index-1]+after_tau_epoch_probs[index-1],
        after_tau_epoch_probs
    ]

    # Default case (when none of the above conditions are met, use 'a')
    result = jnp.select(conditions, choices, default=probabilities)
    final_probabilities = result.at[-1].set(1-jnp.sum(result))
    # print(final_probabilities)
    # vmap_surv_prob = 1 - np.insert(np.cumsum(final_probabilities), 0, 0)[:-1]
    # vmap_surv_prob

    t = jnp.insert(t, 0, 0.0)
    t, c = from_pmf(t, final_probabilities)

    # import phlash
    # from phlash.likelihood.arg import log_density

    # eta = phlash.size_history.SizeHistory(t=t, c=c)
    # dm = phlash.size_history.DemographicModel(
    #     eta=eta, theta=None, rho=1e-8
    # )
    eta = PiecewiseConstant(c=c, t=t)
    log_density = loglik(eta, 1e-8, jnp.array(tmrca_span, dtype=jnp.float64))
    
    # return log_density(dm, tmrca_span[None])# , vmap_surv_prob, c
    return log_density

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import jax.numpy as jnp
import jax
def sample_tmrca_spanss(ts, subkey=jax.random.PRNGKey(1), num_pop=2):
    samples = jax.random.choice(subkey, ts.num_samples, shape=(2,), replace=False)
    sample1, sample2 = samples[0], samples[1]

    pop1 = ts.node(sample1.item(0)).population - 1
    pop2 = ts.node(sample2.item(0)).population - 1
    sample_config = (pop1, pop2)

    # Precompute all TMRCAs and spans into arrays
    tmrcas = []
    spans = []
    for tree in ts.trees():
        spans.append(tree.interval.right - tree.interval.left)
        tmrcas.append(tree.tmrca(sample1, sample2))
    
    # Convert to JAX arrays
    tmrcas = jnp.array(tmrcas)  # Shape: (num_trees,)
    spans = jnp.array(spans)    # Shape: (num_trees,)
    tmrcas_spans = jnp.stack([tmrcas, spans], axis=1)  # Shape: (num_trees, 2)

    # Merge consecutive spans with same TMRCA
    def merge_spans(carry, x):
        current_tmrca, current_span, idx, output = carry
        tmrca, span = x
        
        # Update each component individually
        new_tmrca = jnp.where(tmrca == current_tmrca, current_tmrca, tmrca)
        new_span = jnp.where(tmrca == current_tmrca, current_span + span, span)
        new_idx = jnp.where(tmrca == current_tmrca, idx, idx + 1)
        new_output = jnp.where(
            tmrca == current_tmrca, 
            output, 
            output.at[idx].set(jnp.array([current_tmrca, current_span]))
        )
        
        return (new_tmrca, new_span, new_idx, new_output), None

    init_carry = (tmrcas_spans[0, 0], 0.0, 0, jnp.full((ts.num_trees, 2), jnp.array([1.0, 0.0])))
    final_carry, _ = jax.lax.scan(merge_spans, init_carry, tmrcas_spans)
    final_tmrca, final_span, _, final_output = final_carry
    final_output = final_output.at[-1].set(jnp.array([final_tmrca, final_span]))
    is_ones = jnp.all(final_output == jnp.array([1.0, 0.0]), axis=1)
    reordered_arr = jnp.concatenate([final_output[~is_ones], final_output[is_ones]])

    return reordered_arr, sample_config


INFO:2025-07-22 00:22:25,655:jax._src.xla_bridge:945: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


INFO:2025-07-22 00:22:25,657:jax._src.xla_bridge:945: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)


In [5]:
import networkx as nx
import numpy as np
import jax.experimental.sparse as jesp
import scipy.sparse as sp
graph = nx.Graph()
graph.add_edge(0, 1)
tmp = jesp.BCOO.from_scipy_sparse(sp.triu(nx.adjacency_matrix(graph, np.arange(2)).astype(float)))

In [None]:
import jax
import jax.numpy as jnp
import optax
from functools import partial

def transform_params(params):
    m_unconstrained, q_unconstrained, tau_unconstrained, q_anc_unconstrained = params
    # Sigmoid maps to (0, 1), then scale to (0, 2)
    m = jax.nn.softplus(m_unconstrained)
    q = jax.nn.softplus(q_unconstrained)
    tau = jax.nn.softplus(tau_unconstrained)
    q_anc = jax.nn.softplus(q_anc_unconstrained)
    return (m, q, tau, q_anc)

def inverse_transform_params(params):
    """Convert constrained params back to unconstrained via inverse softplus."""
    m, q, tau, q_anc = params
    m_unconstrained = jnp.log(jnp.exp(m) - 1)  # Inverse softplus
    q_unconstrained = jnp.log(jnp.exp(q) - 1)
    tau_unconstrained = jnp.log(jnp.exp(tau) - 1)
    q_anc_unconstrained = jnp.log(jnp.exp(q_anc) - 1)
    return (m_unconstrained, q_unconstrained, tau_unconstrained, q_anc_unconstrained)

def loss_fn(params, tmp_indices, sample_config, tmrca_span):
    m, q, tau, q_anc = params
    return density(m, q, tau, q_anc, tmp_indices, sample_config, tmrca_span)

# 2. Initialize parameters
def init_params(m_shape, q_shape):
    # Initialize with reasonable values for your problem
    m = jnp.ones(m_shape) * 0.0001  # Example initialization
    q = jnp.ones(q_shape) * 10000
    tau = 1000.0  # Example initial value
    q_anc = 10000 # Example initial value

    # Convert to unconstrained space for optimization
    m_unconstrained = jnp.log(jnp.exp(m) - 1)  # Inverse softplus
    q_unconstrained = jnp.log(jnp.exp(q) - 1)
    tau_unconstrained = jnp.log(jnp.exp(tau) - 1)
    q_anc_unconstrained = jnp.log(jnp.exp(q_anc) - 1)
    
    return (m_unconstrained, q_unconstrained, tau_unconstrained, q_anc_unconstrained)
    # return (m, q, tau, q_anc)

# 4. Set up the optimization
def optimize(params, optimizer, tmp_indices, ts, num_steps=1000):
    # Initialize optimizer state
    opt_state = optimizer.init(params)
    
    # Initialize random key
    key = jax.random.PRNGKey(0)
    
    # Define the update step (now without ts)
    @jax.jit
    def step(params, opt_state, tmp_indices, sample_config, tmrca_span):
        # Compute loss and gradients
        params = transform_params(params)
        jax.debug.print("params:{}", params, ordered=True)
        loss, grads = jax.value_and_grad(loss_fn)(params, tmp_indices, sample_config, tmrca_span)

        # Update parameters
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        
        return params, opt_state, loss, grads
    
    # Training loop
    losses = []
    for step_num in range(num_steps):
        # Generate new samples outside the JIT-compiled step
        key, subkey = jax.random.split(key)
        tmrca_span, sample_config = sample_tmrca_spanss(ts, subkey)
        
        # Perform optimization step
        params, opt_state, loss, grads = step(params, opt_state, tmp_indices, sample_config, tmrca_span)
        print(inverse_transform_params(params))
        # print(loss)
        # print(grads)
        losses.append(loss)
        
        if step_num % 100 == 0:  # Print progress every 100 steps
            print(f"Step {step_num}, Loss: {loss}")
    
    return params, losses

# Example usage
# Initialize parameters
params = init_params(m_shape=jnp.array(m).shape, q_shape=jnp.array(Ne).shape)  # Adjust shapes as needed

optimizer = optax.chain(optax.clip_by_global_norm(1.0),
                        optax.adam(learning_rate=0.1)
)

with jax.debug_nans(False), jax.disable_jit(False):
    optimized_params_unconstrained, losses = optimize(
        params, optimizer, tmp.indices, ts, num_steps=1000
    )

# Convert back to constrained space for final results
optimized_params = transform_params(optimized_params_unconstrained)

print("\nOptimized parameters:")
print(f"m: {optimized_params[0]}")
print(f"q: {optimized_params[1]}")
print(f"tau: {optimized_params[2]}")
print(f"q_anc: {optimized_params[3]}")

Traced<ShapedArray(float64[1001])>with<JVPTrace> with
  primal = Traced<ShapedArray(float64[1001])>with<DynamicJaxprTrace>
  tangent = Traced<ShapedArray(float64[1001])>with<JaxprTrace> with
    pval = (ShapedArray(float64[1001]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x162874850>, in_tracers=(Traced<ShapedArray(float64[1001,1]):JaxprTrace>,), out_tracer_refs=[<weakref at 0x162a00d60; to 'JaxprTracer' at 0x162a011d0>], out_avals=[ShapedArray(float64[1001])], primitive=squeeze, params={'dimensions': (1,)}, effects=frozenset(), source_info=<jax._src.source_info_util.SourceInfo object at 0x1629ec610>, ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=None, xla_metadata={}))
