In [1]:
# default_exp batchbald

In [2]:
# hide

import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


# BatchBALD Algorithm
> Greedy algorithm and score computation

First, we will implement two helper classes to compute conditional entropies $H[y_i|w]$ and entropies $H[y_i]$. 
Then, we will implement BatchBALD and BALD.

In [3]:
# exports
import math
from dataclasses import dataclass
from typing import List

import torch
from toma import toma
from tqdm.auto import tqdm

from batchbald_redux import joint_entropy

We are going to define a couple of sampled distributions to use for our testing our code.

$K=20$ means 20 inference samples.

In [4]:
K = 20

In [5]:
import numpy as np


def get_mixture_prob_dist(p1, p2, m):
    return (1.0 - m) * np.asarray(p1) + m * np.asarray(p2)


p1 = [0.7, 0.1, 0.1, 0.1]
p2 = [0.3, 0.3, 0.2, 0.2]
y1_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.7, 0.1, 0.1]
p2 = [0.2, 0.3, 0.3, 0.2]
y2_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.7, 0.1]
p2 = [0.2, 0.2, 0.3, 0.3]
y3_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.1, 0.7]
p2 = [0.3, 0.2, 0.2, 0.3]
y4_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]


def nested_to_tensor(l):
    return torch.stack(list(map(torch.as_tensor, l)))


ys_ws = nested_to_tensor([y1_ws, y2_ws, y3_ws, y4_ws])

In [6]:
# hide

p = [0.25, 0.25, 0.25, 0.25]
yu_ws = [p for m in range(K)]
yus_ws = nested_to_tensor([yu_ws] * 4)

In [7]:
ys_ws.shape

torch.Size([4, 20, 4])

## Conditional Entropies and Batched Entropies


In [8]:
# exports


def compute_conditional_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = tqdm(total=N, desc="Conditional Entropy", leave=False)

    @toma.execute.chunked(probs_N_K_C, 1024)
    def compute(probs_n_K_C, start: int, end: int):
        nats_n_K_C = probs_n_K_C * torch.log(probs_n_K_C)
        nats_n_K_C[probs_n_K_C == 0] = 0.0

        entropies_N[start:end].copy_(-torch.sum(nats_n_K_C, dim=(1, 2)) / K)
        pbar.update(end - start)

    pbar.close()

    return entropies_N


def compute_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = tqdm(total=N, desc="Entropy", leave=False)

    @toma.execute.chunked(probs_N_K_C, 1024)
    def compute(probs_n_K_C, start: int, end: int):
        mean_probs_n_C = probs_n_K_C.mean(dim=1)
        nats_n_C = mean_probs_n_C * torch.log(mean_probs_n_C)
        nats_n_C[mean_probs_n_C == 0] = 0.0

        entropies_N[start:end].copy_(-torch.sum(nats_n_C, dim=1))
        pbar.update(end - start)

    pbar.close()

    return entropies_N

In [9]:
# Make sure everything is computed correctly.

assert np.allclose(
    compute_conditional_entropy(yus_ws), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1
)
assert np.allclose(compute_entropy(yus_ws), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(HTML(value='Entropy'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

In [10]:
# exporti
# Not publishing these at the moment.


def compute_conditional_entropy_from_logits(logits_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = logits_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = tqdm(total=N, desc="Conditional Entropy", leave=False)

    @toma.execute.chunked(logits_N_K_C, 1024)
    def compute(logits_n_K_C, start: int, end: int):
        nats_n_K_C = logits_n_K_C * torch.exp(logits_n_K_C)

        entropies_N[start:end].copy_(-torch.sum(nats_n_K_C, dim=(1, 2)) / K)
        pbar.update(end - start)

    pbar.close()

    return entropies_N


def compute_entropy_from_logits(logits_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = logits_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = tqdm(total=N, desc="Entropy", leave=False)

    @toma.execute.chunked(logits_N_K_C, 1024)
    def compute(logits_n_K_C, start: int, end: int):
        mean_logits_n_C = torch.logsumexp(logits_n_K_C, dim=1) - math.log(K)
        nats_n_C = mean_logits_n_C * torch.exp(mean_logits_n_C)

        entropies_N[start:end].copy_(-torch.sum(nats_n_C, dim=1))
        pbar.update(end - start)

    pbar.close()

    return entropies_N

In [11]:
# hide

# Make sure everything is computed correctly.
assert np.allclose(
    compute_conditional_entropy_from_logits(yus_ws.log()),
    [1.3863, 1.3863, 1.3863, 1.3863],
    atol=0.1,
)
assert np.allclose(
    compute_entropy_from_logits(yus_ws.log()),
    [1.3863, 1.3863, 1.3863, 1.3863],
    atol=0.1,
)

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(HTML(value='Entropy'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

### Examples

In [12]:
conditional_entropies = compute_conditional_entropy(ys_ws)

print(conditional_entropies)

assert np.allclose(conditional_entropies, [1.2069, 1.2069, 1.2069, 1.2069], atol=0.01)

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

tensor([1.2069, 1.2069, 1.2069, 1.2069], dtype=torch.float64)


In [13]:
entropies = compute_entropy(ys_ws)

print(entropies)

assert np.allclose(entropies, [1.2376, 1.2376, 1.2376, 1.2376], atol=0.01)

HBox(children=(HTML(value='Entropy'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

tensor([1.2376, 1.2376, 1.2376, 1.2376], dtype=torch.float64)


## BatchBALD

To compute BatchBALD exactly for a candidate batch, we'd have to compute $I[(y_b)_B;w] = H[(y_b)_B] - H[(y_b)_B|w]$.

As the $y_b$ are independent given $w$, we can simplify $H[(y_b)_B|w] = \sum_b H[y_b|w]$.

Furthermore, we use a greedy algorithm to build up the candidate batch, so $y_1,\dots,y_{B-1}$ will stay fixed as we determine $y_{B}$. We compute
$H[(y_b)_{B-1}, y_i] - H[y_i|w]$ for each pool element $y_i$ and add the highest scorer as $y_{B}$.

We don't utilize the last optimization here in order to compute the actual scores.


### In the Paper

![BatchBALD algorithm in the paper](batchbald_algorithm.png)


### Implementation

In [14]:
# exports


@dataclass
class CandidateBatch:
    scores: List[float]
    indices: List[int]

In [22]:
# exports


class BatchBALDScorer:
    probs_N_K_C: torch.tensor
    conditional_entropies_N: torch.Tensor
    batch_joint_entropy: joint_entropy.JointEntropy
    batch_conditional_entropies: torch.Tensor

    def __init__(
        self, probs_N_K_C, max_size, num_samples: int, dtype=None, device=None
    ):
        N, K, C = probs_N_K_C.shape
        self.probs_N_K_C = probs_N_K_C

        self.conditional_entropies_N = compute_conditional_entropy(self.probs_N_K_C)
        self.batch_conditional_entropies = 0

        self.batch_joint_entropy = joint_entropy.DynamicJointEntropy(
            num_samples, max_size, K, C, dtype=dtype, device=device
        )

    def append_to_batch(self, index: int):
        self.batch_joint_entropy.add_variables(self.probs_N_K_C[index : index + 1])
        self.batch_conditional_entropies += self.conditional_entropies_N[index].clone()

    def alloc_scores(self) -> torch.Tensor:
        # We always keep these on the CPU.
        scores_N = torch.empty(
            self.probs_N_K_C.shape[0],
            dtype=torch.double,
            pin_memory=torch.cuda.is_available(),
        )
        return scores_N

    def compute_scores(self, out_scores_N: torch.Tensor):
        self.batch_joint_entropy.compute_batch(
            self.probs_N_K_C, output_entropies_B=out_scores_N
        )

        out_scores_N -= self.conditional_entropies_N + self.batch_conditional_entropies

        return out_scores_N


def get_batchbald_batch(
    probs_N_K_C: torch.Tensor,
    batch_size: int,
    num_samples: int,
    dtype=None,
    device=None,
) -> CandidateBatch:
    N, K, C = probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    batchbald_scorer = BatchBALDScorer(
        probs_N_K_C,
        max_size=batch_size - 1,
        num_samples=num_samples,
        dtype=dtype,
        device=device,
    )

    # We always keep these on the CPU.
    scores_N = batchbald_scorer.alloc_scores()

    for i in tqdm(range(batch_size), desc="BatchBALD", leave=False):
        if i > 0:
            latest_index = candidate_indices[-1]
            batchbald_scorer.append_to_batch(latest_index)

        batchbald_scorer.compute_scores(scores_N)
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

    return CandidateBatch(candidate_scores, candidate_indices)


def get_batchbald_batch_plain(
    probs_N_K_C: torch.Tensor,
    batch_size: int,
    num_samples: int,
    dtype=None,
    device=None,
) -> CandidateBatch:
    N, K, C = probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    conditional_entropies_N = compute_conditional_entropy(probs_N_K_C)

    batch_joint_entropy = joint_entropy.DynamicJointEntropy(
        num_samples, batch_size - 1, K, C, dtype=dtype, device=device
    )

    # We always keep these on the CPU.
    scores_N = torch.empty(N, dtype=torch.double, pin_memory=torch.cuda.is_available())

    for i in tqdm(range(batch_size), desc="BatchBALD", leave=False):
        if i > 0:
            latest_index = candidate_indices[-1]
            batch_joint_entropy.add_variables(
                probs_N_K_C[latest_index : latest_index + 1]
            )

        shared_conditional_entropies = conditional_entropies_N[candidate_indices].sum()

        batch_joint_entropy.compute_batch(probs_N_K_C, output_entropies_B=scores_N)

        scores_N -= conditional_entropies_N + shared_conditional_entropies
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

    return CandidateBatch(candidate_scores, candidate_indices)

### Example

In [23]:
get_batchbald_batch(ys_ws.double(), 4, 1000, dtype=torch.double)

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(HTML(value='BatchBALD'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.08691070514744625, 0.11275304532467789], indices=[1, 0, 2, 3])

In [21]:
get_batchbald_batch_plain(ys_ws.double(), 4, 1000, dtype=torch.double)

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(HTML(value='BatchBALD'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.08691070514744625, 0.11275304532467789], indices=[1, 0, 2, 3])

## BALD

BALD is the same as BatchBALD, except that we evaluate points individually, by computing $I[y_i;w]$ for each, and then take the top $B$ scorers.

### BALD scores

Sometimes, we want to obtain BALD scores for all samples as measure of epistemic uncertainty.

In [16]:
# exports


def get_bald_scores(probs_N_K_C: torch.Tensor, dtype=None, device=None) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape

    scores_N = -compute_conditional_entropy(probs_N_K_C)
    scores_N += compute_entropy(probs_N_K_C)

    return scores_N

Determining BALD batch is straighforward then given the scores:

### Finding a BALD batch 

In [17]:
# exports


def get_bald_batch(
    probs_N_K_C: torch.Tensor, batch_size: int, dtype=None, device=None
) -> CandidateBatch:
    N, K, C = probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    scores_N = get_bald_scores(probs_N_K_C, dtype=dtype, device=device)

    candiate_scores, candidate_indices = torch.topk(scores_N, batch_size)

    return CandidateBatch(candiate_scores.tolist(), candidate_indices.tolist())

### Example

In [18]:
get_bald_batch(ys_ws.double(), 4)

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(HTML(value='Entropy'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

CandidateBatch(scores=[0.030715639666234917, 0.030715639666234695, 0.030715639666234695, 0.030715639666234695], indices=[3, 2, 0, 1])

## BALD-ICAL

The computation for BALD-ICAL is simple. We need to keep track of two separate (Batch)BALD terms:

$$\mathrm{I}\left[(y)_{B} ; \omega \mid(x)_{B}, D_{T}\right]-\mathrm{I}\left[(y)_{B} ; \omega \mid(x)_{B}, D_{U} \cup D_{T}\right].$$


In [29]:
# exports


def get_batchbaldical_batch(
    training_probs_N_K_C: torch.Tensor,
    pool_probs_N_K_C: torch.Tensor,
    batch_size: int,
    num_samples: int,
    dtype=None,
    device=None,
) -> CandidateBatch:
    assert training_probs_N_K_C.shape == pool_probs_N_K_C.shape
    N, K, C = training_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    training_batchbald_scorer = BatchBALDScorer(
        training_probs_N_K_C,
        max_size=batch_size - 1,
        num_samples=num_samples,
        dtype=dtype,
        device=device,
    )

    pool_batchbald_scorer = BatchBALDScorer(
        pool_probs_N_K_C,
        max_size=batch_size - 1,
        num_samples=num_samples,
        dtype=dtype,
        device=device,
    )

    # We always keep these on the CPU.
    training_scores_N = training_batchbald_scorer.alloc_scores()
    pool_scores_N = pool_batchbald_scorer.alloc_scores()

    for i in tqdm(range(batch_size), desc="BatchBALD", leave=False):
        if i > 0:
            latest_index = candidate_indices[-1]
            training_batchbald_scorer.append_to_batch(latest_index)
            pool_batchbald_scorer.append_to_batch(latest_index)

        training_batchbald_scorer.compute_scores(training_scores_N)
        pool_batchbald_scorer.compute_scores(pool_scores_N)

        scores_N = training_scores_N - pool_scores_N
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

    return CandidateBatch(candidate_scores, candidate_indices)

### Example

#### Pleasing example of the case when predictions match (full overlap)

In [32]:
get_batchbaldical_batch(ys_ws.double(), ys_ws.double(), 4, 1000, dtype=torch.double)

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(HTML(value='BatchBALD'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

CandidateBatch(scores=[0.0, 0.0, 0.0, 0.0], indices=[0, 1, 2, 3])

In [34]:
get_batchbaldical_batch(ys_ws.double(), torch.ones_like(ys_ws).double(), 4, 1000, dtype=torch.double)

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(HTML(value='Conditional Entropy'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(HTML(value='BatchBALD'), FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

HBox(children=(HTML(value='ExactJointEntropy.compute_batch'), FloatProgress(value=0.0, max=4.0), HTML(value=''…

CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.08691070514744625, 0.11275304532467789], indices=[1, 0, 2, 3])