# MLSS2019: Bayesian Deep Learning

In this tutorial we will learn what basic building blocks are needed
to endow (deep) neural networks with uncertainty estimates, and how
this can be used in active learning or expert-in-the-loop pipelines.

The plan of the tutorial
1. [Setup and imports](#Setup-and-imports)
2. [Easy uncertainty in networks](#Easy-uncertainty-in-networks)
   1. [Adding stochasticity](#Adding-stochasticity)
   2. [Implementing function sampling with the DropoutLinear Layer](#Implementing-function-sampling-with-the-DropoutLinear-Layer)
   3. [Implementing-DropoutLinear](#Implementing-DropoutLinear)
   3. [Comparing sample functions to point-estimates](#Comparing-sample-functions-to-point-estimates)
3. [A brief reminder on Bayesian and Variational Inference](#A-brief-reminder-on-Bayesian-and-Variational-Inference)
4. [Bayesian Active Learning with images](#Bayesian-Active-Learning-with-images)
   1. [the Acquisition Function](#the-Acquisition-Function)
   2. [Dropout $2$-d Convolutional layer and the model](#Dropout-$2$-d-Convolutional-layer-and-the-model)
   3. [Actively learning MNIST](#Actively-Learning-MNIST)

<br>

## Setup and imports

In this section we import necessary modules and functions and
define the computational device.

First, we install some boilerplate service code for this tutorial.

In [None]:
!pip install -q --upgrade git+https://github.com/ivannz/mlss2019-bayesian-deep-learning.git@developer

Next, numpy for computing, matplotlib for plotting and tqdm for progress bars.

In [None]:
import tqdm
import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt

For deep learning stuff will be using [pytorch](https://pytorch.org/).

If you are unfamiliar with it, it is basically like `numpy` with autograd,
native GPU support, and tools for building training and serializing models.
<!-- (and with `axis` argument replaced with `dim` :) -->

There are good introductory tutorials on `pytorch`, like this
[one](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).

In [None]:
import torch
import torch.nn.functional as F

from torch.nn import Linear, Conv2d

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

We will need some functionality from scikit

In [None]:
from sklearn.metrics import confusion_matrix

Next we import the boilerplate code.

* a procedure that implements a minibatch SGD **fit** loop
* a function, that **evaluates** the model on the provided dataset

In [None]:
from mlss2019bdl import fit

```python
# pseudocode
def fit(model, dataset, criterion, ...):
    for epoch in epochs:
        for batch in dataset:
            loss = criterion(model, batch)  # forward pass

            grad = loss.backward()          # gradient via back propagation

            adam_step(grad)
```

In [None]:
from mlss2019bdl import predict

```python
# pseudocode
def predict(model, dataset, ...):
    for input_batch in dataset:
        output.append(model(input_batch))  # forward pass
    
    return concatenate(output)
```

<br>

## Easy uncertainty in networks

Suppose we have the following model: a 3-layer fully connected
network with LeakyReLU activations.

In [None]:
model = torch.nn.Sequential(
    Linear(1, 512, bias=True),
    torch.nn.LeakyReLU(),

    Linear(512, 512, bias=True),
    torch.nn.LeakyReLU(),

    Linear(512, 1, bias=True),
)

model.to(device)

Generate the initial small dataset $S_0 = (x_i, y_i)_{i=1}^{m_0}$
with $y_i = g(x_i)$, $x_i$ on a regular-spaced grid, and $
g
    \colon \mathbb{R} \to \mathbb{R}
    \colon x \mapsto \tfrac{x^2}4 + \sin \frac\pi2 x
$.
<!--
`dataset_from_numpy` **converts** numpy arrays into torch tensors,
**places** them on the specified compute device, **and packages**
into a dataset
-->

In [None]:
from mlss2019bdl import dataset_from_numpy

X_train = np.linspace(-6.0, +6.0, num=20)[:, np.newaxis]
y_train = np.sin(X_train * np.pi / 2) + 0.25 * X_train**2

train = dataset_from_numpy(X_train, y_train, device=device)

In [None]:
X_domain = np.linspace(-10., +10., num=251)[:, np.newaxis]

domain = dataset_from_numpy(X_domain, device=device)

Let's fit our model on `train` using MSE loss and $\ell_2$ penalty
on weights (`weight_decay`):
$$
    \tfrac1{2 m} \|f_\omega(x) - y\|_2^2 + \lambda \|\omega\|_2^2
    \,, $$
where $\omega$ are all the learnable parameters of the model $f_\omega(\cdot)$.

<br>

Fit, ...

In [None]:
fit(model, train, criterion="mse", n_epochs=2000, verbose=True, weight_decay=1e-4)

..., compute the predictions, ...

In [None]:
y_pred = predict(model, domain)

..., and plot them.

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 5))

ax.scatter(X_train, y_train, c="black", s=40, label="train")

ax.plot(X_domain, y_pred.numpy(), c="C0", lw=2, label="prediction")

plt.legend();

This model seems to fit the train set adequately well. However, there is no
way to assess how confident this model is with respect to its predictions.
Indeed, the prediction $\hat{y}_x = f_\omega(x)$ is is a deterministic function
of the input $x$ and the learnt parameters $\omega$.

<br>

### Adding stochasticity

One inexpensive way to make any network into a stochastic function of its
input is to add dropout before any parameterized layer like `linear`
or `convolutional`, [Hinton et al. 2012](https://arxiv.org/abs/1207.0580).
Essentially, dropout applies a Bernoulli mask to the features of the input.

In [Gal, Y. (2016)](http://www.cs.ox.ac.uk/people/yarin.gal/website/thesis/thesis.pdf)
it has been shown that a simple, somewhat ad-hoc approach, of
adding uncertainty quantification to networks through dropout,
[Hinton et al. 2012](https://arxiv.org/abs/1207.0580),
is a special case of Variational Inference.

```
A simple stochastic regularization method allows uncertainty estimation essentially for free!
```

#### (task) Always Active Dropout

For input
$
    x\in \mathbb{R}^{[\mathrm{in}]}
$ the dropout layer acts like this:

$$
    y_j = x_j \, m_j
    \,, $$

where $m\in \mathbb{R}^{[\mathrm{in}]}$ with $
m_j \sim \pi_p(m_j)
    = \mathcal{Ber}\bigl(\bigl\{0, \tfrac1{1-p}\bigr\}, 1-p\bigr)
$,
i.e. equals $\tfrac1{1-p}$ with probability $1-p$ and $0$ otherwise.

`pytorch` has a function for this `F.dropout(input, p, training)`. It multiplies
each element of the `input` tensor by an independent Bernoulli rv. The argument
`p` has the same meaning as above. The boolean argument `training` toggles the
effect: if `False` then the input is returned as-is, otherwise the mask is applied.

In [None]:
class ActiveDropout(torch.nn.Dropout):
    # There is no need to redefine __init__(...), since
    #  we are directly inheriting from `Dropout`.

    def forward(self, input):
        """We need to permanently latch the `training` toggle
        to `True` in order to enable stochastic forward pass in
        evaluation mode (`model.eval()`).
        """

        ## Exercise: self.p - contains the specified dropout rate

        return F.dropout(input, p=self.p, training=True)

        pass

<br>

#### (task) Rebuilding the model

Let's recreate the model above with this freshly minted dropout layer.
Then fit and plot it's prediction uncertainty due to forward pass stochasticity.

In [None]:
def build_model(p=0.5):
    """Build a model with dropout layers' rate set to `p`."""

    return torch.nn.Sequential(
        ## Exercise: Use ActiveDropout before the linear layers of
        #  our first network. Note that dropping out input features
        #  is not a good idea!

        Linear(1, 512, bias=True),
        torch.nn.LeakyReLU(),

        ActiveDropout(p),
        Linear(512, 512, bias=True),
        torch.nn.LeakyReLU(),

        ActiveDropout(p),
        Linear(512, 1, bias=True),

        # pass
    )

<br>

In [None]:
model = build_model(p=0.5)
model.to(device)

fit(model, train, criterion="mse", n_epochs=2000, verbose=True, weight_decay=1e-5)

<br>

#### (task) Sampling the random output

Let's take the test sample $\tilde{S} = (\tilde{x}_i)_{i=1}^m \in \mathcal{X}$
and repeat the stochastic forward pass $B$ times at each $x\in \tilde{S}$:

* for $b = 1 .. B$ do:

  1. draw $y_{bi} \sim f_\omega(\tilde{x}_i)$ for $i = 1 .. m$.

In [None]:
def point_estimate(model, dataset, n_samples=1, verbose=False):
    """Draw pointwise samples with stochastic forward pass."""

    ## Exercise: collect the random predictions over the dataset
    ##  in a list, and then `stack` them into a B x m x d tensor,
    ##  where d is the dimension of the prediction output.

    outputs = []
    for sample in tqdm.tqdm(range(n_samples), disable=not verbose):

        outputs.append(predict(model, dataset))

    return torch.stack(outputs, dim=0)

    pass


samples = point_estimate(model, domain, n_samples=101, verbose=True)
samples.shape

```python
samples.shape  # should be 101 x 251 x 1
```

<br>

The approximate $95\%$ confidence band of predictions is

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 5))

ax.scatter(X_train, y_train, c="black", s=40, label="train")

mean, std = samples.mean(dim=0).numpy(), samples.std(dim=0).numpy()
ax.plot(X_domain, mean + 1.96 * std, c="k")
ax.plot(X_domain, mean - 1.96 * std, c="k");

Let's inspect the draws $y_{bi}$ as $B$ functional samples:
$(x_i, y_{bi})_{i=1}^m$ - the $b$-th sample path.

Below we plot $5$ random paths.

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 5))

ax.scatter(X_train, y_train, c="black", s=40, label="train")
ax.plot(X_domain[:, 0], samples[:5, :, 0].numpy().T, c="C0", lw=1, alpha=0.25);

It is clear that they are very erratic!

Computing stochastic forward passes with a new mask each time is equivalent
to drawing new **independent** prediction from for each point $x\in \tilde{S}$,
without considering that, in fact, at adjacent points the predictions should
be correlated.

For example, is we were interested in uncertainty at some particular point,
this would be okay: **fast and simple**. In contrast, if we were interested in
the uncertainty of an integral **path-dependent** measure of the whole estimated
function, or were doing **optimization** of the unknown true function taking
estimation uncertainty into account, then this clearly erratic behaviour
of paths is undesirable.
Ex. see [blog: Gal, Y. 2016](http://www.cs.ox.ac.uk/people/yarin.gal/website/blog_2248.html)

#### Question(s) (to ponder in your spare time)

* what will happen if you change the default dropout rate in `ActiveDropout` layer?
  Try to rebuild the model with different $p \in (0, 1)$ using `build_model(p)`,
  and then plot the predictive bands.

In [None]:
pass

<br>

### Implementing function sampling with the DropoutLinear Layer

The naive implementation of `ActiveDropout` above defines the predictive
distribution $y\sim p(y \mid x)$ as $y=f_\omega(x; m)$ for $m \sim \pi_p(m)$,
where $\pi(m)$ denotes the distribution of Bernoulli dropout masks
$\mathcal{Ber}\bigl(\bigl\{0, \tfrac1{1-p}\bigr\}, 1-p\bigr)$.

We need to implement some extra functionality on top of `pytorch`,
in order to draw realizations from the induced distribution over
functions, defined by a network, i.e. $
\bigl\{
    f_\omega\colon \mathcal{X}\to\mathcal{Y}
\bigr\}_{\omega \sim q(\omega)}
$
where $q(\omega)$ is a distribution over the parameters.

<br>

#### Freeze/unfreeze interface

First, we create a base **trait-class** `FreezableWeight` that adds
interface for freezing and unfreezing layer's random **weight**
parameter.

In [None]:
class FreezableWeight(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.unfreeze()

    def unfreeze(self):
        self.register_buffer("frozen_weight", None)

    def is_frozen(self):
        """Check if a frozen weight is available."""
        return isinstance(self.frozen_weight, torch.Tensor)

    def freeze(self):
        """Sample from the distribution and freeze."""
        raise NotImplementedError()

Next, we declare a pair of functions:
* `freeze()` instructs a compatible layer (module) to sample and freeze its randomness
* `unfreeze()` requests the layer to undo this

In [None]:
def unfreeze(module):
    for mod in module.modules():
        if isinstance(mod, FreezableWeight):
            mod.unfreeze()

    return module

In [None]:
def freeze(module):
    for mod in module.modules():
        if isinstance(mod, FreezableWeight):
            mod.freeze()

    return module

<br>

#### Sampling realizations

The algorithm to sample a random function is:
* for $b = 1... B$ do:

  1. draw an independent realization $f_b\colon \mathcal{X} \to \mathcal{Y}$
  with from the process $\{f_\omega\}_{\omega \sim q(\omega)}$
  2. get $\hat{y}_{bi} = f_b(\tilde{x}_i)$ for $i=1 .. m$


* compute mean and variance of $\hat{y}_{bi}$ along $b$

In [None]:
def sample_function(model, dataset, n_samples=1, verbose=False):
    """Draw a realization of a random function."""

    ## Exercise: collect the `frozen` predictions over the dataset
    ##  in a list, and then `stack` them into a B x m x d tensor,
    ##  where d is the dimension of the prediction output.

    outputs = []
    for _ in tqdm.tqdm(range(n_samples), disable=not verbose):

        outputs.append(predict(freeze(model), dataset))

    unfreeze(model)

    return torch.stack(outputs, dim=0)

    pass

**(note)** although the internal loop in both functions looks
similar they, conceptually the functions differ:
<strong>
```python
def point_estimate(f, S):
    for x in S:
        for w in f.q:
            yield f(x, w)


def sample_function(f, S):
    for w in f.q:  # thanks to freeze
        for x in S:
            yield f(x, w)
```
</strong>

<br>

### Implementing `DropoutLinear`

Now we will merge `Dropout` and `Linear` layers into one, which

1. (on forward pass) **drops out** the inputs, if necessary, and **applies** the linear (affine) transform
2. (on freeze) **randomly zeros** columns in a copy of the the weight matrix $W$

Preferably, we will try to preserve interface, so that the resulting
object is backwards compatible with `Linear`.

This way we would be able to draw realizations from the induced
distribution over functions defined by the network $
\bigl\{
    f_\omega\colon \mathcal{X}\to\mathcal{Y}
\bigr\}_{\omega \sim q(\omega)}
$
where $q(\omega)$ a distribution over the network parameters.

<br>

#### (task) Fused dropout-linear operation

On the inputs into a linear layer dropout acts like this: for input
$
    x\in \mathbb{R}^{[\mathrm{in}]}
$ and layer weights $
    W\in \mathbb{R}^{[\mathrm{out}] \times [\mathrm{in}]}
$
and bias $
    b\in \mathbb{R}^{[\mathrm{out}]}
$ the resulting effect is

$$
    \tilde{x} = x \odot m
    \,, \\
    y = \tilde{x} W^\top + b
%     = b + \sum_i x_i m_i W_i
    \,, $$

where $\odot$ is the elementwise product and $m\in \mathbb{R}^{[\mathrm{in}]}$
with $m_j \sim \pi_p(m_j) = \mathcal{Ber}\bigl(\bigl\{0, \tfrac1{1-p}\bigr\}, 1-p\bigr)$,
i.e. equals $\tfrac1{1-p}$ with probability $1-p$ and $0$ otherwise.

Let
$
    x\in \mathbb{R}^{[\mathrm{in}]}
$, $
    W\in \mathbb{R}^{[\mathrm{out}] \times [\mathrm{in}]}
$
and $
    b\in \mathbb{R}^{[\mathrm{out}]}
$
* `F.dropout(x, p, on/off)` -- Bernoulli dropout $x\mapsto x\odot m$
  for $m\sim \mathcal{Ber}\bigl(\bigl\{0, \tfrac1{1-p}\bigr\}, 1-p\bigr)$
* `F.linear(x, W, b)` -- affine transformation $x \mapsto x W^\top + b$

**(NOTE)** the weight of a linear layer in `pytorch` is $
{
    [\mathrm{out}]
    \times [\mathrm{in}]
}
$.

In [None]:
def DropoutLinear_forward(self, input):
    ## Exercise: If not frozen, then apply always active dropout,
    #  then linear transformation. If frozen, apply the transform
    #  using the frozen weight

    # linear with frozen weight
    if self.is_frozen():
        return F.linear(input, self.frozen_weight, self.bias)

    # stochastic pass as in `ActiveDropout` + Linear
    input = F.dropout(input, self.p, True)

    return F.linear(input, self.weight, self.bias)
    # return super().forward(F.dropout(input, self.p, True))

    pass

<br>

#### Parameter freezer

For input
$
    x\in \mathbb{R}^{[\mathrm{in}]}
$ and a layer parameters $
    W\in \mathbb{R}^{[\mathrm{out}] \times [\mathrm{in}]}
$
and $
    b\in \mathbb{R}^{[\mathrm{out}]}
$ the effect in `DropoutLinear` is

$$
    y_j
        = \bigl[(x \odot m) W^\top + b\bigr]_j
        = b_j + \sum_i x_i m_i W_{ji}
        = b_j + \sum_i x_i \breve{W}_{ji}
    \,, $$

where the each column of $\breve{W}_i$ is, independently, either
$\mathbf{0} \in \mathbb{R}^{[\mathrm{out}]}$ with probability $p$ or
some (learnable) vector in $\mathbb{R}^{[\mathrm{out}]}$

$$
    \breve{W}_i \sim
\begin{cases}
    \mathbf{0}
        & \text{ w. prob } p \,, \\
    \tfrac1{1-p} M_i
        & \text{ w. prob } 1-p \,.
\end{cases}
$$

Thus the multiplicative effect of the random mask $m$ on $x$ can be
equivalently seen as a random **on/off** switch effect on the
**columns** of the matrix $W$.

In [None]:
def DropoutLinear_freeze(self):
    """Apply dropout with rate `p` to columns of `weight` and freeze it."""
    # we leverage torch's broadcasting semantics and draw a one-row
    #  mask binary mask, that we later multiply the weight by.

    # let's draw the new weight
    prob = torch.full_like(self.weight[:1, :], 1 - self.p)
    feature_mask = torch.bernoulli(prob) / prob

    frozen_weight = self.weight * feature_mask

    # and store it
    self.register_buffer("frozen_weight", frozen_weight)

**(note)**
The parameter distribution of the layer we're building is

$$
    q(W)
        = \prod_i q(W_i)
        = \prod_i \bigl\{
            p \delta_{\mathbf{0}} (W_i)
            + (1 - p) \delta_{\tfrac1{1-p} M_i}(W_i)
        \bigr\}
    \,, $$

where $W_i$ is the $i$-th column of $W$ and $\delta_x$ is a
**point-mass** distribution at $x$.

<br>

Assemble the layer

In [None]:
class DropoutLinear(Linear, FreezableWeight):
    """Linear layer with dropout on inputs."""
    def __init__(self, in_features, out_features, bias=True, p=0.5):
        super().__init__(in_features, out_features, bias=bias)

        self.p = p

    forward = DropoutLinear_forward

    freeze = DropoutLinear_freeze

<br>

### Comparing sample functions to point-estimates 

Let's rewrite the model builder function:

In [None]:
def build_model(p=0.5):
    """Build a model with dropout layers' rate set to `p`."""

    return torch.nn.Sequential(
        ## Exercise: Plug-in `DropoutLinear` layer into our second network.

        Linear(1, 512, bias=True),
        torch.nn.LeakyReLU(),

        DropoutLinear(512, 512, bias=True , p=p),
        torch.nn.LeakyReLU(),

        DropoutLinear(512, 1, bias=True, p=p),

        # pass
    )

Let's create a new instance and retrain the model.

In [None]:
model = build_model(p=0.5)
model.to(device)

fit(model, train, criterion="mse", n_epochs=2000, verbose=True, weight_decay=1e-5)

... and obtain two estimates: pointwise and functional.

In [None]:
samples_pe = point_estimate(model, domain, n_samples=51, verbose=True)
samples_sf = sample_function(model, domain, n_samples=51, verbose=True)

samples_pe.shape, samples_sf.shape

<br>

Let's compare <span style="color:#1f77b4">**point estimates**</span>
with <span style="color:#ff7f0e">**function sampling**</span>.

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 5))

ax.plot(X_domain[:, 0], samples_pe[:15, :, 0].numpy().T,
        c="C1", lw=1, alpha=0.5)

ax.plot(X_domain[:, 0], samples_sf[:15, :, 0].numpy().T,
        c="C0", lw=2, alpha=0.5)

ax.scatter(X_train, y_train, c="black", s=40,
           label="train", zorder=+10);

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 5))

ax.scatter(X_train, y_train, c="black", s=40, label="train")

mean, std = samples_sf.mean(dim=0).numpy(), samples_sf.std(dim=0).numpy()
ax.plot(X_domain, mean + 1.96 * std, c="C0")
ax.plot(X_domain, mean - 1.96 * std, c="C0");

mean, std = samples_pe.mean(dim=0).numpy(), samples_pe.std(dim=0).numpy()
ax.plot(X_domain, mean + 1.96 * std, c="C1")
ax.plot(X_domain, mean - 1.96 * std, c="C1");

Pros of `point-estimate`:
* uses stochastic forward passes -- no need to for extra code and classes

Cons of `point-estimate`:
* predictive distributions at adjacent inputs are independent

#### Question(s) (to ponder in your spare time)

* what happens when you increase the number of samples path-wise and pointwise,
  and inspect their statistics?

<br>

## A brief reminder on Bayesian and Variational Inference

Bayesian Inference is a principled framework of reasoning about uncertainty.

In Bayesian Inference (**BI**) we *assume* that the observation
data $D$ follows a *model* $m$ with data generating distribution
$p(D\mid m, \omega)$ *governed by unknown parameters* $\omega$.
The goal of **BI** is to reason about the model and/or its parameters,
and new data given the observed data $D$ and our assumptions, i.e
to seek the **posterior** parameter and predictive distributions:

$$\begin{align}
    p(d \mid D, m)
        % &= \mathbb{E}_{
        %     \omega \sim p(\omega \mid D, m)
        % } p(d \mid D, \omega, m)
        &= \int p(d \mid D, \omega, m) p(\omega \mid D, m) d\omega
    \,, \\
    p(\omega \mid D, m)
        &= \frac{p(D\mid \omega, m) \, \pi(\omega \mid m)}{p(D\mid m)}
    \,.
\end{align}
$$

* the **prior** distribution $\pi(\omega \mid m)$ reflects our belief
  before having made the observations

* the data distribution $p(D \mid \omega, m)$ reflects our assumptions
  about the data generating process, and determines the parameter
  **likelihood** (Gaussian, Categorical, Poisson)

Unless the distributions and likelihoods are conjugate, posterior in
Bayesian inference is typically intractable and it is common to resort
to **Variational Inference** or **Monte Carlo** approximations.

This key idea of this approach is to seek an approximation $q(\omega)$
to the intractable posterior $p(\omega \mid D, m)$, via a variational
optimization problem over some tractable family of distributions $\mathcal{Q}$:

$$
    q^*(\omega)
        \in \arg \min_{q\in \mathcal{Q}} \mathrm{KL}(q(\omega) \| p(\omega \mid D, m))
    \,, $$

where the Kullback-Leibler divergence between $P$ and $Q$ ($P\ll Q$)
with densities $p$ and $q$, respectively, is given by

$$
    \mathrm{KL}(q(\omega) \| p(\omega))
%         = \mathbb{E}_{\omega \sim Q} \log \tfrac{dQ}{dP}(\omega)
        = \mathbb{E}_{\omega \sim q(\omega)}
            \log \tfrac{q(\omega)}{p(\omega)}
    \,. \tag{kl-div} $$


Note that the family of variational approximations $\mathcal{Q}$ can be
structured **arbitrarily**: point-mass, products, mixture, dependent on
input, having mixed hierarchical structure, -- any valid distribution.

Although computing the divergence w.r.t. the unknown posterior
is still hard and intractable, it is possible to do away with it
through the following identity, which is based on the Bayes rule.

For **any** $q(\omega) \ll p(\omega \mid D; \phi)$ and any model $m$

$$
    \overbrace{
        \log p(D \mid m)
    }^{\text{evidence}}
        = \underbrace{
            \mathbb{E}_{\omega \sim q} \log p(D\mid \omega, m)
        }_{\text{expected conditional likelihood}}
        - \overbrace{
            \mathrm{KL}(q(\omega)\| \pi(\omega \mid m))
        }^{\text{proximity to prior belief}}
        + \underbrace{
            \mathrm{KL}(q(\omega)\| p(\omega \mid D, m))
        }_{\text{posterior approximation}}
    \,. \tag{master-identity} $$

Therefore, we can solve an equivalent maximization problem with respect to $q(\omega)$:

$$
    q^* \in
    \arg\max_{q\in Q}
        \mathrm{ELBO}(q) = 
            \mathbb{E}_{\omega \sim q} \log p(D\mid \omega, m)
            - \mathrm{KL}(q(\omega)\| \pi(\omega \mid m))
    \,. $$

* the expected likelihood -- favours $q$ that place their mass on
parameters $\omega$ that explain the observed data under the specified
model $m$.

* the negative KL-divergence -- encourages variational densities
not to stray away too far from to the prior belief under the model $m$.

If the assumed likelihood $p(D \mid \omega, m)$ and the prior $\pi(\omega\mid m)$
have their own parameters $\phi$, then the lower bound

$$
    \log p_\phi(D \mid m)
        \geq \mathrm{ELBO}(q, \phi)
            = \mathbb{E}_{\omega \sim q(\omega)} \log p_\phi(D \mid \omega, m)
            - \mathbb{E}_{\omega \sim q(\omega)} \log \frac{q(\omega)}{\pi_\phi(\omega)}
    \,, $$

naturaly yields a coordinate-wise ascent algorithm:
* **(E)** step wrt $q$, fixed $\phi$
* **(M)** step wrt $\phi$, fixed $q$

Typically the variational approximation yields high dimensional
integrals, that are computationally heavy. To make the computations
faster without foregoing much of precision, we may use sampling,
or Monte Carlo methods. For the predictive distribution, for example,
we have

$$
\begin{align}
    \mathbb{E}_{y\sim p(y\mid x, D, m)} \, g(y)
        &\overset{\text{BI}}{=}
            \mathbb{E}_{\omega\sim p(\omega \mid D, m)}
                \mathbb{E}_{y\sim p(y\mid x, \omega, D, m)} \, g(y) 
        \\
        &\overset{\text{VI}}{\approx}
            \mathbb{E}_{\omega\sim q(\omega)}
                \mathbb{E}_{y\sim p(y\mid x, \omega, D, m)} \, g(y)
        \\
        &\overset{\text{MC}}{\approx}
%             \hat{\mathbb{E}}_{\omega \sim \mathcal{W}}
%                 \mathbb{E}_{y\sim p(y\mid x, \omega, D, m)} \, g(y)
            \frac1{\lvert \mathcal{W}\rvert} \sum_{\omega \in \mathcal{W}}
                \mathbb{E}_{y\sim p(y\mid x, \omega, D, m)} \, g(y)
    \,,
\end{align}
$$

where $\mathcal{W} = (\omega_b)_{b=1}^B \sim q(\omega)$
-- iid samples from the variational approximation.

**(note)** If $p(y \mid x, \omega, D, m)$ yield "heavy" integrals then
we apply Monte Carlo to it too.

A good summary of Bayesian Inference can be found in [this lecture](http://mlg.eng.cam.ac.uk/zoubin/talks/lect1bayes.pdf), [this paper](https://arxiv.org/abs/1206.7051.pdf), [this review](https://arxiv.org/abs/1601.00670.pdf).
It is also possible to consult [wiki](https://en.wikipedia.org/wiki/Bayesian_inference) and references therein.

<br>

In [None]:
pass

<br>

## Bayesian Active Learning with images

* Data labelling is costly and time consuming
* unlabeled instances are essentially free

**Goal** Achieve high performance with fewer labels by
identifying the best instances to learn from

Essential blocks of active learning:

* a **model** $m$ capable of quantifying uncertainty (preferably a Bayesian model)
* an **acquisition function** $a\colon \mathcal{M} \times \mathcal{X}^* \to \mathbb{R}$
  that for any finite set of inputs $S\subset \mathcal{X}$ quantifies their usefulness
  to the model $m\in \mathcal{M}$
* a labelling **oracle**, e.g. a human expert

The main loop of active learning:

1. fit $m$ on $\mathcal{S}_{\mathrm{labelled}}$

2. get exact (or approximate) $$
    \mathcal{S}^* \in \arg \max\limits_{S \subseteq \mathcal{S}_\mathrm{unlabelled}}
        a(m, S)
$$ satisfying **budget constraints** and **without** access to targets
(constraints, like $\lvert S \rvert \leq \ell$ or other economically motivated ones).

3. request the **oracle** to provide labels for each $x\in \mathcal{S}^*$

4. update $
\mathcal{S}_{\mathrm{labelled}}
    \leftarrow \mathcal{S}^*
        \cup \mathcal{S}_{\mathrm{labelled}}
$ and goto 1.

We already have a Bayesian model that can be used to reason
about uncertainty, so let's focus on the acquisition function.

<br>

### the Acquisition Function

There are many acquisition criteria (borrowed from [Gal17a](http://proceedings.mlr.press/v70/gal17a.html)):
* Classification
  * Max entropy (plain uncertainty)
  * Maximal information about parameters and predictions (mutual information)
  * Variance ratios
  * Mean standard deviation
  * **BALD**
* Regression
  * predictive variance

**BALD** (Bayesian Active Learning by Disagreement) acquisition
criterion is based on the posterior mutual information between model's
predictions $y_x$ at some point $x$ and model's parameters $\omega$:

$$\begin{align}
    a(m, S)
        &= \sum_{x\in S} a(m, \{x\})
        \\
    a(m, \{x\})
        &= \mathbb{I}(y_x; \omega \mid x, m, D)
\end{align}
    \,, \tag{bald} $$

with the [**Mutual Information**](https://en.wikipedia.org/wiki/Mutual_information#Relation_to_Kullback%E2%80%93Leibler_divergence)
(**MI**)
$$
    \mathbb{I}(y_x; \omega \mid x, m, D)
        = \mathbb{H}\bigl(
            \mathbb{E}_{\omega \sim q(\omega\mid m, D)}
                p(y_x \,\mid\, x, \omega, m, D)
        \bigr)
        - \mathbb{E}_{\omega \sim q(\omega\mid m, D)}
            \mathbb{H}\bigl(
                p(y_x \,\mid\, x, \omega, m, D)
            \bigr)
    \,, \tag{mi} $$

and the [(differential) **entropy**](https://en.wikipedia.org/wiki/Differential_entropy#Differential_entropies_for_various_distributions)
(all densities and/or probability mass functions can be conditional):

$$
    \mathbb{H}(p(y))
        = - \mathbb{E}_{y\sim p} \log p(y)
    \,. $$

Instead of the exact formula for **MI** we shall use its **Monte Carlo** (**MC**)
approximation, since the expectations are analytically or numerically
tractable only in simple low dimensional cases.
<!-- probability integrals are still integrals -->

Consider an iid sample $\mathcal{W} = (\omega_b)_{b=1}^B \sim q(\omega \mid m, D)$
of size $B$. The **MC** approximation of the mutual information is

$$
    \mathbb{I}_\mathrm{MC}(y_x; \omega \mid x, m, D)
        = \mathbb{H}\bigl(
            \hat{\mathbb{E}}_{\omega \sim \mathcal{W}}
                p(y_x \,\mid\, x, \omega, m, D)
        \bigr)
        - \hat{\mathbb{E}}_{\omega \sim \mathcal{W}}
            \mathbb{H}\bigl(
                p(y_x \,\mid\, x, \omega, m, D)
            \bigr)
    \,, \tag{mi-mc} $$

where $\hat{\mathbb{E}}_{\omega \sim \mathcal{W}} h(\omega) = \tfrac1B \sum_j h(\omega_j)$
denotes the expectation with respect to the empirical probability measure induced
by the sample $\mathcal{W}$.

<br>

#### (task) implementing entropy

For categorical (discrete) random variables $y \sim \mathcal{Cat}(\mathbf{p})$,
$\mathbf{p} \in \{ \mu \in [0, 1]^d \colon \sum_k \mu_k = 1\}$, the entropy is

$$
    \mathbb{H}(p(y))
        = - \mathbb{E}_{y\sim p(y)} \log p(y)
        = - \sum_k p_k \log p_k
    \,. $$

**(note)** although in calculus $0 \cdot \log 0 = 0$ (because
$\lim_{p\downarrow 0} p \cdot \log p = 0$), in floating point
arithmetic $0 \cdot \log 0 = \mathrm{NaN}$. So you need to add
some **really tiny float number** to the argument of $\log$.

In [None]:
def entropy(proba):
    """Compute the entropy along the last dimension."""

    ## Exercise: get the entropy of a tensor with distributions
    #  along the last axis.

    return - torch.kl_div(torch.tensor(0.).to(proba), proba).sum(dim=-1)
    return - torch.sum(proba * torch.log(proba + 1e-20), dim=-1)

    pass

<br>

#### (task) implementing mutual information

Consider a tensor $p_{bik}$ of probabilities $p(y_{x_i}=k \mid x_i, \omega_b, m, D)$
with $\omega_b \sim q(\omega \mid m, D)$.

Let's implement a procedure that computes the **MC** estimate 
of the posterior predictive distribution

$$
\hat{p}(y_x\mid x, m, D)
    = \hat{\mathbb{E}}_{\omega \sim \mathcal{W}}
        \,p(y_x \mid x, \omega, m, D)
    \,, $$

its **entropy** $
    \mathbb{H}\bigl(\hat{p}(y\mid x, m, D)\bigr)
$ and **mutual information** $
    \mathbb{I}_\mathrm{MC}(y_x ; \omega\mid x, m, D)
$

In [None]:
def mutual_information(proba):
    ## Exercise: compute a Monte Carlo estimator of the predictive
    ##   distribution, its entropy and MI `H E_w p(., w) - E_w H p(., w)`

    proba_avg = proba.mean(dim=0)

    entropy_expected = entropy(proba_avg)
    expected_entropy = entropy(proba).mean(dim=0)

    mut_info = entropy_expected - expected_entropy

    pass

    return proba_avg, entropy_expected, mut_info

<br>

#### (task) implementing BALD acqustion

The acquisition function that we will implement takes in the
sample mutual information and returns the indices of selected
points.



Note that $a(m, S)$ is additively separable, i.e. equals $\sum_{x\in S} a(m, \{x\})$.
This implies that

$$
\begin{align}
    \max_{S \subseteq \mathcal{S}_\mathrm{unlabelled}} a(m, S)
        &= \max_{z \in \mathcal{S}_\mathrm{unlabelled}}
            \max_{F \in \mathcal{S}_\mathrm{unlabelled} \setminus \{z\}}
            \sum_{x\in F \cup \{x\}} a(m, \{x\})
        \\
        &= \max_{z \in \mathcal{S}_\mathrm{unlabelled}}
            a(m, \{z\})
            + \max_{F \in \mathcal{S}_\mathrm{unlabelled} \setminus \{z\}}
                \sum_{x\in F} a(m, \{x\})
\end{align}
    \,. $$

Therefore selecting the $\ell$ `most interesting` points from $\mathcal{S}_\mathrm{unlabelled}$
is trivial.


In [None]:
def acq_bald(mutual_info, n_points=10):
    ## Exercise: implement the acquisition

    indices = mutual_info.argsort()

    return indices[-n_points:]

    pass

<br>

**(note)** A drawback of the `pointwise` top-$\ell$ procedure above is
that, although it acquires individually informative instances, altogether
they might end up **being** `jointly poorly informative`. This can be
corrected if we would seek the highest mutual information among finite
sets $S \subseteq \mathcal{S}_\mathrm{unlabelled}$ of size $\ell$.
Such acquisition function is called **batch-BALD**
([Kirsch et al.; 2019](https://arxiv.org/abs/1906.08158.pdf)):

$$\begin{align}
    a(m, S)
        &= \mathbb{I}\bigl((y_x)_{x\in S}; \omega \mid S, m \bigr)
        = \mathbb{H} \bigl(
            \mathbb{E}_{\omega \sim q(\omega\mid m)} p\bigl((y_x)_{x\in S}\mid S, \omega, m \bigr)
        \bigr)
        - \mathbb{E}_{\omega \sim q(\omega\mid m)} H\bigl(
            p\bigl((y_x)_{x\in S}\mid S, \omega, m \bigr)
        \bigr)
\end{align}
    \,. \tag{batch-bald} $$

This criterion requires exponentially large number of computations and
memory, however there are working solutions like random sampling of subsets
$\mathcal{S}$ of size $\ell$ from $\mathcal{S}_\mathrm{unlabelled}$ or
greedy maximization of this *submodular* criterion.

> Kirsch, A., van Amersfoort, J., & Gal, Y. (2019). BatchBALD: Efficient and Diverse Batch Acquisition for Deep Bayesian Active Learning. arXiv preprint [arXiv:1906.08158](https://arxiv.org/abs/1906.08158.pdf)

<br>

#### (task*) Unbiased estimator of entropy and mutual information

The first term in the **MC** estimate of the mutual information is the
so-called **plug-in** estimator of the entropy:

$$
    \hat{H}
        = \mathbb{H}(\hat{p}) = - \sum_k \hat{p}_k \log \hat{p}_k
    \,, $$

where $\hat{p}_k = \tfrac1B \sum_b p_{bk}$ is the full sample estimator
of the probabilities.

It is known that this plug-in estimate is biased
(see [blog: Nowozin, 2015](http://www.nowozin.net/sebastian/blog/estimating-discrete-entropy-part-1.html)
and references therein, also this [notebook](https://colab.research.google.com/drive/1z9ZDNM6NFmuFnU28d8UO0Qymbd2LiNJW)). <!--($\log$ + Jensen)-->
In order to correct for small-sample bias we can use
[jackknife resampling](https://en.wikipedia.org/wiki/Jackknife_resampling).
It derives an estimate of the finite sample bias from the leave-one-out
estimators of the entropy and is relatively computationally cheap
(see [blog: Nowozin, 2015](http://www.nowozin.net/sebastian/blog/estimating-discrete-entropy-part-2.html),
[Miller, R. G. (1974)](http://www.math.ntu.edu.tw/~hchen/teaching/LargeSample/references/Miller74jackknife.pdf) and these [notes](http://people.bu.edu/aimcinto/jackknife.pdf)).

The jackknife correction of a plug-in estimator $\mathbb{H}(\cdot)$
is computed thus: given a sample $(p_b)_{b=1}^B$ with $p_b$ -- discrete distribution on $1..K$
* for each $b=1.. B$
  * get the leave-one-out estimator: $\hat{p}_k^{-b} = \tfrac1{B-1} \sum_{j\neq b} p_{jk}$
  * compute the plug-in entropy estimator: $\hat{H}_{-b} = \mathbb{H}(\hat{p}^{-b})$
* then compute the bias-corrected entropy estimator $
\hat{H}_J
    = \hat{H} + (B - 1) \bigl\{
        \hat{H} - \tfrac1B \sum_b \hat{H}^{-b}
    \bigr\}
$

**(note)** when we knock the $i$-th data point out of the sample mean
$\mu = \tfrac1n \sum_i x_i$ and recompute the mean $\mu_{-i}$ we get
the following relation
$$ \mu_{-i}
    = \frac1{n-1} \sum_{j\neq i} x_j
    = \frac{n}{n-1} \mu - \tfrac1{n-1} x_i
    = \mu + \frac{\mu - x_i}{n-1}
    \,. $$
This makes it possible to quickly compute leave-one-out estimators of
discrete probability distribution.

In [None]:
if True:
    def mutual_information(proba):
        ## Exercise: MC estimate of the predictive distribution, entropy and MI
        ##  mutual information `H E_w p(., w) - E_w H p(., w)` with jackknife
        ##  correction.

        proba_avg = proba.mean(dim=0)

        # plug-in estimate of entropy
        entropy_expected = entropy(proba_avg)

        # jackknife correction
        proba_loo = proba_avg + (proba_avg - proba) / (len(proba) - 1)

        expected_entropy_loo = entropy(proba_loo).mean(dim=0)
        entropy_expected += (len(proba) - 1) * (entropy_expected - expected_entropy_loo)

        # expected entropy is unbiased
        expected_entropy = entropy(proba).mean(dim=0)

        mut_info = entropy_expected - expected_entropy

        pass

        return proba_avg, entropy_expected, mut_info

<br>

### Dropout $2$-d Convolutional layer and the model

Typically, in convolutional neural networks the dropout acts upon the feature
(channel) information and not on the spatial dimensions. Thus entire channels
are dropped out and for $
    x \in \mathbb{R}^{
        [\mathrm{in}]
        \times h
        \times w}
$ and $
    y \in \mathbb{R}^{
        [\mathrm{out}]
        \times h'
        \times w'}
$ the full effect of the `Dropout+Conv2d` layer is

$$
    y_{lij} = ((x \odot m) \ast W_l)_{ij} + b_l
        = b_l + \sum_k \sum_{pq} x_{k i_p j_q} m_k W_{lkpq}
    \,, \tag{conv-2d} $$
    
where i.i.d $m_k \sim \mathcal{Ber}\bigl(\bigl\{0, \tfrac1{1-p}\bigr\}, 1-p\bigr)$,
and indices $i_p$ and $j_q$ represent the spatial location in $x$ that correspond
to the $p$ and $q$ elements in the kernel $
    W\in \mathbb{R}^{
        [\mathrm{out}]
        \times [\mathrm{in}]
        \times h
        \times w}
$ relative to $(i, j)$ coordinates in $y$.
The exact values of $i_p$ and $j_q$ depend on the configuration of the
convolutional layer, e.g. stride, kernel size and dilation.

<br>

#### (task) Implementing `DropoutConv2d`

For images we don't usually use the `F.dropout` form our previous
experiments, because when applied to the data $
x \in \mathbb{R}^{
    [\mathrm{in}]
    \times h
    \times w
}
$ it would affect random pixels within each input feature: 

$$
\mathrm{F.dropout}(x)
    \colon x \mapsto x \odot m
    = \bigl( x_{kij} \, m_{kij} \bigr)_{kij}
    \,. $$

Please use `F.dropout2d` instead. You would also need to invoke `F.conv2d`.

**(note)**
* to view documentation on something  type in `something?` (with one question mark)
* to view code of something type in `something??` (with two question marks).

In [None]:
class DropoutConv2d(Conv2d, FreezableWeight):
    """2d Convolutional layer with dropout on input features."""
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros',
                 p=0.5):

        super().__init__(in_channels, out_channels, kernel_size, stride=stride,
                         padding=padding, dilation=dilation, groups=groups,
                         bias=bias, padding_mode=padding_mode)

        self.p = p

    def forward(self, input):        
        """Apply feature dropout and then forward pass through the convolution."""
        # Exercise: write a forward pass similar to `DropoutLinear.forward`,
        #  but first, take a look at the code `Conv2d.forward` in a new cell.
        #  It will help you understand what to pass to `F.conv2d`. For cleaner
        #  code you may use a super-method.

        # linear with frozen weight
        if self.is_frozen():
            return F.conv2d(input, self.frozen_weight, self.bias, self.stride,
                            self.padding, self.dilation, self.groups)

        return super().forward(F.dropout2d(input, self.p, True))

        pass

    def freeze(self):
        """Sample the weight from the parameter distribution and freeze it."""
        ## Exercise: much like in `DropoutLinear.freeze` dropout input
        #  filters in the convolutional kernel.

        prob = torch.full_like(self.weight[:1, :, :1, :1], 1 - self.p)
        feature_mask = torch.bernoulli(prob) / prob

        frozen_weight = self.weight * feature_mask

        self.register_buffer("frozen_weight", frozen_weight)
        
        pass

**(note)** For more on convolutions see
[Convolution arithmetic](https://github.com/vdumoulin/conv_arithmetic) 
repo.

<br>

Much like the `SimpleModel` class above in $1$d section,
let's implement a simple deep convolutional network.

In [None]:
class CNNModel(torch.nn.Module):
    """A simple convolutional net."""
    def __init__(self, p=0.5):
        super().__init__()

        self.conv_block = torch.nn.Sequential(
            Conv2d(1, 20, 5, 1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2),

            DropoutConv2d(20, 50, 5, 1, p=p),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2)
        )

        self.fc1 = DropoutLinear(4 * 4 * 50, 400, p=p)
        self.out = DropoutLinear(400, 10, p=p)

    def forward(self, input):
        """Take images and compute their class logits."""
        x = self.conv_block(input).flatten(1)
        return self.out(F.relu(self.fc1(x)))

To fit the classifier that outputs raw logit scores, we typically
normalize outputs via `F.log_softmax` and then feed into them into
`F.nll_loss`, which computes the negative $\log$-likelihood of a
categorical distribution.

However, this is less numerically stable then using `F.cross_entropy`,
which is essentially` log_softmax + nll` fused into one stable operation.
It is good practice to pay attention to numerical stability, especially
when working with `float32`.

In [None]:
def cross_entropy(model, X, y):
    return F.cross_entropy(model(X), y)  # + coef * sum(penalties(model))

<br>

### Actively Learning MNIST

We will partially replicate figure 1. in [Gat et al. (2017): p. 4](http://proceedings.mlr.press/v70/gal17a.html),

> Gal, Y., Islam, R. & Ghahramani, Z.. (2017). Deep Bayesian Active Learning with Image Data. Proceedings of the 34th International Conference on Machine Learning, in [PMLR 70:1183-1192](http://proceedings.mlr.press/v70/gal17a.html)


Prepare the datasets from the `train` part of [MNIST](http://yann.lecun.com/exdb/mnist/):
* ($\mathcal{S}_\mathrm{train}$) initial **training**:
  **empty** -- learn from scratch
  <strike>21 images, purposefully highly imbalanced classes (even absent ones)</strike>
* ($\mathcal{S}_\mathrm{valid}$) our **validation**:
  $5000$ images, stratified
* ($\mathcal{S}_\mathrm{pool}$) acquisition **pool**:
  all remaining images

The true test sample of MNIST is in $\mathcal{S}_\mathrm{test}$ -- we
will use it to evaluate the final performance.

In [None]:
from mlss2019bdl.dataset import get_dataset

S_train, S_pool, S_valid, S_test = get_dataset(
    n_train=0, n_valid=5000, name="MNIST",
    random_state=722_257_201, path="./data")

**(note)** We may just as well use [Kuzushiji-MNIST](https://github.com/rois-codh/kmnist)

Now, a function to plot images in a small dataset. 

In [None]:
from mlss2019bdl.flex import plot

def display(dataset, title=None, show_balance=True, figsize=None):
    images, targets = dataset.tensors
    if not show_balance:
        balance = ""

    else:
        body = [f"{n:2d}" if n > 0 else " *"
                for n in label_counts(targets)]
        balance = "(freq) [ " + ' '.join(body) + " ]"

    # a canvas
    fig, ax = plt.subplots(1, 1, figsize=figsize)

    # show the images
    plot(ax, images, cmap=plt.cm.bone)

    # produce a title
    title = "" if title is None else title
    title = title + (" " if title else "")
    ax.set_title(f"{title}{balance}")

    plt.show()
    plt.close()


def label_counts(labels, n_labels=10):
    return np.bincount(labels.numpy(), minlength=10)

In [None]:
display(S_train, title="Train")

<br>

#### Necessary components for the loop

We need to be able to manipulate the datasets for the **main active learning**
loop. We begin by implementing the following primitives:
* `take` collect the instances at the specified indices into a **new dataset** (object)
* `append` add one dataset to another
* `delete` drops the instances at the specified locations form the copy of the **dataset**

In [None]:
from mlss2019bdl.dataset import take, delete, append

For the **main active learning** loop, besides manipulating the datasets,
we shall also need a function to **predict and acquire** and evaluate
holdout **performance**.


In [None]:
def predict_proba(model, dataset, n_samples=1):
    logits = sample_function(model, dataset, n_samples=n_samples)

    # logit-scores should be transformed into a proper distribution
    return F.softmax(logits, dim=-1)

In [None]:
def acquire(model, dataset, n_points=10, n_samples=1):
    proba = predict_proba(model, dataset, n_samples=n_samples)

    _, _, mutual_info = mutual_information(proba)

    return acq_bald(mutual_info, n_points)

In [None]:
def evaluate(model, dataset, n_samples=1):
    proba = predict_proba(model, dataset, n_samples=n_samples)
    
    proba_avg = proba.mean(dim=0)

    predicted = proba_avg.argmax(dim=-1).numpy()
    target = dataset.tensors[1].cpu().numpy()

    return confusion_matrix(target, predicted)

<br>

#### (task) Implementing the active learning step

Let's code the core of the active learning loop:

1. fit on **train**, then (optional) evaluate on **holdout**
2. acquire from **pool**
3. add to **train** (removing from **pool**)


In [None]:
def active_learning_step(model, S_train, S_pool,
                         n_epochs=5, n_points=10, n_samples=11):
    ## Exercise: implement the fit-acquire loop

    # 1. fit on S_train using `cross_entropy`, set `weight_decay` to 1e-4
    fit(model, S_train, criterion=cross_entropy, n_epochs=n_epochs, weight_decay=1e-4)

    # 2. acquire new instances from S_pool
    indices = acquire(model, S_pool, n_points=n_points, n_samples=n_samples)

    # 3. query the pool for the chosen instances, then take-append-delete
    S_requested = take(S_pool, indices)
    S_train = append(S_train, S_requested)
    S_pool = delete(S_pool, indices)

    pass

    return model, S_train, S_pool, S_requested

<br>

In [None]:
model = CNNModel(p=0.5)

model.to(device)

<br>

Now we Recall that it consists of the following steps

In [None]:
n_epochs, n_samples = 5, 11
n_active, n_points = 75, 10

display(S_train, title="initial train")

scores = []
balances = [label_counts(S_train.tensors[1])]
for step in range(n_active):
    model, S_train, S_pool, S_requested = active_learning_step(
        model, S_train, S_pool, n_epochs=n_epochs,
        n_points=n_points, n_samples=n_samples)

    # (optional) track validation score
    score_matrix = evaluate(model, S_valid, n_samples=n_samples)

    # (optional) report accuracy and the statistics on the acquired batch
    balances.append(label_counts(S_train.tensors[1]))
    scores.append(score_matrix)

    accuracy = score_matrix.diagonal().sum() / score_matrix.sum()
    display(S_requested, title=f"# {len(S_train)} (Acc.) {accuracy:.1%}")

Train of the final $\mathcal{S}_\mathrm{train}$ and evaluate the result

In [None]:
balances = np.stack(balances, axis=0)

fit(model, S_train, criterion=cross_entropy, n_epochs=n_epochs, weight_decay=1e-4)
scores.append(evaluate(model, S_valid, n_samples=n_samples))

scores = np.stack(scores, axis=0)

display(S_train, title="final train", figsize=(16, 9))

<br>

### Results

Let's see the dynamics of the frequency of each class in $\mathcal{S}_\mathrm{train}$

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 7))

lines = ax.plot(balances, lw=2)
plt.legend(lines, list(range(10)), ncol=2);

The dynamics of *one-versus-rest* precision / recall scores on
$\mathcal{S}_\mathrm{valid}$. For binary classification:

$$ \begin{align}
\mathrm{Precision}
    &= \frac{\mathrm{TP}}{\mathrm{TP} + \mathrm{FP}}
        \approx \mathbb{P}(y = 1 \mid \hat{y} = 1)
    \,, \\
\mathrm{Recall}
    &= \frac{\mathrm{TP}}{\mathrm{TP} + \mathrm{FN}}
        \approx \mathbb{P}(\hat{y} = 1 \mid y = 1)
    \,.
\end{align}$$

In [None]:
tp = scores.diagonal(axis1=-2, axis2=-1)
fp, fn = scores.sum(axis=-2) - tp, scores.sum(axis=-1) - tp

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 7))

lines = ax.plot(tp / (tp + fp), lw=2)
ax.set_title("Precision (ovr)")
ax.legend(lines, list(range(10)), ncol=2);

plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 7))

lines = ax.plot(tp / (tp + fn), lw=2)
ax.set_title("Recall (ovr)")
ax.legend(lines, list(range(10)), ncol=2)

plt.show()

The accuracy as a function of active learning iteration.

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 7))

ax.plot(tp.sum(-1) / scores.sum((-2, -1)),
        label='Accuracy', lw=2)
ax.legend()

plt.show()

Volume of data used

In [None]:
f"train : pool = {len(S_train)} : {len(S_pool)}"

Let the test set confusion matrix be the ultimate judge:

In [None]:
score_matrix = evaluate(model, S_test, n_samples=51)

True positives, and false positives / negatives

In [None]:
tp = score_matrix.diagonal(axis1=-2, axis2=-1)
fp, fn = score_matrix.sum(axis=-2) - tp, score_matrix.sum(axis=-1) - tp

accuracy

In [None]:
f"(accuracy) {tp.sum() / score_matrix.sum():.2%}"

one-v-rest precision

In [None]:
{l: f"{p:.2%}" for l, p in enumerate(tp / (tp + fp))}

ovr recall

In [None]:
{l: f"{p:.2%}" for l, p in enumerate(tp / (tp + fn))}

In [None]:
assert False

<br>

In [None]:
proba = predict_proba(model, S_test, n_samples=21)
proba_avg, ent, mi = mutual_information(proba)

In [None]:
from torch.utils.data import TensorDataset

indices = entj.argsort().numpy()[:256]

fig, ax = plt.subplots(1, 1, figsize=(16, 9))
plot_dataset(TensorDataset(*S_test[indices]), n_images=1000)

In [None]:
_, ent_n, mi_n = mutual_information(proba, jackknife=False)
_, ent_j, mi_j = mutual_information(proba, jackknife=True)

In [None]:
plt.scatter(mi_n.numpy(), mi_n.numpy()-mi_j.numpy(), s=5)

<br>

# Intersting stuff, that didn't make the cut

Dense network for MNIST, in case convolutions are too slow on CPU.

In [None]:
class DNNModel(torch.nn.Module):
    """A fully connected net."""
    def __init__(self, p=0.5):
        self.body = torch.nn.Sequential(
            Linear(784, 256),
            torch.nn.ReLU(),

            DropoutLinear(256, 256, p=p),
            torch.nn.ReLU(),

            DropoutLinear(256, 256, p=p),
            torch.nn.ReLU(),

            DropoutLinear(256, 10, p=p),
        )

    def forward(self, input):
        return self.body(input.flatten(1))

The key issue with point-estimates is that values at each $x$ jointly do not
correspond to a function in the parametric family modelled by the network:
there may be no $\omega \in \mathop{supp}q_\theta(\omega)$ such that
$f_\omega(x_i) = y_i$.

<br>

### Variational approximation via Gaussian mean field family

* sampling from $q_\theta(\omega) = q_\theta(b) \otimes q_\theta(W)$ where
$q_\theta(b) = \delta_{b - \beta}$, $q_\theta(W) = \otimes_{ij} q_\theta(W_{ij})$
and $
q_\theta(W_{ij})
    = \mathcal{N}\bigl(
        W_{ij} \big\vert
            \mu_{ij}, \sigma^2_{ij}
        \bigr)
$;
* **local** reparameterization trick $y = x W + b$ implies that $
    y_j \sim \mathcal{N}\bigl(
            \beta_j + \sum_i x_i \mu_{ij},
            \sum_i \sigma^2_{ij} \lvert x_i \rvert^2
        \bigr)
$ and $y_j \bot y_k$, $j\neq k$;
* computing the KL-divergence of $q_\theta(W)$ from $p(W)$ ignoring $b$: $
    \mathop{KL}\bigl(q_\theta(W) \| p(W)\bigr)
        = \mathbb{E}_{W \sim q_\theta}
            \log \tfrac{q_\theta(W)}{p(W)}
$
  * objective diffuse prior $p(W_{ij}) \propto \mathop{const}$
  * objective scale-free prior $p(W_{ij}) \propto \tfrac1{\lvert W_{ij} \rvert}$
  * subjective prior $p(W_{ij}) = \mathcal{N}(W_{ij} \mid 0, \nu^{-1})$

For the this to work, it is necessary to implement
another **trait-class** and a penalty "collector".

In [None]:
class VariatonalApproximation(FreezableWeight):
    def penalty(self):
        raise NotImplementedError()

In [None]:
def penalties(module):
    for mod in module.modules():
        if isinstance(mod, VariatonalApproximation):
            yield mod.penalty()

<br>

### Diffuse prior $p(W_{ij}) \propto \mathop{const}$

Against a **diffuse prior** the KL-divergence is just
the **negative entropy** (up to a constant).

Since $q_\theta(W)$ is a diagonal multivariate normal
$$
    KL\bigl( q_\theta(W) \big\| p(W) \bigr)
        = \sum_{ij} KL\bigl( q_\theta(W_{ij}) \big\| p(W_{ij}) \bigr)
        = \mathop{const} - \sum_{ij} \mathbb{H}(q_\theta(W_{ij}))
    \,. $$

For a multivariate Gaussain $
    p(z) = \mathcal{N}_n\bigl(
        z\,\big\vert\, \mu, \Sigma
    \bigr)
$ we have
$$
    \mathbb{H}(p)
        = - \mathbb{E}_{z\sim p} \log p(z)
        = \tfrac12 \log \det \bigl(2 \pi e \Sigma \bigr)
%         = \tfrac12 \log \det \Sigma + \tfrac{n}2 \log 2 \pi e
    \,. $$

Hence the entropy for a univariate Gaussian is
$$
    \mathbb{H}(q_\theta(W_{ij}))
        = - \mathbb{E}_{W_{ij} \sim q_\theta(W_{ij})} \log q_\theta(W_{ij})
        = \tfrac12 \log \{2 \pi e \, \sigma^2\}
    \,. $$

In [None]:
def kl_div_diffuse(log_sigma2, weight=None):

    const = 0.5 * math.log(2 * math.pi * math.e) * log_sigma2.numel()
    entropy = const + 0.5 * torch.sum(log_sigma2)

    return - entropy

<br>

### Proper prior $p(W_{ij}) = \mathcal{N}(W_{ij} \mid 0, \nu^{-1})$

The KL-divergence between two multivariate Gaussians is given by

$$
    KL\bigl(
        \mathcal{N}_m(\mu_0, \Sigma_0)
        \big\| \mathcal{N}_m(\mu_1, \Sigma_1)
    \bigr)
        = \frac12 \Bigl\{
            \log \frac{\det \Sigma_1}{\det \Sigma_0}
            + \mathop{tr} \bigl( \Sigma_1^{-1} \Sigma_0 \bigr)
            + \bigl(\mu_0 - \mu_1 \bigr)^\top \Sigma_1^{-1} \bigl(\mu_0 - \mu_1 \bigr)
            - m
        \Bigr\}
    \,. $$

KL divergence from a standard normal distribution is
$$
    KL\bigl(
        \mathcal{N}_m(\mu_0, \Sigma_0)
        \big\| \mathcal{N}_m(0, \nu^{-1} I_m)
    \bigr)
        = \frac12 \Bigl\{
            \nu \, (\mathop{tr} \Sigma_0 + \mu_0^\top \mu_0)
            - m \log \nu
            - \log \det \Sigma_0
            - m
        \Bigr\}
    \,. $$
<!--
$$
    \frac12 \Bigl\{
        - \log \det \nu \Sigma_0
        + \nu \mathop{tr} \bigl( \Sigma_0 \bigr)
        + \nu \bigl(\mu_0 - \mu_1 \bigr)^\top \bigl(\mu_0 - \mu_1 \bigr)
        - m
    \Bigr\}
    \,. $$
-->

Therefore, if $q_\theta(W)$ is a diagonal multivariate normal then we get
(put $m=1$, $\Sigma_0 = \sigma^2_{ij}$, $\mu_1 = 0$):

$$
    KL\bigl( q_\theta(W) \big\| p(W) \bigr)
        = \sum_{ij} KL\bigl( q_\theta(W_{ij}) \big\| p(W_{ij}) \bigr)
        = \frac12 \sum_{ij} \bigl(
            \nu \sigma^2_{ij} + \nu \mu_{ij}^2
            - \log \sigma^2_{ij} - \log \nu - 1
        \bigr)
    \,. $$

In [None]:
def kl_div_proper(log_sigma2, weight, nu=1.0):

    const = 0.5 * (math.log(nu) + 1) * log_sigma2.numel()
    nu_term = 0.5 * nu * (torch.exp(log_sigma2) + weight * weight)
    kl_div = torch.sum(nu_term - log_sigma2) - const

    return kl_div

In [None]:
def kl_div_1811_00596(log_sigma2, weight):
    r"""Penalty from arxiv:1811.00596.

    The precision parameter `\nu` in the prior is optimized away.
    """
    # get $- \log \alpha_{ij}$
    neg_log_alpha = 2 * torch.log(abs(weight) + 1e-12) - log_sigma2

    # `softplus` is $x \mapsto \log(1 + e^x)$
    kl_div_approx = torch.sum(0.5 * F.softplus(neg_log_alpha))

    return kl_div_approx

<br>

### Scale-free prior $p(W_{ij}) \propto \tfrac1{\lvert W_{ij} \rvert}$

This prior gives us the so called Variational Dropout
([Molchanov et al. 2017](https://arxiv.org/abs/1701.05369),
[Kingma et al. 2015](https://papers.nips.cc/paper/5666-variational-dropout-and-the-local-reparameterization-trick)).

We may observe the following, unless $\mu_{ij} = 0$:
for $\alpha_{ij} = \tfrac{\sigma^2_{ij}}{\mu_{ij}^2}$

$$
    \mathcal{N}(\mu_{ij}, \sigma^2_{ij})
    \overset{D}{\sim} \mu_{ij} \cdot \mathcal{N}(1, \alpha_{ij})
    \,. $$

Therefore our variational approximation $q_\theta(W)$ can be regarded as
the so called Gaussian Dropout (): $W_{ij}$ are subject to multiplicative
noise $W_{ij} = \mu_{ij} \cdot \varepsilon_{ij}$ for $\varepsilon_{ij}
\sim \mathcal{N}(1, \alpha_{ij})$.

Under this variational family:
$$
    KL\bigl( q_\theta(W_{ij}) \big\| p(W_{ij}) \bigr)
%         = \mathop{const} - \tfrac12 \log \sigma^2_{ij}
%         + \mathbb{E}_{W_{ij} \sim q_\theta(W_{ij})} \log \lvert W_{ij} \rvert
        = \mathop{const} - \tfrac12 \log \sigma^2_{ij}
        + \tfrac12 \log \mu_{ij}^2
        + \mathbb{E}_{\xi \sim \mathcal{N}(0, 1)}
            \log \lvert 1 + \tfrac{\sigma_{ij}}{\lvert \mu_{ij} \rvert} \xi \rvert
    \,. $$

Unfortunately there is no closed-from expression for this divergence, and thus
we have to resort to its approximation in ICML'17 paper [Molchanov et al. 2017](https://arxiv.org/abs/1701.05369):

$$
    KL\bigl( q_\theta(W_{ij}) \big\| p(W_{ij}) \bigr)
%         = \mathop{const}
%         + \mathbb{E}_{\xi \sim \mathcal{N}(1, \alpha_{ij})}
%                         \log{\lvert \xi \rvert}
%                     - \tfrac12 \log \alpha_{ij}
        \approx
            \tfrac12 \log (1 + e^{-\log \alpha}) 
            + k_1(1 - \sigma(k_2 + k_3 \log \alpha))
    \,, $$

where $k_1, k_2, k_3$ and $C$ are approximated by $0.63576$, $1.87320$ and $1.48695$, respectively.

In [None]:
def kl_div_scale_free(log_sigma2, weight):

    # get $- \log \alpha_{ij}$
    neg_log_alpha = 2 * torch.log(abs(weight) + 1e-12) - log_sigma2
    
    # Use the identity 1 - \sigma(z) = \sigma(- z)
    sigmoid = torch.sigmoid(1.48695 * neg_log_alpha - 1.87320)

    # `softplus` is $x \mapsto \log(1 + e^x)$
    kl_div_approx_ij = 0.5 * F.softplus(neg_log_alpha) + 0.63576 * sigmoid
    kl_div_approx = torch.sum(kl_div_approx_ij)

    return kl_div_approx

<br>

### The local reparameterization tirck

Proposed in
[Kingma et al. 2015](https://papers.nips.cc/paper/5666-variational-dropout-and-the-local-reparameterization-trick)
this tirck
* allows to sample parameters implicitly
* reduces variance of the stochastic gradient, generally leading to much faster convergence

The effect of a linear layer on its input is given by the equation
$$
    y = x W + b
    \,, \tag{obvious} $$
with $W \in \mathbb{R}^{m\times n}$, $y\in \mathbb{R}^n$,
$x\in \mathbb{R}^m$ and $b \in \mathbb{R}^n$.

Key observation:
> any non-trivial linear combination of Gaussian
random variables is a Gaussian random variable

Since the layer's $W$ are jointly Gaussian, the means that the distribution
of $y$ is also Gaussian. We can see this if we look closely at the output: the bias term is effectively
fixed to $\beta$ and the weights are

$$
    W_{ij} \sim
        \mathcal{N}(M_{ij}, \sigma^2_{ij})
    \,,\, W_{ij} \bot W_{kl}
    \,,\, ij\neq kl
    \,,
$$

which means that
$$
    y_j = \beta_j + \sum_i x_i W_{ij}
    \Rightarrow
    y_j \sim \mathcal{N}\bigl(
        \beta_j + \sum_i x_i M_{ij}, \sum_i \sigma^2_{ij} x_i^2
    \bigr)
    \tag{local-reparametrization}
    \,. $$

Note that independence of $W_{ij}$ implies zero correlation between them, which
means that distinct $y_j$ and $y_k$ are uncorrelated.

Collecting into a single multivariate Gaussian random vector we get the 
following stochastic forward pass:

$$
    y \sim \mathcal{N}_n\bigl(
        x M + \beta\,,\, \mathrm{diag}(\nu^2)
    \bigr)
    \,, $$

with $\nu^2_j = \sum_{i=1}^m \sigma^2_{ij} x_i^2$.

In [None]:
def stochastic_linear_lrp(layer, input):
    """Forward pass for the linear layer with the local reparameterization trick."""

    ## Exercise: implement the always active local reparametrization trick.
    #  (note) you might want to add 1e-20 inside `.sqrt()`

    # Get the mean
    mu = F.linear(input, layer.weight, layer.bias)

    # Add the resulting effect of weight randomness
    s2 = F.linear(input * input, torch.exp(layer.log_sigma2), None)
    output =  mu + torch.randn_like(s2) * torch.sqrt(s2 + 1e-20)

    pass

    return output

<br>

### Linear layer with Gaussian dropout and the trick 

Now let's combine the trick for the Gaussian approximation $q_\theta(\omega)$
and one of the divergences, given above, into a stochastic linear layer:

* we need write the forward pass and a parameter sampler

In [None]:
class GaussianLinear(Linear, VariatonalApproximation):
    """Linear layer with Gaussian Mean Field weight distribution."""

    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias=bias)

        self.log_sigma2 = torch.nn.Parameter(
            torch.Tensor(*self.weight.shape))

        self.reset_variational_parameters()

    def reset_variational_parameters(self):
        """Initialize the log-variance."""
        self.log_sigma2.data.normal_(-8, 0.01)

    def forward(self, input):
        """Forward pass with the local reparameterization trick."""
        if self.is_frozen():
            return F.linear(input, self.frozen_weight, self.bias)

        return stochastic_linear_lrp(self, input)

    def freeze(self):
        """Return a sample from $q_{\theta_m}(\omega_m)$."""
        
        ## Exercise: sample the weights from the variational approximation
        stdev = torch.exp(0.5 * self.log_sigma2)
        weight = torch.normal(self.weight, std=stdev)

        self.register_buffer("frozen_weight", weight)

    @property
    def penalty(self):
        """KL divergence between $q_{\theta_m}(\omega_m)$ an a prior on $\omega_m$."""

        ## Exercise
        # return kl_div_diffuse(self.log_sigma2, self.weight)
        # return kl_div_scale_free(self.log_sigma2, self.weight)
        # return kl_div_proper(self.log_sigma2, self.weight, nu=1e0)

        return kl_div_1811_00596(self.log_sigma2, self.weight)

A convolution can be represented as matrix-vector product of the doubly
block-circulant embedding (Toeplitz) of the kernel and the unravelled
input. As such, it is an implicit linear layer with block structured
weight matrix, but unlike it, the local reparameterization trick has
a little caveat. If the kernel itself is assumed to have the specified
variational distribution, then the outputs will be spatially correlated
due to the same weight block being reused at each location:

$$
    cov(y_{f\beta}, y_{k\omega})
        = \delta_{f=k} \sum_{c \alpha}
            \sigma^2_{fc \alpha}
            x_{c i_\beta(\alpha)}
            x_{c i_\omega(\alpha)}
    \,, $$

where $i_\beta(\alpha)$ is the location in $x$ for the output location
$\beta$ and kernel offset $\alpha$ (depends on stride and dilation).

In contrast, if instead the Toeplitz embedding blocks are assumed iid
draws from the variational distribution, then covariance becomes

$$
    cov(y_{f\beta}, y_{k\omega})
        = \delta_{f\beta = k\omega} \sum_{c \alpha}
            \sigma^2_{fc \alpha}
            \lvert x_{c i_\omega(\alpha)} \rvert^2
    \,. $$

Molchanov et al. (2017) implicitly assume that kernels is are iid draws
from the variational distribution for different spatial locations. This
effectively zeroes the spatial cross-correlation in the output, reduces
the variance of the gradient in SGVB method.

In [None]:
class GaussianConv2d(Conv2d, VariatonalApproximation):
    """Convolutional layer with Gaussian Mean Field weight distribution."""

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, p=0.05,
                 padding_mode='zeros'):
        super().__init__(in_channels, out_channels, kernel_size, stride=stride,
                         padding=padding, dilation=dilation, groups=groups,
                         bias=bias, padding_mode=padding_mode)

        self.log_sigma2 = torch.nn.Parameter(
            torch.Tensor(*self.weight.shape))

        self.reset_variational_parameters()

    reset_variational_parameters = LinearGaussian.reset_variational_parameters

    def forward(self, input):
        """Forward pass with the local reparameterization trick."""
        if self.is_frozen():
            return F.conv2d(input, self.frozen_weight, None,
                            self.stride, self.padding, self.dilation, self.groups)

        mu = super().forward(input)

        s2 = F.conv2d(input * input, torch.exp(self.log_sigma2), None,
                      self.stride, self.padding, self.dilation, self.groups)
        return mu + torch.randn_like(s2) * torch.sqrt(s2 + 1e-20)

    # we reuse implementation, without inheritance
    freeze = GaussianLinear.freeze

    penalty = GaussianLinear.penalty

<br>

### Linear layer with Bernoulli Dropout and Gaussian approximation

We can also fuse the classical Bernoulli dropout, [Hinton et al. 2012](https://arxiv.org/abs/1207.0580)
with Gaussian variational approximation [Gal, Y. 2016](http://www.cs.ox.ac.uk/people/yarin.gal/website/thesis/thesis.pdf).

The forward pass of the fused model is simply a composition of dropout
on inputs and the local reparameterization trick. However, for the Kullback-Leibler
divergence we need to identify how $W$ are effectively distributed. Indeed,
the variational approximation $q_{\theta}(W)$ is essentially a spike-and-slab
mixture: each row $W_i$ of $W\in \mathbb{R}^{m \times n}$ is either $\mathbf{0}$
with probability $p$ or a Gaussian vector in $\mathbb{R}^n$:

$$
    W_i \sim
\begin{cases}
    \mathbf{0}
        & \text{ w. prob } p \\
    M_i
        & \text{ otherwise } \\
\end{cases}
    \,, \text{indep.}
    \,, i=1, \ldots, m
    \,, M_{ij} \sim \mathcal{N}_n\bigl(
        M_{ij} \big\vert \mu_i, \mathop{diag} \sigma^2_i
    \bigr)
    \,. $$

Under some assumptions and benign relaxations
[Gal, Y. 2016 (eq. (6.3) p.109, Prop. 4 p.149)](http://www.cs.ox.ac.uk/people/yarin.gal/website/thesis/thesis.pdf)
the divergence can approximated by

$$
    KL\bigl( q_\theta(W) \big\| p(W) \bigr)
        = \sum_i KL\bigl( q_\theta(W_i) \big\| p(W_i) \bigr)
        \approx \mathop{const}
        % + \frac{mn} 2 p \{\tau^{-1} + \log \tau\}
        % + m (p \log p + (1-p) \log (1-p))
        + \frac{1-p}2 \sum_{ij}
            \sigma^2_{ij} + \mu_{ij}^2 - \log \sigma^2_{ij}
    \,. $$

<!--
Proposition 4 p.149 is about the approximation of the divergence of
a mixture from the standard Gaussian:

$$
    KL\bigl(
        \sum_k \pi_k \mathcal{N}_n(\mu_k, \Sigma_k)
        \big\| \mathcal{N}_n(0, I_n)
    \bigr)
        \approx \mathop{const}
        + \frac12 \sum_k \pi_k \{
            \mathop{tr}\Sigma_k
            + \mu_k^\top \mu_k
            - \log\det\Sigma_k
        \}
        - \mathbb{H}(\pi)
    \,. $$
-->

In [None]:
class LinearFusedGaussianBernoulli(LinearGaussian):
    """Linear layer with spike-slab (Gaussian) weight distribution."""
    def __init__(self, in_features, out_features, bias=True, p=0.5):
        super().__init__(in_features, out_features, bias=bias)

        self.p = p

    def forward(self, input):
        """Forward pass with Bernoulli dropout, then the Gaussian 
        local reparameterization trick.
        """
        input = F.dropout(input, self.p, True)

        return super().forward(input)

    def sample(self):
        par = super().sample()

        prob = torch.full_like(self.weight[:1], 1 - self.p)
        par["weight"] *= torch.bernoulli(prob) / prob

        return par

    @property
    def penalty(self):
        """Approximate KL divergence."""
        return (1 - self.p) * kl_div_proper(self.log_sigma2, self.weight)

<br>

## Troubles with initializers

Initializing a matrix $W \in\mathbb{R}^{
[\mathrm{out}]
\times [\mathrm{in}]
}$. Out and in are known as `fan-in` and `fan-out` respectively.

Kaiming He:
* (normal) $w_{ij} \sim \mathcal{N}(0, \sigma^2)$ with $
    \sigma = \tfrac{\sqrt2}{\sqrt{
        [\mathrm{in}] (1+a^2)
    }}
$ ($a$ -- negative slope of the ReLU)
* (uniform) $w_{ij} \sim \mathcal{U}[-\sigma, +\sigma]$ with $
\sigma = \tfrac{\sqrt6}{\sqrt{
        [\mathrm{in}] (1+a^2)
}}$

Glorot Xavier:
* (uniform) $w_{ij} \sim \mathcal{U}[-b, +b]$ with $
b = \mathrm{gain} \cdot \sqrt{\tfrac6{
        [\mathrm{out}] + [\mathrm{in}]
    }}
$
* (normal) $w_{ij} \sim \mathcal{N}(0, \sigma^2)$ with $
\sigma = \mathrm{gain} \cdot \sqrt{\tfrac2{
        [\mathrm{out}] + [\mathrm{in}]
    }}
$


Based on this piece on [weight initialization](https://towardsdatascience.com/weight-initialization-in-neural-networks-a-journey-from-the-basics-to-kaiming-954fb9b47c79)

In [None]:
from torch.nn import init

from torch.nn import Linear

class Linear(torch.nn.Linear):
    def reset_parameters(self):
        init.xavier_uniform_(self.weight)
        if self.bias is not None:
            init.zeros_(self.bias)
        

<br>

# Trunk

How about heteroskedasticity?

In [None]:
def heteroskedastic_mse_loss(model, X, y):
    output = model(X)
    mean, log_sigma = output[..., [0]], output[..., [1]]

    value = F.mse_loss(mean, y, reduction="none")

    return torch.mean(value * torch.exp(- log_sigma))

A $1$-d classification?

In [None]:
def log_loss(model, X, y):
    return F.binary_cross_entropy_with_logits(model(X), y)

fit(model, train, criterion=log_loss, batch_size=32, n_epochs=2000, verbose=True)

This setting induces
a distribution of the observed data $D$ conditional on the model
$f$ and its parameters $\omega$: $p(y \mid x, \omega, f)$.

parameters to be optimized over, but rather as random variables the
distribution of which **after** observing the data is to be sought.

* a prior distribution $p(\omega)$ on the parameters $\omega$ -- this
  reflects our belief prior to observing data.

* $p(y \mid x, \omega)$ -- the output distribution (Gaussian, Categorical,
  Poisson), the parameters of which are modeled by some function of the
  input $f_\omega\colon \mathcal{X} \to \mathcal{Y}$. We construct the
  **likelihood** with this building block.

We usually assume that given the parameters $\omega$ (and latent factors,
e.g. source components in mixture models), that the targets are conditionally
independent: $
    p(D \mid \omega)
        = \prod_i p(y_i, x_i\mid \omega)
$.

From these blocks under these assumptions we wish to find $
p(\omega\mid D) = \tfrac{p(D\mid \omega) p(\omega)}{p(D)}
$ -- the **posterior parameter distribution**, and having integrated
out $\omega$ the **posterior predictive distribution** $p(y\mid x, D)$

Need to measure difference between $q^*_D(z)$ and $p(z\mid x)$
(and its gradient) using only cheap operations. By assumption,
we can’t sample from $p(z\mid x)$ or evaluate its density.

We can
* evaluate density $p(x, z)$ aka unnormalized $p(z\mid x)$
* sample from $q^*_D(z)$ and evaluate its density

Our goal is to find the best candidate that approximates conditional density
of latent variables given observed variables. One notion of **proximity**
between distributions is the KL divergence. The Kullback-Leibler divergence
between $P$ and $Q$ with densities $p$ and $q$, respectively, is given by

$$
    \mathrm{KL}(q(\omega) \| p(\omega))
%         = \mathbb{E}_{\omega \sim Q} \log \tfrac{dQ}{dP}(\omega)
        = \mathbb{E}_{\omega \sim q(\omega)}
            \log \tfrac{q(\omega)}{p(\omega)}
    \,. $$

Can get unbiased estimate using only samples from $q(z\mid x)$ and evaluations
of $q(z\mid x)$ and $p(z, x)$.

<br>

Recall that the Kullback-Leibler divergence of distribution $q$
from $p$ is given by

$$
    KL(q\| p)
        = \mathbb{E}_{x\sim q(x)} \log \tfrac{q(x)}{p(x)}
    \,. $$

We can compute the entropy as the Kullback-Leibler divergence
of $p(y)$ from a uniform categorical rv:

$$\begin{align}
    \mathbb{H}\bigl(p(y)\bigr)
        &= - \sum_k p_k \log p_k
        = - \bigl( \sum_k p_k \log p_k \pm \log \tfrac1d \bigr)
%         = - \bigl( \sum_k p_k \log \tfrac{p_k}{\tfrac1d} \bigr) - \log \tfrac1d
        \\
%         &=  \log d - \bigl( \sum_k p_k \log \tfrac{p_k}{\tfrac1d} \bigr)
        &= \log d - KL(p\|\tfrac1d)
\end{align}
    \,. $$

In [None]:
## Obscure -- DON'T USE
def entropy(proba):
    """Compute the entropy along the last dimension."""

    ## Exercise
    return - torch.kl_div(torch.tensor(0.).to(proba), proba).sum(dim=-1)

The Variational Inference approximation of **BALD** is given by
the same expression, except with $q_\theta(\omega)$ instead of
$p(\omega\mid D_\mathrm{train})$:

$$
    \mathbb{I}_\mathrm{vi}(y\,; \omega \mid x, D_\mathrm{train})
        = \mathbb{H}\bigl(
            \mathbb{E}_{\omega \sim q_\theta(\omega)}
                p(y\mid x, \omega, D_\mathrm{train})
        \bigr)
        - \mathbb{E}_{\omega \sim q_\theta(\omega)}
            \mathbb{H}(p(y\mid x, \omega, D_\mathrm{train}))
    \,. \tag{mi-vi} $$

http://ruder.io/word-embeddings-softmax/

In [None]:
# import numpy as np
import tensorflow as tf

# %matplotlib inline
# import matplotlib.pyplot as plt

In [None]:
from tensorflow.keras.layers import Input, Dense, Dropout

inputs = Input(shape=(1,))
x = Dense(512, activation="relu")(inputs)
x = Dropout(0.5)(x, training=True)
x = Dense(512, activation="relu")(x)
x = Dropout(0.5)(x, training=True)
outputs = Dense(1)(x)

model = tf.keras.Model(inputs, outputs)
model.compile(loss="mean_squared_error", optimizer="rmsprop")

In [None]:
model.fit(X_train[:, 0], y_train[:, 0], epochs=2000, verbose=0)

In [None]:
# do stochastic forward passes on x_test:
samples = [model.predict(X_test[:, 0]) for _ in range(501)]
m = np.mean(samples, axis=0).flatten() # predictive mean
std = np.std(samples, axis=0).flatten() # predictive variance

In [None]:
plt.plot(X_test, np.r_[samples][..., 0].T, c="k", alpha=0.05);

In [None]:
# plot mean and uncertainty
plt.plot(X_train, y_train, 'or')
plt.plot(X_test, m, 'gray')

plt.fill_between(X_test[:, 0], m - 2*std, m + 2*std, 
                 color='gray', alpha=0.3) # plot two std (95% confidence)

In [None]:
assert False

<br>

In [None]:
import numpy as np

# %matplotlib inline
import matplotlib.pyplot as plt

from scipy.special import logsumexp

In [None]:
def softmax(logits, axis=-1):
    return np.exp(logits - logsumexp(logits, axis=axis, keepdims=True))

def entropy(proba, axis=-1):
    out = np.multiply(proba, np.log(proba), where=proba>0,
                      out=np.zeros_like(proba))

    return -out.sum(axis=axis)

def ext_h(proba, jackknife=True):
    # broadcasting from the tail!
    proba_avg = proba.mean(axis=0)

    entropy_expected = entropy(proba_avg)

    if jackknife:
        proba_loo = proba_avg + (proba_avg - proba) / (len(proba) - 1)
        expected_entropy_loo = entropy(proba_loo).mean(axis=0)
        entropy_expected += (len(proba) - 1) * (entropy_expected - expected_entropy_loo)

    expected_entropy = entropy(proba).mean(axis=0)

    return entropy_expected, entropy_expected - expected_entropy

In [None]:
# prob = np.r_[0.1, 0.05, 0.2, 0.15, 0.2, 0.1, 0.05, 0.05, 0.05, 0.05]
prob = softmax(np.random.randn(20) * 1e0)
prob

In [None]:
from scipy.stats import dirichlet

ht = -sum(p * np.log(p) for p in prob if p > 0)

d = dirichlet(prob*1e-1)

n_replications = 1001
for n_samples in [11, 51, 101, 1001]:
    proba = d.rvs(size=(n_samples, n_replications))

    hp, mip = ext_h(proba, False)
    hj, mij = ext_h(proba, True)

    plt.figure(figsize=(12, 5))
    plt.hist(hp, bins=21, label="plug-in", alpha=0.5)
    plt.hist(hj, bins=21, label="jackknife", alpha=0.5)

    plt.axvline(ht, c="red")
    plt.axvline(hp.mean(), c="cyan", label="mean plug-in")
    plt.axvline(hj.mean(), c="olive", label="mean jackknife")

    plt.legend()
    plt.show()