In [23]:
import os
import sys
import pickle
import argparse
import numpy as np

import matplotlib.pyplot as plt

import tqdm

import torch
from torch import nn
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.utils.data import DataLoader, Subset
from torch.optim import Adam

from ivon import IVON

sys.path.append("..")
from lib.models import get_model
from lib.datasets import get_dataset
from lib.utils import get_quick_loader, predict_test, flatten, predict_nll_hess, train_model, predict_train2, train_network
from lib.variances import get_covariance_from_iblr, get_covariance_from_adam, get_pred_vars_optim, get_pred_vars_laplace

In [24]:
ds_train, ds_test, transform_train = get_dataset('MOON', return_transform=True, noise=0.1)
input_size = ds_train[0][0].numel()
nc = len(torch.unique(torch.asarray([target for _, target in ds_train])))
tr_targets = torch.asarray([target for _, target in ds_train])
te_targets = torch.asarray([target for _, target in ds_test])
n_train = len(ds_train)


model = get_model('small_mlp', nc, input_size, 'cuda', 42)

learning_rate = 2
mc_samples = 4
n_epochs = 1000
hess_init = 0.1
optimizer = IVON(model.parameters(), lr = learning_rate, ess=n_train, mc_samples=mc_samples, hess_init = hess_init, weight_decay = 1e-3)

In [25]:
# Dataloaders
trainloader = get_quick_loader(DataLoader(ds_train, batch_size=256), device='cuda') # training
trainloader_eval = DataLoader(ds_train, batch_size=256, shuffle=False) # train evaluation
testloader_eval = DataLoader(ds_test, batch_size=256, shuffle=False) # test evaluation
trainloader_vars = DataLoader(ds_train, batch_size=256, shuffle=False) # variance computation

In [26]:

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=n_epochs)
loss_fn = nn.CrossEntropyLoss()

In [27]:
model, losses = train_model(model, loss_fn, optimizer, scheduler, trainloader, n_epochs, n_train-1, _, 'cuda', False)

  1%|          | 11/1000 [00:00<00:09, 101.13it/s]

Epoch 1, Loss: 0.18068774044513702


 12%|█▏        | 121/1000 [00:01<00:08, 103.22it/s]

Epoch 101, Loss: 0.007188075687736273


 22%|██▏       | 220/1000 [00:02<00:07, 105.42it/s]

Epoch 201, Loss: 0.047951746731996536


 32%|███▏      | 320/1000 [00:03<00:06, 104.54it/s]

Epoch 301, Loss: 0.005375325679779053


 42%|████▏     | 421/1000 [00:03<00:03, 191.53it/s]

Epoch 401, Loss: 0.009578786790370941


 55%|█████▍    | 545/1000 [00:03<00:01, 290.35it/s]

Epoch 501, Loss: 0.05345800146460533


 65%|██████▌   | 652/1000 [00:04<00:01, 329.76it/s]

Epoch 601, Loss: 0.006005279719829559


 72%|███████▏  | 722/1000 [00:04<00:00, 336.22it/s]

Epoch 701, Loss: 0.004140423610806465


 85%|████████▍ | 847/1000 [00:05<00:00, 237.91it/s]

Epoch 801, Loss: 0.01831910014152527


 94%|█████████▍| 945/1000 [00:05<00:00, 293.32it/s]

Epoch 901, Loss: 0.008790882304310799


100%|██████████| 1000/1000 [00:05<00:00, 180.16it/s]


In [28]:
def plot_decision_boundary(model, X, y, filename='decision_boundary.pdf', device='cuda'):
    # Define the grid range
    x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
    y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
    xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
                         np.arange(y_min, y_max, 0.01))
    
    # Create the grid points
    grid = np.c_[xx.ravel(), yy.ravel()]
    
    # Predict on the grid
    model.eval()
    with torch.no_grad():
        grid_tensor = torch.tensor(grid, dtype=torch.float32).to(device)
        preds = model(grid_tensor).argmax(dim=1).cpu().numpy()
    
    # Reshape predictions to match the grid
    Z = preds.reshape(xx.shape)
    
    # Plot
    plt.figure(figsize=(8, 6))
    plt.contourf(xx, yy, Z, alpha=0.8, cmap=plt.cm.coolwarm)
    plt.scatter(X[:, 0], X[:, 1], c=y, edgecolor='k', cmap=plt.cm.coolwarm)
    plt.title("Decision Boundary")
    plt.xlabel("Feature 1")
    plt.ylabel("Feature 2")
    
    # Save to file
    plt.savefig(filename, format='pdf')
    plt.close()

In [29]:
plot_decision_boundary(model, ds_train[:][0], ds_train[:][1], "small_mlp_playground.pdf")

  xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
  np.arange(y_min, y_max, 0.01))


In [30]:
print(model)

MLP(
  (hidden_layers): ModuleList(
    (0): Linear(in_features=2, out_features=32, bias=True)
    (1): Linear(in_features=32, out_features=16, bias=True)
  )
  (output_layer): Linear(in_features=16, out_features=2, bias=True)
)
