In [52]:
import numpy as np
from numpy import array, sqrt, zeros, exp, log
from numpy.random import normal as ndist
from numpy.random import randn

import autograd.numpy as anp

import matplotlib.pyplot as plt

In [53]:
# Write functions without too much thought
def G1(ϕ, ν1):
    μ, ρ, γ, σ = ϕ
    return μ + (γ/sqrt(1-ρ**2))*ν1

def Gt(ϕ, νt, xtm1):
    μ, ρ, γ, σ = ϕ
    return μ + ρ*(xtm1 - μ) + γ*νt

def Ht(ϕ, xt):
    return exp(xt)

def xbar(t, ϕ, ηt, yt):
    return log(yt - σ*ηt)

def Cbar1(ϕ, ν, η, y):
    return G1(ϕ, ν[1]) - xbar(1, ϕ, η[1], y[1])

def Cbar(t, ϕ, ν, η, y):
    return Gt(ϕ, ν[t], xbar(t-1, ϕ, η[t-1], y[t-1])) - xbar(t, ϕ, η[t], y[t])

In [54]:
μ = -0.5
γ = 0.4
ρ = 0.9
T = 1000
σ = 0.01
θ = array([μ, ρ, γ, σ])

In [60]:
class SSM:
    def __init__(self, T, σ, autograd=False):
        """Class for State Space Model in Manifold Lifting paper."""
        self.T  = T    # Number of time-steps, dictates dimensionality of the model
        self.σ  = σ    # Noise scale - dictates tightness around the data manifold.
        self.ad = autograd
        self.np = anp if autograd else np
        
    def G1(self, ϕ, ν1):
        """Generates latent variable x1."""
        μ, ρ, γ, σ = ϕ
        return μ + (γ/sqrt(1 - ρ**2))*ν1
    
    def Gt(self, t, ϕ, xtm1, νt):
        """Generates latent variables x2:xT."""
        μ, ρ, γ, σ = ϕ
        return μ + ρ*(xtm1 - μ) + γ*νt
        
    def x̄(self, t, ϕ, ηt):
        """Given ϕ and η it generates the latent variables xt that produced yt."""
        return self.np.log(self.y[t] - self.σ*ηt)
    
    def C̄1(self, ϕ, ν, η):
        """Constraint function for the first latent variable/observation."""
        return self.G1(ϕ, ν[1]) - self.x̄(1, ϕ, η[1])
        
    def C̄t(self, t, ϕ, ν, η):
        """Constraint function for t in {2 .. T}."""
        return self.Gt(t, ϕ, self.x̄(t-1, ϕ, η[t-1]), ν[t]) - self.x̄(t, ϕ, η[t])

    def C̄(self, ξ):
        """Overall constraint function. ξ=(ϕ, ν, η) where ϕ=(μ, ρ, γ, σ)."""
        ϕ, ν, η = ξ[:4], ξ[4:self.T+1], ξ[(self.T+1):(2*self.T + 1)]
        return self.np.concatenate(
            [self.C̄1(ϕ, ν, η)] + [self.C̄t(t, ϕ, ν, η) for t in range(1, self.T)]
        )
    
    def generate_x_given_param(self, ϕtrue):
        """Generates x_1:T given the four parameters."""
        μ, ρ, γ, σ = ϕtrue
        x = zeros(self.T)
        ν = randn(self.T)
        x[0] = ndist(loc=μ, scale=γ/sqrt(1 - ρ**2))
        for t in range(1, self.T):
            x[t] = μ + ρ*(x[t-1] - μ) + γ*ν[t]
        self.true_x = x
        self.ϕtrue  = ϕtrue
    
    def generate_y_given_x(self):
        """Generates data from true latent variables. """
        self.y = self.np.exp(self.true_x) + self.σ*randn(self.T)

In [61]:
ssm = SSM(100, 0.1, autograd=True)

In [63]:
ssm.generate_x_given_param(θ)
ssm.generate_y_given_x()

# Graham

In [46]:
from collections import namedtuple
import numpy as onp
import jax
import jax.numpy as np
from jax.scipy.special import ndtr, ndtri, logit, expit
import argparse
from functools import partial
from abc import ABC, abstractmethod
from functools import wraps

In [2]:
RealInterval = namedtuple("RealInterval", ("lower", "upper"))
reals = RealInterval(-onp.inf, onp.inf)
positive_reals = RealInterval(0, onp.inf)
negative_reals = RealInterval(-onp.inf, 0)
nonnegative_reals = positive_reals
nonpositive_reals = negative_reals

### Transforms

In [3]:
class ElementwiseMonotonicTransform:
    def __init__(self, forward, backward, domain, image, val_and_grad_forward=None):
        self._forward = forward
        self._backward = backward
        self.domain = domain
        self.image = image
        if val_and_grad_forward is None:
            val_and_grad_forward = jax.value_and_grad(forward)
        self._val_and_grad_forward = val_and_grad_forward

    def forward(self, u):
        return self._forward(u)

    def backward(self, x):
        return self._backward(x)

    def forward_and_det_jacobian(self, u):
        if onp.isscalar(u) or u.shape == ():
            return self._val_and_grad_forward(u)
        else:
            x, dx_du = jax.vmap(self._val_and_grad_forward)(u)
            return x, dx_du.sum()

    def __call__(self, u):
        return self._forward(u)

In [4]:
def unbounded_to_lower_bounded(lower):
    """Construct transform from reals to lower-bounded interval.
    Args:
        lower (float): Lower-bound of image of transform.
    """

    return ElementwiseMonotonicTransform(
        forward=lambda u: np.exp(u) + lower,
        backward=lambda x: np.log(x - lower),
        domain=reals,
        image=RealInterval(lower, onp.inf),
    )

def unbounded_to_upper_bounded(upper):
    """Construct transform from reals to upper-bounded interval.
    Args:
        upper (float): Upper-bound of image of transform.
    """
    return ElementwiseMonotonicTransform(
        forward=lambda u: upper - np.exp(u),
        backward=lambda x: np.log(upper - x),
        domain=reals,
        image=RealInterval(-onp.inf, upper),
    )

def unbounded_to_lower_and_upper_bounded(lower, upper):
    """Construct transform from reals to bounded interval.
    Args:
        lower (float): Lower-bound of image of transform.
        upper (float): Upper-bound of image of transform.
    """
    return ElementwiseMonotonicTransform(
        forward=lambda u: lower + (upper - lower) * expit(np.asarray(u, np.float64)),
        backward=lambda x: logit((np.asarray(x, np.float64) - lower) / (upper - lower)),
        domain=reals,
        image=RealInterval(lower, upper),
    )


def diagonal_affine_map(location, scale):

    return ElementwiseMonotonicTransform(
        forward=lambda x: location + scale * x,
        backward=lambda y: (y - location) / scale,
        domain=reals,
        image=reals,
        val_and_grad_forward=lambda x: (location + scale * x, scale),
    )

def standard_normal_to_uniform(lower, upper):

    return ElementwiseMonotonicTransform(
        forward=lambda n: lower + ndtr(n) * (upper - lower),
        backward=lambda u: ndtri((u - lower) / (upper - lower)),
        domain=reals,
        image=RealInterval(lower, upper),
    )

def standard_normal_to_exponential(rate):

    return ElementwiseMonotonicTransform(
        forward=lambda n: -np.log(ndtr(n)) / rate,
        backward=lambda e: ndtri(np.exp(-e * rate)),
        domain=reals,
        image=nonnegative_reals,
    )

def standard_normal_to_half_normal(scale):

    return ElementwiseMonotonicTransform(
        forward=lambda n: ndtri((ndtr(n) + 1) / 2) * scale,
        backward=lambda h: ndtri(2 * ndtr(h / scale) - 1),
        domain=reals,
        image=nonnegative_reals,
    )

def standard_normal_to_truncated_normal(location, scale, lower, upper):

    a = ndtr((lower - location) / scale)
    b = ndtr((upper - location) / scale)
    return ElementwiseMonotonicTransform(
        forward=lambda n: ndtri(a + ndtr(n) * (b - a)) * scale + location,
        backward=lambda t: ndtri((ndtr((t - location) / scale) - a) / (b - a)),
        domain=reals,
        image=RealInterval(lower, upper),
    )


def standard_normal_to_beta(shape_a, shape_b):

    if shape_b == 1:

        def icdf(u):
            return u ** (1 / shape_a)

        def cdf(x):
            return x ** shape_a

    elif shape_a == 1:

        def icdf(u):
            return 1 - (1 - u) ** (1 / shape_b)

        def cdf(x):
            return 1 - (1 - x) ** shape_b

    else:

        raise ValueError("Transform only defined for shape_a == 1 or shape_b == 1")

    return ElementwiseMonotonicTransform(
        forward=lambda n: icdf(ndtr(n)),
        backward=lambda x: ndtri(cdf(x)),
        domain=reals,
        image=RealInterval(0, 1),
    )

### Distributions

In [5]:
class Distribution:
    """Probability distribution with density with respect to Lebesgue measure."""

    def __init__(
        self,
        neg_log_dens,
        log_normalizing_constant,
        sample,
        support,
        from_standard_normal_transform=None,
    ):
        """
        Args:
            neg_log_dens (Callable[[ArrayLike], float]): Function returning the negative
                logarithm of a (potentially unnormalised) density function for the
                distribution with respect to the Lebesgue measure.
            log_normalizing_constant (ArrayLike): Logarithm of the normalising consant
                for density function defined by `neg_log_dens` such that
                    def dens(x): exp(-neg_log_dens(x) - log_normalizing_constant)
                is a normalized probability density function for the distribution.
            sample (Callable[[Generator, Tuple[int...]], ArrayLike]):
            support (object): Object defining support of distribution.
            from_standard_normal_transform (Callable[[ArrayLike], ArrayLike]): Function
                which given a random normal variate(s) outputs a variate(s) from the
                distribution represented by this object. Optional, may be `None`.
        """
        self._neg_log_dens = neg_log_dens
        self.log_normalizing_constant = log_normalizing_constant
        self.sample = sample
        self.support = support
        self.from_standard_normal_transform = from_standard_normal_transform

    def neg_log_dens(self, x, include_normalizing_constant=False):
        nld = self._neg_log_dens(x)
        if include_normalizing_constant:
            nld = nld + self.log_normalizing_constant
        if not (onp.isscalar(x) or x.shape == ()):
            nld = nld.sum()
        return nld

In [6]:
def pullback_distribution(distribution, transform):
    """Pullback a distribution through a differentiable transform.
    Given a distribution `μ` and (differentiable) transform `F` constructs a
    distribution `ν` such that `ν` is the pullback of `μ` under `F` or equivalently `μ`
    is the pushforward of `ν` under `F`, i.e. `F#ν = μ`.
    Args:
        distribution (Distribution): Distribution `μ` to pullback.
        transform (Transform): Transform `F` to pullback distribution through.
    Returns
        Distribution: Pullback distribution `ν`.
    """

    def transformed_neg_log_dens(u):
        x, det_dx_du = transform.forward_and_det_jacobian(u)
        return distribution.neg_log_dens(x) - np.log(det_dx_du)

    def transformed_sample(rng, shape=()):
        x_samples = distribution.sample(rng, shape)
        return onp.asarray(transform.backward(x_samples))

    assert (
        distribution.support == transform.image
    ), "Support of distribution does not match transform image"

    transformed_support = transform.domain

    if distribution.from_standard_normal_transform is not None:

        def transformed_from_standard_normal_transform(n):
            return transform.backward(distribution.from_standard_normal_transform(n))

    else:
        transformed_from_standard_normal_transform = None

    return Distribution(
        neg_log_dens=transformed_neg_log_dens,
        log_normalizing_constant=distribution.log_normalizing_constant,
        sample=transformed_sample,
        support=transformed_support,
        from_standard_normal_transform=transformed_from_standard_normal_transform,
    )


In [7]:
def uniform(lower, upper):
    """Construct uniform distribution with support on real-interval.
    Args:
        lower (float): Lower-bound of support.
        upper (float): Upper-bound of support.
    Returns:
        Distribution: Uniform distribution object.
    """

    def neg_log_dens(x):
        return 0

    log_normalizing_constant = np.log(upper - lower)

    def sample(rng, shape=()):
        return rng.uniform(low=lower, high=upper, size=shape)

    support = RealInterval(lower, upper)

    from_standard_normal_transform = standard_normal_to_uniform(lower, upper)

    return Distribution(
        neg_log_dens=neg_log_dens,
        log_normalizing_constant=log_normalizing_constant,
        sample=sample,
        support=support,
        from_standard_normal_transform=from_standard_normal_transform,
    )

In [8]:
def normal(location, scale):
    """Construct normal distribution with support on real-line.
    Args:
        location (float): Location parameter (mean of distribution).
        scale (float): Scale parameter (standard deviation of distribution).
    Returns:
        Distribution: Normal distribution object.
    """

    def neg_log_dens(x):
        return ((x - location) / scale) ** 2 / 2

    log_normalizing_constant = np.log(2 * np.pi) / 2 + np.log(scale)

    def sample(rng, shape=()):
        return rng.normal(loc=location, scale=scale, size=shape)

    from_standard_normal_transform = diagonal_affine_map(location, scale)

    return Distribution(
        neg_log_dens=neg_log_dens,
        log_normalizing_constant=log_normalizing_constant,
        sample=sample,
        support=reals,
        from_standard_normal_transform=from_standard_normal_transform,
    )


In [9]:
def log_normal(location, scale):
    """Construct log-normal distribution with support on positive reals.
    Args:
        location (float): Location parameter (mean of log of random variable).
        scale (float): Scale parameter (standard deviation of log of random variable).
    Returns:
        Distribution: Log-normal distribution object.
    """

    def neg_log_dens(x):
        return ((np.log(x) - location) / scale) ** 2 / 2 + np.log(x)

    log_normalizing_constant = np.log(2 * np.pi) / 2 + np.log(scale)

    def sample(rng, shape=()):
        return onp.exp(rng.normal(loc=location, scale=scale, size=shape))

    from_standard_normal_transform = ElementwiseMonotonicTransform(
        forward=lambda n: np.exp(location + scale * n),
        backward=lambda x: (np.log(x) - location) / scale,
        domain=reals,
        image=positive_reals
    )

    return Distribution(
        neg_log_dens=neg_log_dens,
        log_normalizing_constant=log_normalizing_constant,
        sample=sample,
        support=positive_reals,
        from_standard_normal_transform=from_standard_normal_transform,
    )

In [10]:
def half_normal(scale):
    """Construct half-normal distribution with support on non-negative reals.
    Args:
        scale (float): Scale parameter.
    Returns:
        Distribution: Half-normal distribution object.
    """

    def neg_log_dens(x):
        return (x / scale) ** 2 / 2

    log_normalizing_constant = np.log(np.pi / 2) / 2 + np.log(scale)

    def sample(rng, shape=()):
        return abs(rng.normal(loc=0, scale=scale, size=shape))

    from_standard_normal_transform = standard_normal_to_half_normal(scale)

    return Distribution(
        neg_log_dens=neg_log_dens,
        log_normalizing_constant=log_normalizing_constant,
        sample=sample,
        support=nonnegative_reals,
        from_standard_normal_transform=from_standard_normal_transform,
    )

In [11]:
def reparametrize_to_unbounded_support(prior_spec):
    if (
        prior_spec.distribution.support.lower != -np.inf
        and prior_spec.distribution.support.upper != np.inf
    ):
        bounding_transform = unbounded_to_lower_and_upper_bounded(
            prior_spec.distribution.support.lower, prior_spec.distribution.support.upper
        )
    elif prior_spec.distribution.support.lower != -np.inf:
        bounding_transform = unbounded_to_lower_bounded(
            prior_spec.distribution.support.lower
        )
    elif prior_spec.distribution.support.upper != np.inf:
        bounding_transform = unbounded_to_upper_bounded(
            prior_spec.distribution.support.upper
        )
    else:
        return prior_spec
    distribution = pullback_distribution(prior_spec.distribution, bounding_transform)
    if prior_spec.transform is not None:
        transform = lambda u: prior_spec.transform(bounding_transform(u))
    else:
        transform = bounding_transform
    return PriorSpecification(
        shape=prior_spec.shape, distribution=distribution, transform=transform
    )

In [12]:
def reparametrize_to_standard_normal(prior_spec):
    from_standard_normal_transform = (
        prior_spec.distribution.from_standard_normal_transform
    )
    if prior_spec.transform is not None:
        transform = lambda u: prior_spec.transform(from_standard_normal_transform(u))
    else:
        transform = from_standard_normal_transform
    return PriorSpecification(
        shape=prior_spec.shape, distribution=normal(0, 1), transform=transform
    )

In [13]:
def set_up_prior(prior_specs):
    def get_shape(spec, data):
        return spec.shape(data) if callable(spec.shape) else spec.shape
    
    def myfunc(shape):
        if shape != ():
            return int(np.product(shape))
        else:
            return 0

    def reparametrized_prior_specs(data):
        for name, spec in prior_specs.items():
            if (
                data.get("parametrization") == "normal"
                and spec.distribution.from_standard_normal_transform is not None
            ):
                yield name, reparametrize_to_standard_normal(spec)
            else:
                yield name, reparametrize_to_unbounded_support(spec)

    def reparametrized_prior_specs_and_u_slices(u, data):
        i = 0
        for name, spec in reparametrized_prior_specs(data):
            shape = get_shape(spec, data)
            size = myfunc(shape) #int(np.product(shape))
            u_slice = u[i] if shape == () else u[i : i + size].reshape(shape)
            i += size
            yield name, spec, u_slice

    def compute_dim_u(data):
        return sum(myfunc(get_shape(spec, data)) for _, spec in reparametrized_prior_specs(data)) 
#         return sum(
#             int(np.product(get_shape(spec, data)))
#             for _, spec in reparametrized_prior_specs(data)
#         )

    def generate_params(u, data):
        params = {}
        for name, spec, u_slice in reparametrized_prior_specs_and_u_slices(u, data):
            if spec.transform is not None:
                params[name] = spec.transform(u_slice)
            else:
                params[name] = u_slice
        return params

    def prior_neg_log_dens(u, data):
        nld = 0
        for _, spec, u_slice in reparametrized_prior_specs_and_u_slices(u, data):
            nld += spec.distribution.neg_log_dens(u_slice)
        return nld

    def sample_from_prior(rng, data, num_sample=None):
        u_slices = []
        for _, spec in reparametrized_prior_specs(data):
            shape = get_shape(spec, data)
            if num_sample is None:
                u_slices.append(
                    np.atleast_1d(spec.distribution.sample(rng, shape).flatten())
                )
            else:
                shape = (num_sample,) + shape
                u_slices.append(
                    np.atleast_2d(spec.distribution.sample(rng, shape).reshape((num_sample, -1)))
                )

        return np.concatenate(u_slices, -1)

    return compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior

In [14]:
PriorSpecification = namedtuple(
    "PriorSpecification",
    ("shape", "distribution", "transform"),
    defaults=((), normal(0, 1), None),
)

In [15]:
prior_specifications = {
    "μ": PriorSpecification(distribution=normal(0, 1)),
    "σ": PriorSpecification(distribution=half_normal(1)),
    "ϕ": PriorSpecification(distribution=uniform(-1, 1)),
}

In [16]:
compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior(
    prior_specifications
)

In [17]:
def generate_x_0(params, v_0, data):
    return params["μ"] + (params["σ"] / (1 - params["ϕ"] ** 2) ** 0.5) * v_0


def forward_func(params, v, x, data):
    return params["μ"] + params["ϕ"] * (x - params["μ"]) + params["σ"] * v


def observation_func(params, n, x, data):
    return np.exp(x / 2) * n


def inverse_observation_func(params, n, y, data):
    return 2 * np.log(y / n)


In [18]:
def construct_state_space_model_generators(
    generate_params, generate_x_0, forward_func, observation_func
):
    """Construct functions to generate obs. and state sequences for state space models.
    Args:
        generate_params (Callable[[ArrayLike, Dict], Dict]): Function which generates a
            dictionary of model parameters given a 1D array of unbounded global latent
            variables and data dictionary.
        generate_x_0 (Callable[[Dict, ArrayLike, Dict], ArrayLike]): Function which
            generates the initial latent state given a dictionary of model parameters,
            an array of unbounded local latent variables and a data dictionary.
        forward_func (Callable[[Dict, ArrayLike, ArrayLike, Dict], ArrayLike]): Function
            which generates the next state in the latent state sequence, given a
            dictionary of model parameters, an array of unbounded local latent
            variables, the current latent state and a data dictionary.
        observation_func (Callable[[Dict, ArrayLike, ArrayLike, Dict], ArrayLike]):
            Function which generates the observation of a latent state, given a
            dictionary of model parameters, an array of unbounded local latent
            (observation noise) variables, the current latent state and a data
            dictionary.
    Returns:
        generate_from_model (
                Callable[[ArrayLike, ArrayLike, Dict], Tuple[Dict, ArrayLike]]):
            Function which given two array arguments and a data dictionary, the first
            array corresponding to all unbounded global latent variables and the second
            corresponding to all unbounded local latent variables, returns a dictionary
            of model parameters and an array corresponding to the generated latent state
            sequence.
        generate_y (
                Callable[[ArrayLike, ArrayLike, ArrayLike, Dict], ArrayLike]):
            Function which given three arrays and a data dictionary, the first array
            corresponding to all unbounded global latent variables, the second
            corresponding to all unbounded local latent variables and the third
            corresponding to all unbounded observation noise variables, returns an array
            corresponding to all observed variables.
    """

    def generate_from_model(u, v, data):
        params = generate_params(u, data)
        x_0 = generate_x_0(params, v[0], data)

        def step(x, v):
            x_ = forward_func(params, v, x, data)
            return x_, x_

        _, x_ = lax.scan(step, x_0, v[1:])
        return params, np.concatenate((x_0[None], x_))

    def generate_y(u, v, n, data):
        params, x = generate_from_model(u, v, data)
        y = jax.vmap(observation_func, (None, 0, 0))(params, n, x, data)
        return y

    return generate_from_model, generate_y


In [19]:
generate_from_model, generate_y = construct_state_space_model_generators(
    generate_params=generate_params,
    generate_x_0=generate_x_0,
    forward_func=forward_func,
    observation_func=observation_func,
)

In [20]:
def extended_prior_neg_log_dens(q, data):
    dim_u = compute_dim_u(data)
    dim_y = data["y_obs"].shape[0]
    u, v, n = q[:dim_u], q[dim_u : dim_u + dim_y], q[dim_u + dim_y :]
    return prior_neg_log_dens(u, data) + (v ** 2).sum() / 2 + (n ** 2).sum() / 2


In [21]:
def posterior_neg_log_dens(q, data):
    dim_u = compute_dim_u(data)
    u, v = q[:dim_u], q[dim_u:]
    _, x = generate_from_model(u, v, data)
    return (
        prior_neg_log_dens(u, data)
        + (v ** 2).sum() / 2
        + (0.5 * ((data["y_obs"] / np.exp(x / 2)) ** 2).sum() + (x / 2).sum())
    )

In [22]:
def constr_split(u, v, n, y, data):
    params = generate_params(u, data)
    x = 2 * np.log(y / n)
    return (
        np.concatenate(
            (
                (
                    params["μ"]
                    + (params["σ"] / (1 - params["ϕ"] ** 2) ** 0.5) * v[0]
                    - x[0]
                )[None],
                params["μ"]
                + params["ϕ"] * (x[:-1] - params["μ"])
                + params["σ"] * v[1:]
                - x[1:],
            )
        ),
        x,
    )

In [23]:
def jacob_constr_split_blocks(u, v, n, y, data):
    dim_u = compute_dim_u(data)
    dim_y = y.shape[0]
    params, dparams_du = jax.jvp(
        lambda u_: generate_params(u_, data), (u,), (np.ones(dim_u),)
    )
    x = 2 * np.log(y / n)
    dx_dy = 2 / y
    one_minus_ϕ_sq = 1 - params["ϕ"] ** 2
    sqrt_one_minus_ϕ_sq = one_minus_ϕ_sq ** 0.5
    v_0_over_sqrt_one_minus_ϕ_sq = v[0] / sqrt_one_minus_ϕ_sq
    x_minus_μ = x[:-1] - params["μ"]
    dc_du = np.stack(
        (
            np.concatenate(
                (
                    dparams_du["μ"][None],
                    (1 - params["ϕ"]) * dparams_du["μ"] * np.ones(dim_y - 1),
                )
            ),
            np.concatenate(
                (
                    dparams_du["σ"][None] * v_0_over_sqrt_one_minus_ϕ_sq,
                    dparams_du["σ"] * v[1:],
                )
            ),
            np.concatenate(
                (
                    dparams_du["ϕ"][None]
                    * params["ϕ"]
                    * params["σ"]
                    * v_0_over_sqrt_one_minus_ϕ_sq
                    / one_minus_ϕ_sq,
                    x_minus_μ * dparams_du["ϕ"],
                )
            ),
        ),
        1,
    )
    dc_dv = np.concatenate(
        [params["σ"][None] / sqrt_one_minus_ϕ_sq, params["σ"] * np.ones(dim_y - 1)]
    )
    dc_dn = 2 / n, -2 * params["ϕ"] / n[:-1]
    c = np.concatenate(
        (
            (params["μ"] + params["σ"] * v_0_over_sqrt_one_minus_ϕ_sq - x[0])[None],
            params["μ"] + params["ϕ"] * x_minus_μ + params["σ"] * v[1:] - x[1:],
        )
    )
    return (dc_du, dc_dv, dc_dn, dx_dy), c


In [24]:
def sample_initial_states(rng, data, num_chain=4, algorithm="chmc"):
    """Sample initial states from prior."""
    init_states = []
    dim_y = data["y_obs"].shape[0]
    for _ in range(num_chain):
        u = sample_from_prior(rng, data)
        v = rng.standard_normal(dim_y)
        if algorithm == "chmc":
            _, x = generate_from_model(u, v, data)
            n = data["y_obs"] / onp.exp(x / 2)
            q = onp.concatenate((u, v, onp.asarray(n)))
        else:
            q = onp.concatenate((u, v))
        init_states.append(q)
    return init_states


In [25]:
def set_up_argparser_with_standard_arguments(description):
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument(
        "--output-root-dir",
        default="results",
        help="Root directory to make experiment output subdirectory in",
    )
    parser.add_argument(
        "--data-dir", default="data", help="Directory containing dataset files",
    )
    parser.add_argument(
        "--algorithm",
        default="chmc",
        choices=("chmc", "hmc"),
        help="Which algorithm to perform inference with, from: chmc, hmc",
    )
    parser.add_argument(
        "--prior-parametrization",
        choices=("unbounded", "normal"),
        default="unbounded",
        help=(
            "Parameterization to use for prior distribution. Default is to define all "
            "parameters as transforms of unbounded variables. Alternatively parameters "
            "may be expressed as transforms of standard normal variates where possible."
        ),
    )
    parser.add_argument(
        "--seed", type=int, default=202101, help="Seed for random number generator"
    )
    parser.add_argument(
        "--num-chain",
        type=int,
        default=4,
        help="Number of independent chains to sample",
    )
    parser.add_argument(
        "--num-warm-up-iter",
        type=int,
        default=1000,
        help="Number of chain iterations in adaptive warm-up sampling stage",
    )
    parser.add_argument(
        "--num-main-iter",
        type=int,
        default=2500,
        help="Number of chain iterations in main sampling stage",
    )
    parser.add_argument(
        "--max-tree-depth",
        type=int,
        default=10,
        help="Maximum depth of binary trajectory tree in each dynamic HMC iteration",
    )
    parser.add_argument(
        "--step-size-adaptation-target",
        type=float,
        default=0.8,
        help="Target acceptance statistic for step size adaptation",
    )
    parser.add_argument(
        "--step-size-reg-coefficient",
        type=float,
        default=0.05,
        help="Regularisation coefficient for step size adaptation",
    )
    parser.add_argument(
        "--metric-type",
        choices=("diagonal", "dense"),
        default="diagonal",
        help=(
            "Metric type to adaptively tune during warm-up stage when using HMC "
            "algorithm. If 'diagonal' a diagonal metric matrix representation is used "
            "with diagonal entries set to reciprocals of estimates of the marginal "
            "posterior variances. If 'dense' a dense metric matrix representation is "
            "used corresponding to the inverse of an estimate of the posterior "
            "covariance matrix."
        ),
    )
    parser.add_argument(
        "--projection-solver",
        choices=("newton", "quasi-newton", "newton-line-search"),
        default="newton",
        help=(
            "Iterative method to solve projection onto manifold when using CHMC "
            "algorithm."
        ),
    )
    parser.add_argument(
        "--projection-solver-max-iters",
        type=int,
        default=50,
        help="Maximum number of iterations to try in projection solver",
    )
    parser.add_argument(
        "--projection-solver-warm-up-constraint-tol",
        type=float,
        default=1e-6,
        help="Warm-up stage tolerance for constraint function norm in projection solver",
    )
    parser.add_argument(
        "--projection-solver-warm-up-position-tol",
        type=float,
        default=1e-5,
        help="Warm-up stage tolerance for change in position norm in projection solver",
    )
    parser.add_argument(
        "--projection-solver-main-constraint-tol",
        type=float,
        default=1e-9,
        help="Main stage tolerance for constraint function norm in projection solver",
    )
    parser.add_argument(
        "--projection-solver-main-position-tol",
        type=float,
        default=1e-8,
        help="Main stage tolerance for change in position norm in projection solver",
    )
    return parser

In [51]:
def maximum_norm(vct):
    """Calculate the maximum (L-infinity) norm of a vector."""
    return abs(vct).max()

In [26]:
def add_ssm_specific_args(parser):
    group = parser.add_mutually_exclusive_group(required=False)
    group.add_argument(
        "--use-manual-constraint-and-jacobian",
        dest="use_manual_constraint_and_jacobian",
        action="store_true",
        help=(
            "Use manually specifed split constraint and Jacobian functions rather "
            "than automatically generated function."
        ),
    )
    group.add_argument(
        "--use-auto-constraint-and-jacobian",
        dest="use_manual_constraint_and_jacobian",
        action="store_false",
        help=(
            "Use automatically generated split constraint and Jacobian functions rather"
            " than manually defined functions generated functions."
        ),
    )
    parser.set_defaults(use_manual_constraint_and_jacobian=True)
    return 

In [36]:
def construct_trace_func(generate_params, data, dim_u, dim_v=None):

    jitted_generate_params = jax.jit(partial(generate_params, data=data))

    if dim_v is None:

        def trace_func(state):
            u = state.pos[:dim_u]
            params = jitted_generate_params(u)
            return {**params, "u": u}

    else:

        def trace_func(state):
            u, v = state.pos[:dim_u], state.pos[dim_u : dim_u + dim_v]
            params = jitted_generate_params(u)
            return {**params, "u": u, "v": v}

    return 

In [41]:
def get_ssm_constrained_system_class_and_kwargs(
    use_manual_constraint_and_jacobian,
    generate_params,
    generate_x_0,
    forward_func,
    inverse_observation_func,
    constr_split,
    jacob_constr_split_blocks,
):
    if use_manual_constraint_and_jacobian:
        constrained_system_class = PartiallyInvertibleStateSpaceModelSystem
        constrained_system_kwargs = {
            "constr_split": constr_split,
            "jacob_constr_split_blocks": jacob_constr_split_blocks,
        }
    else:
        constrained_system_class = AutoPartiallyInvertibleStateSpaceModelSystem
        constrained_system_kwargs = {
            "generate_params": generate_params,
            "generate_x_0": generate_x_0,
            "forward_func": forward_func,
            "inverse_observation_func": inverse_observation_func,
        }
    return constrained_system_class, constrained_system_kwargs

In [48]:
def cache_in_state(*depends_on):
    """Memoizing decorator for system methods.
    Used to decorate `mici.systems.System` methods which compute a function of
    one or more chain state variable(s), with the decorated method caching the
    value returned by the method being wrapped in the `ChainState` object to
    prevent the need for recomputation on future calls if the state variables
    the returned value depends on have not been changed in between the calls.
    Additionally for `ChainState` instances initialized with a `_call_counts`
    argument, the memoized method will update a counter for the method in the
    `_call_counts` attribute every time the method being decorated is called
    (i.e. when there isn't a valid cached value available).
    Args:
       *depends_on: One or more strings corresponding to the names of any state
           variables the value returned by the method depends on, e.g. 'pos' or
           'mom', such that the cache in the state object is correctly cleared
           when the value of any of these variables (attributes) of the state
           object changes.
    """

    def cache_in_state_decorator(method):
        @wraps(method)
        def wrapper(self, state):
            key = _cache_key_func(self, method)
            if key not in state._cache:
                for dep in depends_on:
                    state._dependencies[dep].add(key)
            if key not in state._cache or state._cache[key] is None:
                state._cache[key] = method(self, state)
                if state._call_counts is not None:
                    state._call_counts[key] += 1
            return state._cache[key]

        return wrapper

    return cache_in_state_decorator

def cache_in_state_with_aux(depends_on, auxiliary_outputs):
    """Memoizing decorator for system methods with possible auxiliary outputs.
    Used to decorate `mici.systems.System` methods which compute a function of
    one or more chain state variable(s), with the decorated method caching the
    value or values returned by the method being wrapped in the `ChainState`
    object to prevent the need for recomputation on future calls if the state
    variables the returned value(s) depends on have not been changed in between
    the calls.
    Compared to the `cache_in_state` decorator, this variant allows for methods
    which may optionally also return additional auxiliary outputs, such as
    intermediate result computed while computing the primary output, which
    correspond to the output of another system method decorated with the
    `cache_in_state` or `cache_in_state_with_aux` decorators. If such auxiliary
    outputs are returned they are also used to update cache entry for the
    corresponding decorated method, potentially saving recomputation in
    subsequent calls to that method. A common instance of this pattern is in
    derivative values computed using automatic differentiation (AD), with the
    primal value being differentiated usually either calculated alongside the
    derivative (in forward-mode AD) or calculated first in a forward-pass before
    the derivatives are calculated in a reverse-pass (in reverse-mode AD). By
    caching the value of the primal computed as part of the derivative
    calculation, a subsequent call to a method corresponding to calculation of
    the primal itself will retrieve the cached value and not recompute the
    primal, providing the relevant state variables the primal (and derivative)
    depend on have not been changed in between.
    Additionally for `ChainState` instances initialized with a `_call_counts`
    argument, the memoized method will update a counter for the method in the
    `_call_counts` attribute every time the method being decorated is called
    (i.e. when there isn't a valid cached value available).
    Args:
        depends_on (str or Tuple[str]): A string or tuple of strings, with each
            string corresponding to the name of a state variables the value(s)
            returned by the method depends on, e.g. 'pos' or 'mom', such that
            the cache in the state object is correctly cleared when the value of
            any of these variables (attributes) of the state object changes.
        auxiliary_outputs (str or Tuple[str]): A string or tuple of strings,
            with each string defining an auxiliary output the wrapped method may
            additionally return in addition to the primary output. If auxiliary
            outputs are returned, the returned value should be a tuple with
            first entry the 'primary' output corresponding to the value
            associated with the name of the method and the subsequent entries in
            the tuple corresponding to the auxiliary outputs in the order
            specified by the entries in the `auxiliary_outputs` argument. If the
            primary output is itself a tuple, it must be wrapped in another
            tuple even when no auxiliary outputs are being returned.
    """
    if isinstance(depends_on, str):
        depends_on = (depends_on,)
    if isinstance(auxiliary_outputs, str):
        auxiliary_outputs = (auxiliary_outputs,)

    def cache_in_state_with_aux_decorator(method):
        @wraps(method)
        def wrapper(self, state):
            prim_key = _cache_key_func(self, method)
            keys = [prim_key] + [_cache_key_func(self, a) for a in auxiliary_outputs]
            for i, key in enumerate(keys):
                if key not in state._cache:
                    for dep in depends_on:
                        state._dependencies[dep].add(key)
            if prim_key not in state._cache or state._cache[prim_key] is None:
                vals = method(self, state)
                if isinstance(vals, tuple):
                    for k, v in zip(keys, vals):
                        state._cache[k] = v
                else:
                    state._cache[prim_key] = vals
                if state._call_counts is not None:
                    state._call_counts[prim_key] += 1
            return state._cache[prim_key]

        return wrapper

    return cache_in_state_with_aux_decorator


In [49]:
class System(ABC):
    r"""Base class for Hamiltonian systems.
    The Hamiltonian function \(h\) is assumed to have the general form
    \[ h(q, p) = h_1(q) + h_2(q, p) \]
    where \(q\) and \(p\) are the position and momentum variables respectively,
    and \(h_1\) and \(h_2\) Hamiltonian component functions. The exact
    Hamiltonian flow for the \(h_1\) component can be always be computed as it
    depends only on the position variable however depending on the form of
    \(h_2\) the corresponding exact Hamiltonian flow may or may not be
    simulable.
    By default \(h_1\) is assumed to correspond to the negative logarithm of an
    unnormalized density on the position variables with respect to the Lebesgue
    measure, with the corresponding distribution on the position space being
    the target distribution it is wished to draw approximate samples from.
    """

    def __init__(self, neg_log_dens, grad_neg_log_dens=None):
        """
        Args:
            neg_log_dens (Callable[[array], float]): Function which given a
                position array returns the negative logarithm of an
                unnormalized probability density on the position space with
                respect to the Lebesgue measure, with the corresponding
                distribution on the position space being the target
                distribution it is wished to draw approximate samples from.
            grad_neg_log_dens (
                    None or Callable[[array], array or Tuple[array, float]]):
                Function which given a position array returns the derivative of
                `neg_log_dens` with respect to the position array argument.
                Optionally the function may instead return a 2-tuple of values
                with the first being the array corresponding to the derivative
                and the second being the value of the `neg_log_dens` evaluated
                at the passed position array. If `None` is passed (the default)
                an automatic differentiation fallback will be used to attempt
                to construct the derivative of `neg_log_dens` automatically.
        """
        self._neg_log_dens = neg_log_dens
        self._grad_neg_log_dens = autodiff_fallback(
            grad_neg_log_dens, neg_log_dens, "grad_and_value", "grad_neg_log_dens"
        )

    @cache_in_state("pos")
    def neg_log_dens(self, state):
        """Negative logarithm of unnormalized density of target distribution.
        Args:
            state (mici.states.ChainState): State to compute value at.
        Returns:
            float: Value of computed negative log density.
        """
        return self._neg_log_dens(state.pos)

    @cache_in_state_with_aux("pos", "neg_log_dens")
    def grad_neg_log_dens(self, state):
        """Derivative of negative log density with respect to position.
        Args:
            state (mici.states.ChainState): State to compute value at.
        Returns:
            array: Value of `neg_log_dens(state)` derivative with respect to
                `state.pos`.
        """
        return self._grad_neg_log_dens(state.pos)

    def h1(self, state):
        """Hamiltonian component depending only on position.
        Args:
            state (mici.states.ChainState): State to compute value at.
        Returns:
            float: Value of `h1` Hamiltonian component.
        """
        return self.neg_log_dens(state)

    def dh1_dpos(self, state):
        """Derivative of `h1` Hamiltonian component with respect to position.
        Args:
            state (mici.states.ChainState): State to compute value at.
        Returns:
            array: Value of computed `h1` derivative.
        """
        return self.grad_neg_log_dens(state)

    def h1_flow(self, state, dt):
        """Apply exact flow map corresponding to `h1` Hamiltonian component.
        `state` argument is modified in place.
        Args:
            state (mici.states.ChainState): State to start flow at.
            dt (float): Time interval to simulate flow for.
        """
        state.mom -= dt * self.dh1_dpos(state)

    @abstractmethod
    def h2(self, state):
        """Hamiltonian component depending on momentum and optionally position.
        Args:
            state (mici.states.ChainState): State to compute value at.
        Returns:
            float: Value of `h2` Hamiltonian component.
        """

    @abstractmethod
    def dh2_dmom(self, state):
        """Derivative of `h2` Hamiltonian component with respect to momentum.
        Args:
            state (mici.states.ChainState): State to compute value at.
        Returns:
            array: Value of `h2(state)` derivative with respect to `state.pos`.
        """

    def h(self, state):
        """Hamiltonian function for system.
        Args:
            state (mici.states.ChainState): State to compute value at.
        Returns:
            float: Value of Hamiltonian.
        """
        return self.h1(state) + self.h2(state)

    def dh_dpos(self, state):
        """Derivative of Hamiltonian with respect to position.
        Args:
            state (mici.states.ChainState): State to compute value at.
        Returns:
            array: Value of `h(state)` derivative with respect to `state.pos`.
        """
        if hasattr(self, "dh2_dpos"):
            return self.dh1_dpos(state) + self.dh2_dpos(state)
        else:
            return self.dh1_dpos(state)

    def dh_dmom(self, state):
        """Derivative of Hamiltonian with respect to momentum.
        Args:
            state (mici.states.ChainState): State to compute value at.
        Returns:
            array: Value of `h(state)` derivative with respect to `state.mom`.
        """
        return self.dh2_dmom(state)

    @abstractmethod
    def sample_momentum(self, state, rng):
        """
        Sample a momentum from its conditional distribution given a position.
        Args:
            state (mici.states.ChainState): State defining position to
               condition on.
        Returns:
            mom (array): Sampled momentum.
        """

In [52]:
class _AbstractDifferentiableGenerativeModelSystem(System):
    """Base class for constrained systems for differentiable generative models.
    Compare to in-built Mici constrained system classes, uses 'matrix-free'
    implementations of operations involving constraint function Jacobian and Gram matrix
    to allow exploiting any structure present, and also JIT compiles iterative solvers
    for projection steps to improve performance.
    """

    def __init__(
        self,
        neg_log_dens,
        grad_neg_log_dens,
        constr,
        jacob_constr_blocks,
        decompose_gram,
        lmult_by_jacob_constr,
        rmult_by_jacob_constr,
        lmult_by_inv_gram,
        lmult_by_inv_jacob_product,
        log_det_sqrt_gram,
        lmult_by_pinv_jacob_constr=None,
        normal_space_component=None,
    ):

        if lmult_by_pinv_jacob_constr is None:

            def lmult_by_pinv_jacob_constr(jacob_constr_blocks, gram_components, vct):
                return rmult_by_jacob_constr(
                    *jacob_constr_blocks,
                    lmult_by_inv_gram(*jacob_constr_blocks, *gram_components, vct,),
                )

        if normal_space_component is None:

            def normal_space_component(jacob_constr_blocks, gram_components, vct):
                return lmult_by_pinv_jacob_constr(
                    jacob_constr_blocks,
                    gram_components,
                    lmult_by_jacob_constr(*jacob_constr_blocks, vct),
                )

        def quasi_newton_projection(
            q,
            jacob_constr_blocks_prev,
            gram_components_prev,
            dt,
            constraint_tol,
            position_tol,
            divergence_tol,
            max_iters,
            norm,
        ):
            """Quasi-Newton method to solve projection onto manifold."""

            def body_func(val):
                q, mu, i, _, _ = val
                c = constr(q)
                error = norm(c)
                delta_mu = lmult_by_pinv_jacob_constr(
                    jacob_constr_blocks_prev, gram_components_prev, c
                )
                mu += delta_mu
                q -= delta_mu
                i += 1
                return q, mu, i, norm(delta_mu), error

            def cond_func(val):
                _, _, i, norm_delta_q, error = val
                diverged = np.logical_or(error > divergence_tol, np.isnan(error))
                converged = np.logical_and(
                    error < constraint_tol, norm_delta_q < position_tol
                )
                return np.logical_not(
                    np.logical_or((i >= max_iters), np.logical_or(diverged, converged))
                )

            q, mu, i, norm_delta_q, error = lax.while_loop(
                cond_func, body_func, (q, np.zeros_like(q), 0, np.inf, -1.0)
            )
            return q, mu / dt, i, norm_delta_q, error

        def newton_projection(
            q,
            jacob_constr_blocks_prev,
            dt,
            constraint_tol,
            position_tol,
            divergence_tol,
            max_iters,
            norm,
        ):
            """Newton method to solve projection onto manifold."""

            def body_func(val):
                q, mu, i, _, _ = val
                jac_blocks, c = jacob_constr_blocks(q)
                error = norm(c)
                delta_mu = rmult_by_jacob_constr(
                    *jacob_constr_blocks_prev,
                    lmult_by_inv_jacob_product(
                        *jac_blocks, *jacob_constr_blocks_prev, c
                    ),
                )
                mu += delta_mu
                q -= delta_mu
                i += 1
                return q, mu, i, norm(delta_mu), error

            def cond_func(val):
                _, _, i, norm_delta_q, error = val
                diverged = np.logical_or(error > divergence_tol, np.isnan(error))
                converged = np.logical_and(
                    error < constraint_tol, norm_delta_q < position_tol
                )
                return np.logical_not(
                    np.logical_or((i >= max_iters), np.logical_or(diverged, converged))
                )

            q, mu, i, norm_delta_q, error = lax.while_loop(
                cond_func, body_func, (q, np.zeros_like(q), 0, np.inf, -1.0)
            )
            return q, mu / dt, i, norm_delta_q, error

        def newton_projection_with_line_search(
            q,
            jacob_constr_blocks_prev,
            dt,
            constraint_tol,
            position_tol,
            divergence_tol,
            max_iters,
            max_line_search_iters,
            norm,
        ):
            """Newton method with line search to solve projection onto manifold."""

            def body_func(loop_state):
                q, _, _, i, num_constr_calls = loop_state
                jac_blocks, c = jacob_constr_blocks(q)
                error = norm(c)
                search_direction = -rmult_by_jacob_constr(
                    *jacob_constr_blocks_prev,
                    lmult_by_inv_jacob_product(
                        *jac_blocks, *jacob_constr_blocks_prev, c
                    ),
                )

                def inner_body_func(inner_loop_state):
                    step_size, _, _, j = inner_loop_state
                    new_q = q + step_size * search_direction
                    new_error = norm(constr(new_q))
                    return step_size * 0.5, new_q, new_error, j + 1

                def inner_cond_func(inner_loop_state):
                    _, _, new_error, j = inner_loop_state
                    return np.logical_and(j < max_line_search_iters, new_error > error)


                (step_size, new_q, new_error, j) = inner_body_func((1., None, None, 0))
                (_, new_q, new_error, j) = lax.while_loop(
                    inner_cond_func, inner_body_func, (step_size, new_q, new_error, j)
                )
                return new_q, norm(new_q - q), new_error, i + 1, num_constr_calls + j

            def cond_func(loop_state):
                _, norm_delta_q, error, i, _ = loop_state
                diverged = np.logical_or(error > divergence_tol, np.isnan(error))
                converged = np.logical_and(
                    error < constraint_tol, norm_delta_q < position_tol
                )
                return np.logical_not(
                    np.logical_or((i >= max_iters), np.logical_or(diverged, converged))
                )

            new_q, norm_delta_q, error, i, num_constr_calls = lax.while_loop(
                cond_func, body_func, (q, np.inf, -1., 0, 0)
            )
            return new_q, (q - new_q) / dt, i, norm_delta_q, error, num_constr_calls

        self._constr = jax.jit(constr)
        self._jacob_constr_blocks = jax.jit(jacob_constr_blocks)
        self._decompose_gram = jax.jit(decompose_gram)
        self._log_det_sqrt_gram = jax.jit(log_det_sqrt_gram)
        self._val_and_grad_log_det_sqrt_gram = jax.jit(
            jax.value_and_grad(log_det_sqrt_gram, has_aux=True)
        )
        self._lmult_by_jacob_constr = jax.jit(lmult_by_jacob_constr)
        self._rmult_by_jacob_constr = jax.jit(rmult_by_jacob_constr)
        self._lmult_by_pinv_jacob_constr = jax.jit(lmult_by_pinv_jacob_constr)
        self._lmult_by_inv_jacob_product = jax.jit(lmult_by_inv_jacob_product)
        self._normal_space_component = jax.jit(normal_space_component)
        self._quasi_newton_projection = jax.jit(
            quasi_newton_projection, static_argnames="norm"
        )
        self._newton_projection = jax.jit(newton_projection, static_argnames="norm")
        self._newton_projection_with_line_search = jax.jit(
            newton_projection_with_line_search, static_argnames="norm"
        )
        super().__init__(neg_log_dens=neg_log_dens, grad_neg_log_dens=grad_neg_log_dens)

    def precompile_jax_functions(self, q, solver_norm=maximum_norm):
        self._neg_log_dens(q)
        self._grad_neg_log_dens(q)
        self._constr(q)
        jac_blocks, c = self._jacob_constr_blocks(q)
        gram_components = self._decompose_gram(*jac_blocks)
        self._log_det_sqrt_gram(q)
        self._val_and_grad_log_det_sqrt_gram(q)
        self._lmult_by_jacob_constr(*jac_blocks, q)
        self._rmult_by_jacob_constr(*jac_blocks, c)
        self._lmult_by_pinv_jacob_constr(jac_blocks, gram_components, c)
        self._lmult_by_inv_jacob_product(*jac_blocks, *jac_blocks, c)
        self._normal_space_component(jac_blocks, gram_components, q)
        self._quasi_newton_projection(
            q, jac_blocks, gram_components, 1., 0.1, 0.1, 1., 10, solver_norm
        )
        self._newton_projection(q, jac_blocks, 1., 0.1, 0.1, 1., 10, solver_norm)
        self._newton_projection_with_line_search(
            q, jac_blocks, 1., 0.1, 0.1, 1., 10, 10, solver_norm)

    @cache_in_state("pos")
    def constr(self, state):
        return convert_to_numpy_pytree(self._constr(state.pos))

    @cache_in_state_with_aux("pos", "constr")
    def jacob_constr_blocks(self, state):
        return convert_to_numpy_pytree(self._jacob_constr_blocks(state.pos))

    @cache_in_state("pos")
    def gram_components(self, state):
        return convert_to_numpy_pytree(
            self._decompose_gram(*self.jacob_constr_blocks(state))
        )

    @cache_in_state_with_aux(
        "pos", ("constr", "jacob_constr_blocks", "gram_components"),
    )
    def log_det_sqrt_gram(self, state):
        val, (constr, jacob_constr_blocks, gram_components) = self._log_det_sqrt_gram(
            state.pos
        )
        return convert_to_numpy_pytree(
            (val, constr, jacob_constr_blocks, gram_components)
        )

    @cache_in_state_with_aux(
        "pos",
        ("log_det_sqrt_gram", "constr", "jacob_constr_blocks", "gram_components"),
    )
    def grad_log_det_sqrt_gram(self, state):
        (
            (val, (constr, jacob_constr_blocks, gram_components)),
            grad,
        ) = self._val_and_grad_log_det_sqrt_gram(state.pos)
        return convert_to_numpy_pytree(
            (grad, val, constr, jacob_constr_blocks, gram_components)
        )

    def h1(self, state):
        return self.neg_log_dens(state) + self.log_det_sqrt_gram(state)

    def dh1_dpos(self, state):
        return self.grad_neg_log_dens(state) + self.grad_log_det_sqrt_gram(state)

    def h2(self, state):
        return 0.5 * state.mom @ state.mom

    def dh2_dmom(self, state):
        return state.mom

    def dh2_dpos(self, state):
        return 0 * state.pos

    def dh_dpos(self, state):
        return self.dh1_dpos(state)

    def h2_flow(self, state, dt):
        state.pos += dt * self.dh2_dmom(state)

    def dh2_flow_dmom(self, dt):
        return (dt * IdentityMatrix(), IdentityMatrix())

    def normal_space_component(self, state, vct):
        return onp.asarray(
            self._normal_space_component(
                self.jacob_constr_blocks(state), self.gram_components(state), vct
            )
        )

    def lmult_by_jacob_constr(self, state, vct):
        return onp.asarray(
            self._lmult_by_jacob_constr(*self.jacob_constr_blocks(state), vct)
        )

    def rmult_by_jacob_constr(self, state, vct):
        return onp.asarray(
            self._rmult_by_jacob_constr(*self.jacob_constr_blocks(state), vct)
        )

    def lmult_by_pinv_jacob_constr(self, state, vct):
        return onp.asarray(
            self._lmult_by_pinv_jacob_constr(
                self.jacob_constr_blocks(state), self.gram_components(state), vct
            )
        )

    def lmult_by_inv_jacob_product(self, state_1, state_2, vct):
        return onp.asarray(
            self._lmult_by_inv_jacob_product(
                *self.jacob_constr_blocks(state_1),
                *self.jacob_constr_blocks(state_2),
                vct,
            )
        )

    def project_onto_cotangent_space(self, mom, state):
        mom -= self.normal_space_component(state, mom)
        return mom

    def sample_momentum(self, state, rng):
        mom = rng.standard_normal(state.pos.shape)
        mom = self.project_onto_cotangent_space(mom, state)
        return mom

In [58]:
def standard_normal_neg_log_dens(q):
    """Unnormalised negative log density of standard normal vector."""
    return 0.5 * onp.sum(q ** 2)

def standard_normal_grad_neg_log_dens(q):
    """Gradient and value of negative log density of standard normal vector."""
    return q, 0.5 * onp.sum(q ** 2)

In [59]:
class PartiallyInvertibleStateSpaceModelSystem(
    _AbstractDifferentiableGenerativeModelSystem
):
    """System class for scalar state space models with invertible observation functions.
    Generative model is assumed to be of the form
        params = generate_params(u, data)
        x[0] = generate_x_0(params, v[0], data)
        for t in range(dim_y):
            x[t] = forward_func(params, v[t], x[t - 1], data)
            y[t] = observation_func(params, n[t], x[t], data)
    If `inverse_observation_func` corresponds to the inverse of `observation_func` in
    its third argument,
        observation_func(
            params, n, inverse_observation_func(params, n, y, data), data) == y
    then we can define a 'split' constraint function for the generative model as follows
        def constr_split(u, v, n, y, data):
            params = generate_params(u, data)
            x = [
                inverse_observation_func(params, n[t], y[t], data)
                for t in range(1, dim_y)
            ]
            return array(
                [generate_x_0(params, v[0], data) - x[0]] +
                [
                    forward_func(params, v[t], x[t-1], data) - x[t]
                    for t in range(1, dim_y)
                ]
            ), x
    where `y` is a `(dim_y,)` shaped 1D array of observed variables, `u` is a `(dim_u,)`
    shaped 1D array of global latent variables, `v` is a `(dim_y,)` shaped 1D array of
    local latent variables, `n` is a `(dim_y,)` shaped 1D array of observation noise
    variables and `data` is a dictionary of fixed values / data used by model.
    """

    def __init__(
        self,
        constr_split,
        jacob_constr_split_blocks,
        data,
        dim_u,
        neg_log_dens=standard_normal_neg_log_dens,
        grad_neg_log_dens=standard_normal_grad_neg_log_dens,
    ):
        dim_y = data["y_obs"].shape[0]

        if jacob_constr_split_blocks is None:

            def jacob_constr_split_blocks(u, v, n, y, data):
                dc_du = jax.jacfwd(lambda u_: constr_split(u_, v, n, y, data)[0])(u)
                one_vct = np.ones(dim_y)
                alt_vct = (-1.0) ** np.arange(dim_y)
                _, dx_dy = jax.jvp(
                    lambda y_: constr_split(u, v, n, y_, data)[1], (y,), (one_vct,)
                )
                c, dc_dv = jax.jvp(
                    lambda v_: constr_split(u, v_, n, y, data)[0], (v,), (one_vct,)
                )
                _, dc_dn_1 = jax.jvp(
                    lambda n_: constr_split(u, v, n_, y, data)[0], (n,), (one_vct,)
                )
                _, dc_dn_a = jax.jvp(
                    lambda n_: constr_split(u, v, n_, y, data)[0], (n,), (alt_vct,)
                )
                dc_dn = (
                    (dc_dn_1 + dc_dn_a * alt_vct) / 2,
                    (dc_dn_1[1:] - dc_dn_a[1:] * alt_vct[1:]) / 2,
                )
                return (dc_du, dc_dv, dc_dn, dx_dy), c

        def constr(q):
            u, v, n = np.split(q, (dim_u, dim_u + dim_y))
            return constr_split(u, v, n, data["y_obs"], data)[0]

        def jacob_constr_blocks(q):
            u, v, n = np.split(q, (dim_u, dim_u + dim_y))
            return jacob_constr_split_blocks(u, v, n, data["y_obs"], data)

        def lmult_by_jacob_constr(dc_du, dc_dv, dc_dn, dx_dy, vct):
            vct_u, vct_v, vct_n = (
                vct[:dim_u],
                vct[dim_u : dim_u + dim_y],
                vct[dim_u + dim_y :],
            )
            return (
                dc_du @ vct_u
                + dc_dv * vct_v
                + dc_dn[0] * vct_n
                + np.pad(dc_dn[1] * vct_n[:-1], (1, 0))
            )

        def rmult_by_jacob_constr(dc_du, dc_dv, dc_dn, dx_dy, vct):
            return np.concatenate(
                (
                    vct @ dc_du,
                    vct * dc_dv,
                    vct * dc_dn[0] + np.pad(dc_dn[1] * vct[1:], (0, 1)),
                )
            )

        def decompose_gram(dc_du, dc_dv, dc_dn, dx_dy):
            a = dc_dn[0][:-1] * dc_dn[1]
            b = dc_dv ** 2 + dc_dn[0] ** 2 + np.pad(dc_dn[1] ** 2, (1, 0))
            cap_mtx = np.eye(dim_u) + dc_du.T @ jax.vmap(
                tridiagonal_solve, (None, None, None, 1), 1
            )(a, b, a, dc_du)
            chol_cap_mtx = cholesky(cap_mtx)
            return (a, b, chol_cap_mtx)

        def lmult_by_inv_gram(dc_du, dc_dv, dc_dn, dx_dy, a, b, chol_cap_mtx, vct):
            return tridiagonal_solve(
                a,
                b,
                a,
                vct
                - dc_du
                @ sla.cho_solve(
                    (chol_cap_mtx, True), dc_du.T @ tridiagonal_solve(a, b, a, vct)
                ),
            )

        def lmult_by_inv_jacob_product(
            dc_du_l, dc_dv_l, dc_dn_l, dx_dy_l, dc_du_r, dc_dv_r, dc_dn_r, dx_dy_r, vct
        ):
            a = dc_dn_l[1] * dc_dn_r[0][:-1]
            b = (
                dc_dv_l * dc_dv_r
                + dc_dn_l[0] * dc_dn_r[0]
                + np.pad(dc_dn_l[1] * dc_dn_r[1], (1, 0))
            )
            c = dc_dn_l[0][:-1] * dc_dn_r[1]
            cap_mtx = np.eye(dim_u) + dc_du_r.T @ jax.vmap(
                tridiagonal_solve, (None, None, None, 1), 1
            )(a, b, c, dc_du_l)
            return tridiagonal_solve(
                a,
                b,
                c,
                vct
                - dc_du_l
                @ np.linalg.solve(cap_mtx, dc_du_r.T @ tridiagonal_solve(a, b, c, vct)),
            )

        def log_det_sqrt_gram(q):
            (dc_du, dc_dv, dc_dn, dx_dy), c = jacob_constr_blocks(q)
            (a, b, chol_cap_mtx,) = decompose_gram(dc_du, dc_dv, dc_dn, dx_dy)
            return (
                np.log(chol_cap_mtx.diagonal()).sum()
                + tridiagonal_pos_def_log_det(a, b) / 2
                - np.log(abs(dx_dy)).sum(),
                (c, (dc_du, dc_dv, dc_dn, dx_dy), (a, b, chol_cap_mtx,)),
            )

        super().__init__(
            neg_log_dens=neg_log_dens,
            grad_neg_log_dens=grad_neg_log_dens,
            constr=constr,
            jacob_constr_blocks=jacob_constr_blocks,
            decompose_gram=decompose_gram,
            lmult_by_jacob_constr=lmult_by_jacob_constr,
            rmult_by_jacob_constr=rmult_by_jacob_constr,
            lmult_by_inv_gram=lmult_by_inv_gram,
            lmult_by_inv_jacob_product=lmult_by_inv_jacob_product,
            log_det_sqrt_gram=log_det_sqrt_gram,
        )


In [66]:
import mici

In [67]:
def run_experiment(
    args,
    data,
    rng,
    experiment_name,
    var_names,
    var_trace_func,
    sample_initial_states,
    constrained_system_class,
    constrained_system_kwargs,
    euclidean_system_class=mici.systems.EuclideanMetricSystem,
    euclidean_system_kwargs=None,
    posterior_neg_log_dens=None,
    extended_prior_neg_log_dens=None,
    dir_prefix=None,
    precompile_jax_functions=True,
):

    print(
        f"Running experiment with {experiment_name} model using "
        f"{args.algorithm.upper()} algorithm for inference"
    )

    # Set up output directory and logger

    output_dir = set_up_output_directory(args, experiment_name, dir_prefix)
    set_up_logger(output_dir)

    print(f"Results will be saved to {output_dir}")

    # Add parametrization flag to data dictionary

    data["parametrization"] = args.prior_parametrization

    # Set up Mici objects
    if posterior_neg_log_dens is not None:
        (
            neg_log_dens_hmc,
            grad_neg_log_dens_hmc,
        ) = mlift.construct_mici_system_neg_log_dens_functions(
            partial(posterior_neg_log_dens, data=data)
        )
        euclidean_system_kwargs = {
            "neg_log_dens": neg_log_dens_hmc,
            "grad_neg_log_dens": grad_neg_log_dens_hmc,
        }
    elif euclidean_system_kwargs is None:
        raise ValueError(
            "One of either posterior_neg_log_dens or euclidean_system_kwargs must "
            "not be None "
        )

    if extended_prior_neg_log_dens is not None:
        (
            neg_log_dens_chmc,
            grad_neg_log_dens_chmc,
        ) = mlift.construct_mici_system_neg_log_dens_functions(
            partial(extended_prior_neg_log_dens, data=data)
        )
        constrained_system_kwargs["neg_log_dens"] = neg_log_dens_chmc
        constrained_system_kwargs["grad_neg_log_dens"] = grad_neg_log_dens_chmc

    system, integrator, sampler, adapters, monitor_stats = set_up_mici_objects(
        args,
        rng,
        constrained_system_class,
        constrained_system_kwargs,
        euclidean_system_class,
        euclidean_system_kwargs,
    )

    def hamiltonian_and_call_count_trace_func(state):
        call_counts = {
            name.split(".")[-1] + "_calls": val
            for (name, _), val in state._call_counts.items()
        }
        return {
            **call_counts,
            "hamiltonian": system.h(state),
            "neg_log_dens": system.neg_log_dens(state),
        }

    trace_funcs = [var_trace_func, hamiltonian_and_call_count_trace_func]

    # Initialise chain states

    print("Sampling initial states ...")
    init_states = sample_initial_states(rng, data, args.num_chain, args.algorithm)
    init_states = [
        mici.states.ChainState(pos=q, mom=None, dir=1, _call_counts={})
        for q in init_states
    ]

    # Precompile JAX functions to avoid compilation time appearing in chain run times

    if precompile_jax_functions:
        print("Pre-compiling JAX functions ...")
        compile_time = precompile_system_jax_functions(init_states[0].pos, args, system)
        print(f"Total compile time: {compile_time:.0f} seconds")
    else:
        compile_time = 0

    # Ignore NumPy floating point overflow warnings
    # Prevents warning messages being produced while progress bars are being printed

    np.seterr(over="ignore")

    # Sample chains

    final_states, traces, stats, sampling_time = sample_chains(
        args, sampler, init_states, trace_funcs, adapters, output_dir, monitor_stats
    )

    print(f"Integrator step size: {integrator.step_size:.2g}")
    print(f"Total sampling time: {sampling_time:.0f} seconds")

    # Compute and display summary of time spent on different operation

    print("Computing chain operation times ...")
    operation_times = compute_operation_times(system, final_states)
    print(
        f"Total operation time: {operation_times['total_operation_time']:.3g} seconds"
    )

    # Compute, display and save summary of statistics of traced chain variables

    summary_vars = var_names + ["neg_log_dens", "hamiltonian"]
    summary, summary_dict = compute_and_save_summary(
        output_dir,
        summary_vars,
        traces,
        total_compile_time=compile_time,
        total_sampling_time=sampling_time,
        final_integrator_step_size=integrator.step_size,
        **operation_times,
    )

    print(summary)

    return final_states, traces, stats, summary_dict, 

In [27]:
 parser = set_up_argparser_with_standard_arguments(
        "Run stochastic-volatility model simulated data experiment"
    )

In [28]:
add_ssm_specific_args(parser)

In [29]:
data = dict(np.load("sv-simulated-data.npz"))

In [30]:
# data['parametrization'] = 'normal'

In [31]:
dim_u = compute_dim_u(data)

In [33]:
dim_y = data["y_obs"].shape[0]

In [35]:
rng = onp.random.default_rng(1234)

In [39]:
trace_func = construct_trace_func(generate_params, data, dim_u, dim_v=dim_y)

In [60]:
(
        constrained_system_class,
        constrained_system_kwargs,
    ) = get_ssm_constrained_system_class_and_kwargs(
        True,
        generate_params,
        generate_x_0,
        forward_func,
        inverse_observation_func,
        constr_split,
        jacob_constr_split_blocks,
    )

In [63]:
constrained_system_kwargs.update(data=data, dim_u=dim_u)