In [3]:
import torch
import laplace
import numpy as np
import torch.utils
import matplotlib.pyplot as plt
from main.models import ConvNet
from main.training_models import train_model
from batchbald_redux import repeated_mnist, joint_entropy, batchbald
from laplace.marglik_training import marglik_training
from laplace.curvature import AsdlGGN

%reload_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# load mnist data
train_dataset, val_dataset = repeated_mnist.create_MNIST_dataset()

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32,
                                           sampler=torch.utils.data.SubsetRandomSampler(range(1000)))
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, sampler=torch.utils.data.SubsetRandomSampler(range(1000)))

# collapse data from training_loader into tensor for x and y
x = torch.cat([x for x, y in train_loader], dim=0)
y = torch.cat([y for x, y in train_loader], dim=0)

x_test, y_test = next(iter(val_loader))

In [5]:
# train model and jointly optimize for marginal likelihood
model = ConvNet()

# train model
train_model(model, train_loader, num_epochs=20)

Epoch 1/20, Loss: 0.5403611660003662
Epoch 6/20, Loss: 0.046635936945676804
Epoch 11/20, Loss: 0.01454768143594265
Epoch 16/20, Loss: 0.003275579772889614


ConvNet(
  (features): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): GELU(approximate='none')
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
    (4): GELU(approximate='none')
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
  )
  (classifier): Sequential(
    (0): Linear(in_features=512, out_features=32, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=32, out_features=10, bias=True)
  )
  (_last_layer): Linear(in_features=32, out_features=10, bias=True)
)

In [6]:
# check accuracy
correct = 0
total = 0

with torch.no_grad():
    for data in val_loader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 1000 test images: %d %%' % ( 100 * correct / total))

Accuracy of the network on the 1000 test images: 93 %


In [7]:
la = laplace.Laplace(model,
                     likelihood="classification",
                     subset_of_weights='last_layer',
                     hessian_structure='kron',
                     backend=AsdlGGN
                     )
la.fit(train_loader=train_loader)
la.optimize_prior_precision(method='marglik', verbose=True)



Optimized prior precision is tensor([6.2196]).




In [26]:
x_test = torch.cat([x for x, y in val_loader], dim=0)
x_test = x_test[:32]

In [44]:
from main.bald_sampling import max_joint_eig
K = 500

max_joint_eig(model=la, data=x_test, K=K, batch_size=10)

det_y: tensor(-40.9118) det_theta: tensor(-796.5087) det_joint: tensor(-857.2563) eig: tensor(9.9179)
eig: tensor(9.9179)
det_y: tensor(-39.0422) det_theta: tensor(-792.7565) det_joint: tensor(-851.6059) eig: tensor(9.9036)
eig: tensor(9.9036)
det_y: tensor(-40.4950) det_theta: tensor(-794.9068) det_joint: tensor(-854.1402) eig: tensor(9.3692)
eig: tensor(9.3692)
det_y: tensor(-40.9354) det_theta: tensor(-794.5879) det_joint: tensor(-855.5502) eig: tensor(10.0134)
eig: tensor(10.0134)
det_y: tensor(-39.6517) det_theta: tensor(-799.7205) det_joint: tensor(-859.0887) eig: tensor(9.8582)
eig: tensor(9.8582)
det_y: tensor(-40.6233) det_theta: tensor(-792.5306) det_joint: tensor(-852.1053) eig: tensor(9.4757)
eig: tensor(9.4757)
det_y: tensor(-39.8192) det_theta: tensor(-793.7623) det_joint: tensor(-853.3255) eig: tensor(9.8720)
eig: tensor(9.8720)
det_y: tensor(-85.9473) det_theta: tensor(-794.5605) det_joint: tensor(-917.5234) eig: tensor(18.5078)
eig: tensor(18.5078)
det_y: tensor(-84.67

KeyboardInterrupt: 

In [201]:
import pstats
import cProfile

def profile_wrapper():
    max_joint_eig(model=la, data=x_test, K=100, batch_size=3)

cProfile.run('profile_wrapper()', 'output.prof')

# Print the stats
with open('output_stats.txt', 'w') as stream:
    stats = pstats.Stats('output.prof', stream=stream).sort_stats('cumulative')
    stats.print_stats()

selected: [20]
selected: [20, 1]
selected: [20, 1, 23]
