## PyTorch implemtation of the CBO

This notebook gives a brief introduction to the consensus-based optimization for the `PyTorch` framework. It covers the typical `torch` training loop and how to integrate it with 'Weights and biases' (`wandb`). Training is performed for the canonical MNIST dataset and a shallow network.

Imports:

In [14]:
import os
import sys

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import wandb

sys.path.extend([os.pardir,
                 os.path.join(os.pardir, os.pardir)])

In [2]:
from src.datasets import load_mnist_dataloaders

In [3]:
from torchmetrics import Accuracy

Define a model to train:

In [16]:
model = nn.Sequential(
    nn.Flatten(1, 3),
    nn.Linear(28 ** 2, 10),
    nn.BatchNorm1d(10, affine=False),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.BatchNorm1d(10, affine=False),
    nn.Linear(10, 10),
    nn.LogSoftmax(),
)

Initialize the `wandb`. You will need to set the value of `entity` to your `wandb` login (you need to create the account firts).

In [17]:
wandb.init(project='CBO', entity='itukh')

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁
train_acc,▁▄▇▃█
train_loss,▁█▄▇▃
train_shift_norm,▁▄█
val_acc,▁█▄
val_loss,▅█▁

0,1
epoch,0.0
train_acc,0.16667
train_loss,3.36326
train_shift_norm,7.63122
val_acc,0.12231
val_loss,3.57649


Charts of your training above will be updated during the training process in real time. You can view them by the link specified in the previous cell output.

Specify values of training and CBO hyperparameters:

In [18]:
# Training params
epochs = 50
batch_size = 60
# CBO params
n_particles = 100
alpha = 50
l = 1  # lambda
sigma = 0.4 ** 0.5
dt = 0.1
anisotropic = True
eps = 1e-2

Create a consensus-based optimizer:

In [19]:
from src.torch import Optimizer

optimizer = Optimizer(model, n_particles=n_particles, alpha=alpha, sigma=sigma,
                      l=l, dt=dt, anisotropic=anisotropic, eps=eps)

Create a wrapper for the loss function:

In [20]:
from src.torch import Loss

loss_fn = Loss(F.nll_loss, optimizer)

Above `F.nll_loss` is the standard `torch` implemtation of negative log likelihood.

Prepare the MNIST dataloaders:

In [21]:
train_dataloader, test_dataloader = load_mnist_dataloaders(train_batch_size=batch_size,
                                                           test_batch_size=batch_size)

Update the `wandb` config to save the hyperparameter values for the current run. In principle, it is optional.

In [22]:
wandb.config = {
  'epochs': epochs,
  'batch_size': batch_size,
    
  'n_particles': n_particles,
  'alpha': alpha,
  'lambda': l,
  'sigma': sigma,
  'dt': dt,
  'eps': eps,
}

Write helper functions to evalueate your model and log the results into `wandb`:

In [23]:
accuracy = Accuracy()

def evaluate(model, X_, y_):
    with torch.no_grad():
        outputs = model(X_)
        y_pred = torch.argmax(outputs, dim=1)
        loss = loss_fn(outputs, y_)
        acc = accuracy(y_pred, y_)
    return loss, acc

def log(loss, acc, epoch, stage='train', shift_norm=None):
    wandb.log({
        f'{stage}_loss': loss,
        f'{stage}_acc': acc,
        'epoch': epoch,
        f'{stage}_shift_norm': shift_norm,
    })

The main training looop:

In [24]:
n_batches = len(train_dataloader)

In [26]:
for epoch in range(epochs):
    for batch, (X, y) in enumerate(train_dataloader):
        train_loss, train_acc = evaluate(model, X, y)
        log(train_loss, train_acc, epoch, shift_norm=optimizer.shift_norm)
        optimizer.zero_grad()  # optional
        loss_fn.backward(X, y)
        optimizer.step()

        with torch.no_grad():
            losses = []
            accuracies = []
            for X_test, y_test in test_dataloader:
                loss, acc = evaluate(model, X_test, y_test)
                losses.append(loss)
                accuracies.append(acc)
            val_loss, val_acc = np.mean(losses), np.mean(accuracies)
            log(val_loss, val_acc, epoch, 'val')
        
        print(f'Epoch: {epoch + 1:2}/{epochs}, batch: {batch + 1:4}/{n_batches}, train loss: {train_loss:8.3f}, train acc: {train_acc:8.3f}, val loss: {val_loss:8.3f}, val acc: {val_acc:8.3f}',
              end='\r')

  input = module(input)


Epoch:  1/50, batch:   41/1000, train loss:    2.238, train acc:    0.183, val loss:    2.266, val acc:    0.180

KeyboardInterrupt: 

In [56]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇██
train_acc,▁▃▅▆▃▄▄▅▄▆▆▆▆▅▇▅▇▆▇▆▅▇▆▇▇▆▅▇█▇▇▇▇▇▇▇▇█▇█
train_loss,█▇▆▄▆▅▅▄▄▂▄▃▃▄▂▃▂▃▂▃▄▂▃▂▂▃▄▁▁▃▂▂▂▂▁▂▁▁▂▁
train_shift_norm,█▁▄▁▇▄▁▂▂▁▁▂▁▁▁▅▇▁▁▂▅▁▁▂▂▅▁▁▂▂▂▁▁▂▂█▄▁▁▁
val_acc,▁▃▃▅▄▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇█████████████████
val_loss,█▇▆▅▅▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,7.0
train_acc,0.76667
train_loss,0.72038
train_shift_norm,0.0261
val_acc,0.82515
val_loss,0.55655
