# Hierarchical variational structure learning

See appendix for details.

$\texttt{log\_R} :: (Z, R, \mathbb{R}^N) \rightarrow \mathbb{R}$

$\texttt{log\_Q} :: (Z, R, \mathbb{R}^K) \rightarrow \mathbb{R}$

$\texttt{log\_Qp} :: (Z, R) \rightarrow \mathbb{R}$

Assumed usage: first sample $(z, r) \sim Q$ and then use $\texttt{grad\_w}$ and $\texttt{grad\_}\theta$.

In [10]:
from jax import grad

def grad_w(z, r, log_R, w):
    def _log_R(w):
        return log_R(z, x, w)
        
    return jax.grad(_log_R)(w)

def grad_theta(z, r, log_R, log_Q, log_Qp, w, θ):
    def L(θ):
        # R * Qp (meta) on top
        # Q on the bottom.
        # Assume: Qp doesn't depend on theta -- add a theta
        # to Qp if this is wrong.
        return log_R(z, r, w) + log_Qp(r, z) - log_Q(z, r, θ)

    grad_L = jax.grad(L)

    def log_Q(θ):
        return log_Q(z, r, θ)

    grad_log_Q = jax.grad(log_Q)
    
    return grad_L(θ) + L(θ) * grad_log_Q(θ)