# MCMC using backward filtering, forward guiding for shapes

Parameter inference for trees with Gaussian transitions along edges and observations at the leaf nodes for landmark represented shapes. Please refer to the notebook [mcmc_Gaussian_BFFG.ipynb](mcmc_Gaussian_BFFG.ipynb) for a simpler version with $\mathbb R^2$ data. In the present version, the node covariance is constant throughout the tree similarly to [mcmc_Gaussian_BFFG.ipynb](mcmc_Gaussian_BFFG.ipynb). Shape dependent node covariance will follow in a later version.

The conditioning and upwards/downwards message passing and fusing operations follow the backward filtering, forward guiding approach of Frank van der Meulen, Moritz Schauer et al., see https://arxiv.org/abs/2010.03509 and https://arxiv.org/abs/2203.04155 . The latter reference provides an accesible introduction to the scheme and the notation used in this example.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from jax.random import PRNGKey, split
import hyperiax
import jax
from jax import numpy as jnp
from hyperiax.execution import LevelwiseTreeExecutor
from hyperiax.models import DownLambda, UpDownLambda
from hyperiax.models.functional import sum_fuse_children
from hyperiax.tree.updaters import update_noise_inplace
from hyperiax.mcmc import ParameterStore, VarianceParameter
from hyperiax.mcmc.metropolis_hastings import metropolis_hastings
from hyperiax.mcmc.plotting import trace_plots

import matplotlib.pyplot as plt
from tqdm import tqdm


In [None]:
# seed,
seed = 423
#import os; seed = int(os.urandom(5).hex(), 16)
key = PRNGKey(seed)

# Shape related setup

In [None]:
import matplotlib.pyplot as plt 

# plotting
def plot_shape(q):
    q = q.reshape((-1,d))
    plt.plot(q[:,0],q[:,1],'.')
    plt.axis('equal')

In [None]:
# define shape and plot
d = 2; n = 18 # d = dimension of embedding space (usually 2), n = number of landmarkds
phis = jnp.linspace(0,2*jnp.pi,n,endpoint=False) # circular shape
root = jnp.vstack((jnp.cos(phis),jnp.sin(phis))).T.flatten()

# plot
plot_shape(root)

In [None]:
# diffusion and covariance specification for shape processes
kQ12 = lambda x,params: params['k_alpha']*jnp.exp(-.5/params['k_sigma']*jnp.sum(jnp.square(x),2))

# evaluate k on two pairs of landmark configurations
kQ12_q = lambda q1,q2,params: kQ12(q1.reshape((-1,d))[:,jnp.newaxis,:]-q2.reshape((-1,d))[jnp.newaxis,:,:],params)

# evaluate k on one landmark configurations against itself with each landmark pair resulting in a dxd matric
# i,jth entry of result is kQ12(x_i,x_j)*eye(d)
def Q12(q,params): 
    A = jnp.einsum('ij,kl->ikjl',kQ12_q(q,q,params),jnp.eye(2))
    return A.reshape((A.shape[0]*A.shape[1],A.shape[2]*A.shape[3]))

# diffusion matrix
sigma = lambda params: Q12(root,params)
# covariance matrix
def a(params): _sigma = sigma(params); return jnp.einsum('ij,kj->ik',_sigma,_sigma)

# Gaussian tree, constant node covariance

First, we initialize the tree. We set the root to the shape defined above.

In [None]:
# create tree and initialize with noise
tree = hyperiax.tree.builders.symmetric_tree(1,20)
print('Tree:',tree)

# set edge lengths on all nodes
edge_length = 1.
tree['edge_length'] = edge_length

# data dimension
d = 2

# root value
tree.root['value'] = root

We then define parameters for the Gaussian transition kernel.

In [None]:
# parameters, variance and observation noise
params = ParameterStore({
    'k_alpha': VarianceParameter(.1), # kernel amplitude, governs global tree variance
    'k_sigma': VarianceParameter(.25), # kernel width, for Gaussian kernels this is proportional to the variance
    'obs_var': VarianceParameter(1e-3) # observation noise variance
    })

Now follows the down transitions. At first, we define the unconditional transitions, which are just Gaussian samples. The covariance is sqare of the pairwise kernel evaluations in `sigma(params)` times the variance parameter times edge lengths.

In [None]:
# vmapped version of down_unconditional. In this version, the function takes a single node, not a batch (contrast to the batched version in mcmc_Gaussian_BFFG.ipynb)
@jax.jit
def down_unconditional(noise,edge_length,parent_value,params,**args):
    def f(noise,edge_length,parent_value):
        var = edge_length # variance is edge length
        return {'value': parent_value+jnp.sqrt(var)*jnp.einsum('ij,j->i',sigma(params),noise)}

    return jax.vmap(f)(noise,edge_length,parent_value)
downmodel_unconditional = DownLambda(down_fn=down_unconditional)
down_unconditional = LevelwiseTreeExecutor(downmodel_unconditional)

We can now draw noise and perform a downwards pass. This gives values at all nodes of the tree. Note that observation noise is not added to the leaves yet.

In [None]:
subkey, key = split(key)
noise_tree = hyperiax.tree.initializers.initialize_noise(tree, subkey, (n*d,))
dtree = down_unconditional.down(noise_tree,params.values())



Add uncorrelated observation noise to leaves.



In [None]:
# copy the sampled tree and add noise to leaf nodes
leaf_tree = dtree.copy()
for node in leaf_tree.iter_leaves():
    key, subkey = split(key)
    node['value'] += jnp.sqrt(params['obs_var'].value)*jax.random.normal(subkey,node['value'].shape) # add observation noise

Plot a generated tree.

In [None]:
fig,ax = plt.subplots(figsize=(10,8))
for i in range(n):
    dtree.plot_tree_2d(selector=lambda z: z['value'].reshape((n,d))[i],ax=ax)
plt.gca().set_title('Sampled tree without leaf noise')



We now define the backwards filter through the up function. The Gaussian are parametrized in the $(c,F,H)$ format make the fuse just a sum of the results of the up operation. See https://arxiv.org/abs/2203.04155 for details.

In [None]:
# backwards filter
@jax.jit
def up(noise,edge_length,F_T,H_T,params,**args):
    def f(noise,edge_length,F_T,H_T):
        var = edge_length # variance is edge length
        covar = var*a(params) # covariance matrix

        Sigma_T = jnp.linalg.inv(H_T)
        v_T = Sigma_T@F_T
    
        invPhi_0 = (jnp.eye(n*d)+H_T*covar)
        Sigma_0 = Sigma_T@invPhi_0 # = Sigma_T+covar
        H_0 = jnp.linalg.inv(Sigma_0)
        F_0 = jnp.linalg.solve(invPhi_0,F_T)
        v_0 = Sigma_0@F_0
        c_0 = -jax.scipy.stats.multivariate_normal.logpdf(v_0,jnp.zeros(n*d),Sigma_0)
    
        return {'c_0': c_0, 'F_0': F_0, 'H_0': H_0, 'F_T': F_T, 'H_T': H_T}
    return jax.vmap(f)(noise,edge_length,F_T,H_T)

We initialize the tree for up by computing the $c,F,H$-values at the leaves.

In [None]:
# initialize tree for up
def init_up(tree,params):
    for node in tree.iter_bfs():
        if node.children and node.parent:
            del node.data['value']
        else:
            v = node['value']
            Sigma = params['obs_var'].value*jnp.eye(n*d)
            H = jnp.eye(n*d)/params['obs_var'].value
            F = H@v
            c = -jax.scipy.stats.multivariate_normal.logpdf(v,jnp.zeros(n*d),Sigma)
            node.data = {**node.data, 'F_T': F, 'H_T': H}



We can now define the conditional downwards pass, i.e. the forwards guiding.

In [None]:
@jax.jit
def down_conditional(noise,edge_length,F_T,H_T,parent_value,params,**args):
    def f(noise,edge_length,F_T,H_T,parent_value):
        x = parent_value
        var = edge_length # variance is edge length
        covar = var*a(params) # covariance matrix

        invSigma = jnp.linalg.inv(covar)
        H = H_T+invSigma
        mu = jnp.linalg.solve(H,F_T+invSigma@x)
        #return {'value': mu+jnp.linalg.solve(jnp.real(jax.scipy.linalg.sqrtm(H)),noise)}
        return {'value': mu+jax.scipy.linalg.solve_triangular(jax.scipy.linalg.cholesky(H,lower=True),noise)}

    return jax.vmap(f)(noise,edge_length,F_T,H_T,parent_value)



We create the model and executor for the backwards filter (up) and forwards guiding (down).

In [None]:
# create model and executor
updownmodel = UpDownLambda(up_fn=up,fuse_fn=sum_fuse_children(axis=0),down_fn=down_conditional)
updown = LevelwiseTreeExecutor(updownmodel)

We make an upwards pass and a downwards conditional sampling to test. Subsequently, we time the three operations (uncondtional down, conditional down, and up).

In [None]:
# backwards filter and fowards guiding
utree = leaf_tree.copy()
init_up(utree,params)
utree = updown.up(utree,params.values())
utree.root['value'] = root
dtree_conditional = updown.down(utree,params.values())

# time the operations
subkey, key = split(key)
noise_tree = hyperiax.tree.initializers.initialize_noise(tree, subkey, (n*d,))
%time down_unconditional.down(noise_tree,params.values())
%time updown.up(utree,params.values())
updown.up(utree,params.values())
%time updown.up(utree,params.values())

We test the setup by sampling a number of trees and computing mean and covariance of the leaf data.

In [None]:
# do statistics on the leaf values
leaves = jnp.array([n['value'] for n in dtree.iter_leaves()])
c = utree.root['c_0']; F = utree.root['F_0']; H = utree.root['H_0']
print("root conditional mean vs. sample mean:",jnp.linalg.solve(H,F),jnp.mean(leaves,0))
print("root conditional cov vs. sample cov:",jnp.linalg.inv(H),jnp.cov(leaves.T))

# sample statistics
K = 500 # number samples
samples = jnp.zeros((K,len(list(dtree.iter_leaves())),n*d))

# sample new noise
update_noise = lambda tree,key: update_noise_inplace(lambda node,new: new,tree,key)

for i in tqdm(range(K)):
    subkey, key = split(key)
    update_noise(utree,subkey)
    dtree = updown.down(utree,params.values())
    # collect values
    samples = samples.at[i].set(jnp.array([n.data['value']-m.data['value'] for n,m in zip(dtree.iter_leaves(),leaf_tree.iter_leaves())]))
print("observation noise: ")
print("mean: ",jnp.mean(samples,axis=(0,1)))
print("cov: ",jnp.cov(samples.reshape(-1,n*d).T))
for i in tqdm(range(K)):
    subkey, key = split(key)
    update_noise(noise_tree,subkey)
    dtree = down_unconditional.down(noise_tree,params.values())
    # collect values
    samples = samples.at[i].set(jnp.array([n.data['value'] for n in dtree.iter_leaves()]))
    # add observation noise
    subkey,key = jax.random.split(key)
    samples = samples.at[i].set(samples[i]+jnp.sqrt(params['obs_var'].value)*jax.random.normal(subkey,samples[i].shape))
 print sample statistics for all leaves
print("leaves: ")
for i in range(samples.shape[1]):
    print("mean: ",jnp.mean(samples[:,i],0))
    print("cov: ",jnp.cov(samples[:,i].T))

# MCMC

In [None]:
import matplotlib.pyplot as plt

# Define the inverse gamma log PDF function
def inverse_gamma_logpdf(x, alpha, beta):
    return alpha * jnp.log(beta) - jax.scipy.special.gammaln(alpha) - (alpha + 1) * jnp.log(x) - beta / x

# Set the parameters for the inverse gamma distribution
alpha = 2
beta = 0.003

# Generate values in the interval from 0 to 0.01
x_values = jnp.linspace(0.0001, 0.01, 100)  # Avoid zero to prevent log(0)
y_values = inverse_gamma_logpdf(x_values, alpha, beta)

# Plot the function
plt.figure(figsize=(8, 4))
plt.plot(x_values, y_values, label=f'Inverse Gamma Log PDF (alpha={alpha}, beta={beta})')
plt.title('Inverse Gamma Log PDF')
plt.xlabel('x')
plt.ylabel('Log PDF')
plt.legend()
plt.grid(True)
plt.show()


We make two MCMC runs: First, we use that the model with constant node covariance is fully Gaussian and we can read the data likelihood directly from the results of upwards pass. Subsequently, we also sample the state of the tree to get a likelihood approximation from the conditional downwards pass. This version is not necessary in the current model, but it points towards how inference in non-Gaussian models (e.g. non-linear diffusion processes along the edges) will look.

In [None]:
# inference for Gaussian model, likelihood from backwards filtering
def log_likelihood(state):
    """Log likelihood of the tree."""
    parameters,tree = state
    v,c,F,H = tree.root['value'],tree.root['c_0'],tree.root['F_0'],tree.root['H_0']
    return -c+F@v-.5*v.T@H@v

def log_posterior(data,state):
    """Log posterior given the state and data."""
    parameters,_ = state
    log_prior = parameters.log_prior()
    log_like = log_likelihood(state)
    return log_prior + log_like

def proposal(data, state, key):
    parameters,tree = state

    new_parameters = parameters.propose(key)
    # backwards filtering with updated parameters
    utree = data.copy(); init_up(utree,new_parameters)
    utree = updown.up(utree,new_parameters.values())

    return new_parameters,utree

# tree values and parameters
init_params = ParameterStore({
    'k_alpha': VarianceParameter(.25), # kernel amplitude, governs global tree variance
    'k_sigma': VarianceParameter(.5), # kernel width, for Gaussian kernels this is proportional to the variance
    'obs_var': VarianceParameter(1e-3,alpha=2,beta=.003,keep_constant=False) # observation noise variance. We keep it constant here because of lacking identifiability (which will not be the case with higher-dimensional correlated data)
    })
print("Initial parameters: ",init_params.values())
print("data parameters: ",params.values())

# initial state
leaf_tree.root['value'] = root
utree = leaf_tree.copy(); init_up(utree,init_params)
init_state = (init_params,updown.up(utree,init_params.values()))

# Run Metropolis-Hastings
subkey, key = split(key)
log_likelihoods,samples = metropolis_hastings(log_posterior, proposal, leaf_tree, init_state, 200, burn_in=200, rng_key=key, savef=lambda state: state[0])

# plot
plt.plot(log_likelihoods)
plt.xlabel("Iterations")
plt.title('Log likelihood')
trace_plots(samples)

In [None]:
# inference for Gaussian model, likelihood from forward guiding

# Crank-Nicolson update with possibly node-dependent lambd
lambd = lambda node: .9
update_CN = lambda tree,key: update_noise_inplace(lambda node,new: node['noise']*lambd(node)+jnp.sqrt((1-lambd(node)**2))*new,tree,key)
zero_noise = lambda tree,key: update_noise_inplace(lambda node,new: jnp.zeros_like(node['noise']),tree,key)

# downwards pass to compute likelihoods
#@jax.jit
def down_log_likelihood(noise,value,edge_length,parent_value,params,**args):
    var = edge_length # variance is edge length
    #covar = jnp.einsum('i,jk->ijk',var,a(params)) # covariance matrix
    covar = a(params)/params['k_alpha']**2 # covariance without amplitude
    sqrt_covar = sigma(params)/params['k_alpha'] # square root covariance without amplitude
    #sqrt_covar = jnp.real(jax.scipy.linalg.sqrtm(covar))
    #chol_covar = jax.scipy.linalg.cholesky(covar,lower=True)

    #return {'log_likelihood': jax.vmap(lambda value,m,covar: jax.scipy.stats.multivariate_normal.logpdf(value,m,covar))(value,parent_value,covar)}
    #return {'log_likelihood': jnp.mean(jax.scipy.stats.norm.logpdf(noise),1) }
    #return {'log_likelihood': jnp.mean(jax.scipy.stats.norm.logpdf(
    #    jax.vmap(lambda v,m,covar: jnp.linalg.solve(jnp.real(jax.scipy.linalg.sqrtm(covar)),v-m))(value,parent_value,covar)
    #    ),1)
    #    }
    #return {'log_likelihood': jnp.mean(jax.scipy.stats.norm.logpdf(
    #    jax.vmap(lambda v,m,covar: jax.scipy.linalg.solve_triangular(jax.scipy.linalg.cholesky(covar,lower=True),v-m))(value,parent_value,covar)
    #    ),1)
    #    }
    #return {'log_likelihood': jnp.mean(
    #    jax.vmap(lambda v,m,var: jax.scipy.stats.norm.logpdf(jax.scipy.linalg.solve_triangular(chol_covar,v-m),0,jnp.sqrt(var)*params['k_alpha']))(value,parent_value,var)
    #    ,1)
    #    }
    #diffs = jax.vmap(lambda v,m,var: v-m)(value,parent_value,var)
    #anoise = jax.vmap(lambda v,m,var: jnp.linalg.solve(jnp.sqrt(var)*sqrt_covar,v-m))(value,parent_value,var)
    #print('z',anoise.shape,jnp.mean(anoise,0),jnp.diag(jnp.cov(anoise.T)))
    #print('z',anoise.shape,jnp.diag(jnp.cov(anoise.T)))
    return {'log_likelihood': jnp.mean(
        jax.vmap(lambda v,m,var: jax.scipy.stats.norm.logpdf(jnp.linalg.solve(sqrt_covar,v-m),0,jnp.sqrt(var)*params['k_alpha']))(value,parent_value,var)
        ,1)
        }
downmodel_log_likelihood = DownLambda(down_fn=down_log_likelihood)
down_log_likelihood = LevelwiseTreeExecutor(downmodel_log_likelihood,batch_size=100)

# log likelihood of the tree
def log_likelihood(data,state):
    """Log likelihood of the tree."""
    params,tree = state
    log_likelihood_tree = down_log_likelihood.down(tree,params.values())
    log_likelihood_tree.root['log_likelihood'] = 0
    tree_log_likelihood = jnp.mean(jnp.array([node['log_likelihood'] for node in log_likelihood_tree.iter_bfs()]))
    residuals = jnp.array([sample['value']-obs['value'] for sample,obs in zip(tree.iter_leaves(),data.iter_leaves())]) 
    leaves_log_likelihood = jnp.mean(jax.scipy.stats.norm.logpdf(residuals,0,jnp.sqrt(params['obs_var'].value))
    return tree_log_likelihood+leaves_log_likelihood

def log_posterior(data,state):
    """Log posterior given the state and data."""
    parameters,_ = state
    log_prior = parameters.log_prior()
    log_like = log_likelihood(data,state)
    return log_prior + log_like

def proposal(data, state, key):
    subkeys = jax.random.split(key,2)
    parameters,tree = state

    # new tree with the leaf data
    utree = tree.copy(); 
    for data_leaf,utree_leaf in zip(data.iter_leaves(),utree.iter_leaves()):
        utree_leaf['value'] = data_leaf['value']

    # update parameters
    new_parameters = parameters.propose(subkeys[0])
    # backwards filtering with updated parameters
    init_up(utree,new_parameters)
    utree = updown.up(utree,new_parameters.values())

    # update tree, CN update and forward filtering with the updated noise
    utree_CN = update_CN(utree,subkeys[1])
    dtree = updown.down(utree_CN,new_parameters.values())

    return new_parameters,dtree

# tree values and parameters
init_params = ParameterStore({
    'k_alpha': VarianceParameter(.25,alpha=3,beta=.5,keep_constant=False), # kernel amplitude, governs global tree variance
    'k_sigma': VarianceParameter(.5,keep_constant=False), # kernel width, for Gaussian kernels this is proportional to the variance
    'obs_var': VarianceParameter(1e-3,alpha=2,beta=.003,keep_constant=False) # observation noise variance. We keep it constant here because of lacking identifiability (which will not be the case with higher-dimensional correlated data)
    })
print("Initial parameters: ",init_params.values())
print("data parameters: ",params.values())

# initial state
leaf_tree.root['value'] = root
leaf_tree = zero_noise(leaf_tree,key)
utree = leaf_tree.copy(); init_up(utree,init_params);
init_state = (init_params,updown.down(updown.up(utree,init_params.values()),init_params.values()))

# Run Metropolis-Hastings
subkey, key = split(key)
log_likelihoods, samples = metropolis_hastings(log_posterior, proposal, leaf_tree, init_state, 200, burn_in=200, rng_key=key, savef=lambda state: state[0])

# plot
plt.plot(log_likelihoods)
plt.xlabel("Iterations")
plt.title('Log likelihood')
trace_plots(samples)