In [1]:
 %load_ext autoreload
import cbx as cbx
from cbx.dynamics.cbo import CBO
import numpy as np

import torch
import torch.nn as nn
import torchvision
from cbx.noise import anisotropic_noise
from torch.func import functional_call, stack_module_state, vmap

# Load data
We load the train and test data.

In [2]:
data_path = "../../../datasets/"
transform = torchvision.transforms.ToTensor()
train_data = torchvision.datasets.MNIST(data_path, train=True, transform=transform, download=False)
test_data = torchvision.datasets.MNIST(data_path, train=False, transform=transform, download=False)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64,shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64,shuffle=False, num_workers=0)

# Load model

In [3]:
from models import Perceptron
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Perceptron().to(device)

In [4]:
from collections import OrderedDict
N = 50
models = [Perceptron() for _ in range(N)]
params, buffers = stack_module_state(models)
pnames = params.keys()
pshapes = {}
pcuts = []
pprop = OrderedDict()
for p in pnames:
    a = 0
    if len(pprop)>0:
        a = pprop[next(reversed(pprop))][-1]
    pprop[p] = (params[p][0,...].shape, a, a + params[p][0,...].numel())
    
def flatten_parameters(params, pnames):
    return torch.concatenate([params[pname].view(N,-1).detach() for pname in pnames], dim=-1)

w = flatten_parameters(params,pnames).to(device)

In [5]:
def eval_model(x, model, w, pprop):
    params = {p: w[pprop[p][-2]:pprop[p][-1]].view(pprop[p][0]) for p in pprop}
    return functional_call(model, (params, {}), x)

def eval_models(x, model, w, pprop):
    return vmap(eval_model, (None, None, 0, None))(x, model, w, pprop)

def eval_loss(x, y, loss_fct, model, w, pprop):
    with torch.no_grad():
        return loss_fct(eval_model(x, model, w, pprop), y)
    
def eval_losses(x, y, loss_fct, model, w, pprop):
    return vmap(eval_loss, (None, None, None, None, 0, None))(x, y, loss_fct, model, w, pprop)

def eval_acc(model, w, pprop):
    res = 0
    num_img = 0
    for (x,y) in iter(test_loader):
        res += torch.sum(eval_model(x, model, w, pprop).argmax(axis=1)==y)
        num_img += x.shape[0]
    return res/num_img

In [6]:
class objective:
    def __init__(self, train_loader, N, device, model, pprop):
        self.train_loader = train_loader
        self.data_iter = iter(train_loader)
        self.N = N
        self.epochs = 0
        self.device = device   
        self.loss_fct = nn.CrossEntropyLoss()
        self.model = model
        self.pprop = pprop
        self.set_batch()
        
    def __call__(self, w):   
        return eval_losses(self.x, self.y, self.loss_fct, self.model, w[0,...], self.pprop)
    
    def set_batch(self,):
        (x,y) = next(self.data_iter, (None, None))
        if x is None:
            self.data_iter = iter(self.train_loader)
            (x,y) = next(self.data_iter)
            self.epochs += 1
        self.x = x.to(self.device)
        self.y = y.to(self.device)

# Set up CBX Dynamic

In [7]:
kwargs = {'alpha':50.0,
        'dt': 0.1,
        'sigma': 0.1,
        'lamda': 1.0,
        'term_args':{'max_time': 20},
        'verbosity':2,
        'batch_args':{'batch_size':N},
        #'batch_size': M,
        'check_f_dims':False}

In [15]:
def norm_torch(x, axis, **kwargs):
    return torch.linalg.norm(x, dim=axis, **kwargs)  


resamplings = [cbx.utils.resampling.loss_update_resampling(M=1, wait_thresh=40)]
f = objective(train_loader, N, device, model, pprop)
def cc(f, x, alpha):
    energy = f(x) # update energy
    weights = - alpha * energy
    coeffs = torch.exp(weights - torch.logsumexp(weights, axis=(-1,), keepdims=True))[...,None]
    return (x * coeffs).sum(axis=-2, keepdims=True), energy

def normal_torch(mean, std, size):
    return torch.normal(mean, std, size).to(device)
noise = anisotropic_noise(norm = norm_torch, sampler = normal_torch)
dyn = CBO(f, f_dim='3D', x=w[None,...], noise=noise, 
          resamplings=resamplings, 
          norm=norm_torch,
          copy=torch.clone,
          normal=normal_torch,
          compute_consensus=cc,
          **kwargs)
sched = cbx.scheduler.multiply(factor=1.03, name='alpha')

# Train the network

In [16]:
while f.epochs < 10:
    dyn.step()
    sched.update(dyn)
    f.set_batch()
    if dyn.it%10 == 0:
        print('Cur Best energy: ' + str(dyn.best_cur_energy))
        print('Best energy: ' + str(dyn.best_energy))
        print('Alpha: ' + str(dyn.alpha))
        print('Sigma: ' + str(dyn.sigma))
    if e != f.epochs:
        e = f.epochs
        print(30*'-')
        print('Epoch: ' +str(f.epochs))
        print('Accuracy: ' + str(eval_acc(dyn.best_particle[None,...])))
        print(30*'-')

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

In [20]:
import numpy as np

In [21]:
x = np.zeros(5)

In [22]:
x.where

AttributeError: 'numpy.ndarray' object has no attribute 'where'