<script>
    function findAncestor (el, name) {
        while ((el = el.parentElement) && el.nodeName.toLowerCase() !== name);
        return el;
    }
    function colorAll(el, textColor) {
        el.style.color = textColor;
        Array.from(el.children).forEach((e) => {colorAll(e, textColor);});
    }
    function setBackgroundImage(src, textColor) {
        var section = findAncestor(document.currentScript, "section");
        if (section) {
            section.setAttribute("data-background-image", src);
			if (textColor) colorAll(section, textColor);
        }
    }
</script>

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/title-slide-background.png");
    
</script>

<h1 style="color:White;">Variational Bayesian Inference</h1>
<h2 style="color:White;" >Pavan Chaggar</h2>

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>
## Outline: 

* Introduction and Motivation 
* Variational Bayes
* Application to modelling Alzheimer's

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>

## Introduction and Motivation: Why do Bayesian Inference?

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>

Very simply, we wish to assess the evidence for some hypothesis given some data. Or, alternatively, ask what does this model tell us about the data?

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/presentations/ox-math-background.png");
</script>
We can do this using the Bayes-Price-Laplace rule, as follows.

For observations $\mathbf{x} = x_{1:n}$ and latent variables  $\mathbf{z} = z_{1:m}$, we have a join density

$$ p(\mathbf{x}, \mathbf{z}) $$

To evalulate a particular hypothesis, we need to evaluate the posterior $ p(\mathbf{z} \mid \mathbf{x}) $, thus we decompose the joint distribution:

$$ p(\mathbf{x}, \mathbf{z}) = p(\mathbf{x} \mid \mathbf{z})p(\mathbf{z}) = p(\mathbf{z} \mid \mathbf{x})p(\mathbf{x}) $$

and we can obtain the posterior by the Bayes-Price-Laplace rule: 

$$p(\mathbf{z} \mid \mathbf{x}) = \frac{p(\mathbf{x} \mid \mathbf{z})p(\mathbf{z})}{p(\mathbf{x})} $$

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>

### Bayes-Price-Laplace Rule

$$ \underbrace{p(\mathbf{z} \mid \mathbf{x})}_{posterior} = \frac{\overbrace{p(\mathbf{x} \mid \mathbf{z})}^{likelihood}\overbrace{p(\mathbf{z})}^{prior}}{\underbrace{p(\mathbf{x})}_{evidence}} $$


- Likelihood: Probability that a particular set of parameter values generate the observations.

- Prior: Probability representing our initial beliefs about the parameter values.

- Evidence: Normalising factor; probability of observing our data (given our model). Otherwise known as the marginal likelihood.

- Posterior: Probability that some data are _caused_ by some set of parameters.

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>

So, why can't we just calculate this? The problem lies in the denominator, evaluated as: 

$$ {p(\mathbf{x})} = \int p(\mathbf{x} , \mathbf{z}) d\mathbf{z} $$

Typically, this integral is difficult to solve analytically or computationally, primarily since the problem becomes intractable with increasing numbers of latent variables. 

Therefore, we need a way to evaluate the posterior distribution without computing the denominator.

To do this, we turn to approximate Bayesian inference. While there many ways to approximate the posterior distribution, we use variational inference for computational efficiency and speed. (More on this later.) 

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>

## Variational Bayesian Inference

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>

VB aims to circumvent the large time complexity by approaching the problem through optimisation. 

The process begins by positing a contrived _approximate_ density, $\mathfrak{D}$ of latent variables $\mathbf{z}$. 

Then, within this family, we wish to find the approximate posteriors that minimise the Kullback-Leibler divergence to the true posterior. 

$$ q^{*}(\mathbf{z}) = \underset{q(\mathbf{z}) \in \mathfrak{D}}{argmin} \mathrm{KL}(q(\mathbf{z}) \mid \mid p(\mathbf{z} \mid \mathbf{x})) $$

However, this still depends on the intracable evidence...

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>

To make this more clear, we can rewrite the expression using Bayes' rule: 

$$ q^{*}(\mathbf{z}) = \underset{q(\mathbf{z}) \in \mathfrak{D}}{argmin} \mathrm{KL}\bigg(q(\mathbf{z}) \mid \mid \frac{p(\mathbf{x} \mid \mathbf{z})p(\mathbf{z})}{p(\mathbf{x})}\bigg) $$


We can avoid evidence term with the following manipulations. First, let's break down the KL divergence, recalling that the KL divergence can otherwise be expressed as the expectated value of the first argument minus the expected value of the second argument: 

$$ \mathrm{KL}(q(\mathbf{z}) \mid \mid p(\mathbf{z} \mid \mathbf{x})) = \mathbb{E}[\log q(\mathbf{\mathbf{z}})] - \mathbb{E}[\log p(\mathbf{x} \mid \mathbf{z})] - \mathbb{E}[\log p(\mathbf{z})] + \log p(\mathbf{x}) $$


And simply rearrange and drop the constant evidence term: 

\begin{align}
    \label{eqn:free-energy2}
-\mathrm{KL}(q(\mathbf{z}) \mid \mid p(\mathbf{z} \mid \mathbf{x})) &> \mathbb{E}[\log p(\mathbf{x} \mid \mathbf{z})] - \mathbb{E}[\log q(\mathbf{z})] + \mathbb{E}[\log p(\mathbf{z})] \\
    \label{eqn:free-energy3}
    \mathbf{F} &= \underbrace{\mathbb{E}[\log p(\mathbf{x} \mid \mathbf{z})]}_{accuracy} - \underbrace{\mathrm{KL}(q(\mathbf{z}) \mid \mid p(\mathbf{z}))}_{complexity}
\end{align}  


<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>

There are a number of methods with which to optmise the free energy. Here, we focus on analytic variational Bayes, which uses the calculus of variations to perform the optimisation.

Importantly, we must make a mean-field assumption such that: 

$$ q(z) = \prod_{i} q_{z_{i}}(z_i) $$


After rewriting $\mathbf{F}$ as a density we formulate an Euler-Lagrange equation that has the following solution:

$$ p(z_{i}) \propto \int p(x \mid z)p(z)q(z_{-i})dz_{-i} $$ 

Where $-i$ represents the latent variables that are not in $i$.


For a full and proper derivation of this, see: Beal, M.J. (2003) Variational Algorithms for Approximate Bayesian Inference

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>


## Applications

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>



In the context of Alzheimer's modelling, we can use the variational inference framework to estimate the posterior distributions of our model parameters given some data.

The first step in doing this is to specify a generative model (a data generating process). 

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>




### Forward Model 

We can set up a forward model in the following way and begin variational inference. 

Firstly, we assume the data, $\mathbf{y}$ are generated by a non-linear forward model with normally distributed noise:

$$ \mathbf{y} = g(\boldsymbol{\theta}) + \epsilon $$ 
$$ \epsilon \approx \mathcal{N}(0, \Phi^{-1}) $$ 



Using this forward model, we can parameterise the log-likelihood with the data conditioned on $\Theta$, a set of independent parameters:

$$ p(\mathbf{y} \mid \mathbf{\Theta}) = \frac{N}{2} \log(\Phi) - \frac{1}{2}(\mathbf{y} - g(\boldsymbol{\theta}))^{T} \Phi (\mathbf{y} - g(\boldsymbol{\theta})) $$



We define priors with the following distributions on $\mathbf{\theta}$ and $\Phi$: 

$$ p(\theta \mid \mathbf{y}) = MVN(\mathbf{m_0}, \Sigma^{-1}_0) $$ 

$$ p(\phi \mid \mathbf{y}) = Ga(s_0, c_0) $$ 



Using these, we can construct the approximate log posterior and thus derive equations. 

$$ L = \log p(\mathbf{y} \mid \mathbf{\Theta}) + \log p(\theta) + \log p(\phi) $$

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>


Lastly, as per the mean field assumption, we can factorise the approximate posterior into two groups:

$$ q(\Theta \mid \mathbf{y}) = q(\theta_n)q(\phi_n) $$ 



and choose conjucate distributions for their form:

$$ q(\theta \mid \mathbf{y}) \approx MVN(\theta; \mathbf{m}, \Lambda^{-1}) $$ 

$$ q(\phi \mid \mathbf{y}) \approx Ga(\phi; s, c) $$ 



Remeber, these are two critical assumptions. First, the mean field assumptions ensures we can update each parameter interatively and independently. Second, the conjugacy allows us to derive the updates analytically!

For a full derivation of the update rules (and code), see: https://github.com/PavanChaggar/Bayesian_inference

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>


## A Software Interlude...

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>



In [None]:
class NetworkFKPP(Model):
    def fkpp(self, p, t, theta):
        k, a = theta
        du = k * (-self.L() @ p) + (a * p) * (1 - p)
        return du

    def solve(self, p, theta):
        return numerically integrate fkpp

    def forward(self, u0): 
        p = u0[:-2]
        theta = u0[-2:]
        
        u = self.solve(p, theta) 
        return u 

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>



In [None]:
p = np.zeros([83]) + 1e-5
mask = [25, 26, 39, 40, 66, 67, 80, 81]
p[mask] = 0.1

k = 5
a = 10

m = NetworkFKPP(adjacency_matrix)

m.t = np.linspace(0,1,100)

u0 = np.append(p, [k, a])

sim = m.forward(u0)

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>


<img align="right" width="500" height="750" src="forward_results.png">

This generates a forward model that produces the curve seen in (a) 

We can generate synthetic data by adding Gaussian noise, $\mathcal{N}(0,0.1)$

Our goal is to infer the initial parameter values underlying (a) from the data (b)

We now use variational Bayes to _invert_ our generative model and estimate the posterior distributions!

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>



In [None]:
# set priors 
p0 = np.zeros([83])
k0 = 0
a0 = 0

u_0 = np.append(p0, [k0, a0])

problem = VBProblem(model=m, data=data, init_means=u_0)

In [None]:
n=100

sol, F = problem.infer(n=n)

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/presentations/ox-math-background.png");
</script>

<img align="right" width="500" height="500" src="inference_results.png">

| Parameter      | True Value   | Inferred Value|
| :------------- | :----------: | -----------:  |
|  $k$           | 5.           | 2.938         |
|  $\alpha$      | 10           | 9.976       \||

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>


# Work to do

* Implement different VB methods for more flexibility

* Create generative models for fewer data points

* Test other inference methods such as simulation based inference

* Write a paper, get a DPhil etc... 

<script>
    setBackgroundImage("/Users/pavanchaggar/Documents/ResearchDocs/Presentations/background_imgs/ox-math-background.png");
</script>

## Useful References

Beal, M.J. (2003)
Variational Algorithms for Approximate Bayesian Inference
https://cse.buffalo.edu/faculty/mbeal/thesis/

Blei, D. et al., (2017)
Variational Inference: A Review for Statisticians 
https://arxiv.org/pdf/1601.00670.pdf

CHappell, M. Groves, A.R. Woolrich M.W.
The FMRIB Variational Bayes Tutorial
https://vb-tutorial.readthedocs.io/en/latest/