In [1]:
import torch
import laplace
import matplotlib.pyplot as plt
import numpy as np

from batchbald_redux import repeated_mnist
from main.models import ConvNet
from main.training_models import train_model
from main.bald_sampling import compute_entropy, compute_conditional_entropy

%reload_ext autoreload
%autoreload 2

In [28]:
# load mnist data
train_dataset, val_dataset = repeated_mnist.create_MNIST_dataset()

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32,
                                           sampler=torch.utils.data.SubsetRandomSampler(range(80)))
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, sampler=torch.utils.data.SubsetRandomSampler(range(1000)))

In [29]:
x_test, y_test = next(iter(val_loader))

In [36]:
# train model and optimize
model = ConvNet()

# train model
model = train_model(model, train_loader, num_epochs=20)

# compute out of sample performance (crudely)
model.eval()
with torch.no_grad():
    y_pred = model(x_test)
    print('Accuracy:', torch.mean((y_pred.argmax(dim=1) == y_test).float()).item())

Epoch 1/20, Loss: 2.223040819168091
Epoch 2/20, Loss: 1.9746544361114502
Epoch 3/20, Loss: 1.6898173093795776
Epoch 4/20, Loss: 1.3834806680679321
Epoch 5/20, Loss: 0.9185072779655457
Epoch 6/20, Loss: 0.694076657295227
Epoch 7/20, Loss: 0.503821611404419
Epoch 8/20, Loss: 0.13865363597869873
Epoch 9/20, Loss: 0.251486599445343
Epoch 10/20, Loss: 0.14383897185325623
Epoch 11/20, Loss: 0.033987682312726974
Epoch 12/20, Loss: 0.08825220167636871
Epoch 13/20, Loss: 0.02064281888306141
Epoch 14/20, Loss: 0.014900978654623032
Epoch 15/20, Loss: 0.020352903753519058
Epoch 16/20, Loss: 0.005083861760795116
Epoch 17/20, Loss: 0.0029607515316456556
Epoch 18/20, Loss: 0.0059080119244754314
Epoch 19/20, Loss: 0.0046522533521056175
Epoch 20/20, Loss: 0.0045122671872377396
Accuracy: 0.65625


In [42]:
# linearize and laplace approximation
la = laplace.Laplace(model, likelihood='classification', subset_of_weights='last_layer', hessian_structure='kron', temperature=1e-3)
la.fit(train_loader)
la.optimize_prior_precision(pred_type='glm', method='marglik', link_approx='probit', verbose=True)

Optimized prior precision is tensor([11.0047]).




## Testing BALD

### Computing entropy (first term in BALD)

In [48]:
from main.bald_sampling import compute_entropy, compute_entropy_weights

In [49]:
ent = compute_entropy(la_model=la, data=x_test)
ent_w = compute_entropy_weights(la_model=la, data=x_test)

In [50]:
ent

tensor([1.6447, 1.5470, 1.7109, 1.3268, 1.3466, 1.6111, 1.5681, 1.6632, 1.3571,
        1.6855, 1.4234, 1.7456, 0.7215, 1.9799, 1.7025, 1.6693, 1.3543, 1.6446,
        1.7003, 1.4787, 1.9724, 1.7405, 1.8122, 1.4522, 1.3558, 0.7002, 1.4526,
        1.5812, 1.3970, 1.5918, 1.6652, 1.2756])

In [46]:
ent_w

tensor([2.0600, 2.1252, 2.0115, 2.1458, 2.0465, 2.1216, 2.0717, 2.0842, 2.0722,
        2.1786, 2.0694, 2.1631, 1.9006, 2.1552, 2.0533, 1.9263, 2.1923, 2.0608,
        1.9933, 2.1623, 2.0291, 2.0901, 2.1564, 2.0800, 1.8979, 1.8932, 2.0041,
        2.1570, 2.1095, 2.0825, 2.1016, 2.0035])

### Computing conditional entropy

In [277]:
# Function to assign weights to a parameter
def set_last_linear_layer_combined(model, new_weights_and_bias):
    # Find the last linear layer
    last_linear_layer = None
    for module in model.modules():
        if isinstance(module, torch.nn.Linear):
            last_linear_layer = module
    
    if last_linear_layer is None:
        raise ValueError("No linear layer found in the model")

    # Get the shapes
    out_features, in_features = last_linear_layer.weight.shape
    
    # Check if the input tensor has the correct shape
    expected_shape = (out_features * in_features + out_features,)
    if new_weights_and_bias.shape != expected_shape:
        raise ValueError(f"Input tensor shape {new_weights_and_bias.shape} doesn't match the expected shape {expected_shape}")

    # Split the input tensor into weights and bias
    new_weights = new_weights_and_bias[:out_features * in_features].reshape(out_features, in_features)
    new_bias = new_weights_and_bias[out_features * in_features:]

    # Set new weights and bias
    last_linear_layer.weight.data = new_weights
    last_linear_layer.bias.data = new_bias

    return last_linear_layer

In [278]:
'''
First we sample from the posterior of p(weights | data)

For each weight sample, we compute the predictive distribution p(y | x, weights),
by passing x through the model with the sampled weights and applying the softmax function.

Doing this for many samples, we can compute the entropy of the predictive distribution at each x.

'''
# linearize and laplace approximation
la = laplace.Laplace(model, likelihood='classification', subset_of_weights='last_layer', hessian_structure='kron', temperature=1e-3)
la.fit(train_loader)
la.optimize_prior_precision(pred_type='glm', method='marglik', link_approx='probit', verbose=True)

Optimized prior precision is tensor([36.8790]).


In [279]:
# Sample from the posterior
posterior_weights = la.sample(n_samples=50)
entropies = torch.zeros(posterior_weights.shape[0], x_test.shape[0])

# Compute the entropy for each sample
for i, weights in enumerate(posterior_weights):
    # Set the weights in the model
    set_last_linear_layer_combined(la.model, weights)

    # fit the model
    la.fit(train_loader)

    # Optimise the prior precision
    la.optimize_prior_precision(pred_type='glm', method='marglik', link_approx='probit', verbose=True)

    # Compute the predictive distribution
    probs = la(x_test, pred_type='glm', link_approx='probit')

    # Compute the entropy
    entropies[i] = _h(probs)

Optimized prior precision is tensor([16.4651]).
Optimized prior precision is tensor([17.4157]).
Optimized prior precision is tensor([16.7425]).
Optimized prior precision is tensor([17.2752]).
Optimized prior precision is tensor([16.7267]).
Optimized prior precision is tensor([15.8028]).


KeyboardInterrupt: 

In [None]:
# average over sampled weights
entropies_avg = entropies.mean(dim=0)

print(f'BALD has {torch.sum(ent < entropies_avg).item()} zeros out of {ent.shape[0]} samples.')
bald = torch.max(ent - entropies_avg, torch.zeros_like(ent))
bald

BALD has 0 zeros out of 32 samples.


tensor([1.4202, 1.6857, 0.9458, 1.6858, 1.4193, 1.3705, 1.0489, 1.5071, 0.8167,
        1.5586, 1.3837, 1.3683, 1.4766, 1.7605, 1.7392, 0.9636, 1.2336, 0.7664,
        0.8247, 0.9784, 1.7967, 1.4989, 1.0402, 0.9998, 1.7251, 1.8253, 1.4178,
        1.7722, 1.1391, 0.9367, 1.2260, 1.3115])

In [None]:
from main.bald_sampling import compute_bald
# linearize and laplace approximation
la = laplace.Laplace(model, likelihood='classification', subset_of_weights='last_layer', hessian_structure='kron', temperature=1e-3)
la.fit(train_loader)
la.optimize_prior_precision(pred_type='glm', method='marglik', link_approx='probit', verbose=True)

bald = compute_bald(la, x_test, train_loader, n_samples=50)
bald

Optimized prior precision is tensor([29.4391]).
Optimized prior precision is tensor([14.3348]).
Optimized prior precision is tensor([14.2572]).
Optimized prior precision is tensor([14.9633]).
Optimized prior precision is tensor([14.3452]).
Optimized prior precision is tensor([15.4623]).
Optimized prior precision is tensor([15.0644]).
Optimized prior precision is tensor([14.3005]).
Optimized prior precision is tensor([14.8923]).
Optimized prior precision is tensor([14.0761]).
Optimized prior precision is tensor([14.5439]).
Optimized prior precision is tensor([14.1053]).
Optimized prior precision is tensor([15.2824]).
Optimized prior precision is tensor([15.5675]).
Optimized prior precision is tensor([13.0978]).
Optimized prior precision is tensor([12.8305]).
Optimized prior precision is tensor([14.6344]).
Optimized prior precision is tensor([14.2011]).
Optimized prior precision is tensor([14.5757]).
Optimized prior precision is tensor([14.7716]).
Optimized prior precision is tensor([15.

tensor([1.4505, 1.7123, 0.9928, 1.6769, 1.3852, 1.3303, 1.0606, 1.4980, 0.7725,
        1.5536, 1.3738, 1.3918, 1.4747, 1.7254, 1.7494, 0.9781, 1.1987, 0.7730,
        0.8604, 1.0272, 1.7938, 1.4872, 1.1301, 1.0425, 1.7671, 1.8071, 1.4647,
        1.7551, 1.2098, 0.9509, 1.2345, 1.3743])

The scale of values is looking pretty plausible when compared to the values obtained from bald as in batchbald_redux.
This is using sampling from the parameter space. Now try to do sampling in f-space to compute *conditional entropy*, analagous to how Houlsby describes it. If that yields the same values it could be much quicker.

In [None]:
def compute_conditional_entropy(la_model, data, train_loader, refit=True, n_samples=50):
    # Sample from the posterior
    posterior_weights = la_model.sample(n_samples=n_samples)
    entropies = torch.zeros(posterior_weights.shape[0], data.shape[0])

    # Compute the entropy for each sample
    for i, weights in enumerate(posterior_weights):
        # Set the weights in the model
        set_last_linear_layer_combined(la_model.model, weights)

        if refit:
            # fit the model
            la_model.fit(train_loader)

            # Optimise the prior precision
            la_model.optimize_prior_precision(pred_type='glm', method='marglik', link_approx='probit', verbose=False)

        # Compute the predictive distribution
        probs = la_model(data, pred_type='glm', link_approx='probit')

        # Compute the entropy
        entropies[i] = _h(probs)

    return entropies.mean(dim=0)

In [None]:
e = compute_entropy(la, x_test)
e

tensor([2.2419, 2.2292, 2.2521, 2.2292, 2.2303, 2.2296, 2.2470, 2.2322, 2.2734,
        2.2292, 2.2525, 2.2308, 2.2298, 2.2304, 2.2292, 2.2655, 2.2714, 2.2305,
        2.2593, 2.2741, 2.2292, 2.2292, 2.2609, 2.2346, 2.2292, 2.2292, 2.2292,
        2.2294, 2.2766, 2.2786, 2.2330, 2.2293])

In [None]:
ce = compute_conditional_entropy(la, x_test, train_loader, refit=True, n_samples=10)
ce



tensor([2.2745, 2.2632, 2.2592, 2.2210, 2.2492, 2.2516, 2.2389, 2.2652, 2.2517,
        2.2552, 2.2774, 2.2649, 2.2410, 2.2399, 2.2166, 2.2865, 2.2378, 2.2657,
        2.2585, 2.2596, 2.2477, 2.2480, 2.2557, 2.2832, 2.2004, 2.2211, 2.2495,
        2.2393, 2.2542, 2.2793, 2.2512, 2.2730])