# `torch_blue` basic tutorial

This tutorial will run you through the basic process of implementing a Bayesian Neural
Network (BNN). It assumes that you have basic familiarity with implementing neural
networks in PyTorch and mostly highlight the (typically small) differences in usage.
Specifically, this tutorial demonstrates how to implement the same model showcased in the
[PyTorch Quickstart tutorial](https://docs.pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html)
as a BNN, while pointing to more advanced features to read up on in the documentation.

This tutorial also provides several block with optional, more detailed information, like
the one below. They aim to provide you with a basic understanding of why certain things
are required or how they work. If you just wish to get a trainable model feel free to
skip them and maybe come back later with more time.

> **Variational Inference**
>
> `torch_blue.vi` implements BNNs via a variational inference method often know as
> Bayes-by-backprop. Fundamentally, it describes the weights as distributions by
> representing the distribution of each weight with the parameters of an assumed
> distribution type. More simply put and for the most common usage, each weight is
> described by the mean and standard deviation (std) of a Gaussian distribution. When
> making a forward pass a sample from each weight distribution, which are used as
> weights for that single forward pass. To obtain an output distribution this forward
> pass is run several times with different samples. Training the model means training
> all these means and stds.

This tutorial uses type hinting, so a couple extra imports are needed. Finally, we also
import `torch_blue.vi`, which takes a very similar role to `torch.nn` but for BNN layers.

In [None]:
from typing import Callable

import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

from torch_blue import vi

Setting up datasets is completely unchanged from PyTorch.

In [None]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

Creating a dataloader also remains the same.

In [None]:
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for x, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {x.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break


Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


## Creating Models

To set up a BNN you will have to replace non-Bayesian PyTorch layers with Bayesian
`torch_blue` layers. This affects only layers that have weights. Currently, `torch_blue`
supports linear layers (`VILinear`), convolution layers (`VIConv1d`, `VIConv2d`,
`VIConv3d`) and Transformer layers (`VITransformer`, `VIMultiheadAttention`). These can
all be imported from `torch_blue.vi`. As you might have noticed it is as simple as
importing from `torch_blue.vi` instead of `torch.nn` and prepending `VI` to the layer
name.

Note that while there a many more layer types, most of them do not need to be made
Bayesian, either because they do not have weights (like activation functions and
pooling layers) or because it does not necessarily make sense to make the Bayesian due
to their purpose (mostly normalization layers like layer norms). `torch_blue` layers
seamlessly combine with PyTorch layers. However, there is one important thing to
consider:

<div class="alert alert-block alert-warning">
<b>Important:</b> When implementing a Bayesian model with torch_blue the outermost
layer must always inherit from vi.VIModule instead of nn.Module. This may only be a
wrapper fowards the inputs and outputs, but it is essential for functionality.
</div>

In [None]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

# Define model
class NeuralNetwork(vi.VIModule):
    def __init__(self) -> None:
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            vi.VILinear(28 * 28, 512),
            nn.ReLU(),
            vi.VILinear(512, 512),
            nn.ReLU(),
            vi.VILinear(512, 10),
        )

    def forward(self, x_: Tensor) -> Tensor:
        x_ = self.flatten(x_)
        logits = self.linear_relu_stack(x_)
        return logits

model = NeuralNetwork().to(device)
print(model)

Using cpu device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): VILinear()
    (1): ReLU()
    (2): VILinear()
    (3): ReLU()
    (4): VILinear()
  )
)


**TLDR:** Modules should be subclasses of `vi.VIModule` instead of `nn.Module` and
layers with weights need to be replaced with Bayesian layers, which can be found in
`torch_blue.vi` and simply add `VI` in front of their usual name.

> **Why do I need to inherit from `vi.VIModule`?**
>
> As mentioned in the explanation on variational inference several forward passes need
> to be run to obtain an output distribution. `torch_blue` automates this process by
> duplicating the inputs in an additional sample dimension and vectorizing the forward
> pass over this dimension making the process relatively efficient especially when used
> with GPUs. However, this process needs to be preformed only once and not for each
> layer. Therefore, `torch_blue` automatically detects the outermost `VIModule` and
> makes it perform this task. Additionally, loss calculation requires knowledge on the
> probability of the actual weight samples. The outermost module also gathers and
> integrates this information into the output.

## Optimizing the Model Parameters

A Bayesian model needs to be trained on a Bayesian loss that also optimizes for the
correct uncertainties. In `torch_blue` this is the `KullbackLeiblerLoss`, which is
minimal if the weight distributions are identical the ideal weight distribution implied
by the data.

It works for any task, but requires a task specific predictive distribution. This
distribution specifies your expectation on the distribution of your outputs. For
regression task this will typically be `MeanFieldNormal` (from
`torch_blue.vi.distributions`), which makes the loss behave similar to `MSELoss` (from
`torch.nn`). Since the task here is classification we will use a `Categorical`
distribution (again from `torch_blue.vi.distributions`), which will make the loss
behave similar to `CrossEntropyLoss` (from `torch.nn`). Finally, the loss requires the
total number of data samples (**not** batches!).

The optimizer can be set as usual.

> **Advanced note:** If you are optimizing for computational speed you might want to
> have a look at `vi.AnalyticalKullbackLeiblerLoss`, which cannot be used in all cases,
> but may reduce computational cost.

In [None]:
from torch_blue.vi.distributions import Categorical

predictive_distribution = Categorical()
loss_fn = vi.KullbackLeiblerLoss(
    predictive_distribution, dataset_size=len(training_data)
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


> **Why Adam and not SGD?**
>
> Variational inference can be somewhat prone to exploding gradient effects specifically
> in early epochs. This can be fixed by carefully adjusting hyperparameters or
> introducing learning rate warm-up. Adam also avoids this issue likely due to
> gradient clipping and is the simplest approach for the purposes of this tutorial.

Besides the changed model and loss, the training loop is unchanged.

In [None]:
def train(
    dataloader: DataLoader,
    model: vi.VIModule,
    loss_fn: Callable,
    optimizer: torch.optim.Optimizer,
) -> None:
    size = len(dataloader.dataset)
    model.train()
    for batch, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)

        # Compute prediction error
        pred = model(x)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(x)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


When checking the shape of your output, you will notice an additional dimension. This is
a sample dimension that `torch_blue` generates automatically. To properly represent the
output distribution a BNN needs to be run several times on the same sample and the
results are aggregated. `torch_blue` automates this process in the predictive
distribution. Inside your model you can ignore this dimension due to vectorization.

<div class="alert alert-block alert-warning">
<b>Important:</b> The output of torch_blue models is not a normal Tensor, but a VIReturn
object. This behaves the same as a Tensor, but performing any operations on it before
passing it to the loss may break loss computation. Therefore, try to integrate all
operation into your model.
</div>

> **A Note on samples**
>
> You may want to change the number of samples per forward pass for better sampling or
> lower computational load. This can be done by passing the keyword argument `samples`
> (default=10) to the forward call of the model, which specifies the number of samples
> to use. It is automatically caught by `torch_blue` in the vectorization process.
> We have found that 10 tends to strike a good balance between accuracy and compute.


> **What is `VIReturn`?**
>
> `VIReturn` is a thin wrapper around `torch.Tensor`. All it does is add the attribute
> `log_probs`, which is used to store the weight probability information mentioned in
> the block on `VIModule` and pass it to the loss. Since PyTorch is not aware of this
> any operation that creates a new Tensor works, but will lose the `log_probs`
> attribute, which is required for loss computation.

The only change in the test loop is, that the predictive distribution is used to
aggregate the sample dimension.

In [None]:
def test(dataloader: DataLoader, model: vi.VIModule, loss_fn: Callable) -> None:
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0.0, 0.0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            samples = model(x)
            test_loss += loss_fn(samples, y).item()

            pred = predictive_distribution.predictive_parameters_from_samples(
                samples
            )
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )

Time to bring it home and train the model.

In [None]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 1979140.625000  [   64/60000]
loss: 1849943.125000  [ 6464/60000]
loss: 1786325.250000  [12864/60000]
loss: 1750446.375000  [19264/60000]
loss: 1701688.750000  [25664/60000]
loss: 1652999.125000  [32064/60000]
loss: 1606158.250000  [38464/60000]
loss: 1572845.000000  [44864/60000]
loss: 1523578.000000  [51264/60000]
loss: 1481647.375000  [57664/60000]
Test Error: 
 Accuracy: 84.7%, Avg loss: 1457582.563694 

Epoch 2
-------------------------------
loss: 1448494.500000  [   64/60000]
loss: 1409645.875000  [ 6464/60000]
loss: 1359520.500000  [12864/60000]
loss: 1324419.750000  [19264/60000]
loss: 1279386.500000  [25664/60000]
loss: 1240244.875000  [32064/60000]
loss: 1194377.250000  [38464/60000]
loss: 1163121.750000  [44864/60000]
loss: 1127201.250000  [51264/60000]
loss: 1083609.875000  [57664/60000]
Test Error: 
 Accuracy: 84.6%, Avg loss: 1068581.152070 

Epoch 3
-------------------------------
loss: 1059136.000000  [   64/60000]
loss: 10