# MOET structure learning via variational inference

We start with an unnormalized measure:
$$
\begin{align}
&P(dz, x; w) :: M~\mathbb{R} \\
&P(dz, x; w) = \textbf{do}\{z \leftarrow P_z;~\textbf{return}~R(z; w, x)\} \\
\end{align}
$$

We define the density of the unnormalized measure (and assume $P_z$ has access to its own density):
$$
\begin{align}
P(z, x; w) = P_z(z) * R(z; w, x)
\end{align}
$$

where $P_z$ is a prior over _trees_ ($z$) and $R$ is a reward function which returns the _score of the data $x$_ using probabilistic circuit with tree structure $z$ and weights $w$.

## Variational inference

We seek to learn a proposal over tree structure using variational inference. We construct two gradient estimators: we start with the log marginal likelihood of the data objective, and arrive at the evidence lower bound (ELBO):

$$L(w, \theta; x) = \log P(x; w) = \log E_{z \sim Q(dz; \theta)}[\frac{P(z, x; w)}{ Q(z; \theta)}] \geq \underbrace{E_{z \sim Q(dz; \theta)}[\log P(z, x; w) - \log Q(z; \theta)]}_{ELBO}$$

Let $\mathcal{L}(z, x; w, \theta) = \log P(z, x; w) - \log Q(z; \theta)$. From the ELBO, we derive two estimators:

$$
\begin{align}
\nabla_w L(w, \theta; x) &= E_{z \sim Q(dz; θ)}[\nabla_w \log R(z, x; w)] \\
\nabla_\theta L(w, \theta; x) &= E_{z \sim Q(dz; \theta)}[\nabla_\theta \mathcal{L}(r; w, \theta, x) + \mathcal{L}(r; w, \theta, x) \times \nabla_\theta \log Q(r; \theta)]
\end{align}
$$

### Improved gradient estimator efficiency via auxiliary variables

The performance of the REINFORCE estimator depends greatly on the structure of $R(z; w, x)$. Because REINFORCE will only provide non-zero gradient signal when (TODO)

In other words, if $Q(dz; \theta)$ spreads its probability mass away from the support of the unnormalized $P$, the estimator loses efficiency. We can correct for this effect by constructing a proposal $Q$ which builds a valid (on support) tree $z$.

We write:

$$
\begin{align}
&Q(dz; \theta) :: M~\mathcal{Z} \\
&Q(dz; \theta) = \textbf{do}\{ r \leftarrow Q_r(\theta); z \leftarrow Q_z(r); \textbf{return}~z \}
\end{align}
$$

The REINFORCE estimator requires that we evaluate the density of $Q$, but now the distribution over $z$ which $Q$ defines involves a projective pushforward of a joint distribution over $(r, z)$ (a pushforward which forgets $r$, which is equivalent to marginalizing $r$ out of the joint distribution). 

We can estimate the density of $Q(z; \theta)$ by constructing an importance sampling estimate of the density. First, we define the joint:

$$
\begin{align}
&Q'(dr, dz; \theta) :: D~(\mathcal{R} \times \mathcal{Z}) \\
&Q'(dr, dz; \theta) = \textbf{do}\{ r \leftarrow Q_r(\theta); z \leftarrow Q_z(r); \textbf{return}~(r, z) \}
\end{align}
$$

where $D$ denotes a measure with normalized density. Now, we construct a sampling estimator:

$$
\begin{align}
&\chi(dz, dw; \theta) :: M~(\mathcal{Z} \times \mathcal{R}) \\
&\chi(dz, dw; \theta) = \textbf{do}\{ (r, z) \leftarrow Q'(\theta); \textbf{return} (z, \frac{Q''(r; z)}{Q'(r, z)}) \}
\end{align}
$$

The term $\frac{1}{w} = \frac{Q''(r; z)}{Q'(r, z)}$ is the reciprocal of an importance weight $w$. Under the sampler $Q(dz; \theta) = \textbf{proj}_0~\chi(dz, dw; \theta)$, the property of the weight is:

$$
\begin{align}
\mathbb{E}_{z \sim Q(dz; \theta)}[\frac{1}{w}] = \frac{1}{Q(z; \theta)}
\end{align}
$$