Her er et første skud på beskrivelse af local updates (markdown, can feks indsættes øverst i localfast.ipynb). Der er sikkert fejl i det, skal kigge det igennem igen, men jeg tror den overordnede struktur er nogenlunde rigtig:
Let $n$ be a node with parent $p$ and children $c_1,\ldots,c_k$. Let $l_{p,n}$ denote the edge length from $p$ to $n$. Let $v_n$ be the value of $n$. We assume that $v_n|v_p\simeq \mathcal N(v_p,l_{p,n}\Sigma(v_p,\theta_p))$ where $\Sigma(v_p)$ is the covariance matrix, a function of the node value $v_p$ and parameters $\theta_p$.
We wish to sample from the distribution $v_n|v_p,v_{c_1},\ldots v_{c_n}$. We have
$$p(v_n|v_p,c_1...,c_n) = \frac{p(v_{c_1},\ldots,v_{c_n}|v_p,v_n)p(v_n|v_p)}{p(v_{c_1},\ldots,v_{c_n}|v_p)}= \frac{p(v_{c_1}|v_n)\cdots p(v_{c_n}|v_n)p(v_n|v_p)}{p(v_{c_1},\ldots,v_{c_n}|v_p)}$$
using that $v_{c_1}|v_n,\ldots,v_{c_n}|v_n$ are independent.
All parent conditional probabilities are Gaussian, and we write the densities on canonical form for exponential families, i.e. $p(v_n|v_p) = \exp(-c_p+F_p^Tv_n-\frac12 v_n^TH_pv_n)$ (see https://arxiv.org/abs/2203.04155). Note that $c=-\log p(v;0,\Sigma)$, $F=Hv$ and $H=\Sigma^{-1}$. Write $H_{c_i}=l_{n,c_i}^{-1}\Sigma(v_n,\theta_n)^{-1}$. In this notation, we have
$$\log p(v_{c_i}|v_n)
= \mathrm{constant}-\frac12(v_{c_i}-v_n)^Tl_{n,c_i}^{-1}\Sigma(v_n,\theta_n)^{-1}(v_{c_i}-v_n)
= -c_{c_i}+F_{c_i}^Tv_n-\frac12 v_n^TH_{c_i}v_n
$$
and thus
\begin{align}
& \log p(v_n|v_p,v_{c_1},\ldots,v_{c_n}) \\
& = -\big(c_p+\sum_{i=1}^n c_{c_i}\big)+\big(F_p+\sum_{i=1}^n F_{c_i}^T\big)v_n-\frac12 v_n^T\big(H_p+\sum_{i=1}^n H_{c_i}\big)v_n
-\log p(v_{c_1},\ldots,v_{c_n}|v_p)
\ .
\end{align}
We don’t need to wory about $\log p(v_{c_1},\ldots,v_{c_n}|v_p)$, since it cancels out in the MH step.
In the MH step of the MCMC sampler, we sample a new value and parameters $v_n’,\theta_n’$ based on $v_n,\theta_n$ and accept/reject it by evaluating the log-ratio
$$\log p(v_n’|v_p,v_{c_1},\ldots,v_{c_n})-\log p(v_n|v_p,v_{c_1},\ldots,v_{c_n})$$
According to the above, to evaluate this we need $\sum_{i=1}^n c_{c_i}$, $\sum_{i=1}^n F_{c_i}^T$, and $\sum_{i=1}^n H_{c_i}$ from the up operation with $H_{c_i}=l_{n,c_i}^{-1}\Sigma(v_n,\theta_n)^{-1}$, $F_{c_i}=H_{c_i}v_{c_i}$, and $c_{c_i}=-\log \phi(v_{c_i};0,l_{n,c_i}\Sigma(v_n,\theta_n))$ where $\phi(x;0,\Sigma)$ is the Gaussian density.
This gives the operations
#
- Up operation:
  - Compute $H_{c_i}=l_{n,c_i}^{-1}\Sigma(v_n)^{-1}$
  - Compute $F_{c_i}=H_{c_i}v_{c_i}$
  - Compute $c_{c_i}=-\log \phi(v_{c_i};0,l_{n,c_i}\Sigma(v_n))$
- Reduce operation:
  - Sum $\sum_{i=1}^n c_{c_i}$, $\sum_{i=1}^n F_{c_i}^T$, and $\sum_{i=1}^n H_{c_i}$
- Down operation:
  - Compute $H_p=l_{p,n}^{-1}\Sigma(v_p)^{-1}$
  - Compute $F_p=H_pv_p$
  - Compute $c_p=-\log \phi(v_n;0,l_{p,n}\Sigma(v_p))$
- Local update:
  - Propose new $v_n’$ and $\theta_n’$
  - Compute acceptance ratio using the up and down results
  - Accept or reject the proposal
Note above that we only need to invert $\Sigma(v_n,\theta_n)$ one time to compute the up for all children.
Nodes can be updated sequentially or in parallel with the same result as long as no node is the parent or child of another node that is being updated. This is achieved with a red/black node partition.

In [2]:
%load_ext autoreload
%autoreload 2

In [23]:
import hyperiax
from jax.random import PRNGKey, split
import jax
from jax import numpy as jnp
from hyperiax.execution import OrderedExecutor, UnorderedExecutor
from hyperiax.models import UpLambdaReducer, DownLambda, UpLambda, UpdateLambdaReducer
from hyperiax.models.functional import pass_up
from hyperiax.tree.topology import symmetric_topology
from hyperiax.tree import HypTree
from hyperiax.plotting import plot_tree_text, plot_tree_2d_scatter
from matplotlib import pyplot as plt
import jax
from einops import rearrange, repeat

In [16]:
key = PRNGKey(0)

In [17]:
topology = symmetric_topology(height=3, degree=2)
plot_tree_text(topology)

                 None
        ┌─────────┴─────────┐
       None                None        
   ┌────┴────┐         ┌────┴────┐     
  None      None      None      None   
 ┌─┴──┐    ┌─┴──┐    ┌─┴──┐    ┌─┴──┐  
None None None None None None None None


In [6]:
d = 2

In [7]:
tree = HypTree(topology)

tree.add_property('edge_length', shape=(1,))
tree.add_property('obs_var', shape=(1,))
tree.add_property('noise', shape=(d,))
tree.add_property('value', shape=(d,))
tree.add_property('H', shape=(d,d))
tree.add_property('F', shape=(d,))
tree.add_property('C', shape=(1,))

In [8]:
key, k1 = split(key)
tree.data['value'] = jax.random.normal(key, shape=tree.data['value'].shape)
tree.data['noise'] = jax.random.normal(k1, shape=tree.data['noise'].shape)
tree.data['edge_length'] = 1/(tree.node_depths+1) # for testing
tree.data['obs_var'] = tree.data['edge_length']#jnp.ones_like(tree.data['obs_var'])*0.01
repeated_eye = repeat(jnp.eye(d),'i j->n i j', n=len(tree))
sigmas = tree.data['obs_var'][:,:,None]*repeated_eye
tree.data['H'] = repeated_eye/tree.data['obs_var'][:,:,None]
tree.data['F'] = jnp.einsum('nij,nj->ni', tree.data['H'], tree.data['value'])
tree.data['C'] = -jax.scipy.stats.multivariate_normal.logpdf(tree.data['value'],jnp.zeros(d),sigmas)

# only leaf values matter - rest can be assumed undefined despite having value

- Up operation:
  - Compute $H_{c_i}=l_{n,c_i}^{-1}\Sigma(v_n)^{-1}$
  - Compute $F_{c_i}=H_{c_i}v_{c_i}$
  - Compute $c_{c_i}=-\log \phi(v_{c_i};0,l_{n,c_i}\Sigma(v_n))$

In [9]:
def sigma(value):
    return repeat(jnp.eye(value.shape[-1]), 'i j -> n i j', n=value.shape[0])

In [10]:
def up(value, edge_length, parent_value, params):
    H = (1/edge_length[:,:,None])*jnp.linalg.inv(sigma(parent_value))
    F = jnp.einsum('bij,bj->bi', H, value)
    C = -jax.scipy.stats.multivariate_normal.logpdf(value,jnp.zeros(value.shape[-1]),sigma(value))
    return {'H': H, 'F': F, 'C': C}

In [11]:
def transform(child_H, child_F, child_C, params):
    return {'H': child_H, 'F': child_F, 'C': child_C}

In [12]:
model = UpLambdaReducer(up, transform, {'H': 'sum', 'F': 'sum', 'C': 'sum'})
exe = OrderedExecutor(model)

In [13]:
exe.up(tree)

In [1]:
def down():
    ## not sure about equations and what to pass
    ...

- Local update:
  - Propose new $v_n’$ and $\theta_n’$
  - Compute acceptance ratio using the up and down results
  - Accept or reject the proposal

\begin{align}
& \log p(v_n|v_p,v_{c_1},\ldots,v_{c_n}) \\
& = -\big(c_p+\sum_{i=1}^n c_{c_i}\big)+\big(F_p+\sum_{i=1}^n F_{c_i}^T\big)v_n-\frac12 v_n^T\big(H_p+\sum_{i=1}^n H_{c_i}\big)v_n
-\log p(v_{c_1},\ldots,v_{c_n}|v_p)
\ .
\end{align}

In [59]:
def update(child_H, child_F, child_C, parent_C, parent_F, parent_H, value, leaf_mask, root_mask, **kwargs):
    v = value ## propose instead
    #theta = ...
    # do this batched
    #ll = - (parent_C + child_C) + (parent_F + child_F.T)-0.5*v.T*(parent_H + child_H)*v
    accepted = jnp.ones(v.shape[0]) # properly calculate this

    new_val = jnp.where(accepted[:,None], v, value)
    return_val = jnp.where(leaf_mask[:,None], value, new_val)
    return {'value': return_val}

def up2(H,C,F, **kwargs):
    return {'H': H, 'C': C, 'F': F}

model = UpdateLambdaReducer(up_fn=up2, update_fn=update, reductions={'H': 'sum', 'F': 'sum', 'C': 'sum'})

In [60]:
exe = UnorderedExecutor(model)

In [61]:
exe.update(tree, key=key)

[False False False False]
[False False False False]


TypeError: add got incompatible shapes for broadcasting: (4, 2), (2, 4).