## PyTorch implemtation of the CBO

This notebook gives a brief introduction to the consensus-based optimization for the `PyTorch` framework. It covers the typical `torch` training loop and how to integrate it with 'Weights and biases' (`wandb`). Training is performed for the canonical MNIST dataset and a shallow network.

Imports:

In [1]:
import os
import sys
import time

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import wandb

sys.path.extend([os.pardir,
                 os.path.join(os.pardir, os.pardir)])

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from src.datasets import load_mnist_dataloaders

In [3]:
from torchmetrics import Accuracy

In [4]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('gpu')
# device = torch.device('cpu')

Models:

In [5]:
def build_small_network():
    return nn.Sequential(
        nn.Flatten(1, 3),
        nn.Linear(28 ** 2, 10),
        nn.BatchNorm1d(10, affine=False),
        nn.ReLU(),
        nn.Linear(10, 10),
        nn.BatchNorm1d(10, affine=False),
        nn.Linear(10, 10),
        nn.LogSoftmax(),
    )

In [6]:
class LeNet1(nn.Module):
    def __init__(self):
        super(LeNet1, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 6, 3, stride=1, padding=1),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5, stride=1, padding=0),
            nn.MaxPool2d(2, 2),
            nn.Flatten(1, 3),
            nn.Linear(400, 120),
            nn.Linear(120, 84),
            nn.Linear(84, 10),
            nn.LogSoftmax(),
        )
    
    def forward(self, X):
        return self.model(X)

Define a model to train:

In [7]:
model = LeNet1().to(device)

Initialize the `wandb`. You will need to set the value of `entity` to your `wandb` login (you need to create the account firts).

In [8]:
wandb.init(project='CBO', entity='itukh')

[34m[1mwandb[0m: Currently logged in as: [33mitukh[0m. Use [1m`wandb login --relogin`[0m to force relogin


Charts of your training above will be updated during the training process in real time. You can view them by the link specified in the previous cell output.

Specify values of training and CBO hyperparameters:

In [9]:
# Training params
epochs = 50
batch_size = 60
# CBO params
n_particles = 100
particles_batch_size = 10
alpha = 50
l = 1  # lambda
sigma = 0.4 ** 0.5
dt = 0.1
anisotropic = True
eps = 1e-5
gamma=1e-5
# Adtional params
use_multiprocessing=False
n_processes=6

Create a consensus-based optimizer:

In [10]:
from src.torch import Optimizer

optimizer = Optimizer(model, n_particles=n_particles, alpha=alpha, sigma=sigma,
                      l=l, dt=dt, anisotropic=anisotropic, eps=eps, gamma=gamma, 
                      use_multiprocessing=use_multiprocessing, n_processes=n_processes,
                      particles_batch_size=particles_batch_size, device=device)

Create a wrapper for the loss function:

In [11]:
from src.torch import Loss

loss_fn = Loss(F.nll_loss, optimizer)

Above `F.nll_loss` is the standard `torch` implemtation of negative log likelihood.

Prepare the MNIST dataloaders:

In [12]:
train_dataloader, test_dataloader = load_mnist_dataloaders(train_batch_size=batch_size,
                                                           test_batch_size=batch_size)

Update the `wandb` config to save the hyperparameter values for the current run. In principle, it is optional.

In [13]:
wandb.config = {
  'epochs': epochs,
  'batch_size': batch_size,
    
  'n_particles': n_particles,
  'alpha': alpha,
  'lambda': l,
  'sigma': sigma,
  'dt': dt,
  'eps': eps,
}

Write helper functions to evalueate your model and log the results into `wandb`:

In [14]:
accuracy = Accuracy()

def evaluate(model, X_, y_):
    with torch.no_grad():
        outputs = model(X_)
        y_pred = torch.argmax(outputs, dim=1)
        loss = loss_fn(outputs, y_)
        acc = accuracy(y_pred.cpu(), y_.cpu())
    return loss, acc

def log(loss, acc, epoch, stage='train', shift_norm=None):
    wandb.log({
        f'{stage}_loss': loss,
        f'{stage}_acc': acc,
        'epoch': epoch,
        f'{stage}_shift_norm': shift_norm,
    })

Profiling debug

In [15]:
# import cProfile, pstats, io
# from pstats import SortKey
# pr = cProfile.Profile()

The main training looop:

In [16]:
# from viztracer import VizTracer
# tracer = VizTracer()

In [17]:
n_batches = len(train_dataloader)

In [18]:
for epoch in range(epochs):
    for batch, (X, y) in enumerate(train_dataloader):
        X, y = X.to(device), y.to(device)
        train_loss, train_acc = evaluate(model, X, y)
        log(train_loss, train_acc, epoch, shift_norm=optimizer.shift_norm)
        optimizer.zero_grad()
        loss_fn.backward(X, y, backward_gradients=True)
        # tracer.start()
        optimizer.step()
        # tracer.stop()
        # tracer.save('profile.json')
        # break

        with torch.no_grad():
            losses = []
            accuracies = []
            for X_test, y_test in test_dataloader:
                X_test, y_test = X_test.to(device), y_test.to(device)
                loss, acc = evaluate(model, X_test, y_test)
                losses.append(loss.cpu())
                accuracies.append(acc.cpu())
            val_loss, val_acc = np.mean(losses), np.mean(accuracies)
            log(val_loss, val_acc, epoch, 'val')
        
        print(f'Epoch: {epoch + 1:2}/{epochs}, batch: {batch + 1:4}/{n_batches}, train loss: {train_loss:8.3f}, train acc: {train_acc:8.3f}, val loss: {val_loss:8.3f}, val acc: {val_acc:8.3f}',
              end='\r')
    # break

  input = module(input)


Epoch:  1/50, batch:   84/1000, train loss:  111.170, train acc:    0.250, val loss:  111.589, val acc:    0.1771

KeyboardInterrupt: 

In [None]:
optimizer.finish()

In [None]:
wandb.finish()