In [None]:
# default_exp batchbald

In [None]:
# hide

import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


# BatchBALD Algorithm
> Greedy algorithm and score computation

First, we will implement two helper classes to compute conditional entropies $H[y_i|w]$ and entropies $H[y_i]$. 
Then, we will implement BatchBALD and BALD.

In [None]:
# exports
import math
from dataclasses import dataclass
from typing import List

import numpy as np
import torch
from toma import toma
from blackhc.progress_bar import with_progress_bar, create_progress_bar

from batchbald_redux import joint_entropy

We are going to define a couple of sampled distributions to use for our testing our code.

$K=20$ means 20 inference samples.

In [None]:
K = 20

In [None]:
def get_mixture_prob_dist(p1, p2, m):
    return (1.0 - m) * np.asarray(p1) + m * np.asarray(p2)


p1 = [0.7, 0.1, 0.1, 0.1]
p2 = [0.3, 0.3, 0.2, 0.2]
y1_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.7, 0.1, 0.1]
p2 = [0.2, 0.3, 0.3, 0.2]
y2_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.7, 0.1]
p2 = [0.2, 0.2, 0.3, 0.3]
y3_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.1, 0.7]
p2 = [0.3, 0.2, 0.2, 0.3]
y4_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]


def nested_to_tensor(l):
    return torch.stack(list(map(torch.as_tensor, l)))


ys_ws = nested_to_tensor([y1_ws, y2_ws, y3_ws, y4_ws])

In [None]:
# hide

p = [0.25, 0.25, 0.25, 0.25]
yu_ws = [p for m in range(K)]
yus_ws = nested_to_tensor([yu_ws] * 4)

In [None]:
ys_ws.shape

torch.Size([4, 20, 4])

## Conditional Entropies and Batched Entropies

To start with, we write two functions to compute the conditional entropy $H[y_i|w]$ and the entropy $H[y_i]$ for each input sample.

In [None]:
def compute_conditional_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = create_progress_bar(N, tqdm_args=dict(desc="Conditional Entropy", leave=False))
    pbar.start()
    
    @toma.execute.chunked(probs_N_K_C, 1024)
    def compute(probs_n_K_C, start: int, end: int):
        nats_n_K_C = probs_n_K_C * torch.log(probs_n_K_C)
        nats_n_K_C[probs_n_K_C == 0] = 0.0

        entropies_N[start:end].copy_(-torch.sum(nats_n_K_C, dim=(1, 2)) / K)
        pbar.update(end - start)

    pbar.finish()

    return entropies_N


def compute_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = create_progress_bar(N, tqdm_args=dict(desc="Entropy", leave=False))
    pbar.start()

    @toma.execute.chunked(probs_N_K_C, 1024)
    def compute(probs_n_K_C, start: int, end: int):
        mean_probs_n_C = probs_n_K_C.mean(dim=1)
        nats_n_C = mean_probs_n_C * torch.log(mean_probs_n_C)
        nats_n_C[mean_probs_n_C == 0] = 0.0

        entropies_N[start:end].copy_(-torch.sum(nats_n_C, dim=1))
        pbar.update(end - start)

    pbar.finish()

    return entropies_N

In [None]:
# Make sure everything is computed correctly.

assert np.allclose(compute_conditional_entropy(yus_ws), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)
assert np.allclose(compute_entropy(yus_ws), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

However, our neural networks usually use a `log_softmax` as final layer. To avoid having to call `.exp_()`, which is easy to miss and annoying to debug, we will instead use a version that uses `log_probs` instead of `probs`.

In [None]:
# exports


def compute_conditional_entropy(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = create_progress_bar(N, tqdm_args=dict(desc="Conditional Entropy", leave=False))
    pbar.start()

    @toma.execute.chunked(log_probs_N_K_C, 1024)
    def compute(log_probs_n_K_C, start: int, end: int):
        nats_n_K_C = log_probs_n_K_C * torch.exp(log_probs_n_K_C)

        entropies_N[start:end].copy_(-torch.sum(nats_n_K_C, dim=(1, 2)) / K)
        pbar.update(end - start)

    pbar.finish()

    return entropies_N


def compute_entropy(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = create_progress_bar(N, tqdm_args=dict(desc="Entropy", leave=False))
    pbar.start()

    @toma.execute.chunked(log_probs_N_K_C, 1024)
    def compute(log_probs_n_K_C, start: int, end: int):
        mean_log_probs_n_C = torch.logsumexp(log_probs_n_K_C, dim=1) - math.log(K)
        nats_n_C = mean_log_probs_n_C * torch.exp(mean_log_probs_n_C)

        entropies_N[start:end].copy_(-torch.sum(nats_n_C, dim=1))
        pbar.update(end - start)

    pbar.finish()

    return entropies_N

In [None]:
# hide

# Make sure everything is computed correctly.
assert np.allclose(compute_conditional_entropy(yus_ws.log()), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)
assert np.allclose(compute_entropy(yus_ws.log()), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

### Examples

In [None]:
conditional_entropies = compute_conditional_entropy(ys_ws.log())

print(conditional_entropies)

assert np.allclose(conditional_entropies, [1.2069, 1.2069, 1.2069, 1.2069], atol=0.01)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([1.2069, 1.2069, 1.2069, 1.2069], dtype=torch.float64)


In [None]:
entropies = compute_entropy(ys_ws.log())

print(entropies)

assert np.allclose(entropies, [1.2376, 1.2376, 1.2376, 1.2376], atol=0.01)

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([1.2376, 1.2376, 1.2376, 1.2376], dtype=torch.float64)


## BatchBALD

To compute BatchBALD exactly for a candidate batch, we'd have to compute $I[(y_b)_B;w] = H[(y_b)_B] - H[(y_b)_B|w]$.

As the $y_b$ are independent given $w$, we can simplify $H[(y_b)_B|w] = \sum_b H[y_b|w]$.

Furthermore, we use a greedy algorithm to build up the candidate batch, so $y_1,\dots,y_{B-1}$ will stay fixed as we determine $y_{B}$. We compute
$H[(y_b)_{B-1}, y_i] - H[y_i|w]$ for each pool element $y_i$ and add the highest scorer as $y_{B}$.

We don't utilize the last optimization here in order to compute the actual scores.


### In the Paper

![BatchBALD algorithm in the paper](batchbald_algorithm.png)


### Implementation

In [None]:
# exports


@dataclass
class CandidateBatch:
    scores: List[float]
    indices: List[int]

In [None]:
# exports


class BatchBALDScorer:
    log_probs_N_K_C: torch.tensor
    conditional_entropies_N: torch.Tensor
    batch_joint_entropy: joint_entropy.JointEntropy
    batch_conditional_entropies: torch.Tensor

    def __init__(self, log_probs_N_K_C, *, max_size, num_samples: int, dtype=None, device=None):
        N, K, C = log_probs_N_K_C.shape
        self.log_probs_N_K_C = log_probs_N_K_C

        self.conditional_entropies_N = compute_conditional_entropy(self.log_probs_N_K_C)
        self.batch_conditional_entropies = 0

        self.batch_joint_entropy = joint_entropy.DynamicJointEntropy(
            num_samples, max_size, K, C, dtype=dtype, device=device
        )

    def append_to_batch(self, index: int):
        self.batch_joint_entropy.add_variables(self.log_probs_N_K_C[index : index + 1])
        self.batch_conditional_entropies += self.conditional_entropies_N[index].clone()

    def alloc_scores(self) -> torch.Tensor:
        # We always keep these on the CPU.
        scores_N = torch.empty(
            self.log_probs_N_K_C.shape[0],
            dtype=torch.double,
            pin_memory=torch.cuda.is_available(),
        )
        return scores_N

    def compute_scores(self, out_scores_N: torch.Tensor):
        self.batch_joint_entropy.compute_batch(self.log_probs_N_K_C, output_entropies_B=out_scores_N)

        out_scores_N -= self.conditional_entropies_N + self.batch_conditional_entropies

        return out_scores_N


def get_batch_bald_batch(
    log_probs_N_K_C: torch.Tensor, *, batch_size: int, num_samples: int, dtype=None, device=None
) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    batchbald_scorer = BatchBALDScorer(
        log_probs_N_K_C,
        max_size=batch_size - 1,
        num_samples=num_samples,
        dtype=dtype,
        device=device,
    )

    # We always keep these on the CPU.
    scores_N = batchbald_scorer.alloc_scores()

    for i in with_progress_bar(range(batch_size), tqdm_args=dict(desc="BatchBALD", leave=False)):
        if i > 0:
            latest_index = candidate_indices[-1]
            batchbald_scorer.append_to_batch(latest_index)

        batchbald_scorer.compute_scores(scores_N)
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

    return CandidateBatch(candidate_scores, candidate_indices)

In [None]:
def get_batchbald_batch_plain(
    log_probs_N_K_C: torch.Tensor,
    *,
    batch_size: int,
    num_samples: int,
    dtype=None,
    device=None,
) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    conditional_entropies_N = compute_conditional_entropy(log_probs_N_K_C)

    batch_joint_entropy = joint_entropy.DynamicJointEntropy(
        num_samples, batch_size - 1, K, C, dtype=dtype, device=device
    )

    # We always keep these on the CPU.
    scores_N = torch.empty(N, dtype=torch.double, pin_memory=torch.cuda.is_available())

    for i in with_progress_bar(range(batch_size), tqdm_args=dict(desc="BatchBALD", leave=False)):
        if i > 0:
            latest_index = candidate_indices[-1]
            batch_joint_entropy.add_variables(log_probs_N_K_C[latest_index : latest_index + 1])

        shared_conditional_entropies = conditional_entropies_N[candidate_indices].sum()

        batch_joint_entropy.compute_batch(log_probs_N_K_C, output_entropies_B=scores_N)

        scores_N -= conditional_entropies_N + shared_conditional_entropies
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

    return CandidateBatch(candidate_scores, candidate_indices)

### Example

In [None]:
get_batch_bald_batch(ys_ws.log().double(), batch_size=4, num_samples=1000, dtype=torch.double)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.0869107051474467, 0.11275304532467878], indices=[1, 0, 2, 3])

In [None]:
get_batchbald_batch_plain(ys_ws.log().double(), batch_size=4, num_samples=1000, dtype=torch.double)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.0869107051474467, 0.11275304532467878], indices=[1, 0, 2, 3])

## BALD

BALD is the same as BatchBALD, except that we evaluate points individually, by computing $I[y_i;w]$ for each, and then take the top $B$ scorers.

### BALD scores

Sometimes, we want to obtain BALD scores for all samples as measure of epistemic uncertainty.

In [None]:
# exports


def get_bald_scores(log_probs_N_K_C: torch.Tensor, *, dtype=None, device=None) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape

    scores_N = -compute_conditional_entropy(log_probs_N_K_C)
    scores_N += compute_entropy(log_probs_N_K_C)

    return scores_N

Determining BALD batch is straighforward then given the scores:

### Finding a BALD batch 

In [None]:
# exports


def get_top_k_scorers(scores_N: torch.Tensor, *, batch_size: int) -> CandidateBatch:
    N = len(scores_N)
    batch_size = min(batch_size, N)

    candidate_scores, candidate_indices = torch.topk(scores_N, batch_size)

    return CandidateBatch(candidate_scores.tolist(), candidate_indices.tolist())


def get_bald_batch(log_probs_N_K_C: torch.Tensor, *, batch_size: int, dtype=None, device=None) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape

    scores_N = get_bald_scores(log_probs_N_K_C, dtype=dtype, device=device)

    return get_top_k_scorers(scores_N, batch_size=batch_size)

### Example

In [None]:
get_bald_batch(ys_ws.log().double(), batch_size=4)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.030715639666234917, 0.030715639666234917, 0.030715639666234695], indices=[1, 2, 0, 3])

## BALD-ICAL

The computation for BALD-ICAL is simple. We need to keep track of two separate (Batch)BALD terms:

$$\mathrm{I}\left[(y)_{B} ; \omega \mid(x)_{B}, D_{T}\right]-\mathrm{I}\left[(y)_{B} ; \omega \mid(x)_{B}, D_{U} \cup D_{T}\right].$$


In [None]:
# exports


def get_batch_eval_bald_batch(
    training_log_probs_N_K_C: torch.Tensor,
    eval_log_probs_N_K_C: torch.Tensor,
    *,
    batch_size: int,
    num_samples: int,
    dtype=None,
    device=None,
) -> CandidateBatch:
    assert training_log_probs_N_K_C.shape == eval_log_probs_N_K_C.shape
    N, K, C = training_log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    training_batchbald_scorer = BatchBALDScorer(
        training_log_probs_N_K_C,
        max_size=batch_size - 1,
        num_samples=num_samples,
        dtype=dtype,
        device=device,
    )

    pool_batchbald_scorer = BatchBALDScorer(
        eval_log_probs_N_K_C,
        max_size=batch_size - 1,
        num_samples=num_samples,
        dtype=dtype,
        device=device,
    )

    # We always keep these on the CPU.
    training_scores_N = training_batchbald_scorer.alloc_scores()
    pool_scores_N = pool_batchbald_scorer.alloc_scores()

    for i in with_progress_bar(range(batch_size), tqdm_args=dict(desc="BatchBALD", leave=False)):
        if i > 0:
            latest_index = candidate_indices[-1]
            training_batchbald_scorer.append_to_batch(latest_index)
            pool_batchbald_scorer.append_to_batch(latest_index)

        training_batchbald_scorer.compute_scores(training_scores_N)
        pool_batchbald_scorer.compute_scores(pool_scores_N)

        scores_N = training_scores_N - pool_scores_N
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

    return CandidateBatch(candidate_scores, candidate_indices)

### Example

#### Pleasing example of the case when predictions match (full overlap)

In [None]:
get_batch_eval_bald_batch(ys_ws.log().double(), ys_ws.log().double(), batch_size=4, num_samples=1000, dtype=torch.double)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.0, 0.0, 0.0, 0.0], indices=[0, 1, 2, 3])

In [None]:
get_batch_eval_bald_batch(
    ys_ws.log().double(), torch.zeros_like(ys_ws).double(), batch_size=4, num_samples=1000, dtype=torch.double
)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.0869107051474467, 0.11275304532467878], indices=[1, 0, 2, 3])

## ThompsonBALD

We compute the joint entropy as in BALD, but for the conditional entropy, we simply compute the entropy of a single $\omega$ sample and then pick the highest scorer, before we another sample etc.

In [None]:
# exports


def compute_each_conditional_entropy(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape

    entropies_N_K = torch.empty((N, K), dtype=torch.double)

    pbar = create_progress_bar(N, tqdm_args=dict(desc="Entropy", leave=False))
    pbar.start()

    @toma.execute.chunked(log_probs_N_K_C, 1024)
    def compute(log_probs_n_K_C, start: int, end: int):
        nats_n_K_C = log_probs_n_K_C * torch.exp(log_probs_n_K_C)

        entropies_N_K[start:end].copy_(-torch.sum(nats_n_K_C, dim=2))
        pbar.update(end - start)

    pbar.finish()

    return entropies_N_K


def get_thompson_bald_batch(
    log_probs_N_K_C: torch.Tensor, *, batch_size: int, dtype=None, device=None
) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape
    assert K >= batch_size

    batch_size = min(batch_size, N)

    entropy_N = compute_entropy(log_probs_N_K_C)
    all_conditional_entropies_N_K = compute_each_conditional_entropy(log_probs_N_K_C)

    thompson_bald_scores_N_K = entropy_N[:, None] - all_conditional_entropies_N_K

    candidate_scores, candidate_indices = [], []
    for b in range(batch_size):
        candidate_score, candidate_index = thompson_bald_scores_N_K[:, b].max(dim=0)

        candidate_scores.append(candidate_score.item())
        candidate_indices.append(candidate_index.item())

        thompson_bald_scores_N_K[candidate_index] = -float("inf")

    return CandidateBatch(candidate_scores, candidate_indices)

In [None]:
get_thompson_bald_batch(ys_ws.log().double(), batch_size=4, dtype=torch.double)

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.29714917957723086, 0.25731026605631957, 0.21965481456439662, 0.18409109389668443], indices=[2, 0, 1, 3])

## RandomBALD Baseline

We take the top $C \times B$ and randomly pick $B$ candidates from that.

In [None]:
# exports


def get_top_random_scorers(scores_N: torch.Tensor, *, num_classes: int, batch_size: int) -> CandidateBatch:
    N = len(scores_N)
    batch_size = min(batch_size, N)

    L = min(batch_size * num_classes, N)

    candidate_scores, candidate_indices = torch.topk(scores_N, L)

    sub_indices = torch.randperm(L)[:batch_size]

    return CandidateBatch(candidate_scores[sub_indices].tolist(), candidate_indices[sub_indices].tolist())


def get_random_bald_batch(log_probs_N_K_C: torch.Tensor, *, batch_size: int, dtype=None, device=None) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape

    scores_N = get_bald_scores(log_probs_N_K_C, dtype=dtype, device=device)

    return get_top_random_scorers(scores_N, num_classes=C, batch_size=batch_size)

In [None]:
get_random_bald_batch(ys_ws.log().double(), batch_size=4, dtype=torch.double)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.030715639666234917, 0.030715639666234917, 0.030715639666234695], indices=[2, 1, 0, 3])

## Additional BALD-ICAL variants

Instead of using BatchBALD, let's compute BALD directly and use either the top-k, TopRandom or Thomp

In [None]:
# exports


def get_eval_bald_scores(
    training_log_probs_N_K_C: torch.Tensor,
    eval_log_probs_N_K_C: torch.Tensor,
    *,
    dtype=None,
    device=None,
) -> torch.Tensor:
    assert training_log_probs_N_K_C.shape == eval_log_probs_N_K_C.shape

    training_scores_N = get_bald_scores(training_log_probs_N_K_C, dtype=dtype, device=device)
    pool_scores_N = get_bald_scores(eval_log_probs_N_K_C, dtype=dtype, device=device)

    scores_N = training_scores_N - pool_scores_N

    return scores_N


def get_eval_bald_batch(
    training_log_probs_N_K_C: torch.Tensor,
    pool_log_probs_N_K_C: torch.Tensor,
    *,
    batch_size: int,
    dtype=None,
    device=None,
) -> CandidateBatch:
    return get_top_k_scorers(
        get_eval_bald_scores(training_log_probs_N_K_C, pool_log_probs_N_K_C, dtype=dtype, device=device),
        batch_size=batch_size,
    )


def get_top_random_eval_bald_batch(
    training_log_probs_N_K_C: torch.Tensor,
    pool_log_probs_N_K_C: torch.Tensor,
    *,
    batch_size: int,
    num_classes: int,
    dtype=None,
    device=None,
) -> CandidateBatch:
    return get_top_random_scorers(
        get_eval_bald_scores(training_log_probs_N_K_C, pool_log_probs_N_K_C, dtype=dtype, device=device),
        batch_size=batch_size,
        num_classes=num_classes,
    )

In [None]:
get_eval_bald_batch(ys_ws.log().double(), ys_ws.log().double(), batch_size=4, dtype=torch.double)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.0, 0.0, 0.0, 0.0], indices=[2, 3, 0, 1])

In [None]:
get_top_random_eval_bald_batch(
    ys_ws.log().double(), ys_ws.log().double(), batch_size=4, num_classes=10, dtype=torch.double
)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.0, 0.0, 0.0, 0.0], indices=[1, 2, 0, 3])

## TemperedBALD

Use temperature-scaled BALD scores for importance sampling.

In [None]:
# exports


def get_sampled_tempered_scorers(scores_N: torch.Tensor, *, temperature: float, batch_size: int) -> CandidateBatch:
    N = len(scores_N)
    batch_size = min(batch_size, N)

    tempered_scores_N = scores_N ** temperature
    tempered_scores_N[tempered_scores_N < 0] = 0.0
    partition_constant = tempered_scores_N.sum()
    p = tempered_scores_N / partition_constant

    candidate_indices = np.random.choice(N, size=batch_size, replace=False, p=p)
    candidate_scores = scores_N[candidate_indices]

    return CandidateBatch(candidate_scores.tolist(), candidate_indices.tolist())

In [None]:
get_sampled_tempered_scorers(get_bald_scores(ys_ws.log().double()), temperature=10, batch_size=2)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.030715639666234917], indices=[1, 0])

## ICAL

As part of an ablation (and to see how it performs), we can also compute the ICAL score.

In [None]:
# exports


def get_eig_scores(
    training_log_probs_N_K_C: torch.Tensor,
    pool_log_probs_N_K_C: torch.Tensor,
    *,
    dtype=None,
    device=None,
) -> torch.Tensor:
    assert training_log_probs_N_K_C.shape == pool_log_probs_N_K_C.shape

    N, K, C = training_log_probs_N_K_C.shape

    scores_N = compute_entropy(training_log_probs_N_K_C) - compute_entropy(pool_log_probs_N_K_C)

    return scores_N

In [None]:
get_eig_scores(ys_ws.log().double(), ys_ws.log().double())

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([0., 0., 0., 0.], dtype=torch.float64)

In [None]:
# exports


# TODO: refactor the BatchBALDScorer to deduplicate some of this?
def get_batch_eig_batch(
    training_log_probs_N_K_C: torch.Tensor,
    pool_log_probs_N_K_C: torch.Tensor,
    *,
    batch_size: int,
    num_samples: int,
    dtype=None,
    device=None,
) -> CandidateBatch:
    assert training_log_probs_N_K_C.shape == pool_log_probs_N_K_C.shape
    N, K, C = training_log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    training_joint_entropy = joint_entropy.DynamicJointEntropy(
        num_samples, batch_size, K, C, dtype=dtype, device=device
    )

    pool_joint_entropy = joint_entropy.DynamicJointEntropy(num_samples, batch_size, K, C, dtype=dtype, device=device)

    # We always keep these on the CPU.
    training_scores_N = torch.empty(
        N,
        dtype=torch.double,
        pin_memory=torch.cuda.is_available(),
    )

    pool_scores_N = torch.empty(
        N,
        dtype=torch.double,
        pin_memory=torch.cuda.is_available(),
    )

    for i in with_progress_bar(range(batch_size), tqdm_args=dict(desc="BatchBALD", leave=False)):
        if i > 0:
            latest_index = candidate_indices[-1]
            training_joint_entropy.add_variables(training_log_probs_N_K_C[latest_index : latest_index + 1])
            pool_joint_entropy.add_variables(pool_log_probs_N_K_C[latest_index : latest_index + 1])

        training_joint_entropy.compute_batch(training_log_probs_N_K_C, output_entropies_B=training_scores_N)
        pool_joint_entropy.compute_batch(pool_log_probs_N_K_C, output_entropies_B=pool_scores_N)

        scores_N = training_scores_N - pool_scores_N
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

    return CandidateBatch(candidate_scores, candidate_indices)

In [None]:
get_batch_eig_batch(ys_ws.log().double(), ys_ws.log().double(), batch_size=4, num_samples=1000, dtype=torch.double)

BatchBALD:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.0, 0.0, 0.0, 0.0], indices=[0, 1, 2, 3])

## CoreSetBALD

(Following the new notation.)

Instead of computing $I[Y;\omega|x]$, we use our knowledge of the labels and compute: $$I[y;\omega|x]= H[y|x] - \mathbb{E}_{p(\omega|y,x)} H[y|x,\omega].$$

In [None]:
# exports

def get_coreset_bald_scores_from_predictions(
    log_probs_N_K_C: torch.Tensor, target_probs_N_C: torch.Tensor, *, dtype=None, device=None
) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape
    assert target_probs_N_C.shape == (K, C)

    log_prob_mean_N_C = torch.logsumexp(log_probs_N_K_C, dim=1) - np.log(K)

    entropy_N_C = -log_prob_mean_N_C
    conditional_entropy = -torch.mean(log_probs_N_K_C * log_probs_N_K_C.exp(), dim=1) / log_prob_mean_N_C.exp()
    mutual_bits_N_C = entropy_N_C - conditional_entropy

    cross_mutual_information = torch.sum(target_probs_N_C * mutual_bits_N_C, dim=1)

    return cross_mutual_information


def get_coreset_bald_scores(
    log_probs_N_K_C: torch.Tensor, labels_N: torch.Tensor, *, dtype=None, device=None
) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape

    labels_N_1_1 = labels_N[:, None, None]
    log_probs_N_K = (
        joint_entropy.gather_expand(log_probs_N_K_C, dim=2, index=labels_N_1_1)
        .squeeze(2)
        .to(dtype=dtype, device=device)
    )

    log_prob_mean_N = torch.logsumexp(log_probs_N_K, dim=1) - np.log(K)

    lhs = -log_prob_mean_N
    rhs = -torch.mean(log_probs_N_K * log_probs_N_K.exp(), dim=1) / log_prob_mean_N.exp()

    coreset_infogain = lhs - rhs

    return coreset_infogain

In [None]:
get_coreset_bald_scores(ys_ws.log().double(), torch.tensor([0, 1, 2, 3])), [get_coreset_bald_scores(ys_ws.log().double(), torch.tensor([i, i, i, i])) for i in range(3)]

(tensor([0.0300, 0.0300, 0.0300, 0.0300], dtype=torch.float64),
 [tensor([0.0300, 0.0207, 0.0207, 0.0474], dtype=torch.float64),
  tensor([0.0474, 0.0300, 0.0207, 0.0207], dtype=torch.float64),
  tensor([0.0207, 0.0474, 0.0300, 0.0207], dtype=torch.float64)])

## BatchCoreSetBALD

(Following the new notation.)

The batch version of this acquisition function can be computed more easily:

$$ \operatorname{I}[(y)_B;\omega|(x)_B] = \operatorname{H}[(y)_B|(x)_B] - \mathbb{E}_{p(\omega|(y)_B, (x)_B)} \operatorname{H}[(y)_B|(x)_B, \omega], $$

where $p(\omega|(y)_B, (x)_B) = \frac{ p((y)_B| (x)_B, \omega) p(\omega) }{ p((y)_B| (x)_B) }$ as usual, and we make use of the independence of the $(y)_B$ given $\omega$.

We can make this efficient for computing scores in parallel by using:
$$p((y)_B|(x)_B, \omega) = p(y_B|x_B, \omega) \; p((y)_{B-1}|(x)_{B-1}, \omega).$$

### Do we have sub-modularity?

Unclear.

In [None]:
# exports


def get_batch_coreset_bald_batch(
    log_probs_N_K_C: torch.Tensor, labels_N: torch.Tensor, *, batch_size: int, dtype=None, device=None
) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    labels_N_1_1 = labels_N[:, None, None]
    log_probs_N_K = (
        joint_entropy.gather_expand(log_probs_N_K_C, dim=2, index=labels_N_1_1)
        .squeeze(2)
        .to(dtype=dtype, device=device)
    )

    # p((y)_{B-1}|(x)_{B-1}, \omega)
    log_probs_join_batch_K = torch.zeros_like(log_probs_N_K[0], dtype=dtype, device=device)

    for i in with_progress_bar(range(batch_size), tqdm_args=dict(desc="BatchCoreSetBALD", leave=False)):
        # p((y)_B|(x)_B, \omega) = p(y_B|x_B, \omega) * p((y)_{B-1}|(x)_{B-1}, \omega)
        log_prob_joint_N_K = log_probs_N_K + log_probs_join_batch_K[None, :]

        if False:
            # Marginalize over \omega (but using sum not mean):
            # + np.log(K) compared to actual joint log prob
            log_prob_joint_sum_N_1 = log_prob_joint_N_K.logsumexp(dim=1, keepdim=True)
            # p(\omega) cancels out in:
            # p(\omega|(y)_B, (x)_B) = \frac{ p((y)_B| (x)_B, \omega) p(\omega) }{ p((y)_B| (x)_B) }
            log_prob_conditional_omega_joint_N_K = log_prob_joint_N_K - log_prob_joint_sum_N_1
            conditional_entropy_joint_N = -torch.sum(
                log_prob_conditional_omega_joint_N_K.exp() * log_prob_joint_N_K, dim=1
            )
            entropy_joint_N = -log_prob_joint_sum_N_1.squeeze(1) + np.log(K)
            scores_N = entropy_joint_N - conditional_entropy_joint_N
        else:
            # Marginalize over \omega (but using sum not mean):
            # + np.log(K) compared to actual joint log prob
            prob_joint_sum_N = log_prob_joint_N_K.exp().sum(dim=1, keepdim=False)
            # p(\omega) cancels out in:
            # p(\omega|(y)_B, (x)_B) = \frac{ p((y)_B| (x)_B, \omega) p(\omega) }{ p((y)_B| (x)_B) }
            conditional_entropy_joint_N = (
                -torch.sum(log_prob_joint_N_K.exp() * log_prob_joint_N_K, dim=1) / prob_joint_sum_N
            )
            entropy_joint_N = -prob_joint_sum_N.log() + np.log(K)
            scores_N = entropy_joint_N - conditional_entropy_joint_N

        # Select candidate
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

        # Update log_probs_join_batch_K
        log_probs_join_batch_K = log_prob_joint_N_K[candidate_index]

    return CandidateBatch(candidate_scores, candidate_indices)

In [None]:
get_coreset_bald_scores(ys_ws.log().double(), torch.tensor([0, 1, 2, 3])), [get_coreset_bald_scores(ys_ws.log().double(), torch.tensor([i, i, i, i])) for i in range(3)]

(tensor([0.0300, 0.0300, 0.0300, 0.0300], dtype=torch.float64),
 [tensor([0.0300, 0.0207, 0.0207, 0.0474], dtype=torch.float64),
  tensor([0.0474, 0.0300, 0.0207, 0.0207], dtype=torch.float64),
  tensor([0.0207, 0.0474, 0.0300, 0.0207], dtype=torch.float64)])

In [None]:
ys_ws.shape

torch.Size([4, 20, 4])

In [None]:
get_batch_coreset_bald_batch(ys_ws.log().double(), torch.tensor([0, 1, 2, 3]), batch_size=4, dtype=torch.double)

BatchCoreSetBALD:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030021323375763354, 0.10871562110954991, 0.21684316724892772, 0.3375447132429139], indices=[0, 1, 2, 3])

## DirichletBALD (Unclear/TODO)

This is enspired by Energy-Based Models and Dirichlet distributions. Instead of working with the probabilities after the Softmax layer, we use the logits directly and view them log concentrations of a Dirichlet distribution.

We can combine this with MC Dropout to get different Dirichlet samples by averaging the log concentrations. This leads to the exact same computation as before (geometric averaging of the probabilities). 

This is different than fitting a Dirichlet distribution. Taking the log concentrations instead of probabilities does not throw away "density" information from the model.

We could recover a mutual information term using the Dirichlet assumption then?

(In general, it is not entirely clear to me how to combine them. One path could be to use a conjugate prior distribution and taking the mean/mode of that. The conjugate prior of a Dirichlet distribution is the Boojum distribution, which is quite complex and does not provide an analytical solution for computing its mean or mode.



In [None]:
# def get_dirichlet_bald_scores(logits_N_K_C: torch.Tensor, *,
#     dtype=None,
#     device=None) -> torch.Tensor: