# MLSS2019 -- Bayesian active learning

In this tutorial we will learn what basic building blocks are needed
to endow (deep) neural networks with uncertainty estimates, and how
this can be used in active learning of a expert-in-the-loop pipelines.



In [None]:
import tqdm
import numpy as np

import torch
import torch.nn.functional as F

%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

from torch.utils.data import TensorDataset, DataLoader
from torch.nn import Linear, Conv2d

In [None]:
np.random.randint(0x7fff_ffff)

In [None]:
random_state = np.random.RandomState(722_257_201)

In [None]:
device = torch.device("cpu")

<br>

In [None]:
a = np.r_[:120].reshape(2, 3, 4, 5)

In [None]:
a.transpose(3, 0, 1, 2)

##### Some service functions
This function converts numpy arrays into torch tensors and places them on the specified compute device.

A procedure to plot a $1$-d function (to keep aesthetics in one place).

In [None]:
from mls import dataset_from_numpy

from utils import plot1d

<br>

## Uncertainty estimation in Neural networks

Consider a function
$$
    f
    \colon \mathbb{R} \to \mathbb{R}
    \colon x \mapsto x \sin x
    \,. $$

In [None]:
def f(start, stop, num=50, noise=0.01):
    """Simple toy 1-D function taken from Yarin's post_.

    .. _post: http://www.cs.ox.ac.uk/people/yarin.gal/website/blog_3d801aa532c1ce.html
    """
    X = np.linspace(start, stop, num)
    y = X * np.sin(X)

    if noise > 0:
        y += noise * random_state.randn(num)

    return X[:, np.newaxis], y[:, np.newaxis]

Generate the initial small dataset $S_0 = (x_i, y_i)_{i=1}^{m_0}$ of $m_0 = 20$
samples with $y_i = f(x_i)$ and $x_i$ on a regular-spaced grid over $[-4, +4]$.

In [None]:
X_train, y_train = map(np.concatenate, zip(
    f(-6.0, +6.0, num=20, noise=0.0)
))

X_test, y_test = f(-10., +10., num=251, noise=0.0)

Plot the sample and $S_0$.

In [None]:
plot1d("Train-test 1d function")

Let's build a small model

In [None]:
model = torch.nn.Sequential(
    Linear(1, 512, bias=True),
    torch.nn.ReLU(),
    torch.nn.BatchNorm1d(512, affine=False),

    Linear(512, 512, bias=True),
    torch.nn.ReLU(),
    torch.nn.BatchNorm1d(512, affine=False),

    Linear(512, 1, bias=True),
)

<br>

Move model to device and convert numpy arrays to tensors (both train and test datasets).

In [None]:
model.to(device)

train = dataset_from_numpy(X_train, y_train, device=device, dtype=torch.float)

test = dataset_from_numpy(X_test, y_test, device=device, dtype=torch.float)

A simple fit loop

In [None]:
def fit(model, dataset, batch_size=32, n_epochs=1,
        loss_fn=F.nll_loss, verbose=False):
    """Fit the model with SGD on the specified dataset."""
    model.to(device)

    # an optimizer for model's parameters
    optim = torch.optim.Adam(model.parameters(), lr=1e-3)  # , weight_decay=1e-5)

    # minibatch generator for the training loop
    feed = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for epoch in tqdm.tqdm(range(n_epochs), disable=not verbose):

        model.train()
        for X, y in feed:
            # forward pass
            output = model(X.to(device))

            # criterion: batch-average loss
            loss = loss_fn(output, y.to(device), reduction="mean")

            # get gradients with backward pass
            optim.zero_grad()
            loss.backward()

            # SGD update
            optim.step()

    return model

<br>

... and let's run it!

In [None]:
fit(model, train, n_epochs=2000, loss_fn=F.mse_loss, verbose=True)

<br>

Let's compute the predictions

In [None]:
def apply(model, feed):
    """Collect model's outputs on the dataset without autograd."""
    model.eval()

    # disable gradients (huge speed up!)
    with torch.no_grad():

        # compute and collect the outputs
        return torch.cat([
            model(X.to(device)).cpu() for X, *rest in feed
        ], dim=0)

In [None]:
test_feed = DataLoader(test, batch_size=512, shuffle=False)

y_pred = apply(model, test_feed)

<br>

Here is how the model predictions look like:

In [None]:
plot1d("prediction", y_pred.numpy()[np.newaxis])

It appears that this model fits the curve adequately well.

Can we somehow add uncertainty estimation to it?

<br>

## Bayesian Inference

A Bayesian version of any model treats weights not as fixed parameters
to be optimized over, but rather as random variables the distribution
of which **after** observing the data is to be sought.

* a prior distribution $p(\omega)$ on the parameters $\omega$ -- this
  reflects our belief prior to observing data.

* $p(y \mid x, \omega)$ -- the output distribution (Gaussian, Categorical,
  Poisson), the parameters of which are modeled by some function of the
  input $f_\omega\colon \mathcal{X} \to \mathcal{Y}$. We construct the
  **likelihood** with this building block.

We usually assume that given the parameters $\omega$ (and latent factors,
e.g. source components in mixture models), that the targets are conditionally
independent: $
    p(D \mid \omega)
        = \prod_i p(y_i, x_i\mid \omega)
$.

From these blocks under these assumptions we wish to find $
p(\omega\mid D) = \tfrac{p(D\mid \omega) p(\omega)}{p(D)}
$ -- the **posterior parameter distribution**, and having marginalized
$\omega$ the **posterior predictive distribution** $p(y\mid x, D)$

Unless the distributions and likelihoods are conjugate, posterior in
Bayesian inference is typically intractable: it is common to resort
to **Variational Inference** or **Monte Carlo** approximations, or both.

### Variational Inference

Let's specify a parametric family $Q$ of densities over the latent variables,
i.e. the parameters of the network. Each $q(\omega) \in Q$ is a candidate approximation
to the exact intractable posterior $p(\omega \mid D)$. 

Need to measure difference between $q^*_D(z)$ and $p(z\mid x)$
(and its gradient) using only cheap operations. By assumption,
we can’t sample from $p(z\mid x)$ or evaluate its density.

We can
* evaluate density $p(x, z)$ aka unnormalized $p(z\mid x)$
* sample from $q^*_D(z)$ and evaluate its density

Our goal is to find the best candidate that approximates conditional density
of latent variables given observed variables. One notion of **proximity**
between distributions is the KL divergence. The Kullback-Leibler divergence
between $P$ and $Q$ with densities $p$ and $q$, respectively, is given by

$$
    \mathrm{KL}(q(\omega) \| p(\omega))
%         = \mathbb{E}_{\omega \sim Q} \log \tfrac{dQ}{dP}(\omega)
        = \mathbb{E}_{\omega \sim q(\omega)}
            \log \tfrac{q(\omega)}{p(\omega)}
    \,. $$

Can get unbiased estimate using only samples from $q(z\mid x)$ and evaluations
of $q(z\mid x)$ and $p(z, x)$.

For any distribution $q$ (density or probability mass function) and
any likelihood of the data $p(D)$ we have

$$
    \log p(D)
%         = \mathbb{E}_{\omega \sim q} \log \tfrac{p(\omega, D)}{p(\omega \mid D)}
%         = \mathbb{E}_{\omega \sim q} \log \tfrac{p(\omega, D)}{q(\omega)}
%         + \mathbb{E}_{\omega \sim q} \log \tfrac{q(\omega)}{p(\omega \mid D)}
%         = \mathbb{E}_{\omega \sim q} \log p(D\mid \omega)
%         + \mathbb{E}_{\omega \sim q} \log \tfrac{p(\omega)}{q(\omega)}
%         + \mathbb{E}_{\omega \sim q} \log \tfrac{q(\omega)}{p(\omega \mid D)}
        = \mathbb{E}_{\omega \sim q} \log p(D\mid \omega)
        - \mathrm{KL}(q(\omega)\| p(\omega))
        + \mathrm{KL}(q(\omega)\| p(\omega \mid D))
    \,, $$

provided $q(\omega) \ll p(\omega \mid D)$.

Therefore, we can optimize an alternative objective that is equivalent to the
KL up to an added constant:

$$
    q^* \in
    \arg\max_{q\in Q}
        \mathrm{ELBO}(q) = 
            \mathbb{E}_{\omega \sim q} \log p(D\mid \omega)
            - \mathrm{KL}(q(\omega)\| p(\omega))
    \,. $$

* the expected likelihood -- favours $q$ that place their mass on
parameters $\omega$ that explain the observed data under the specified
model.

* the negative KL-divergence -- encourages variational densities
not to stray away too far from to the prior.

Thus the goal of **VI** in our case is to find

$$
    q^* \in \arg\min_{q\in Q} \mathrm{KL}(q(\omega)\| p(\omega \mid D))
    \,. $$

Although computing the divergence w.r.t. the posterior is still hard and
intractable, it is possible due to the following identity based on the
Bayes rule.

> We consider the mean-field variational family of densities $Q$, where the parameters of the model are assumed to be mutually independent random variables each governed by a distinct factor in the variational density:

$$
    q(\omega)
        = \prod_j q_j(\omega_j)
    \,. $$

For a well-written review of variational inference see

> Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational inference: A review for statisticians. [Journal of the American Statistical Association, 112(518), 859-877](https://www.tandfonline.com/doi/full/10.1080/01621459.2017.1285773).

[arXiv:1601.00670](https://arxiv.org/pdf/1601.00670.pdf)

<br>

Since we clearly need to give our neural networks stochastic
parameters and `pytorch` does not allow this out-of-the-box,
we will have to build it ourselves.

We shall make:
* a base class that tags derived classes as being capable of sampling their
parameters $\omega_j$ from $q_\theta(\omega_j)$
* a procedure that crawls over the model's components, requests random draws
  and collects parameter samples.

In [None]:
class VariationalModule(torch.nn.Module):
    def sample(self):
        raise NotImplementedError("Derived classes must implement a parameter "
                                  "sampler from their own distribution.")


def named_variational_modules(module, prefix=""):
    for name, mod in module.named_modules():
        if isinstance(mod, VariationalModule):
            yield name, mod


def named_parameter_samples(module, prefix=""):
    """Returns an iterator over parameter draw from all stochastic
    modules in the network, yielding both the name of the parameter
    as well as the parameter itself.

    Yields
    ------
    (string, torch.Tensor):
        Tuple containing the name and sampled parameter
    """

    for name, mod in named_variational_modules(module, prefix):

        par_prefix = name + ('.' if name else '')
        for key, par in mod.sample().items():
            yield par_prefix + key, par

<br>

### (implicit) Variational approximation via Bernoulli Dropout

We can use the classical Bernoulli dropout, [Hinton et al. 2012](https://arxiv.org/abs/1207.0580)
with as a variational approximation [Gal, Y. 2016](http://www.cs.ox.ac.uk/people/yarin.gal/website/thesis/thesis.pdf).

```
A simple regularization method allows uncertainty estimation essentially for free!
```

#### Dense linear layer

The *classical dropout* acts upon the inputs into a linear layer.

For input
$
    x\in \mathbb{R}^{[\mathrm{in}]}
$ and a layer parameters $
    W\in \mathbb{R}^{[\mathrm{out}] \times [\mathrm{in}]}
$
and $
    b\in \mathbb{R}^{[\mathrm{out}]}
$ the resulting effect is

$$
    \tilde{x} = x \odot m
    \,, \\
    y = \tilde{x} W^\top + b
%     = b + \sum_i x_i m_i W_i
    \,, $$

where $\odot$ is the elementwise product and $m\in \mathbb{R}^{[\mathrm{in}]}$
with $m_j \sim \mathcal{Ber}\bigl(\bigl\{0, \tfrac1{1-p}\bigr\}, 1-p\bigr)$,
i.e. equals $\tfrac1{1-p}$ with probability $1-p$ and $0$ otherwise.

Other variants of Bernoulli dropout have been studied as well (like `dropConnect`
which zeroes out individual elements of the weight matrix), see
[Gal, Y. 2016](http://www.cs.ox.ac.uk/people/yarin.gal/website/thesis/thesis.pdf).

Let's implement this fused operation. `torch` has
* `F.dropout` for Bernouli dropout
* `F.linear` for $y = x W^\top + b$

In [None]:
def stochastic_linear(layer, input, p=0.5):
    """Apply dropout and then affine transformation."""

    ## Exercise
    input = F.dropout(input, p, True)
    output = F.linear(input, layer.weight, layer.bias)

    return output

<br>

The multiplicative effect of the random mask $m$ on $x$ can be
equivalently seen as a random **on/off** switch effect on the
**columns** of the matrix $W$:

$$
    y_j
%         = b_j + \sum_i \tilde{x}_i W_{ji}
%         = b_j + \sum_i x_i m_i W_{ji}
        = b_j + \sum_i x_i \breve{W}_{ji}
    \,, $$

where the each column of $\breve{W}_i$ is, independently, either
$\mathbf{0}\in \mathbb{R}^{[\mathrm{out}]}$ with probability $p$ or
some (learnable) vector in $\mathbb{R}^{[\mathrm{out}]}$

$$
    \breve{W}_i \sim
\begin{cases}
    \mathbf{0}
        & \text{ w. prob } p \,, \\
    \tfrac1{1-p} M_i
        & \text{ w. prob } 1-p \,.
\end{cases}
$$

**(NOTE)** `pytorch` stores the matrix of the linear layer as $
W \in \mathbb{R}^{
    [\mathrm{out}]
    \times [\mathrm{in}]
}
$.

In [None]:
def dropout_columns(weight, p=0.5):
    """Apply dropout with rate `p` to columns of `weight`."""

    ## Exercise
    p = torch.full_like(weight[:1, :], 1 - p)
    sample = weight * torch.bernoulli(p) / p

    return sample

<br>

The implementation of the stochastic `Linear+Dropout` layer should:

1. dropout the inputs and apply the linear (affine) transformation on forward pass
2. randomly zero entire columns of the weight matrix $W$, when queried for a sample

In [None]:
class LinearBernoulli(Linear, VariationalModule):
    """Linear layer with dropout on inputs."""
    def __init__(self, in_features, out_features, bias=True, p=0.5):
        super().__init__(in_features, out_features, bias=bias)

        self.p = p

    def forward(self, input):        
        """Apply dropout and then affine transformation."""

        return stochastic_linear(self, input, self.p)

    def sample(self):
        """Sample the weight from the variational distribution."""

        weight = dropout_columns(self.weight, self.p)
        return {"weight": weight, "bias": self.bias}


**(NB)**
This effect amounts to a `special` variational approximation $q_{\theta}(W)$:

$$
    q_\theta(W)
        = \prod_i q_\theta(W_i)
        = \prod_i \{
            p \delta_{\mathbf{0}}(W_i)
            + (1 - p) \delta_{\tfrac1{1-p} M_i}(W_i)
        \}
    \,, $$

where $\delta_\mathbf{a}$ is a **point-mass** distribution at $\mathbf{a}$.

<br>

### Estimating predictive uncertainty

There are two approaches to estimating predictive uncertainty
for the sample $\tilde{S} = (\tilde{x}_i)_{i=1}^m \in \mathcal{X}$.

#### Point-estimate approach ([blog: Gal, Y. 2016](http://www.cs.ox.ac.uk/people/yarin.gal/website/blog_3d801aa532c1ce.html)):

* for $i = 1... m$ do:

  1. draw an iid sample of parameters $\Omega = (\omega_b)_{b=1}^B \sim q_\theta(\omega)$
  2. compute $y_{bi} = f_{\omega_b}(\tilde{x}_i)$ for $b=1 .. B$.


* compute mean and variance of $\hat{y}_{bi}$ along $b$

In [None]:
def point_estimate(model, dataset, n_samples=1, verbose=False):
    """Draw pointwise samples with stochastic forward pass."""

    outputs = []
    for _ in tqdm.tqdm(range(n_samples), disable=not verbose):

        outputs.append(apply(model, dataset))

    return torch.stack(outputs, dim=0)

* uses stochastic forward passes -- no need to for extra code and classes
* predictive distributions at adjacent inputs are independent

#### Sample function approach ([blog: Gal, Y. 2016](http://www.cs.ox.ac.uk/people/yarin.gal/website/blog_2248.html)):

* for $b = 1... B$ do:

  1. draw a realization $f_b\colon \mathcal{X} \to \mathcal{Y}$
  with from the process $\{f_\omega\}_{\omega \sim q_\theta(\omega)}$
  2. get $\hat{y}_{bi} = f_b(\tilde{x}_i)$ for $i=1 .. m$


* compute mean and variance of $\hat{y}_{bi}$ along $b$

The implementation performs the discussed steps verbatim:
1. sample $f \sim f_{\omega \sim q_\theta(\omega)}$ independently
2. compute the outputs for each sample realization.

In [None]:
def sample_function(model, dataset, n_samples=1, verbose=False):
    """Draw a realization of a random function."""

    outputs = []
    for _ in tqdm.tqdm(range(n_samples), disable=not verbose):
        det_model = realization(model)

        outputs.append(apply(det_model, dataset))

    return torch.stack(outputs, dim=0)

**(note)** we have not yet implemented the `realization(...)` function

<br>

### Making our model Bayesian

We need to be able to sample a realization of our Bayesian model
regarded as a random function. For this the following steps seem 
sufficient:

1. draw random parameters ($\omega \sim q_\theta(\omega)$)
2. create a deterministic clone of the network $\bar{f}_\cdot$
3. set $\bar{f}_\cdot$'s parameters to $\omega$.

In [None]:
def realization(model):
    parameters = model.state_dict()
    parameters.update(named_parameter_samples(model))

    det_model = model.deterministic()
    det_model.to(next(model.parameters()).device)

    det_model.load_state_dict(parameters, strict=False)

    return det_model

In light of the functionality implemented above and discussed 
in the reminder on uncertainty, our model object must be capable
of:
* performing a (stochastic) forward pass
* producing a deterministic clone of itself on-demand

In [None]:
class SimpleModel(torch.nn.Module):
    def __init__(self, l_linear=LinearBernoulli):
        super().__init__()
        
        self.body = torch.nn.Sequential(
            Linear(1, 512, bias=True),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(512, affine=False),

            l_linear(512, 512, bias=True),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(512, affine=False),

            l_linear(512, 1, bias=True),
        )
    
    def forward(self, input):
        return self.body(input)

    def deterministic(self):
        """Returns a deterministic version of self."""
        return type(self)(l_linear=Linear)

<br>

Let's create a new instance and retrain the model.

In [None]:
model = SimpleModel()
fit(model, train, n_epochs=2000, loss_fn=F.mse_loss, verbose=True)

Compute the mean and variance of the predictive distribution using Monte Carlo:

$$
\begin{align}
    \mathbb{E}_{y\sim p(y\mid x, D)} \, g(y)
        &\overset{\text{BI}}{=}
            \mathbb{E}_{\omega\sim p(\omega\mid D)}
                \mathbb{E}_{y\sim p(y\mid x, \omega)} \, g(y) 
        \\
        &\overset{\text{VI}}{\approx}
            \mathbb{E}_{\omega\sim q_\theta(\omega)}
                \mathbb{E}_{y\sim p(y\mid x, \omega)} \, g(y)
        \\
        &\overset{\text{MC}}{\approx}
%             \hat{\mathbb{E}}_{\omega \sim \mathcal{W}}
%                 \mathbb{E}_{y\sim p(y\mid x, \omega)} \, g(y)
            \frac1{\lvert \mathcal{W}\rvert} \sum_{\omega \in \mathcal{W}}
                \mathbb{E}_{y\sim p(y\mid x, \omega)} \, g(y)
    \,,
\end{align}
$$

where $\mathcal{W} = (\omega_b)_{b=1}^B \sim q_\theta(\omega)$
-- iid samples from the variational approximation.

**(NB)** recall that we assume $p(y \mid x, \omega)$ to be tractable.

In [None]:
outputs_pe = point_estimate(model, test_feed, n_samples=101, verbose=True)

mean_pe, std_pe = outputs_pe.mean(dim=0), outputs_pe.std(dim=0)

In [None]:
outputs_sf = sample_function(model, test_feed, n_samples=101, verbose=True)

mean_sf, std_sf = outputs_sf.mean(dim=0), outputs_sf.std(dim=0)

<br>

In [None]:
plot1d("Sample function", outputs_sf.numpy())

Let's compare point estimates with function sampling

In [None]:
fig, ax = canvas1d()

ax.plot(X_test[:, 0], outputs_sf[..., 0].numpy().T, c="C0", alpha=0.05)
ax.plot(X_test, mean_sf.numpy(), c="fuchsia", lw=3, label="mean")

ax.plot(X_test, y_test, lw=2, color="k", alpha=0.5, label="test")
ax.scatter(X_train, y_train, c="k", s=20, label="train")

ax.set_title("Sample functions")
plt.legend(ncol=3);

In [None]:
fig, ax = canvas1d()

ax.plot(X_test[:, 0], outputs_pe[..., 0].numpy().T, c="C0", alpha=0.05)
ax.plot(X_test, mean_pe.numpy(), c="fuchsia", lw=3, label="mean")

ax.plot(X_test, y_test, lw=2, color="k", alpha=0.5, label="test")
ax.scatter(X_train, y_train, c="k", s=20, label="train")

ax.set_title("Point estimates")
plt.legend(ncol=3);

<br>

## Bayesian Active learning

### General idea

* Data labelling is costly and time consuming
* unlabeled instances are essentially free

**Goal** Achieve high performance with fewer labels by
identifying the best instances to learn from

Essential blocks of active learning:

* a model $m$ or a function $a$ capable of quantifying uncertainty
* a labelling **oracle**, e.g. a human expert

Essential steps of active learning:

1. fit $m$ on $\mathcal{S}_{\mathrm{labelled}}$
2. get exact or approximate $
    \mathcal{S}^* \in \arg \max_{U \subseteq \mathcal{S}_\mathrm{unlabelled}}
        \mathbf{a}(U; m)
$ **without** access to targets and satisfying **budget constraints**
3. request an **oracle** to provide labels for $\mathcal{S}^*$
4. update $
\mathcal{S}_{\mathrm{labelled}}
    \leftarrow \mathcal{S}^*
        \cup \mathcal{S}_{\mathrm{labelled}}
$

### Acquisition functions

There are many acquisition criteria (borrowed from [Gal17a](http://proceedings.mlr.press/v70/gal17a.html)):
* Max entropy (plain uncertainty)
* Maximal information about parameters and predictions (mutual information)
* Variance ratios
* Mean standard deviation


The _entropy_ (all densities and/or probability mass functions can be conditional):

$$
    \mathbb{H}(p(y))
        = - \mathbb{E}_{y\sim p} \log p(y)
    \,. $$

**BALD** (Bayesian Active Learning by Disagreement) acquisition criterion is the posterior
mutual information between model's predictions at some point $x$ and model's parameters:

$$
    \mathbb{I}(y\,; \omega \mid x, D_\mathrm{train})
        = \mathbb{H}\bigl(
            \mathbb{E}_{\omega \sim p(\omega\mid D_\mathrm{train})}
                p(y\mid x, \omega, D_\mathrm{train})
        \bigr)
        - \mathbb{E}_{\omega \sim p(\omega\mid D_\mathrm{train})}
            \mathbb{H}(p(y\mid x, \omega, D_\mathrm{train}))
    \,. \tag{mi-bi} $$

The Variational Inference approximation of **BALD** is given by
the same expression, except with $q_\theta(\omega)$ instead of
$p(\omega\mid D_\mathrm{train})$:

$$
    \mathbb{I}_\mathrm{vi}(y\,; \omega \mid x, D_\mathrm{train})
        = \mathbb{H}\bigl(
            \mathbb{E}_{\omega \sim q_\theta(\omega)}
                p(y\mid x, \omega, D_\mathrm{train})
        \bigr)
        - \mathbb{E}_{\omega \sim q_\theta(\omega)}
            \mathbb{H}(p(y\mid x, \omega, D_\mathrm{train}))
    \,. \tag{mi-vi} $$

Consider an iid sample $\mathcal{W} = (\omega_b)_{b=1}^B \sim q_\theta(\omega)$
of size $B$. The Monte Carlo approximation of the mutual information is

$$
    \mathbb{I}_\mathrm{MC}(y\,; \omega \mid x, D_\mathrm{train})
        = \mathbb{H}\bigl(
            \hat{\mathbb{E}}_{\omega \sim \mathcal{W}}
                p(y\mid x, \omega, D_\mathrm{train})
        \bigr)
        - \hat{\mathbb{E}}_{\omega \sim \mathcal{W}}
            \mathbb{H}(p(y\mid x, \omega, D_\mathrm{train}))
    \,, \tag{mi-mc} $$

where $\hat{\mathbb{E}}_{\omega \sim \mathcal{W}} h(\omega) = \tfrac1B \sum_j h(\omega_j)$
denoted the expectation with respect to the empirical probability measure induced
by the sample $\mathcal{W}$.

<br>

#### Block: entropy and mutual information

For discrete distributions $p(y) \in \delta_m$ we can use KL-divergence
to compute entropy:
$$
    \mathbb{H}(p(y))
        = \log n - \mathbb{E}_{y\sim p} \log \tfrac{p_y}{\tfrac1n}
        = - KL(p\|\tfrac{\mathbf{1}}n)
    \,, $$
basically, the Kullback-Leibler divergence of $p(y)$ from a uniformly random choice.

**(NB)** We do this, because pytorch has a **numerically stable** KL-div implementation.

In [None]:
def entropy(proba):

    # return - torch.sum(proba * torch.log(proba), dim=-1)

    return - torch.kl_div(torch.zeros_like(proba), proba).sum(dim=-1)

We will implement a procedure that computes relevant estimates of the posterior
predictve distribution, namely
* its sample approximation $
\hat{p}(y\mid x, D) = \hat{\mathbb{E}}_{\omega \sim\mathcal{W}} p(y \mid x, \omega)
$

* its sample entropy $\hat{H}(y\mid x, D) = \mathbb{H}\bigl(\hat{p}(y\mid x, D)\bigr)$

* the _mutual information_ $\hat{I}(y; \omega\mid x, D)$:
$$
    \mathbb{I}(y; \omega)
        = \mathbb{H}\bigl(
            \mathbb{E}_{\omega} p(y\mid \omega)
        \bigr)
        - \mathbb{E}_{\omega}
            \mathbb{H}(p(y\mid \omega))
    \,. $$



In [None]:
def mutual_information(proba):
    # average across function samples
    avg_proba = proba.mean(dim=0)

    # mutual information components for `H E_w p(., w) - E_w H p(., w)`
    ent_exp = entropy(avg_proba)
    exp_ent = entropy(proba).mean(dim=0)

    mut_info = ent_exp - exp_ent

    return avg_proba, ent_exp, mut_info

<br>

#### Block: prediction, evaluation and acquisition

In [None]:
def predict(model, dataset, n_samples=1, kind="function"):
    feed = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=False)

    # Monte-Carlo function samples
    if kind == "function":
        logits = sample_function(model, feed, n_samples=n_samples)

    else:
        logits = point_estimate(model, feed, n_samples=n_samples)

    # logit-scores should be transformed to a proper distribution
    proba = F.softmax(logits, dim=-1)

    return mutual_information(proba)

**(NB)** when working with floating point numbers always try to pay
attention to numeric stability issues:

$$
\mathop{softmax}{z}
    = \bigl( \tfrac{e^{z_i}}{\sum_j e^{z_j} } \bigr)_{i=1}^n
    \,, $$
here you would likely use the *log-sum-exp* trick.

<br>

As the measure of uncertainty for classification we can use entropy of the predictive distribution.

For regression, as we have seen in the $1d$ examples above, we used the variance.

In [None]:
def evaluate(model, dataset, n_samples=1):
    proba, entropy, mutual_info = predict(model, dataset, n_samples=n_samples)

    predicted = proba.numpy().argmax(axis=-1)
    target = dataset.tensors[1].cpu().numpy()

    return confusion_matrix(target, predicted)

<br>

**BALD** acquisition function works like this:
* estimate mutual information $\mathbb{I}_\mathrm{vi}(y \,; \omega \mid x, D_\mathrm{train})$
  for each instance $x$ in $\mathcal{S}_\mathrm{pool}$
* pick the top $\ell$ instances

> Gal, Y., Islam, R. & Ghahramani, Z.. (2017). Deep Bayesian Active Learning with Image Data. Proceedings of the 34th International Conference on Machine Learning, in [PMLR 70:1183-1192](http://proceedings.mlr.press/v70/gal17a.html)


In [None]:
def acquisition(model, dataset, n_points=10, n_samples=1):
    proba, entropy, mutual_info = predict(model, dataset, n_samples=n_samples)

    indices = mutual_info.argsort()

    return indices[-n_points:], proba, entropy, mutual_info

**(NB)** A drawback of the `pointwise` top-$\ell$ procedure above is
that although it acquires individually informative instances, altogether
they might end up **being** `jointly poorly informative`.

This can be corrected if we would seek the highest mutual information
among sets $\mathcal{S} = (x_i)_{i=1}^\ell$. This leads to combinatorial
explosion of the amount of computations and memory requirements, however
there are working solutions like random sampling of subsets $\mathcal{S}$
of size $\ell$ from $\mathcal{S}_\mathrm{pool}$ or greedy maximization
of this *submodular* criterion.

$$
\begin{align}
    \mathbb{I}_\mathrm{vi}(\{y_1, ..., y_\ell\}\,; \omega \mid \{x_1, ..., x_\ell\}, D_\mathrm{train})
        &= \mathbb{H}\bigl(
            \mathbb{E}_{\omega \sim q_\theta(\omega)}
                p(\{y_1, ..., y_\ell\}\mid \{x_1, ..., x_\ell\}, \omega, D_\mathrm{train})
        \bigr)
        \\
        % conditional independence given \omega
        % &- \mathbb{E}_{\omega \sim q_\theta(\omega)}
        %     \mathbb{H}(p(\{y_1, ..., y_\ell\}\mid \{x_1, ..., x_\ell\}, \omega, D_\mathrm{train}))
        &- \sum_{i=1}^\ell \mathbb{E}_{\omega \sim q_\theta(\omega)}
            \mathbb{H}(p(y_i\mid x_i, \omega, D_\mathrm{train}))
\end{align}
        \,. \tag{mi-vi-batch} $$

> Kirsch, A., van Amersfoort, J., & Gal, Y. (2019). BatchBALD: Efficient and Diverse Batch Acquisition for Deep Bayesian Active Learning. arXiv preprint [arXiv:1906.08158](https://arxiv.org/abs/1906.08158.pdf)

<br>

## Bayesian Active Learning with images

We will partially replicate figure 1. in [Gat et al. (2017): p. 4](http://proceedings.mlr.press/v70/gal17a.html),

> Gal, Y., Islam, R. & Ghahramani, Z.. (2017). Deep Bayesian Active Learning with Image Data. Proceedings of the 34th International Conference on Machine Learning, in PMLR 70:1183-1192

### Prepare datasets

Prepare the datasets from the `train` part of [MNIST](http://yann.lecun.com/exdb/mnist/):
* ($\mathcal{S}_\mathrm{train}$) initial **training**:
  $\approx 20$ images, purposefully highly imbalanced classes (even absent classes)
* ($\mathcal{S}_\mathrm{valid}$) our **validation**:
  $5000$ images, stratified
* ($\mathcal{S}_\mathrm{pool}$) acquisition **pool**:
  all remaining images

In [None]:
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

dataset = datasets.MNIST("""./data""",
                         train=True, download=True, transform=transform)

test = datasets.MNIST("""./data""",
                      train=False, download=True, transform=transform)

**(NB)** We may just as well use [Kuzushiji-MNIST](https://github.com/rois-codh/kmnist)

Create an imbalanced class label distribution:

In [None]:
n_train = 20

distribution = random_state.dirichlet([0.1] * 10)
initial_imbalance = np.round(distribution * n_train).astype(int)

initial_imbalance

Get indices for each subsample.

In [None]:
targets = dataset.targets.numpy()

ix_rest, ix_valid = train_test_split(
    np.arange(len(targets)), stratify=targets, test_size=5000,
    shuffle=True, random_state=random_state)

select the specified number of instances from each class

In [None]:
indices = []
for label, freq in enumerate(initial_imbalance):
    ix = np.flatnonzero(targets[ix_rest] == label)
    indices.extend(ix[:freq])

ix_train = np.take(ix_rest, indices)
ix_pool = np.delete(ix_rest, indices)

Split the dataset:
* The reason for the following procedure is that MNIST torchvision
  dataset allows only single element indexing.

In [None]:
def collect_images(dataset, indices):
    """MNIST torchvision dataset"""
    pairs = (dataset[i] for i in tqdm.tqdm(indices))

    data, target = zip(*pairs)

    return torch.stack(data, dim=0), torch.tensor(target)

S_train = TensorDataset(*collect_images(dataset, ix_train))
S_valid = TensorDataset(*collect_images(dataset, ix_valid))
S_pool  = TensorDataset(*collect_images(dataset, ix_pool ))

<br>

Apart from evaluating and the acquisition criteria, we will need to be
able to manipulate the datasets for the **main active learning** loop.

We begin by implementing the following primitives:
* `take` collect the instances at the specified indices into a **new dataset** (object)
* `append` add one dataset to another
* `delete` drops the instances at the specified locations form the copy of the **dataset**

In [None]:
def take(pool, indices):
    """Copy the specified samples from the pool."""

    mask = torch.zeros(len(pool), dtype=torch.bool)
    mask[indices] = True

    return TensorDataset(*pool[mask])


def delete(pool, indices):
    """Drop the specified samples from the pool."""

    mask = torch.ones(len(pool), dtype=torch.bool)
    mask[indices] = False

    return TensorDataset(*pool[mask])


def append(train, new):
    """Append new samples to the train dataset."""
    tensors = [
        torch.cat(pair, dim=0)
        for pair in zip(train.tensors, new.tensors)
    ]

    return TensorDataset(*tensors)

<br>

#### $2$-d Convolutional layer

Typically, in convolutional neural networks the dropout acts upon the feature
(channel) information and not on the spatial dimensions. Thus entire channels
are dropped out and for $
    x \in \mathbb{R}^{
        [\mathrm{in}]
        \times h
        \times w}
$ and $
    y \in \mathbb{R}^{
        [\mathrm{out}]
        \times h'
        \times w'}
$ the full effect of the `Dropout+Conv2d` layer is

$$
    y_{lij} = ((x \odot m) \ast W_l)_{ij} + b_l
        = b_l + \sum_k \sum_{pq} x_{k i_p j_q} m_k W_{lkpq}
    \,, \tag{conv-2d} $$
    
where i.i.d $m_k \sim \mathcal{Ber}\bigl(\bigl\{0, \tfrac1{1-p}\bigr\}, 1-p\bigr)$,
and indices $i_p$ and $j_q$ represent the spatial location in $x$ that correspond
to the $p$ and $q$ elements in the kernel $
    W\in \mathbb{R}^{
        [\mathrm{out}]
        \times [\mathrm{in}]
        \times h
        \times w}
$ relative to $(i, j)$ coordinates in $y$.
The exact values of $i_p$ and $j_q$ depend on the configuration of the
convolutional layer, e.g. stride, kernel size and dilation.

In [None]:
class Conv2dBernoulli(torch.nn.Conv2d, VariationalModule):
    """Linear layer with dropout on inputs."""
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, p=0.05,
                 padding_mode='zeros'):
        super().__init__(in_channels, out_channels, kernel_size, stride=stride,
                         padding=padding, dilation=dilation, groups=groups,
                         bias=bias, padding_mode=padding_mode)

        self.p = p

    def forward(self, input):        
        """Apply dropout and then the convolution."""

        # Exercise
        return super().forward(F.dropout2d(input, self.p, True))

    def sample(self):
        """Sample the weight from the variational distribution."""

        # Exercise
        p = torch.full_like(self.weight[:1, :, :1, :1], 1 - self.p)
        weight = self.weight * torch.bernoulli(p) / p

        return {"weight": weight, "bias": self.bias}

**(NB)** For more on convolutions see [Convolution arithmetic](https://github.com/vdumoulin/conv_arithmetic) 
    repo.

<br>

Much like the `SimpleModel` class above in $1-d$ study,
let's implement a simple deep convolutional network.

In [None]:
class CNNModel(torch.nn.Module):
    """A convolutional net."""
    def __init__(self, l_conv2d=Conv2dBernoulli, l_linear=LinearBernoulli):
        super().__init__()

        self.conv1 = Conv2d(1, 20, 5, 1)
        self.conv2 = l_conv2d(20, 50, 5, 1)
        self.fc1 = l_linear(4 * 4 * 50, 500)
        self.fc2 = l_linear(500, 10)

    def forward(self, x):
        x = F.avg_pool2d(F.relu(self.conv1(x)), 2, 2)
        x = F.avg_pool2d(F.relu(self.conv2(x)), 2, 2)
        x = F.relu(self.fc1(x.reshape(-1, 4 * 4 * 50)))
        return F.log_softmax(self.fc2(x), dim=1)

    def deterministic(self):
        """Return a deterministic version of self."""
        return type(self)(l_conv2d=Conv2d, l_linear=Linear)

<br>

##### Some service functions

* The first computes the label frequencies
* The second -- prepares a text representation of the frequencies

In [None]:
def label_counts(labels, n_labels=10):
    return np.bincount(labels.numpy(), minlength=10)

def display_counts(labels):
    body = [f"{n:2d}" if n > 0 else " ." for n in label_counts(labels)]
    return "[ " + ' '.join(body) + " ]"

<br>

### Actual Bayesian Active Learning

Now we start the active learning loop

In [None]:
model = CNNModel()
model

<br>

Let's code the active learning loop. Recall that it
consists of the following steps:

1. fit on **train**
2. evaluate on **holdout**
3. acquire from **pool**
4. add to **train**


In [None]:
n_epochs, n_samples = 20, 11
n_active, n_points = 50, 10

scores = []
balances = [label_counts(S_train.tensors[1])]

balance_str = display_counts(S_train.tensors[1])
print(f">>> # {len(S_train):4d}: (starting) {balance_str}")

for step in range(n_active):
    # 1. fit
    fit(model, S_train, n_epochs=n_epochs)

    # 2. track validation score
    score_matrix = evaluate(model, S_valid, n_samples=n_samples)

    # 3. acquire new instances
    indices, proba, ent, mutual_info = acquisition(
        model, S_pool, n_points=n_points, n_samples=n_samples)

    # 4. query the pool for the chosen instances
    S_requested = take(S_pool, indices)
    S_train = append(S_train, S_requested)
    S_pool = delete(S_pool, indices)

    # 5. (optional) report accuracy and the statistics on the acquired batch
    tp = score_matrix.diagonal()
    accuracy = tp.sum() / score_matrix.sum()

    balance_str = display_counts(S_requested.tensors[1])
    print(f">>> # {len(S_train):4d}: (acquired) {balance_str}"
          f" (Accuracy) {accuracy:.2%}")

    balances.append(label_counts(S_train.tensors[1]))
    scores.append(score_matrix)


fit(model, S_train, n_epochs=n_epochs)
scores.append(evaluate(model, S_valid, n_samples=n_samples))

In [None]:
balances = np.stack(balances, axis=0)
scores = np.stack(scores, axis=0)

<br>

Let's see the dynamics of the frequency of each class in $\mathcal{S}_\mathrm{train}$

In [None]:
plt.plot(balances);

The dynamics of precision / recall scores on in $\mathcal{S}_\mathrm{valid}$

In [None]:
tp = scores.diagonal(axis1=-2, axis2=-1)
fp, fn = scores.sum(axis=-2) - tp, scores.sum(axis=-1) - tp

fig, ax = plt.subplots(1, 2, figsize=(12, 5))
ax[0].plot(tp / (tp + fp))
ax[0].set_title("Precision (ovr)")

ax[1].plot(tp / (tp + fn))
ax[1].set_title("Recall (ovr)")

plt.show()

The accuracy as a function of active learning iteration.

In [None]:
plt.plot(tp.sum(-1) / scores.sum((-2, -1)), label='Accuracy');

<br>

Volume of data used

In [None]:
f"train : pool = {len(S_train)} : {len(S_pool)}"

Let the test set confusion matrix be the ultimate judge:

In [None]:
S_test = TensorDataset(*collect_images(test, np.r_[:len(test)]))

score_matrix = evaluate(model, S_test, n_samples=51)

True positives, and false positives / negatives

In [None]:
tp = score_matrix.diagonal(axis1=-2, axis2=-1)
fp, fn = score_matrix.sum(axis=-2) - tp, score_matrix.sum(axis=-1) - tp

accuracy

In [None]:
f"(accuracy) {tp.sum() / score_matrix.sum():.2%}"

one-v-rest precision

In [None]:
{l: f"{p:.2%}" for l, p in enumerate(tp / (tp + fp))}

ovr recall

In [None]:
{l: f"{p:.2%}" for l, p in enumerate(tp / (tp + fn))}

In [None]:
assert False

In [None]:
score_matrix = score

<br>

In [None]:
plt.plot(mutual_info.numpy())

In [None]:
plt.plot(ent.numpy())

In [None]:
from collections import defaultdict

groups = defaultdict(list)
for l, e in zip(S_valid.tensors[1], ent):
    groups[int(l)].append(float(e))

In [None]:
plt.hist(groups.values())

<br>

# Intersting stuff, that didn't make the cut

Dense network for MNIST, in case convolutions are too slow on CPU.

In [None]:
class DNNModel(torch.nn.Sequential):
    """A fully connected net."""
    def __init__(self, l_linear=LinearBernoulli):
        super().__init__(
            Linear(784, 256),
            torch.nn.ReLU(),
            l_linear(256, 256),
            torch.nn.ReLU(),
            l_linear(256, 256),
            torch.nn.ReLU(),            
            l_linear(256, 10),
        )

    def forward(self, input):
        output = super().forward(input.flatten(1))
        return F.log_softmax(output, dim=-1)

    def deterministic(self):
        """Return a deterministic version of self."""
        return type(self)(l_linear=Linear)

<br>

For **any** $q(\omega)$ (point-mass, products, mixtures, frank-wolfe
boosted convex ensembles, even input dependent, **anything**) and any
$\phi$
$$
    \overbrace{
        \log p(D; \phi)
    }^{\text{unconditional likelihood}}
        = \underbrace{
            \mathbb{E}_{\omega \sim q} \log p(D\mid \omega; \phi)
        }_{\text{expected conditional likelihood}}
        - \overbrace{
            \mathrm{KL}(q(\omega)\| p(\omega; \phi))
        }^{\text{staying close to prior}}
        + \underbrace{
            \mathrm{KL}(q(\omega)\| p(\omega \mid D; \phi))
        }_{\text{posterior approximation}}
    \,. \tag{master-identity} $$

<br>

### Variational approximation via Gaussian Mean Field Family

* sampling from $q_\theta(\omega) = q_\theta(b) \otimes q_\theta(W)$ where
$q_\theta(b) = \delta_{b - \beta}$, $q_\theta(W) = \otimes_{ij} q_\theta(W_{ij})$
and $
q_\theta(W_{ij})
    = \mathcal{N}\bigl(
        W_{ij} \big\vert
            \mu_{ij}, \sigma^2_{ij}
        \bigr)
$;
* **local** reparameterization trick $y = x W + b$ implies that $
    y_j \sim \mathcal{N}\bigl(
            \beta_j + \sum_i x_i \mu_{ij},
            \sum_i \sigma^2_{ij} \lvert x_i \rvert^2
        \bigr)
$ and $y_j \bot y_k$, $j\neq k$;
* computing the KL-divergence of $q_\theta(W)$ from $p(W)$ ignoring $b$: $
    \mathop{KL}\bigl(q_\theta(W) \| p(W)\bigr)
        = \mathbb{E}_{W \sim q_\theta}
            \log \tfrac{q_\theta(W)}{p(W)}
$
  * diffuse $p(W_{ij}) \propto \mathop{const}$
  * scale-free $p(W_{ij}) \propto \tfrac1{\lvert W_{ij} \rvert}$
  * proper standard normal $p(W_{ij}) = \mathcal{N}(W_{ij} \mid 0, \nu^{-1})$

For the this to work, it is necessary to implement a penalty "collector".

In [None]:
def penalties(module):
    """Returns an iterator over all penalties in the network.

    Yields
    ------
    troch.Tensor:
        Tuple value of the penalty.

    Note
    ----
    Penalties from shared modules are returned only once.
    """
    for name, mod in named_variational_modules(module):
        yield mod.penalty

For all $\phi$
$$
    \log p(D; \phi)
    \geq \max_{\theta}
        \mathrm{ELBO}(\theta, \phi)
            = \mathbb{E}_{\omega \sim q_\theta(\omega)} \log p(D \mid \omega; \phi)
            - \mathbb{E}_{\omega \sim q_\theta(\omega)} \log \tfrac{q_\theta(\omega)}{p_\phi(\omega)}
    \,, $$
naturaly yields a coordinate-wise ascent algorithm:
* **(E)** step $\theta$ fixed $\phi$
* **(M)** step $\phi$ fixed $\theta$

<br>

### Diffuse prior $p(W_{ij}) \propto \mathop{const}$

Against a **diffuse prior** the KL-divergence is just
the **negative entropy** (up to a constant).

Since $q_\theta(W)$ is a diagonal multivariate normal
$$
    KL\bigl( q_\theta(W) \big\| p(W) \bigr)
        = \sum_{ij} KL\bigl( q_\theta(W_{ij}) \big\| p(W_{ij}) \bigr)
        = \mathop{const} - \sum_{ij} \mathbb{H}(q_\theta(W_{ij}))
    \,, $$

where $\mathbb{H}(q)$ is the [(differential) entropy ](https://en.wikipedia.org/wiki/Differential_entropy#Differential_entropies_for_various_distributions).

For a multivariate Gaussain $
    p(z) = \mathcal{N}_n\bigl(
        z\,\big\vert\, \mu, \Sigma
    \bigr)
$ we have
$$
    \mathbb{H}(p)
        = - \mathbb{E}_{z\sim p} \log p(z)
        = \tfrac12 \log \det \bigl(2 \pi e \Sigma \bigr)
%         = \tfrac12 \log \det \Sigma + \tfrac{n}2 \log 2 \pi e
    \,. $$

Hence the entropy for a univariate Gaussian is
$$
    \mathbb{H}(q_\theta(W_{ij}))
        = - \mathbb{E}_{W_{ij} \sim q_\theta(W_{ij})} \log q_\theta(W_{ij})
        = \tfrac12 \log \{2 \pi e \, \sigma^2\}
    \,. $$

In [None]:
def kl_div_diffuse(log_sigma2, weight=None):

    const = 0.5 * math.log(2 * math.pi * math.e) * log_sigma2.numel()
    entropy = const + 0.5 * torch.sum(log_sigma2)

    return - entropy

<br>

### Proper prior $p(W_{ij}) = \mathcal{N}(W_{ij} \mid 0, \nu^{-1})$

The KL-divergence between two multivariate Gaussians is given by

$$
    KL\bigl(
        \mathcal{N}_m(\mu_0, \Sigma_0)
        \big\| \mathcal{N}_m(\mu_1, \Sigma_1)
    \bigr)
        = \frac12 \Bigl\{
            \log \frac{\det \Sigma_1}{\det \Sigma_0}
            + \mathop{tr} \bigl( \Sigma_1^{-1} \Sigma_0 \bigr)
            + \bigl(\mu_0 - \mu_1 \bigr)^\top \Sigma_1^{-1} \bigl(\mu_0 - \mu_1 \bigr)
            - m
        \Bigr\}
    \,. $$

KL divergence from a standard normal distribution is
$$
    KL\bigl(
        \mathcal{N}_m(\mu_0, \Sigma_0)
        \big\| \mathcal{N}_m(0, \nu^{-1} I_m)
    \bigr)
        = \frac12 \Bigl\{
            \nu \, (\mathop{tr} \Sigma_0 + \mu_0^\top \mu_0)
            - m \log \nu
            - \log \det \Sigma_0
            - m
        \Bigr\}
    \,. $$
<!--
$$
    \frac12 \Bigl\{
        - \log \det \nu \Sigma_0
        + \nu \mathop{tr} \bigl( \Sigma_0 \bigr)
        + \nu \bigl(\mu_0 - \mu_1 \bigr)^\top \bigl(\mu_0 - \mu_1 \bigr)
        - m
    \Bigr\}
    \,. $$
-->

Therefore, if $q_\theta(W)$ is a diagonal multivariate normal then we get
(put $m=1$, $\Sigma_0 = \sigma^2_{ij}$, $\mu_1 = 0$):

$$
    KL\bigl( q_\theta(W) \big\| p(W) \bigr)
        = \sum_{ij} KL\bigl( q_\theta(W_{ij}) \big\| p(W_{ij}) \bigr)
        = \frac12 \sum_{ij} \bigl(
            \nu \sigma^2_{ij} + \nu \mu_{ij}^2
            - \log \sigma^2_{ij} - \log \nu - 1
        \bigr)
    \,. $$

In [None]:
def kl_div_proper(log_sigma2, weight, nu=1.0):

    const = 0.5 * (math.log(nu) + 1) * log_sigma2.numel()
    nu_term = 0.5 * nu * (torch.exp(log_sigma2) + weight * weight)
    kl_div = torch.sum(nu_term - log_sigma2) - const

    return kl_div

In [None]:
def kl_div_1811_00596(log_sigma2, weight):
    """Penalty from arxiv:1811.00596."""
    # get $- \log \alpha_{ij}$
    neg_log_alpha = 2 * torch.log(abs(weight) + 1e-12) - log_sigma2

    # `softplus` is $x \mapsto \log(1 + e^x)$
    kl_div_approx = torch.sum(0.5 * F.softplus(neg_log_alpha))

    return kl_div_approx

<br>

### Scale-free prior $p(W_{ij}) \propto \tfrac1{\lvert W_{ij} \rvert}$

This prior gives us the so called Variational Dropout
([Molchanov et al. 2017](https://arxiv.org/abs/1701.05369),
[Kingma et al. 2015](https://papers.nips.cc/paper/5666-variational-dropout-and-the-local-reparameterization-trick)).

We may observe the following, unless $\mu_{ij} = 0$:
for $\alpha_{ij} = \tfrac{\sigma^2_{ij}}{\mu_{ij}^2}$

$$
    \mathcal{N}(\mu_{ij}, \sigma^2_{ij})
    \overset{D}{\sim} \mu_{ij} \cdot \mathcal{N}(1, \alpha_{ij})
    \,. $$

Therefore our variational approximation $q_\theta(W)$ can be regarded as
the so called Gaussian Dropout (): $W_{ij}$ are subject to multiplicative
noise $W_{ij} = \mu_{ij} \cdot \varepsilon_{ij}$ for $\varepsilon_{ij}
\sim \mathcal{N}(1, \alpha_{ij})$.

Under this variational family:
$$
    KL\bigl( q_\theta(W_{ij}) \big\| p(W_{ij}) \bigr)
%         = \mathop{const} - \tfrac12 \log \sigma^2_{ij}
%         + \mathbb{E}_{W_{ij} \sim q_\theta(W_{ij})} \log \lvert W_{ij} \rvert
        = \mathop{const} - \tfrac12 \log \sigma^2_{ij}
        + \tfrac12 \log \mu_{ij}^2
        + \mathbb{E}_{\xi \sim \mathcal{N}(0, 1)}
            \log \lvert 1 + \tfrac{\sigma_{ij}}{\lvert \mu_{ij} \rvert} \xi \rvert
    \,. $$

Unfortunately there is no closed-from expression for this divergence, and thus
we have to resort to its approximation in ICML'17 paper [Molchanov et al. 2017](https://arxiv.org/abs/1701.05369):

$$
    KL\bigl( q_\theta(W_{ij}) \big\| p(W_{ij}) \bigr)
%         = \mathop{const}
%         + \mathbb{E}_{\xi \sim \mathcal{N}(1, \alpha_{ij})}
%                         \log{\lvert \xi \rvert}
%                     - \tfrac12 \log \alpha_{ij}
        \approx
            \tfrac12 \log (1 + e^{-\log \alpha}) 
            + k_1(1 - \sigma(k_2 + k_3 \log \alpha))
    \,, $$

where $k_1, k_2, k_3$ and $C$ are approximated by $0.63576$, $1.87320$ and $1.48695$, respectively.

In [None]:
def kl_div_scale_free(log_sigma2, weight):

    # get $- \log \alpha_{ij}$
    neg_log_alpha = 2 * torch.log(abs(weight) + 1e-12) - log_sigma2
    
    # Use the identity 1 - \sigma(z) = \sigma(- z)
    sigmoid = torch.sigmoid(1.48695 * neg_log_alpha - 1.87320)

    # `softplus` is $x \mapsto \log(1 + e^x)$
    kl_div_approx_ij = 0.5 * F.softplus(neg_log_alpha) + 0.63576 * sigmoid
    kl_div_approx = torch.sum(kl_div_approx_ij)

    return kl_div_approx

<br>

### The local reparameterization tirck

Proposed in
[Kingma et al. 2015](https://papers.nips.cc/paper/5666-variational-dropout-and-the-local-reparameterization-trick)
this tirck
* allows to sample parameters implicitly
* reduces variance of the stochastic gradient, generally leading to much faster convergence

The effect of a linear layer on its input is given by the equation
$$
    y = x W + b
    \,, \tag{obvious} $$
with $W \in \mathbb{R}^{m\times n}$, $y\in \mathbb{R}^n$,
$x\in \mathbb{R}^m$ and $b \in \mathbb{R}^n$.

Key observation:
> any non-trivial linear combination of Gaussian
random variables is a Gaussian random variable

Since the layer's $W$ are jointly Gaussian, the means that the distribution
of $y$ is also Gaussian. We can see this if we look closely at the output: the bias term is effectively
fixed to $\beta$ and the weights are

$$
    W_{ij} \sim
        \mathcal{N}(M_{ij}, \sigma^2_{ij})
    \,,\, W_{ij} \bot W_{kl}
    \,,\, ij\neq kl
    \,,
$$

which means that
$$
    y_j = \beta_j + \sum_i x_i W_{ij}
    \Rightarrow
    y_j \sim \mathcal{N}\bigl(
        \beta_j + \sum_i x_i M_{ij}, \sum_i \sigma^2_{ij} x_i^2
    \bigr)
    \tag{local-reparametrization}
    \,. $$

Note that independence of $W_{ij}$ implies zero correlation between them, which
means that distinct $y_j$ and $y_k$ are uncorrelated.

Collecting into a single multivariate Gaussian random vector we get the 
following stochastic forward pass:

$$
    y \sim \mathcal{N}_n\bigl(
        x M + \beta\,,\, \mathrm{diag}(\nu^2)
    \bigr)
    \,, $$

with $\nu^2_j = \sum_{i=1}^m \sigma^2_{ij} x_i^2$.

In [None]:
def stochastic_linear_lrp(layer, input):
    """Forward pass for the linear layer with the local reparemetrization trick."""

    ## Exercise
    # Get the mean
    mu = F.linear(input, layer.weight, layer.bias)
#     if not layer.training:
#         # not deterministic pass even on `eval`
#         return mu

    # Add the resulting effect of weight randomness
    s2 = F.linear(input * input, torch.exp(layer.log_sigma2), None)
    return mu + torch.randn_like(s2) * torch.sqrt(s2 + 1e-20)

<br>

### Linear layer with Gaussian dropout and the trick 

Now let's combine the trick for the Gaussian approximation $q_\theta(\omega)$
and one of the divergences, given above, into a stochastic linear layer:

* we need write the forward pass and a parameter sampler

In [None]:
class LinearGaussian(torch.nn.Linear, VariationalModule):
    """Linear layer with Gaussian Mean Field weight distribution."""

    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias=bias)

        self.log_sigma2 = torch.nn.Parameter(
            torch.Tensor(*self.weight.shape))

        self.reset_variational_parameters()

    def reset_variational_parameters(self):
        """Initialize the log-variance."""
        self.log_sigma2.data.normal_(-8, 0.01)

    def forward(self, input):
        """Forward pass with the local reparemetrization trick."""

        return stochastic_linear_lrp(self, input)

    def sample(self):
        """Return a sample from $q_{\theta_m}(\omega_m)$."""
        
        ## Exercise
        stdev = torch.exp(0.5 * self.log_sigma2)
        weight = torch.normal(self.weight, std=stdev)

        return {"weight": weight, "bias": self.bias}

    @property
    def penalty(self):
        """KL divergence between $q_{\theta_m}(\omega_m)$ an a prior on $\omega_m$."""

        ## Exercise
        # return kl_div_diffuse(self.log_sigma2, self.weight)
        # return kl_div_scale_free(self.log_sigma2, self.weight)
        # return kl_div_proper(self.log_sigma2, self.weight, nu=1e0)

        return kl_div_1811_00596(self.log_sigma2, self.weight)

<br>

### Linear layer with Bernoulli Dropout and Gaussian approximation

We can also fuse the classical Bernoulli dropout, [Hinton et al. 2012](https://arxiv.org/abs/1207.0580)
with Gaussian variational approximation [Gal, Y. 2016](http://www.cs.ox.ac.uk/people/yarin.gal/website/thesis/thesis.pdf).

The forward pass of the fused model is simply a composition of dropout
on inputs and the local reparameterization trick.

In [None]:
def forward_bernoulli_lrp(layer, input, p=0.2):
    """Forward pass with Bernoulli dropout and the local reparemetrization trick."""
    input = F.dropout(input, p, layer.training)

    return linear_with_lrp(layer, input)

<br>

However, for the Kullback-Leibler divergence we need to identify how
$W$ are effectively distributed. Indeed, the variational approximation
$q_{\theta}(W)$ is essentially a spike-and-slab mixture: each row $W_i$
of $W\in \mathbb{R}^{m \times n}$ is either $\mathbf{0}$ with probability
$p$ or a Gaussian vector in $\mathbb{R}^n$:

$$
    W_i \sim
\begin{cases}
    \mathbf{0}
        & \text{ w. prob } p \\
    M_i
        & \text{ otherwise } \\
\end{cases}
    \,, \text{indep.}
    \,, i=1, \ldots, m
    \,, M_{ij} \sim \mathcal{N}_n\bigl(
        M_{ij} \big\vert \mu_i, \mathop{diag} \sigma^2_i
    \bigr)
    \,. $$

Under some assumptions and benign relaxations
[Gal, Y. 2016 (eq. (6.3) p.109, Prop. 4 p.149)](http://www.cs.ox.ac.uk/people/yarin.gal/website/thesis/thesis.pdf)
the divergence can approximated by

$$
    KL\bigl( q_\theta(W) \big\| p(W) \bigr)
        = \sum_i KL\bigl( q_\theta(W_i) \big\| p(W_i) \bigr)
        \approx \mathop{const}
        % + \frac{mn} 2 p \{\tau^{-1} + \log \tau\}
        % + m (p \log p + (1-p) \log (1-p))
        + \frac{1-p}2 \sum_{ij}
            \sigma^2_{ij} + \mu_{ij}^2 - \log \sigma^2_{ij}
    \,. $$

<!--
Proposition 4 p.149 is about the approximation of the divergence of
a mixture from the standard Gaussian:

$$
    KL\bigl(
        \sum_k \pi_k \mathcal{N}_n(\mu_k, \Sigma_k)
        \big\| \mathcal{N}_n(0, I_n)
    \bigr)
        \approx \mathop{const}
        + \frac12 \sum_k \pi_k \{
            \mathop{tr}\Sigma_k
            + \mu_k^\top \mu_k
            - \log\det\Sigma_k
        \}
        - \mathbb{H}(\pi)
    \,. $$
-->

We are ready to implement this layer.

In [None]:
class LinearFusedGaussianBernoulli(LinearGaussian):
    """Linear layer with spike-slab (Gaussian) weight distribution."""
    def __init__(self, in_features, out_features, bias=True, p=0.5):
        super().__init__(in_features, out_features, bias=bias)

        self.p = p

    def forward(self, input):
        input = F.dropout(input, self.p, self.training)

        return super().forward(input)

    def sample(self):
        p = torch.full_like(self.weight[:1], 1 - self.p)
        mask = torch.bernoulli(p) / p

        stdev = torch.exp(0.5 * self.log_sigma2)
        weight = torch.normal(self.weight, std=stdev)
        return {
            "weight": weight * mask,
            "bias": self.bias
        }

    @property
    def penalty(self):
        """Approximate KL divergence."""
        return (1 - self.p) * kl_div_proper(self.log_sigma2, self.weight)

<br>

https://github.com/keras-team/keras/blob/c10d24959b0ad615a21e671b180a1b2466d77a2b/keras/initializers.py#L341

Initializing a matrix $W \in\mathbb{R}^{
[\mathrm{out}]
\times [\mathrm{in}]
}$. Out and in are known as `fan-in` and `fan-out` respectively.

Kaiming He:
* (normal) $w_{ij} \sim \mathcal{N}(0, \sigma^2)$ with $
    \sigma = \tfrac{\sqrt2}{\sqrt{
        [\mathrm{in}] (1+a^2)
    }}
$ ($a$ -- negative slope of the ReLU)
* (uniform) $w_{ij} \sim \mathcal{U}[-\sigma, +\sigma]$ with $
\sigma = \tfrac{\sqrt6}{\sqrt{
        [\mathrm{in}] (1+a^2)
}}$

Glorot Xavier:
* (uniform) $w_{ij} \sim \mathcal{U}[-b, +b]$ with $
b = \mathrm{gain} \cdot \sqrt{\tfrac6{
        [\mathrm{out}] + [\mathrm{in}]
    }}
$
* (normal) $w_{ij} \sim \mathcal{N}(0, \sigma^2)$ with $
\sigma = \mathrm{gain} \cdot \sqrt{\tfrac2{
        [\mathrm{out}] + [\mathrm{in}]
    }}
$


Based on this piece on [weight initialization](https://towardsdatascience.com/weight-initialization-in-neural-networks-a-journey-from-the-basics-to-kaiming-954fb9b47c79)

<hr>

In [None]:
state_dict = {
    k: v for k, v in zip(
        ['body.0.weight', 'body.0.bias',
         'body.2.weight', 'body.2.bias',
         'body.4.weight', 'body.4.bias'],
        model.state_dict().values()
    )
}

In [None]:
model = SimpleModel()

In [None]:
model.load_state_dict(state_dict)

In [None]:
dataset = TensorDataset(*train)

feed = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

<hr>

In [None]:
model.train()

y_pred = torch.stack([
    apply(model, test_feed) for _ in range(501)
], dim=0)

In [None]:
y_pred, y_std = y_pred.mean(dim=0), y_pred.std(dim=0)

In [None]:
fig, ax = canvas1d()

ax.plot(X_test, y_test, lw=2, color="k", alpha=0.25, label="test")
ax.plot(X_test, y_pred.numpy(), c="C0", lw=2, label="predict")

ax.scatter(X_train, y_train, c="k", s=20, label="train")

# confidence bands
for m in [0.25, 0.50, 0.75, 1.00]:
    y_hi = (y_pred + 2 * y_std * m).numpy()
    y_lo = (y_pred - 2 * y_std * m).numpy()
    plt.fill_between(X_test[:, 0], y_lo[:, 0], y_hi[:, 0],
                     color="C3", alpha=0.08)

plt.legend(ncol=2)
plt.show()

<hr>

In [None]:
def draw_bands(X, mean, std, ax=None, bands=(0.50, 1.00, 1.50, 2.00), **kwargs):
    ax = plt.gca() if ax is None else ax

    for band in sorted(bands):
        y_hi = (mean + std * band).numpy()
        y_lo = (mean - std * band).numpy()
        ax.fill_between(X[:, 0], y_lo[:, 0], y_hi[:, 0], **kwargs)

    return ax

In [None]:
fig, ax = canvas1d()

ax.plot(X_test, y_test, lw=2, color="k", alpha=0.25, label="test")

ax.plot(X_test, mean_sf.numpy(), c="C0", lw=2, label="mean-sf")

ax.scatter(X_train, y_train, c="k", s=20, label="train")

draw_bands(X_test, mean_sf, std_sf, ax=ax, color="C3", alpha=0.08, zorder=-5)

plt.legend(ncol=2)
plt.show()

<hr>

The key issue with point-estimates is that values at each $x$ jointly do not
coorespond to a function in the parameteric family modelled by the network:
there may be no $\omega \in \mathop{supp}q_\theta(\omega)$ such that
$f_\omega(x_i) = y_i$.