## PyTorch implemtation of the CBO

This notebook gives a brief introduction to the consensus-based optimization for training PyTorch neural networks. It covers the typical `torch` training loop and how to integrate it with 'Weights and biases' (`wandb`). Training is performed for the canonical MNIST dataset and a shallow network.

First, we impored the required libraries:

In [2]:
import os
import sys
import time
import pickle

import matplotlib.pyplot as plt

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import wandb

In [3]:
from torchmetrics import Accuracy

To import the library modules, we neeed to add the path to the root folder to `sys.path`:

In [4]:
root_path = os.path.join(os.getcwd().split('cbo-in-python')[0], 'cbo-in-python')

sys.path.append(root_path)

First, we import a function for loading the train and test MNIST dataloaders:

In [5]:
from src.datasets import load_mnist_dataloaders

For convenience, in `src.torch.models` we provide a few model architectures for user experiments. Surely, one may also implement a neural network from scratch (using the `torch.nn`).

In [6]:
from src.torch.models import SmallMLP

Now, we import two remaining classes for performing the CBO optimization.

In [7]:
from src.torch import Optimizer, Loss

One may use cuda for accelerated computations. The command below will determine the computational device (cuda or CPU) based on the availability of the cuda. It is a standard PyTorch way of doing it.

In [8]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('gpu')

Now, we can load train and test dataloaders:

In [11]:
batch_size = 60  # samples-level batching

train_dataloader, test_dataloader = load_mnist_dataloaders(train_batch_size=batch_size,
                                                           test_batch_size=batch_size)

n_batches = len(train_dataloader)

Now, we define the remaining optimization params:

In [16]:
epochs = 25
particles_batch_size = 10

n_particles = 100
alpha = 50
sigma = 0.4 ** 0.5
l = 1
dt = 0.1
anisotropic = True
eps = 1e-2

partial_update = False

use_multiprocessing=False

eval_freq = 100  # how often to evaluate the validation (test) accuracy

Now, we define two helper functions to perform the training:

Function `evaluate` is used to calculate the model accuracy on the current batch.

In [18]:
accuracy = Accuracy()

In [19]:
def evaluate(model, X_, y_):
    with torch.no_grad():
        outputs = model(X_)
        y_pred = torch.argmax(outputs, dim=1)
        loss = loss_fn(outputs, y_)
        acc = accuracy(y_pred.cpu(), y_.cpu())
    return loss, acc

Function `log` is used to log the metrics to `wandb`:

In [21]:
def log(loss, acc, epoch, stage='train', shift_norm=None):
    wandb.log({
        f'{stage}_loss': loss,
        f'{stage}_acc': acc,
        'epoch': epoch,
        f'{stage}_shift_norm': shift_norm,
    })

'Weights and biases' (`wandb`) is an experiment tracking tool for machine learning. Please refer to the [official website](https://wandb.ai/site) for more details. The command bellow will initialize the current experiment:

In [None]:
wandb.init(project='CBO', entity='itukh')

We will use the provided shallow `SmallMLP` model.

In [14]:
model = SmallMLP().to(device)
model

SmallMLP(
  (model): Sequential(
    (0): Flatten(start_dim=1, end_dim=3)
    (1): Linear(in_features=784, out_features=10, bias=True)
    (2): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (3): ReLU()
    (4): Linear(in_features=10, out_features=10, bias=True)
    (5): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (6): Linear(in_features=10, out_features=10, bias=True)
    (7): LogSoftmax(dim=None)
  )
)

In order to perform the optimization, we need to define:
* `optimizer` (`src.torch.Optimizer`);
* `loss_fn` (`src.torch.Loss`).

In [17]:
optimizer = Optimizer(model, n_particles=n_particles, alpha=alpha, sigma=sigma,
                      l=l, dt=dt, anisotropic=anisotropic, eps=eps, partial_update=partial_update,
                      use_multiprocessing=use_multiprocessing,
                      particles_batch_size=particles_batch_size, device=device)
loss_fn = Loss(F.nll_loss, optimizer)

Now, let's proceed with the final training loop:

In [None]:
for epoch in range(epochs):  # main loop over the training epochs
    for batch, (X, y) in enumerate(train_dataloader):  # nested loop over the batches of training samples
        X, y = X.to(device), y.to(device)  # we convert data samples to the device   
        train_loss, train_acc = evaluate(model, X, y)
        log(train_loss, train_acc, epoch)  # log the metrics to wandb
        loss_fn.backward(X, y)  # use the current training data batch
        optimizer.step()  # optimization step
        
        if batch % eval_freq == 0 or batch == n_batches - 1:  # evaluate the test accuracy
            with torch.no_grad():
                losses = []
                accuracies = []
                for X_test, y_test in test_dataloader:
                    X_test, y_test = X_test.to(device), y_test.to(device)
                    loss, acc = evaluate(model, X_test, y_test)
                    losses.append(loss.cpu())
                    accuracies.append(acc.cpu())
                val_loss, val_acc = np.mean(losses), np.mean(accuracies)

            print(f'Epoch: {epoch + 1:2}/{epochs}, batch: {batch + 1:4}/{n_batches}, train loss: {train_loss:8.3f}, train acc: {train_acc:8.3f}, val loss: {val_loss:8.3f}, val acc: {val_acc:8.3f}',
                  end='\r')

Please refer to [this folder](https://github.com/Igor-Tukh/cbo-in-python/tree/master/notebooks/experiments) for more advanced usage examples (for instance, for using the gamma term).