In [5]:
# flow_marked_pp_minimal.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax import struct

# -----------------------------
# Small utilities
# -----------------------------


def searchsorted_right(edges: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray:
    return jnp.searchsorted(edges, x, side="right")


def bin_spike_times_to_bins(
    spike_times: jnp.ndarray, time_edges: jnp.ndarray
) -> jnp.ndarray:
    idx = searchsorted_right(time_edges, spike_times) - 1
    idx = jnp.clip(idx, 0, time_edges.shape[0] - 2)
    return idx


def interp_position_at_times(
    position_time: jnp.ndarray, position: jnp.ndarray, query_times: jnp.ndarray
) -> jnp.ndarray:
    # Linear interpolation on the time axis
    idx_r = searchsorted_right(position_time, query_times)
    idx_l = jnp.clip(idx_r - 1, 0, position_time.shape[0] - 1)
    idx_r = jnp.clip(idx_r, 0, position_time.shape[0] - 1)
    t0, t1 = position_time[idx_l], position_time[idx_r]
    w = jnp.where(t1 > t0, (query_times - t0) / (t1 - t0 + 1e-9), 0.0)
    p0, p1 = position[idx_l], position[idx_r]
    return (1 - w)[:, None] * p0 + w[:, None] * p1  # (N,P)


# -----------------------------
# Intensity network λ0(x)
# -----------------------------


class LambdaNet(nn.Module):
    """Position-only baseline rate λ0(x). Output > 0 via softplus."""

    hidden: int = 128
    depth: int = 2

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        h = x
        for _ in range(self.depth):
            h = nn.silu(nn.Dense(self.hidden)(h))
        return nn.softplus(nn.Dense(1)(h)).squeeze(-1) + 1e-6  # (N,)


# -----------------------------
# Conditional RealNVP-style flow for marks: m -> z
# (custom, tiny; no external flow libs)
# -----------------------------


class CouplingNN(nn.Module):
    """MLP that outputs shift and log_scale for the transformed dims."""

    out_dim: int
    hidden: int = 128
    depth: int = 2

    @nn.compact
    def __call__(
        self, masked: jnp.ndarray, ctx: jnp.ndarray
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # masked: (N, d_mask); ctx: (N, C)
        h = jnp.concatenate([masked, ctx], axis=-1)
        for _ in range(self.depth):
            h = nn.gelu(nn.Dense(self.hidden)(h))
        raw = nn.Dense(self.out_dim)(h)  # 2 * d_trans
        shift, log_scale = jnp.split(raw, 2, axis=-1)
        log_scale = jnp.clip(log_scale, -5.0, 5.0)  # stabilize
        return shift, log_scale


class MarkFlow(nn.Module):
    """Conditional flow: log p(m | x) via coupling layers conditioned on x."""

    mark_dim: int
    ctx_dim: int = 32
    n_layers: int = 6
    hidden: int = 128

    def setup(self):
        # Build alternating masks and their FIXED integer indices up front
        masks = []
        idx_m_list = []
        idx_t_list = []
        for i in range(self.n_layers):
            mask = np.zeros(self.mark_dim, dtype=bool)
            mask[i % 2 :: 2] = True
            masks.append(mask)
            idx_m_list.append(np.where(mask)[0].astype(np.int32))  # conditioned dims
            idx_t_list.append(np.where(~mask)[0].astype(np.int32))  # transformed dims
        self.masks = tuple(masks)  # optional (for debugging)
        self.idx_m = tuple(idx_m_list)  # tuple of np.int32 arrays
        self.idx_t = tuple(idx_t_list)

        # Context net (from position x)
        self.ctx_net = nn.Sequential([nn.Dense(self.ctx_dim), nn.tanh])

        # One conditioner per layer; output = (shift, log_scale) for transformed dims
        self.conds = tuple(
            CouplingNN(out_dim=2 * len(self.idx_t[i]), hidden=self.hidden)
            for i in range(self.n_layers)
        )

    def __call__(self, m: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray:
        """
        Return log p(m | x) using base N(0, I) and affine coupling transforms.
        m: (N, D), x: (N, P)
        """
        ctx = self.ctx_net(x)  # (N, C)
        z = m
        sum_log_det = jnp.zeros((m.shape[0],))
        for idx_m, idx_t, cond in zip(self.idx_m, self.idx_t, self.conds):
            # Use PRECOMPUTED integer indices (static) — no jnp.where in jit
            z_m = z[:, idx_m]  # (N, d_m)
            z_t = z[:, idx_t]  # (N, d_t)

            shift, log_scale = cond(z_m, ctx)  # (N, d_t) each
            log_scale = jnp.clip(log_scale, -5.0, 5.0)
            z_t_new = (z_t - shift) * jnp.exp(-log_scale)
            sum_log_det = sum_log_det - jnp.sum(log_scale, axis=-1)

            z = z.at[:, idx_t].set(z_t_new)

        D = self.mark_dim
        log_pz = -0.5 * (jnp.sum(z * z, axis=-1) + D * jnp.log(2.0 * jnp.pi))
        return log_pz + sum_log_det


# -----------------------------
# Config & state
# -----------------------------


@dataclass
class ModelConfig:
    position_dim: int = 2
    mark_dim: int = 8
    hidden: int = 128
    depth_lambda: int = 2
    ctx_dim: int = 32
    n_flow_layers: int = 6
    lr_lambda: float = 1e-3
    lr_flow: float = 1e-3


@struct.dataclass
class ModelState:
    # pytree leaves (OK inside jit)
    params_lambda: dict
    params_flow: dict
    opt_lambda: optax.OptState
    opt_flow: optax.OptState

    # static (non-pytree) fields
    tx_lambda: optax.GradientTransformation = struct.field(pytree_node=False)
    tx_flow: optax.GradientTransformation = struct.field(pytree_node=False)
    lambda_net: LambdaNet = struct.field(pytree_node=False)
    mark_flow: MarkFlow = struct.field(pytree_node=False)
    cfg: ModelConfig = struct.field(pytree_node=False)


def init_model(rng: jax.random.KeyArray, cfg: ModelConfig) -> ModelState:
    rng_lam, rng_flow = jax.random.split(rng, 2)
    lambda_net = LambdaNet(hidden=cfg.hidden, depth=cfg.depth_lambda)
    mark_flow = MarkFlow(
        mark_dim=cfg.mark_dim,
        ctx_dim=cfg.ctx_dim,
        n_layers=cfg.n_flow_layers,
        hidden=cfg.hidden,
    )

    # Dummy shapes to initialize params
    x_dummy = jnp.zeros((1, cfg.position_dim))
    m_dummy = jnp.zeros((1, cfg.mark_dim))

    params_lambda = lambda_net.init(rng_lam, x_dummy)
    params_flow = mark_flow.init(rng_flow, m_dummy, x_dummy)

    tx_lambda = optax.adamw(cfg.lr_lambda)
    tx_flow = optax.adamw(cfg.lr_flow)
    opt_lambda = tx_lambda.init(params_lambda)
    opt_flow = tx_flow.init(params_flow)

    return ModelState(
        params_lambda=params_lambda,
        params_flow=params_flow,
        opt_lambda=opt_lambda,
        opt_flow=opt_flow,
        tx_lambda=tx_lambda,
        tx_flow=tx_flow,
        lambda_net=lambda_net,
        mark_flow=mark_flow,
        cfg=cfg,
    )


# -----------------------------
# Training: marked Poisson likelihood (no history => integral = Δt * λ0)
# -----------------------------


@jax.jit
def loss_step(
    state: ModelState,
    x_spk: jnp.ndarray,
    m_spk: jnp.ndarray,
    x_time: jnp.ndarray,
    dt: float,
) -> Tuple[ModelState, dict]:
    """
    One step: minimize
      - E_spikes [ log λ0(x) + log p_flow(m|x) ] + E_time [ λ0(x) * dt ]
    """

    dt = jnp.asarray(dt)

    # Lambda params update (appears in events + integral)
    def lam_total_loss(params_lambda):
        lam_spk = state.lambda_net.apply(params_lambda, x_spk)
        lam_time = state.lambda_net.apply(params_lambda, x_time)
        # events: -log λ
        events = -jnp.log(lam_spk + 1e-12).mean()
        # integral: + E_t[ λ * dt ]
        integral = lam_time.mean() * dt
        return events + integral

    # Flow params update (only in mark terms)
    def flow_loss(params_flow):
        lp_mark = state.mark_flow.apply(params_flow, m_spk, x_spk)  # log p(m|x)
        return -lp_mark.mean()

    lam_value, lam_grads = jax.value_and_grad(lam_total_loss)(state.params_lambda)
    flow_value, flow_grads = jax.value_and_grad(flow_loss)(state.params_flow)

    upd_lam, opt_lam = state.tx_lambda.update(
        lam_grads, state.opt_lambda, state.params_lambda
    )
    upd_flow, opt_flow = state.tx_flow.update(
        flow_grads, state.opt_flow, state.params_flow
    )

    params_lambda = optax.apply_updates(state.params_lambda, upd_lam)
    params_flow = optax.apply_updates(state.params_flow, upd_flow)

    new_state = ModelState(
        params_lambda=params_lambda,
        params_flow=params_flow,
        opt_lambda=opt_lam,
        opt_flow=opt_flow,
        tx_lambda=state.tx_lambda,
        tx_flow=state.tx_flow,
        lambda_net=state.lambda_net,
        mark_flow=state.mark_flow,
        cfg=state.cfg,
    )
    info = dict(loss_lambda=lam_value, loss_flow=flow_value)
    return new_state, info


def train(
    rng: jax.random.KeyArray,
    position_time: jnp.ndarray,  # (T_pos,)
    position: jnp.ndarray,  # (T_pos, P)
    spike_times: jnp.ndarray,  # (S,)
    marks: jnp.ndarray,  # (S, M)
    cfg: ModelConfig,
    steps: int = 2000,
    batch_spikes: int = 1024,
    batch_time: int = 2048,
) -> ModelState:
    state = init_model(rng, cfg)
    T = position_time.shape[0]
    dt = float(jnp.mean(jnp.diff(position_time)))
    rng_np = np.random.default_rng(0)

    # Precompute spike positions
    st = spike_times
    x_spk_all = interp_position_at_times(position_time, position, st)

    for step in range(steps):
        # sample spikes
        if st.shape[0] > 0:
            sel_s = rng_np.choice(
                st.shape[0],
                size=min(batch_spikes, st.shape[0]),
                replace=st.shape[0] < batch_spikes,
            )
            x_spk = x_spk_all[sel_s]
            m_spk = marks[sel_s]
        else:
            x_spk = jnp.zeros((1, cfg.position_dim))
            m_spk = jnp.zeros((1, cfg.mark_dim))

        # sample time points for integral
        sel_t = rng_np.choice(T, size=min(batch_time, T), replace=T < batch_time)
        x_time = position[sel_t]

        state, info = loss_step(state, x_spk, m_spk, x_time, dt)
        if (step + 1) % 200 == 0:
            print(
                f"[{step+1:04d}] loss_lambda={float(info['loss_lambda']):.4f}  loss_flow={float(info['loss_flow']):.4f}"
            )

    return state


# -----------------------------
# Non-local decoding (no history; add exp(h(t)) as a scalar multiplier if you like)
# -----------------------------


def nonlocal_log_likelihood(
    state: ModelState,
    time_edges: jnp.ndarray,  # (N+1,)
    X_grid: jnp.ndarray,  # (B, P)
    spike_times: jnp.ndarray,  # (S,)
    marks: jnp.ndarray,  # (S, M)
    block_size: int = 128,
) -> jnp.ndarray:
    """
    Returns LL over (time bin, position bin): shape (N, B).
    For each bin n and position x_b:
        Sum_{spikes in n} [ log λ0(x_b) + log p_flow(m_i | x_b) ] - Δt_n * λ0(x_b)
    """
    N = time_edges.shape[0] - 1
    B = X_grid.shape[0]
    dt = time_edges[1:] - time_edges[:-1]  # (N,)
    out = jnp.zeros((N, B))

    # Bin spikes
    inb = jnp.logical_and(spike_times >= time_edges[0], spike_times < time_edges[-1])
    st = spike_times[inb]
    m = marks[inb]
    idx = (
        bin_spike_times_to_bins(st, time_edges)
        if st.shape[0] > 0
        else jnp.array([], dtype=jnp.int32)
    )

    # Integral term: - Δt * λ0(x_b)
    for s in range(0, B, block_size):
        xb = X_grid[s : s + block_size]
        lam = state.lambda_net.apply(state.params_lambda, xb)  # (b,)
        out = out.at[:, s : s + block_size].add(-dt[:, None] * lam[None, :])

    # Spike terms
    if st.shape[0] > 0:
        for s in range(0, B, block_size):
            xb = X_grid[s : s + block_size]  # (b,P)
            # tile spikes across positions
            S = m.shape[0]
            b = xb.shape[0]
            Xrep = jnp.repeat(xb[None, :, :], S, axis=0).reshape(S * b, -1)
            Mrep = jnp.repeat(m[:, None, :], b, axis=1).reshape(S * b, -1)
            # scores
            lam_rep = state.lambda_net.apply(state.params_lambda, Xrep).reshape(
                S, b
            )  # (S,b)
            lp_mark = state.mark_flow.apply(state.params_flow, Mrep, Xrep).reshape(S, b)
            contrib = jnp.log(lam_rep + 1e-12) + lp_mark  # (S,b)
            add_blk = jax.ops.segment_sum(contrib, idx, num_segments=N)  # (N,b)
            out = out.at[:, s : s + block_size].add(add_blk)

    return out  # (N,B)


# -----------------------------
# Example usage (toy, to check shapes)
# -----------------------------
if __name__ == "__main__":
    rng = jax.random.PRNGKey(0)

    # Fake positions on a 2D arena
    T = 4000
    P = 2
    M = 8
    t_pos = jnp.linspace(0.0, 100.0, T)
    pos = jnp.stack([jnp.sin(0.05 * t_pos), jnp.cos(0.05 * t_pos)], axis=-1)  # (T,2)

    # Fake spikes + marks
    S = 6000
    st = jnp.sort(jax.random.uniform(rng, (S,)) * t_pos[-1])
    marks = jax.random.normal(rng, (S, M))

    cfg = ModelConfig(position_dim=P, mark_dim=M, hidden=128, n_flow_layers=6)
    state = train(rng, t_pos, pos, st, marks, cfg, steps=800)

    # Non-local decode: 20 ms bins over a 5 s window; 20x20 grid
    edges = jnp.linspace(10.0, 15.0, 251)
    xs = jnp.linspace(-1.5, 1.5, 20)
    grid = jnp.stack(jnp.meshgrid(xs, xs), axis=-1).reshape(-1, 2)
    ll = nonlocal_log_likelihood(state, edges, grid, st, marks, block_size=64)
    print("LL shape:", ll.shape)

[0200] loss_lambda=-2.6851  loss_flow=11.2947
[0400] loss_lambda=-2.6858  loss_flow=11.1851
[0600] loss_lambda=-2.6890  loss_flow=11.0938
[0800] loss_lambda=-2.6888  loss_flow=11.0561
LL shape: (250, 400)
