In [None]:
# hide

import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


# Acquisition Function: SieveBALD
> Greedy algorithm and score computation

First, we will implement two helper classes to compute conditional entropies $H[y_i|w]$ and entropies $H[y_i]$. 
Then, we will implement BatchBALD and BALD.

In [None]:
import math

import numpy as np
import torch
from blackhc.progress_bar import create_progress_bar
from toma import toma

from batchbald_redux.acquisition_functions.sievebald import * 
from batchbald_redux.acquisition_functions.batchbald import * 
from batchbald_redux.joint_entropy import *

We are going to define a couple of sampled distributions to use for our testing our code.

$K=20$ means 20 inference samples.

In [None]:
K = 20

In [None]:
def get_mixture_prob_dist(p1, p2, m):
    return (1.0 - m) * np.asarray(p1) + m * np.asarray(p2)


p1 = [0.7, 0.1, 0.1, 0.1]
p2 = [0.3, 0.3, 0.2, 0.2]
y1_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.7, 0.1, 0.1]
p2 = [0.2, 0.3, 0.3, 0.2]
y2_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.7, 0.1]
p2 = [0.2, 0.2, 0.3, 0.3]
y3_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.1, 0.7]
p2 = [0.3, 0.2, 0.2, 0.3]
y4_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]


def nested_to_tensor(l):
    return torch.stack(list(map(torch.as_tensor, l)))


ys_ws = nested_to_tensor([y1_ws, y2_ws, y3_ws, y4_ws])

In [None]:
# hide

p = [0.25, 0.25, 0.25, 0.25]
yu_ws = [p for m in range(K)]
yus_ws = nested_to_tensor([yu_ws] * 4)

In [None]:
ys_ws.shape

torch.Size([4, 20, 4])

However, our neural networks usually use a `log_softmax` as final layer. To avoid having to call `.exp_()`, which is easy to miss and annoying to debug, we will instead use a version that uses `log_probs` instead of `probs`.

In [None]:
# hide

# Make sure everything is computed correctly.
assert np.allclose(compute_conditional_entropy(yus_ws.log()), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)
assert np.allclose(compute_entropy(yus_ws.log()), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

### Examples

In [None]:
conditional_entropies = compute_conditional_entropy(ys_ws.log())

print(conditional_entropies)

assert np.allclose(conditional_entropies, [1.2069, 1.2069, 1.2069, 1.2069], atol=0.01)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([1.2093, 1.2093, 1.2093, 1.2093], dtype=torch.float64)


In [None]:
entropies = compute_entropy(ys_ws.log())

print(entropies)

assert np.allclose(entropies, [1.2376, 1.2376, 1.2376, 1.2376], atol=0.01)

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([1.2376, 1.2376, 1.2376, 1.2376], dtype=torch.float64)


## SieveBALD

This is the 2-BALD approximation (leaving out $ D_{train}$):
$$I[Y_1, \ldots, Y_n;\Omega \mid x_1, \ldots,x_n] \approx \sum_i I[Y_i;\Omega\mid x_i] - \sum_{i<j} I[Y_i;Y_j \mid x_i,x_j].$$

See also https://www.notion.so/SieveBALD-using-a-marginal-total-correlation-assumption-and-or-by-forcing-it-2e4a9548d4124b6bb8e0dcbba789887a.

In [None]:
%%time
get_sieve_bald_batch(np.repeat(ys_ws, 10, axis=0).log().double(), batch_size=5, dtype=torch.double)

Entropy:   0%|          | 0/40 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/40 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/40 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/40 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/40 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/40 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/40 [00:00<?, ?it/s]

CPU times: user 1.79 s, sys: 67 ms, total: 1.85 s
Wall time: 1.46 s


CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.08671183981604269, 0.11199240029961555, 0.13546126772230105], indices=[0, 1, 2, 3, 4])

In [None]:
ys_ws.shape

torch.Size([4, 20, 4])

In [None]:
%%time
get_batch_bald_batch(np.repeat(ys_ws, 10, axis=0).log().double(), batch_size=5, num_samples=1000000, dtype=torch.double)

Conditional Entropy:   0%|          | 0/40 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/5 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/40 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/40 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/40 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/40 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/40 [00:00<?, ?it/s]

CPU times: user 1.56 s, sys: 64.3 ms, total: 1.63 s
Wall time: 1.23 s


CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.0869107051474467, 0.11275304532467878, 0.1372853331853925], indices=[10, 0, 1, 20, 2])