In [1]:
%load_ext autoreload
%autoreload 2
import cbx as cbx
from cbx.dynamics.cbo import CBO
import numpy as np

import torch
import torch.nn as nn
import torchvision
from cbx.noise import anisotropic_noise

# Load data
We load the train and test data.

In [2]:
data_path = "../../../datasets/"
transform = torchvision.transforms.ToTensor()
train_data = torchvision.datasets.MNIST(data_path, train=True, transform=transform, download=False)
test_data = torchvision.datasets.MNIST(data_path, train=False, transform=transform, download=False)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64,shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64,shuffle=False, num_workers=0)

# Load model

In [3]:
from models import Perceptron
from cbx_torch_utils import flatten_parameters, get_param_properties, eval_losses, norm_torch, compute_consensus_torch, normal_torch, eval_acc
device = 'cuda' if torch.cuda.is_available() else 'cpu'
N = 50
models = [Perceptron(sizes=[784,100,10]) for _ in range(N)]
model = models[0]
pnames = [p[0] for p in model.named_parameters()]
w = flatten_parameters(models, pnames).to(device)
pprop = get_param_properties(models, pnames=pnames)

In [4]:
class objective:
    def __init__(self, train_loader, N, device, model, pprop):
        self.train_loader = train_loader
        self.data_iter = iter(train_loader)
        self.N = N
        self.epochs = 0
        self.device = device   
        self.loss_fct = nn.CrossEntropyLoss()
        self.model = model
        self.pprop = pprop
        self.set_batch()
        
    def __call__(self, w):   
        return eval_losses(self.x, self.y, self.loss_fct, self.model, w[0,...], self.pprop)
    
    def set_batch(self,):
        (x,y) = next(self.data_iter, (None, None))
        if x is None:
            self.data_iter = iter(self.train_loader)
            (x,y) = next(self.data_iter)
            self.epochs += 1
        self.x = x.to(self.device)
        self.y = y.to(self.device)

# Set up CBX Dynamic

In [5]:
kwargs = {'alpha':50.0,
        'dt': 0.1,
        'sigma': 0.1,
        'lamda': 1.0,
        'term_args':{'max_time': 20},
        'verbosity':0,
        'batch_args':{'batch_size':N},
        #'batch_size': M,
        'check_f_dims':False}

In [6]:
f = objective(train_loader, N, device, model, pprop)
resamplings = [cbx.utils.resampling.loss_update_resampling(M=1, wait_thresh=40)]
noise = anisotropic_noise(norm = norm_torch, sampler = normal_torch(device))
dyn = CBO(f, f_dim='3D', x=w[None,...], noise=noise, 
          resamplings=resamplings, 
          norm=norm_torch,
          copy=torch.clone,
          normal=normal_torch(device),
          compute_consensus=compute_consensus_torch,
          post_process = lambda *args: None,
          **kwargs)
sched = cbx.scheduler.multiply(factor=1.03, name='alpha')

# Train the network

In [None]:
e = 0
while f.epochs < 10:
    dyn.step()
    sched.update(dyn)
    f.set_batch()
    if e != f.epochs:
        e = f.epochs
        print(30*'-')
        print('Epoch: ' +str(f.epochs))
        acc = eval_acc(model, dyn.best_particle[0,...], pprop, test_loader)
        print('Accuracy: ' + str(acc.item()))
        print(30*'-')