In [240]:
import torch
import laplace
import numpy as np
import torch.utils
import matplotlib.pyplot as plt
from main.models import ConvNet, BayesianConvNet
from main.training_models import train_model
from main.utils import entropy
from batchbald_redux import repeated_mnist, joint_entropy, batchbald
from laplace.utils.utils import normal_samples

%reload_ext autoreload
%autoreload 2

In [139]:
# load mnist data
train_dataset, val_dataset = repeated_mnist.create_MNIST_dataset()

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32,
                                           sampler=torch.utils.data.SubsetRandomSampler(range(1000)))
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, sampler=torch.utils.data.SubsetRandomSampler(range(1000)))

In [140]:
model = train_model(ConvNet(), train_loader, num_epochs=25, lr=1e-3)

# evaluate model on validation set
correct = 0
total = 0
with torch.no_grad():
    for images, labels in val_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the validation images: %d %%' % (100 * correct / total))

Epoch 1/25, Loss: 1.3662784099578857
Epoch 2/25, Loss: 1.1333781480789185
Epoch 3/25, Loss: 0.5565322041511536
Epoch 4/25, Loss: 0.211321160197258
Epoch 5/25, Loss: 0.1780632883310318
Epoch 6/25, Loss: 0.12224686145782471
Epoch 7/25, Loss: 0.10107780992984772
Epoch 8/25, Loss: 0.08403026312589645
Epoch 9/25, Loss: 0.45377659797668457
Epoch 10/25, Loss: 0.04758601635694504
Epoch 11/25, Loss: 0.039059266448020935
Epoch 12/25, Loss: 0.15554024279117584
Epoch 13/25, Loss: 0.006701083853840828
Epoch 14/25, Loss: 0.011465385556221008
Epoch 15/25, Loss: 0.04479203000664711
Epoch 16/25, Loss: 0.0017860461957752705
Epoch 17/25, Loss: 0.0054418547078967094
Epoch 18/25, Loss: 0.0011644844198599458
Epoch 19/25, Loss: 0.0010600830428302288
Epoch 20/25, Loss: 0.002040612045675516
Epoch 21/25, Loss: 0.006248588673770428
Epoch 22/25, Loss: 0.019717438146471977
Epoch 23/25, Loss: 0.004651464056223631
Epoch 24/25, Loss: 0.0004772499087266624
Epoch 25/25, Loss: 0.000592651660554111
Accuracy of the networ

In [219]:
# collapse data from training_loader into tensor for x and y
x = torch.cat([x for x, y in train_loader], dim=0)
y = torch.cat([y for x, y in train_loader], dim=0)

x_test = torch.cat([x for x, y in val_loader], dim=0)[:50]
y_test = torch.cat([y for x, y in val_loader], dim=0)[:50]

In [214]:
# User-specified LA flavor
la = laplace.Laplace(model,
                     likelihood="classification",
                     subset_of_weights='all',
                     hessian_structure='kron',
                     enable_backprop=True
                     )

print('Fitting the Laplace approximation...')
la.fit(train_loader, progress_bar=True)

print('Optimizing the prior precision...')
la.optimize_prior_precision(
    method='marglik'
)

# User-specified predictive approx.
#print('Computing the predictive distribution...')
#pred = la(x, pred_type="glm", link_approx="probit")  # probabilities (N x C)

Fitting the Laplace approximation...


[Computing Hessian]: 100%|██████████| 32/32 [00:03<00:00, 10.51it/s]


Optimizing the prior precision...


In [215]:
predicted = la(x_test, pred_type="glm", link_approx='probit')
samples = la.predictive_samples(x_test, n_samples=1000)

In [216]:
# switch 0 and 1 dimensions
p_samples = torch.swapaxes(samples, 0, 1)
log_probs = torch.log(p_samples)

In [244]:
# mean is (N, C) and variance is (N, C, C)
f_mu, f_var = la._glm_predictive_distribution(x_test)

# obtain samples of predicted network outputs f. f_samples is (n_samples, N, C)
# using softmax would give us the probabilities as in la.predictive_samples

n_samples = 1000
f_samples = normal_samples(f_mu, f_var, n_samples)

In [259]:
# turn f_var into block diagonal matrix of size NC x NC
N = x_test.shape[0]
C = 10

f_var_block = torch.zeros(N * C, N * C)
for i in range(N):
    f_var_block[i*C:(i+1)*C, i*C:(i+1)*C] = f_var[i]

# compute eigenvalues of f_var_block
eigvals = np.linalg.eigvals(f_var_block.detach())
nonzeros = np.sum(eigvals > 1e-6)
print(f'Number of non-zero eigenvalues: {(nonzeros)} out of {N*C}')

Number of non-zero eigenvalues: 500 out of 500


## Functional (co)variance using Jacobian

In [68]:
def compute_jacobian(x, model, la):
    num_params = la.n_params

    # Computing the Jacobian of the model with respect to the parameters
    J = torch.zeros(x.size(0), 10, num_params)
    for i in range(x.size(0)):
        model.zero_grad()
        output = model(x[i].unsqueeze(0))

        for c in range(10):
            output[0, c].backward(retain_graph=True)
            # select the last parameters of the last layer 
            last_layer_gradients = torch.cat([p.grad.flatten() for p in model.parameters()])[-num_params:]

            J[i, c, :] = last_layer_gradients
    return J.detach()

J = compute_jacobian(x, model, la)

# fraction of zeros in the Jacobian
(J == 0).sum() / J.numel()