In [None]:
# default_exp batchbald

In [None]:
# hide

import blackhc.project.script
from nbdev.showdoc import *

# BatchBALD Algorithm
> Greedy algorithm and score computation

First, we will implement two helper classes to compute conditional entropies $H[y_i|w]$ and entropies $H[y_i]$. 
Then, we will implement BatchBALD and BALD.

In [None]:
# exports
import math
from dataclasses import dataclass
from enum import Enum
from typing import List

import numpy as np
import scipy.stats
import torch
from blackhc.progress_bar import create_progress_bar, with_progress_bar
from toma import toma

from batchbald_redux import joint_entropy

We are going to define a couple of sampled distributions to use for our testing our code.

$K=20$ means 20 inference samples.

In [None]:
K = 20

In [None]:
def get_mixture_prob_dist(p1, p2, m):
    return (1.0 - m) * np.asarray(p1) + m * np.asarray(p2)


p1 = [0.7, 0.1, 0.1, 0.1]
p2 = [0.3, 0.3, 0.2, 0.2]
y1_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.7, 0.1, 0.1]
p2 = [0.2, 0.3, 0.3, 0.2]
y2_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.7, 0.1]
p2 = [0.2, 0.2, 0.3, 0.3]
y3_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.1, 0.7]
p2 = [0.3, 0.2, 0.2, 0.3]
y4_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]


def nested_to_tensor(l):
    return torch.stack(list(map(torch.as_tensor, l)))


ys_ws = nested_to_tensor([y1_ws, y2_ws, y3_ws, y4_ws])

  return torch.stack(list(map(torch.as_tensor, l)))


In [None]:
# hide

p = [0.25, 0.25, 0.25, 0.25]
yu_ws = [p for m in range(K)]
yus_ws = nested_to_tensor([yu_ws] * 4)

In [None]:
ys_ws.shape

torch.Size([4, 20, 4])

## Conditional Entropies and Batched Entropies

To start with, we write two functions to compute the conditional entropy $H[y_i|w]$ and the entropy $H[y_i]$ for each input sample.

In [None]:
def compute_conditional_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    probs_N_K_C = probs_N_K_C.to(dtype=torch.double)

    pbar = create_progress_bar(N, tqdm_args=dict(desc="Conditional Entropy", leave=False))
    pbar.start()

    @toma.execute.chunked(probs_N_K_C, 1024)
    def compute(probs_n_K_C, start: int, end: int):
        nats_n_K_C = probs_n_K_C * torch.log(probs_n_K_C)
        nats_n_K_C[torch.isnan(nats_n_K_C)] = 0.0

        entropies_N[start:end].copy_(-torch.sum(nats_n_K_C, dim=(1, 2)) / K, non_blocking=True)
        pbar.update(end - start)

    pbar.finish()

    return entropies_N


def compute_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    probs_N_K_C = probs_N_K_C.to(dtype=torch.double)

    pbar = create_progress_bar(N, tqdm_args=dict(desc="Entropy", leave=False))
    pbar.start()

    @toma.execute.chunked(probs_N_K_C, 1024)
    def compute(probs_n_K_C, start: int, end: int):
        mean_probs_n_C = probs_n_K_C.mean(dim=1)
        nats_n_C = mean_probs_n_C * torch.log(mean_probs_n_C)
        nats_n_C[torch.isnan(nats_n_C)] = 0.0

        entropies_N[start:end].copy_(-torch.sum(nats_n_C, dim=1), non_blocking=True)
        pbar.update(end - start)

    pbar.finish()

    return entropies_N

In [None]:
# Make sure everything is computed correctly.

assert np.allclose(compute_conditional_entropy(yus_ws), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)
assert np.allclose(compute_entropy(yus_ws), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

However, our neural networks usually use a `log_softmax` as final layer. To avoid having to call `.exp_()`, which is easy to miss and annoying to debug, we will instead use a version that uses `log_probs` instead of `probs`.

In [None]:
# exports


def compute_conditional_entropy(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    log_probs_N_K_C = log_probs_N_K_C.to(torch.double)

    pbar = create_progress_bar(N, tqdm_args=dict(desc="Conditional Entropy", leave=False))
    pbar.start()

    @toma.execute.chunked(log_probs_N_K_C, 65536)
    def compute(log_probs_n_K_C, start: int, end: int):
        nats_n_K_C = log_probs_n_K_C * torch.exp(log_probs_n_K_C)
        nats_n_K_C[torch.isnan(nats_n_K_C)] = 0.0

        entropies_N[start:end].copy_(-torch.sum(nats_n_K_C, dim=(1, 2)) / K)
        pbar.update(end - start)

    pbar.finish()

    return entropies_N


def compute_entropy(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    log_probs_N_K_C = log_probs_N_K_C.to(torch.double)

    pbar = create_progress_bar(N, tqdm_args=dict(desc="Entropy", leave=False))
    pbar.start()

    @toma.execute.chunked(log_probs_N_K_C, 65536)
    def compute(log_probs_n_K_C, start: int, end: int):
        mean_log_probs_n_C = torch.logsumexp(log_probs_n_K_C, dim=1) - math.log(K)
        nats_n_C = mean_log_probs_n_C * torch.exp(mean_log_probs_n_C)
        nats_n_C[torch.isnan(nats_n_C)] = 0.0

        entropies_N[start:end].copy_(-torch.sum(nats_n_C, dim=1))
        pbar.update(end - start)

    pbar.finish()

    return entropies_N

In [None]:
# hide

# Make sure everything is computed correctly.
assert np.allclose(compute_conditional_entropy(yus_ws.log()), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)
assert np.allclose(compute_entropy(yus_ws.log()), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

### Examples

In [None]:
conditional_entropies = compute_conditional_entropy(ys_ws.log())

print(conditional_entropies)

assert np.allclose(conditional_entropies, [1.2069, 1.2069, 1.2069, 1.2069], atol=0.01)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([1.2069, 1.2069, 1.2069, 1.2069], dtype=torch.float64)


In [None]:
entropies = compute_entropy(ys_ws.log())

print(entropies)

assert np.allclose(entropies, [1.2376, 1.2376, 1.2376, 1.2376], atol=0.01)

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([1.2376, 1.2376, 1.2376, 1.2376], dtype=torch.float64)


## BatchBALD

To compute BatchBALD exactly for a candidate batch, we'd have to compute $I[(y_b)_B;w] = H[(y_b)_B] - H[(y_b)_B|w]$.

As the $y_b$ are independent given $w$, we can simplify $H[(y_b)_B|w] = \sum_b H[y_b|w]$.

Furthermore, we use a greedy algorithm to build up the candidate batch, so $y_1,\dots,y_{B-1}$ will stay fixed as we determine $y_{B}$. We compute
$H[(y_b)_{B-1}, y_i] - H[y_i|w]$ for each pool element $y_i$ and add the highest scorer as $y_{B}$.

We don't utilize the last optimization here in order to compute the actual scores.


### In the Paper

![BatchBALD algorithm in the paper](batchbald_algorithm.png)


### Implementation

In [None]:
# exports


@dataclass
class CandidateBatch:
    scores: List[float]
    indices: List[int]

In [None]:
# exports


class BatchBALDScorer:
    log_probs_N_K_C: torch.tensor
    conditional_entropies_N: torch.Tensor
    batch_joint_entropy: joint_entropy.JointEntropy
    batch_conditional_entropies: torch.Tensor

    def __init__(self, log_probs_N_K_C, *, max_size, num_samples: int, dtype=None, device=None):
        N, K, C = log_probs_N_K_C.shape
        self.log_probs_N_K_C = log_probs_N_K_C

        self.conditional_entropies_N = compute_conditional_entropy(self.log_probs_N_K_C)
        self.batch_conditional_entropies = 0

        self.batch_joint_entropy = joint_entropy.DynamicJointEntropy(
            num_samples, max_size, K, C, dtype=dtype, device=device
        )

    def append_to_batch(self, index: int):
        self.batch_joint_entropy.add_variables(self.log_probs_N_K_C[index : index + 1])
        self.batch_conditional_entropies += self.conditional_entropies_N[index].clone()

    def alloc_scores(self) -> torch.Tensor:
        # We always keep these on the CPU.
        scores_N = torch.empty(
            self.log_probs_N_K_C.shape[0],
            dtype=torch.double,
            pin_memory=torch.cuda.is_available(),
        )
        return scores_N

    def compute_scores(self, out_scores_N: torch.Tensor):
        self.batch_joint_entropy.compute_batch(self.log_probs_N_K_C, output_entropies_B=out_scores_N)

        out_scores_N -= self.conditional_entropies_N + self.batch_conditional_entropies

        return out_scores_N


def get_batch_bald_batch(
    log_probs_N_K_C: torch.Tensor, *, batch_size: int, num_samples: int, dtype=None, device=None
) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    batchbald_scorer = BatchBALDScorer(
        log_probs_N_K_C,
        max_size=batch_size - 1,
        num_samples=num_samples,
        dtype=dtype,
        device=device,
    )

    # We always keep these on the CPU.
    scores_N = batchbald_scorer.alloc_scores()

    for i in with_progress_bar(range(batch_size), tqdm_args=dict(desc="BatchBALD", leave=False)):
        if i > 0:
            latest_index = candidate_indices[-1]
            batchbald_scorer.append_to_batch(latest_index)

        batchbald_scorer.compute_scores(scores_N)
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

    return CandidateBatch(candidate_scores, candidate_indices)

In [None]:
def get_batchbald_batch_plain(
    log_probs_N_K_C: torch.Tensor,
    *,
    batch_size: int,
    num_samples: int,
    dtype=None,
    device=None,
) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    conditional_entropies_N = compute_conditional_entropy(log_probs_N_K_C)

    batch_joint_entropy = joint_entropy.DynamicJointEntropy(
        num_samples, batch_size - 1, K, C, dtype=dtype, device=device
    )

    # We always keep these on the CPU.
    scores_N = torch.empty(N, dtype=torch.double, pin_memory=torch.cuda.is_available())

    for i in with_progress_bar(range(batch_size), tqdm_args=dict(desc="BatchBALD", leave=False)):
        if i > 0:
            latest_index = candidate_indices[-1]
            batch_joint_entropy.add_variables(log_probs_N_K_C[latest_index : latest_index + 1])

        shared_conditional_entropies = conditional_entropies_N[candidate_indices].sum()

        batch_joint_entropy.compute_batch(log_probs_N_K_C, output_entropies_B=scores_N)

        scores_N -= conditional_entropies_N + shared_conditional_entropies
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

    return CandidateBatch(candidate_scores, candidate_indices)

### Example

In [None]:
get_batch_bald_batch(ys_ws.log().double(), batch_size=4, num_samples=1000, dtype=torch.double)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.0869107051474467, 0.11275304532467878], indices=[1, 0, 2, 3])

In [None]:
get_batchbald_batch_plain(ys_ws.log().double(), batch_size=4, num_samples=1000, dtype=torch.double)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.0869107051474467, 0.11275304532467878], indices=[1, 0, 2, 3])

## BALD

BALD is the same as BatchBALD, except that we evaluate_old points individually, by computing $I[y_i;w]$ for each, and then take the top $B$ scorers.

### BALD scores

Sometimes, we want to obtain BALD scores for all samples as measure of epistemic uncertainty.

In [None]:
# exports

# TODO: remove unused parameters!
def get_bald_scores(log_probs_N_K_C: torch.Tensor, *, dtype=None, device=None) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape

    scores_N = -compute_conditional_entropy(log_probs_N_K_C)
    scores_N += compute_entropy(log_probs_N_K_C)

    return scores_N

Determining BALD batch is straighforward then given the scores:

### Finding a BALD batch 

In [None]:
# exports


def get_top_k_scorers(scores_N: torch.Tensor, *, batch_size: int) -> CandidateBatch:
    N = len(scores_N)
    batch_size = min(batch_size, N)

    candidate_scores, candidate_indices = torch.topk(scores_N, batch_size)

    return CandidateBatch(candidate_scores.tolist(), candidate_indices.tolist())


def get_bald_batch(log_probs_N_K_C: torch.Tensor, *, batch_size: int, dtype=None, device=None) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape

    scores_N = get_bald_scores(log_probs_N_K_C, dtype=dtype, device=device)

    return get_top_k_scorers(scores_N, batch_size=batch_size)

### Example

In [None]:
get_bald_batch(ys_ws.log().double(), batch_size=4)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.030715639666234917, 0.030715639666234917, 0.030715639666234695], indices=[1, 2, 0, 3])

## EPIG-BALD

The computation for EPIG-BALD is simple. We need to keep track of two separate (Batch)BALD terms:

$$\mathrm{I}\left[(y)_{B} ; \omega \mid(x)_{B}, D_{T}\right]-\mathrm{I}\left[(y)_{B} ; \omega \mid(x)_{B}, D_{U} \cup D_{T}\right].$$


In [None]:
# exports


def get_batch_eval_bald_batch(
    training_log_probs_N_K_C: torch.Tensor,
    eval_log_probs_N_K_C: torch.Tensor,
    *,
    batch_size: int,
    num_samples: int,
    dtype=None,
    device=None,
) -> CandidateBatch:
    assert training_log_probs_N_K_C.shape == eval_log_probs_N_K_C.shape
    N, K, C = training_log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    training_batchbald_scorer = BatchBALDScorer(
        training_log_probs_N_K_C,
        max_size=batch_size - 1,
        num_samples=num_samples,
        dtype=dtype,
        device=device,
    )

    pool_batchbald_scorer = BatchBALDScorer(
        eval_log_probs_N_K_C,
        max_size=batch_size - 1,
        num_samples=num_samples,
        dtype=dtype,
        device=device,
    )

    # We always keep these on the CPU.
    training_scores_N = training_batchbald_scorer.alloc_scores()
    pool_scores_N = pool_batchbald_scorer.alloc_scores()

    for i in with_progress_bar(range(batch_size), tqdm_args=dict(desc="BatchBALD", leave=False)):
        if i > 0:
            latest_index = candidate_indices[-1]
            training_batchbald_scorer.append_to_batch(latest_index)
            pool_batchbald_scorer.append_to_batch(latest_index)

        training_batchbald_scorer.compute_scores(training_scores_N)
        pool_batchbald_scorer.compute_scores(pool_scores_N)

        scores_N = training_scores_N - pool_scores_N
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

    return CandidateBatch(candidate_scores, candidate_indices)

### Example

#### Pleasing example of the case when predictions match (full overlap)

In [None]:
get_batch_eval_bald_batch(
    ys_ws.log().double(), ys_ws.log().double(), batch_size=4, num_samples=1000, dtype=torch.double
)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.0, 0.0, 0.0, 0.0], indices=[0, 1, 2, 3])

In [None]:
get_batch_eval_bald_batch(
    ys_ws.log().double(), torch.zeros_like(ys_ws).double(), batch_size=4, num_samples=1000, dtype=torch.double
)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.0869107051474467, 0.11275304532467878], indices=[1, 0, 2, 3])

## ThompsonBALD

We compute the joint entropy as in BALD, but for the conditional entropy, we simply compute the entropy of a single $\omega$ sample and then pick the highest scorer, before we another sample etc.

In [None]:
# exports


def compute_each_conditional_entropy(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape

    entropies_N_K = torch.empty((N, K), dtype=torch.double)

    pbar = create_progress_bar(N, tqdm_args=dict(desc="Entropy", leave=False))
    pbar.start()

    @toma.execute.chunked(log_probs_N_K_C, 1024)
    def compute(log_probs_n_K_C, start: int, end: int):
        nats_n_K_C = log_probs_n_K_C * torch.exp(log_probs_n_K_C)

        entropies_N_K[start:end].copy_(-torch.sum(nats_n_K_C, dim=2))
        pbar.update(end - start)

    pbar.finish()

    return entropies_N_K


def get_thompson_bald_batch(
    log_probs_N_K_C: torch.Tensor, *, batch_size: int, dtype=None, device=None
) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape
    assert K >= batch_size

    batch_size = min(batch_size, N)

    entropy_N = compute_entropy(log_probs_N_K_C)
    all_conditional_entropies_N_K = compute_each_conditional_entropy(log_probs_N_K_C)

    thompson_bald_scores_N_K = entropy_N[:, None] - all_conditional_entropies_N_K

    candidate_scores, candidate_indices = [], []
    for b in range(batch_size):
        candidate_score, candidate_index = thompson_bald_scores_N_K[:, b].max(dim=0)

        candidate_scores.append(candidate_score.item())
        candidate_indices.append(candidate_index.item())

        thompson_bald_scores_N_K[candidate_index] = -float("inf")

    return CandidateBatch(candidate_scores, candidate_indices)

In [None]:
get_thompson_bald_batch(ys_ws.log().double(), batch_size=4, dtype=torch.double)

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.29714917957723086, 0.25731026605631957, 0.21965481456439662, 0.18409109389668443], indices=[2, 0, 1, 3])

## RandomBALD Baseline

We take the top $C \times B$ and randomly pick $B$ candidates from that.

In [None]:
# exports


def get_top_random_scorers(scores_N: torch.Tensor, *, num_classes: int, batch_size: int) -> CandidateBatch:
    N = len(scores_N)
    batch_size = min(batch_size, N)

    L = min(batch_size * num_classes, N)

    candidate_scores, candidate_indices = torch.topk(scores_N, L)

    sub_indices = torch.randperm(L)[:batch_size]

    return CandidateBatch(candidate_scores[sub_indices].tolist(), candidate_indices[sub_indices].tolist())


def get_random_bald_batch(log_probs_N_K_C: torch.Tensor, *, batch_size: int, dtype=None, device=None) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape

    scores_N = get_bald_scores(log_probs_N_K_C, dtype=dtype, device=device)

    return get_top_random_scorers(scores_N, num_classes=C, batch_size=batch_size)

In [None]:
get_random_bald_batch(ys_ws.log().double(), batch_size=4, dtype=torch.double)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.030715639666234917, 0.030715639666234695, 0.030715639666234917], indices=[1, 2, 3, 0])

## Additional EPIG-BALD variants

Instead of using BatchBALD, let's compute BALD directly and use either the top-k, TopRandom or Thomp

In [None]:
# exports


def get_eval_bald_scores(
    training_log_probs_N_K_C: torch.Tensor,
    eval_log_probs_N_K_C: torch.Tensor,
    *,
    dtype=None,
    device=None,
) -> torch.Tensor:
    assert training_log_probs_N_K_C.shape == eval_log_probs_N_K_C.shape

    training_scores_N = get_bald_scores(training_log_probs_N_K_C, dtype=dtype, device=device)
    pool_scores_N = get_bald_scores(eval_log_probs_N_K_C, dtype=dtype, device=device)

    scores_N = training_scores_N - pool_scores_N

    return scores_N


def get_eval_bald_batch(
    training_log_probs_N_K_C: torch.Tensor,
    pool_log_probs_N_K_C: torch.Tensor,
    *,
    batch_size: int,
    dtype=None,
    device=None,
) -> CandidateBatch:
    return get_top_k_scorers(
        get_eval_bald_scores(training_log_probs_N_K_C, pool_log_probs_N_K_C, dtype=dtype, device=device),
        batch_size=batch_size,
    )


def get_top_random_eval_bald_batch(
    training_log_probs_N_K_C: torch.Tensor,
    pool_log_probs_N_K_C: torch.Tensor,
    *,
    batch_size: int,
    num_classes: int,
    dtype=None,
    device=None,
) -> CandidateBatch:
    return get_top_random_scorers(
        get_eval_bald_scores(training_log_probs_N_K_C, pool_log_probs_N_K_C, dtype=dtype, device=device),
        batch_size=batch_size,
        num_classes=num_classes,
    )

In [None]:
get_eval_bald_batch(ys_ws.log().double(), ys_ws.log().double(), batch_size=4, dtype=torch.double)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.0, 0.0, 0.0, 0.0], indices=[2, 3, 0, 1])

In [None]:
get_top_random_eval_bald_batch(
    ys_ws.log().double(), ys_ws.log().double(), batch_size=4, num_classes=10, dtype=torch.double
)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.0, 0.0, 0.0, 0.0], indices=[3, 2, 1, 0])

## TemperedBALD

Use temperature-scaled BALD scores for importance sampling.

In [None]:
# exports


def get_sampled_tempered_scorers(scores_N: torch.Tensor, *, temperature: float, batch_size: int) -> CandidateBatch:
    N = len(scores_N)
    batch_size = min(batch_size, N)

    # If we exponentiate scores_N beforehand, we obtain a softmax function here.
    tempered_scores_N = scores_N ** (1 / temperature)
    tempered_scores_N[scores_N < 0] = 0.0
    partition_constant = tempered_scores_N.sum()
    p = tempered_scores_N / partition_constant

    # TODO: change this to use PyTorch instead of numpy?
    candidate_indices = np.random.choice(N, size=batch_size, replace=False, p=p.cpu().numpy())
    candidate_scores = scores_N[candidate_indices]

    return CandidateBatch(candidate_scores.tolist(), candidate_indices.tolist())

In [None]:
get_sampled_tempered_scorers(get_bald_scores(ys_ws.log().double()), temperature=10, batch_size=2)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234695, 0.030715639666234917], indices=[3, 1])

## Stochastic Acquisition

Re-implementation for the final paper experiments (hopefully without bugs...)

In [None]:
# exports


def get_random_samples(scores_N: torch.Tensor, *, batch_size: int) -> CandidateBatch:
    N = len(scores_N)
    batch_size = min(batch_size, N)

    indices = np.random.choice(N, size=batch_size, replace=False)
    candidate_batch = CandidateBatch([0.0] * batch_size, indices.tolist())
    return candidate_batch


def get_softmax_samples(scores_N: torch.Tensor, *, coldness: float, batch_size: int) -> CandidateBatch:
    # As coldness -> 0, we obtain random sampling.
    if coldness == 0.0:
        return get_random_samples(scores_N, batch_size=batch_size)

    N = len(scores_N)
    noised_scores_N = scores_N + scipy.stats.gumbel_r.rvs(loc=0, scale=1 / coldness, size=N, random_state=None)

    return get_top_k_scorers(noised_scores_N, batch_size=batch_size)


def get_power_samples(scores_N: torch.Tensor, *, coldness: float, batch_size: int) -> CandidateBatch:
    return get_softmax_samples(torch.log(scores_N), coldness=coldness, batch_size=batch_size)


def get_softrank_samples(scores_N: torch.Tensor, *, coldness: float, batch_size: int) -> CandidateBatch:
    N = len(scores_N)

    sorted_indices_N = torch.argsort(scores_N, descending=True)
    ranks_N = torch.argsort(sorted_indices_N) + 1

    return get_power_samples(1 / ranks_N, coldness=coldness, batch_size=batch_size)


def get_simulation_samples(scores_N: torch.Tensor, *, coldness: float, batch_size: int) -> CandidateBatch:
    # As coldness -> 0, we obtain random sampling.
    if coldness == 0.0:
        return get_random_samples(scores_N, batch_size=batch_size)

    N = len(scores_N)
    batch_size = min(batch_size, N)

    indices = []
    scores = []
    current_scores_N = torch.clone(scores_N)
    for i in range(batch_size):
        score, index = torch.max(current_scores_N, dim=0)
        scores += [score.item()]
        indices += [index.item()]
        current_scores_N = current_scores_N / (1+ scipy.stats.expon.rvs(loc=0, scale=1 / coldness, size=N, random_state=None))
        current_scores_N[index] = float("-inf")

    return CandidateBatch(scores=scores, indices=indices)

In [None]:
assert get_power_samples(torch.as_tensor([1, 0, 0]), coldness=1, batch_size=1).indices == [0]

In [None]:
np.unique(
    sum(
        (get_power_samples(torch.as_tensor([0.5, 0.25, 0.25]), coldness=1, batch_size=1).indices for _ in range(1000)),
        [],
    ),
    return_counts=True,
)[1] / 1000
# should be around [0.5, 0.25, 0.25]

array([0.501, 0.266, 0.233])

In [None]:
np.unique(
    sum(
        (
            get_softmax_samples(torch.as_tensor([np.log(2), 0, 0]), coldness=1, batch_size=1).indices
            for _ in range(1000)
        ),
        [],
    ),
    return_counts=True,
)[1] / 1000
# should be around [0.5, 0.25, 0.25]

array([0.487, 0.268, 0.245])

In [None]:
np.unique(
    sum((get_softrank_samples(torch.as_tensor([3, 2, 1]), coldness=1, batch_size=1).indices for _ in range(1000)), []),
    return_counts=True,
)[1] / 1000 * 11
# p = [1, 1/2, 1/3] / ((6+3+2)/6) = [1, 1/2, 1/3] * 6 / 11 = [6/11, 3/11, 2/11]

array([6.061, 2.959, 1.98 ])

In [None]:
np.unique(
    sum((get_softrank_samples(torch.as_tensor([3, 2, 1]), coldness=8, batch_size=1).indices for _ in range(1000)), []),
    return_counts=True,
)[1] / 1000 * 11

array([10.967,  0.022,  0.011])

In [None]:
np.unique(
    sum((get_softmax_samples(torch.as_tensor([3, 2, 1]), coldness=0, batch_size=1).indices for _ in range(1000)), []),
    return_counts=True,
)[1] / 1000

array([0.353, 0.313, 0.334])

In [None]:
np.unique(
    sum((get_softmax_samples(torch.as_tensor([3, 2, 1]), coldness=0, batch_size=1).indices for _ in range(1000)), []),
    return_counts=True,
)[1] / 1000

In [None]:
np.unique(
    sum((get_hypothesis_samples(torch.as_tensor([3, 2, 1]), coldness=1, batch_size=2).indices for _ in range(1000)), []),
    return_counts=True,
)[1] / 1000 / 2

array([0.5  , 0.442, 0.058])

In [None]:
# exports


class StochasticMode(Enum):
    Power = "Power"
    Softmax = "Softmax"
    Softrank = "Softrank"
    Simulation = "Simulation"


def get_stochastic_samples(
    scores_N: torch.Tensor, *, coldness: float, batch_size: int, mode: StochasticMode
) -> CandidateBatch:
    if mode == StochasticMode.Power:
        return get_power_samples(scores_N, coldness=coldness, batch_size=batch_size)
    elif mode == StochasticMode.Softmax:
        return get_softmax_samples(scores_N, coldness=coldness, batch_size=batch_size)
    elif mode == StochasticMode.Softrank:
        return get_softrank_samples(scores_N, coldness=coldness, batch_size=batch_size)
    elif mode == StochasticMode.Simulation:
        return get_simulation_samples(scores_N, coldness=coldness, batch_size=batch_size)
    else:
        return ValueError(f"Unknown mode")

In [None]:
[get_stochastic_samples(torch.randn(size=(4,)).exp_(), coldness=1, batch_size=2, mode=mode) for mode in StochasticMode]

[CandidateBatch(scores=[2.982371351302047, 1.8058588693740623], indices=[3, 0]),
 CandidateBatch(scores=[4.119388972071258, 2.457687553278123], indices=[0, 1]),
 CandidateBatch(scores=[0.47892253155902625, -0.7071420806250666], indices=[2, 0]),
 CandidateBatch(scores=[4.90777063369751, 0.32464359644716917], indices=[0, 2])]

## EPIG

As part of an ablation (and to see how it performs), we can also compute the ICAL score.

In [None]:
# exports


def get_eig_scores(
    training_log_probs_N_K_C: torch.Tensor,
    pool_log_probs_N_K_C: torch.Tensor,
    *,
    dtype=None,
    device=None,
) -> torch.Tensor:
    assert training_log_probs_N_K_C.shape == pool_log_probs_N_K_C.shape

    N, K, C = training_log_probs_N_K_C.shape

    scores_N = compute_entropy(training_log_probs_N_K_C) - compute_entropy(pool_log_probs_N_K_C)

    return scores_N

In [None]:
get_eig_scores(ys_ws.log().double(), ys_ws.log().double())

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([0., 0., 0., 0.], dtype=torch.float64)

In [None]:
# exports


# TODO: refactor the BatchBALDScorer to deduplicate some of this?
def get_batch_eig_batch(
    training_log_probs_N_K_C: torch.Tensor,
    pool_log_probs_N_K_C: torch.Tensor,
    *,
    batch_size: int,
    num_samples: int,
    dtype=None,
    device=None,
) -> CandidateBatch:
    assert training_log_probs_N_K_C.shape == pool_log_probs_N_K_C.shape
    N, K, C = training_log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    training_joint_entropy = joint_entropy.DynamicJointEntropy(
        num_samples, batch_size, K, C, dtype=dtype, device=device
    )

    pool_joint_entropy = joint_entropy.DynamicJointEntropy(num_samples, batch_size, K, C, dtype=dtype, device=device)

    # We always keep these on the CPU.
    training_scores_N = torch.empty(
        N,
        dtype=torch.double,
        pin_memory=torch.cuda.is_available(),
    )

    pool_scores_N = torch.empty(
        N,
        dtype=torch.double,
        pin_memory=torch.cuda.is_available(),
    )

    for i in with_progress_bar(range(batch_size), tqdm_args=dict(desc="BatchBALD", leave=False)):
        if i > 0:
            latest_index = candidate_indices[-1]
            training_joint_entropy.add_variables(training_log_probs_N_K_C[latest_index : latest_index + 1])
            pool_joint_entropy.add_variables(pool_log_probs_N_K_C[latest_index : latest_index + 1])

        training_joint_entropy.compute_batch(training_log_probs_N_K_C, output_entropies_B=training_scores_N)
        pool_joint_entropy.compute_batch(pool_log_probs_N_K_C, output_entropies_B=pool_scores_N)

        scores_N = training_scores_N - pool_scores_N
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

    return CandidateBatch(candidate_scores, candidate_indices)

In [None]:
get_batch_eig_batch(ys_ws.log().double(), ys_ws.log().double(), batch_size=4, num_samples=1000, dtype=torch.double)

BatchBALD:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.0, 0.0, 0.0, 0.0], indices=[0, 1, 2, 3])

## Information Gain

(Following the new notation.)

Instead of computing $I[Y;\omega|x]$, we use our knowledge of the labels and compute: $$I[y;\omega|x]= H[y|x] - \mathbb{E}_{p(\omega|y,x)} H[y|x,\omega].$$

In [None]:
# exports


def get_coreset_bald_scores_from_predictions(
    log_probs_N_K_C: torch.Tensor, target_probs_N_C: torch.Tensor, *, dtype=None, device=None
) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape
    assert target_probs_N_C.shape == (N, C)

    log_probs_N_K_C = log_probs_N_K_C.to(dtype=dtype, device=device)
    target_probs_N_C = target_probs_N_C.to(dtype=dtype, device=device)

    log_prob_mean_N_C = torch.logsumexp(log_probs_N_K_C, dim=1) - np.log(K)

    entropy_N_C = -log_prob_mean_N_C
    conditional_entropy = -torch.mean(log_probs_N_K_C * log_probs_N_K_C.exp(), dim=1) / log_prob_mean_N_C.exp()
    mutual_bits_N_C = entropy_N_C - conditional_entropy

    cross_mutual_information = torch.sum(target_probs_N_C * mutual_bits_N_C, dim=1)

    return cross_mutual_information


def get_coreset_bald_scores(
    log_probs_N_K_C: torch.Tensor, labels_N: torch.Tensor, *, dtype=None, device=None
) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape

    labels_N_1_1 = labels_N[:, None, None]
    log_probs_N_K = (
        joint_entropy.gather_expand(log_probs_N_K_C, dim=2, index=labels_N_1_1)
        .squeeze(2)
        .to(dtype=dtype, device=device)
    )

    log_prob_mean_N = torch.logsumexp(log_probs_N_K, dim=1) - np.log(K)

    lhs = -log_prob_mean_N
    rhs = -torch.mean(log_probs_N_K * log_probs_N_K.exp(), dim=1) / log_prob_mean_N.exp()

    coreset_infogain = lhs - rhs

    return coreset_infogain

In [None]:
get_coreset_bald_scores(ys_ws.log().double(), torch.tensor([0, 1, 2, 3])), [
    get_coreset_bald_scores(ys_ws.log().double(), torch.tensor([i, i, i, i])) for i in range(3)
]

(tensor([0.0300, 0.0300, 0.0300, 0.0300], dtype=torch.float64),
 [tensor([0.0300, 0.0207, 0.0207, 0.0474], dtype=torch.float64),
  tensor([0.0474, 0.0300, 0.0207, 0.0207], dtype=torch.float64),
  tensor([0.0207, 0.0474, 0.0300, 0.0207], dtype=torch.float64)])

## Batch Information Gain

(Following the new notation.)

The batch version of this acquisition function can be computed more easily:

$$ \operatorname{I}[(y)_B;\omega|(x)_B] = \operatorname{H}[(y)_B|(x)_B] - \mathbb{E}_{p(\omega|(y)_B, (x)_B)} \operatorname{H}[(y)_B|(x)_B, \omega], $$

where $p(\omega|(y)_B, (x)_B) = \frac{ p((y)_B| (x)_B, \omega) p(\omega) }{ p((y)_B| (x)_B) }$ as usual, and we make use of the independence of the $(y)_B$ given $\omega$.

We can make this efficient for computing scores in parallel by using:
$$p((y)_B|(x)_B, \omega) = p(y_B|x_B, \omega) \; p((y)_{B-1}|(x)_{B-1}, \omega).$$

### Do we have sub-modularity?

Unclear.

In [None]:
# exports


def get_batch_coreset_bald_batch(
    log_probs_N_K_C: torch.Tensor, labels_N: torch.Tensor, *, batch_size: int, dtype=None, device=None
) -> CandidateBatch:
    # We want to compute (note this does not follow the notation from below):
    # CoreSetBALD = H[y_1, ..., y_n ] - E_p(w) p(y_1, ..., y_n | w) / p(y_1, ..., y_n) H[y_1, ..., y_n | w]
    # H[y_1, ..., y_n | w] = H[y_1, ..., y_{n-1} | w] + H[y_n | w] because y_i _||_ y_j | w
    N, K, C = log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    labels_N_1_1 = labels_N[:, None, None]
    log_probs_N_K = (
        joint_entropy.gather_expand(log_probs_N_K_C, dim=2, index=labels_N_1_1)
        .squeeze(2)
        .to(dtype=dtype, device=device)
    )

    # p((y)_{B-1}|(x)_{B-1}, \omega)
    log_probs_conditional_joint_batch_K = torch.zeros_like(log_probs_N_K[0], dtype=dtype, device=device)

    for i in with_progress_bar(range(batch_size), tqdm_args=dict(desc="BatchCoreSetBALD", leave=False)):
        # p((y)_B|(x)_B, \omega) = p(y_B|x_B, \omega) * p((y)_{B-1}|(x)_{B-1}, \omega)
        log_prob_conditional_joint_N_K = log_probs_N_K + log_probs_conditional_joint_batch_K[None, :]

        # Marginalize over w (but using sum not mean):
        # p((y)_B|(x)_B) = E_p(\omega) p((y)_B|(x)_B, \omega)
        # log_prob_joint_N_1 = log_prob_conditional_joint_N_K.logsumexp(dim=1, keepdim=True) - np.log(K)
        log_prob_joint_pK_N_1 = log_prob_conditional_joint_N_K.logsumexp(dim=1, keepdim=True)
        # \frac{ p((y)_B| (x)_B, \omega) }{ p((y)_B| (x)_B) }
        # log_ratio_N_K = log_prob_conditional_joint_N_K - log_prob_joint_N_1
        # log_ratio_N_K = log_prob_conditional_joint_N_K - log_prob_joint_pK_N_1 + np.log(K)
        log_ratio_mK_N_K = log_prob_conditional_joint_N_K - log_prob_joint_pK_N_1
        # conditional_entropy_joint_N = -torch.mean(log_ratio_N_K.exp() * log_prob_conditional_joint_N_K, dim=1)
        # conditional_entropy_joint_N =
        #       -torch.mean((log_ratio_mK_N_K + np.log(K)).exp() * log_prob_conditional_joint_N_K, dim=1)
        conditional_entropy_joint_N = -torch.sum(log_ratio_mK_N_K.exp() * log_prob_conditional_joint_N_K, dim=1)
        # entropy_joint_N = -log_prob_joint_N_1.squeeze(1)
        # entropy_joint_N = -(log_prob_joint_pK_N_1 - np.log(K)).squeeze(1)
        entropy_joint_N = -log_prob_joint_pK_N_1.squeeze(1) + np.log(K)
        scores_N = entropy_joint_N - conditional_entropy_joint_N

        # Select candidate
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

        # Update log_probs_conditional_joint_batch_K
        log_probs_conditional_joint_batch_K = log_prob_conditional_joint_N_K[candidate_index]

    return CandidateBatch(candidate_scores, candidate_indices)

In [None]:
def get_batch_coreset_bald_batch_simpler(
    log_probs_N_K_C: torch.Tensor, labels_N: torch.Tensor, *, batch_size: int, dtype=None, device=None
) -> CandidateBatch:
    # We want to compute (note this does not follow the notation from below):
    # CoreSetBALD = H[y_1, ..., y_n ] - E_p(w) p(y_1, ..., y_n | w) / p(y_1, ..., y_n) H[y_1, ..., y_n | w]
    # H[y_1, ..., y_n | w] = H[y_1, ..., y_{n-1} | w] + H[y_n | w] because y_i _||_ y_j | w
    N, K, C = log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    labels_N_1_1 = labels_N[:, None, None]
    log_probs_N_K = (
        joint_entropy.gather_expand(log_probs_N_K_C, dim=2, index=labels_N_1_1)
        .squeeze(2)
        .to(dtype=dtype, device=device)
    )

    # p((y)_{B-1}|(x)_{B-1}, \omega)
    log_probs_conditional_joint_batch_K = torch.zeros_like(log_probs_N_K[0], dtype=dtype, device=device)

    for i in with_progress_bar(range(batch_size), tqdm_args=dict(desc="BatchCoreSetBALD", leave=False)):
        # p((y)_B|(x)_B, \omega) = p(y_B|x_B, \omega) * p((y)_{B-1}|(x)_{B-1}, \omega)
        log_prob_conditional_joint_N_K = log_probs_N_K + log_probs_conditional_joint_batch_K[None, :]

        # Marginalize over w (but using sum not mean):
        # p((y)_B|(x)_B) = E_p(\omega) p((y)_B|(x)_B, \omega)
        log_prob_joint_N_1 = log_prob_conditional_joint_N_K.logsumexp(dim=1, keepdim=True) - np.log(K)

        # \frac{ p((y)_B| (x)_B, \omega) }{ p((y)_B| (x)_B) }
        log_ratio_N_K = log_prob_conditional_joint_N_K - log_prob_joint_N_1
        conditional_entropy_joint_N = -torch.mean(log_ratio_N_K.exp() * log_prob_conditional_joint_N_K, dim=1)
        entropy_joint_N = -log_prob_joint_N_1.squeeze(1)
        scores_N = entropy_joint_N - conditional_entropy_joint_N

        # Select candidate
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

        # Update log_probs_conditional_joint_batch_K
        log_probs_conditional_joint_batch_K = log_prob_conditional_joint_N_K[candidate_index]

    return CandidateBatch(candidate_scores, candidate_indices)

In [None]:
get_coreset_bald_scores(ys_ws.log().double(), torch.tensor([0, 1, 2, 3])).numpy(), [
    get_coreset_bald_scores(ys_ws.log().double(), torch.tensor([i, i, i, i])) for i in range(3)
]

(array([0.03002132, 0.03002132, 0.03002132, 0.03002132]),
 [tensor([0.0300, 0.0207, 0.0207, 0.0474], dtype=torch.float64),
  tensor([0.0474, 0.0300, 0.0207, 0.0207], dtype=torch.float64),
  tensor([0.0207, 0.0474, 0.0300, 0.0207], dtype=torch.float64)])

In [None]:
ys_ws.shape

torch.Size([4, 20, 4])

In [None]:
get_batch_coreset_bald_batch(ys_ws.log().double(), torch.tensor([0, 1, 2, 3]), batch_size=4, dtype=torch.double)

BatchCoreSetBALD:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030021323375763576, 0.10871562110954991, 0.2168431672489275, 0.3375447132429139], indices=[0, 1, 2, 3])

In [None]:
get_batch_coreset_bald_batch_simpler(ys_ws.log().double(), torch.tensor([0, 1, 2, 3]), batch_size=4, dtype=torch.double)

BatchCoreSetBALD:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030021323375763687, 0.10871562110954991, 0.2168431672489275, 0.3375447132429139], indices=[0, 1, 2, 3])

## CoreSet-PIG & Coreset-PIG-BALD

Combining EIG with CoreSets to use $I[y_{eval}; y_{batch} | x_{eval}; x_{batch}, D_{train}]$.

This is really easy to compute as $H[y_{batch} | x_{batch}, D_{train}] - H[y_{batch} | y_{eval}, x_{eval}; x_{batch}, D_{train}]$.

In [None]:
# exports


def get_coreset_eig_scores(
    *,
    training_log_probs_N_K_C: torch.Tensor,
    eval_log_probs_N_K_C: torch.Tensor,
    labels_N: torch.Tensor,
    dtype=None,
    device=None
) -> torch.Tensor:
    # We want to compute I[y_eval; y_batch].
    # I[y_eval; y_train] = H[y_batch] - H[y_batch|y_eval]
    N, K, C = training_log_probs_N_K_C.shape

    labels_N_1_1 = labels_N[:, None, None]
    training_log_probs_N_K = (
        joint_entropy.gather_expand(training_log_probs_N_K_C, dim=2, index=labels_N_1_1)
        .squeeze(2)
        .to(dtype=dtype, device=device)
    )
    training_log_prob_mean_N = torch.logsumexp(training_log_probs_N_K, dim=1) - np.log(K)

    eval_log_probs_N_K = (
        joint_entropy.gather_expand(eval_log_probs_N_K_C, dim=2, index=labels_N_1_1)
        .squeeze(2)
        .to(dtype=dtype, device=device)
    )
    eval_log_prob_mean_N = torch.logsumexp(eval_log_probs_N_K, dim=1) - np.log(K)

    pig = -training_log_prob_mean_N + eval_log_prob_mean_N

    return pig


def get_coreset_eig_bald_scores(
    *,
    training_log_probs_N_K_C: torch.Tensor,
    eval_log_probs_N_K_C: torch.Tensor,
    labels_N: torch.Tensor,
    dtype=None,
    device=None
) -> torch.Tensor:
    # We want to compute I[y_eval; y_batch; W].
    # I[y_eval; y_batch; W] = I[y_batch; W] - I[y_batch; W|y_eval]
    training_coreset = get_coreset_bald_scores(training_log_probs_N_K_C, labels_N=labels_N, dtype=dtype, device=device)
    eval_coreset = get_coreset_bald_scores(eval_log_probs_N_K_C, labels_N=labels_N, dtype=dtype, device=device)
    return training_coreset - eval_coreset

In [None]:
get_coreset_eig_scores(
    training_log_probs_N_K_C=ys_ws.log().double(),
    eval_log_probs_N_K_C=ys_ws.log().double(),
    labels_N=torch.tensor([0, 1, 2, 3]),
    dtype=torch.double,
)

tensor([0., 0., 0., 0.], dtype=torch.float64)

In [None]:
get_coreset_eig_bald_scores(
    training_log_probs_N_K_C=ys_ws.log().double(),
    eval_log_probs_N_K_C=ys_ws.log().double(),
    labels_N=torch.tensor([0, 1, 2, 3]),
    dtype=torch.double,
)

tensor([0., 0., 0., 0.], dtype=torch.float64)

## SieveBALD

This is the 2-BALD approximation (leaving out $ D_{train}$):
$$I[Y_1, \ldots, Y_n;\Omega \mid x_1, \ldots,x_n] \approx \sum_i I[Y_i;\Omega\mid x_i] - \sum_{i<j} I[Y_i;Y_j \mid x_i,x_j].$$

See also https://www.notion.so/SieveBALD-using-a-marginal-total-correlation-assumption-and-or-by-forcing-it-2e4a9548d4124b6bb8e0dcbba789887a.

In [None]:
# exports


def get_sieve_bald_batch(log_probs_N_K_C: torch.Tensor, *, batch_size: int, dtype=None, device=None) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape
    batch_size = min(batch_size, N)

    candidate_scores = []
    candidate_indices = []

    entropies_N = compute_entropy(log_probs_N_K_C)

    # we start with BALD scores
    scores_N = entropies_N - compute_conditional_entropy(log_probs_N_K_C)

    last_score = 0.0
    for _ in range(batch_size):
        # Pick the highest scorer.
        # This is amenable to lazy greedy and lazier than lazy greedy, though we do not implement this here. (Yet)
        candidate_score, candidate_index = scores_N.max(dim=0)

        # TODO: break here if candidate_score < 0 at this point!

        candidate_score += last_score
        last_score = candidate_score

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

        # Update the acquired item's score so it is not picked again.
        scores_N[candidate_index] = -float("inf")

        # Decrease scores for other items
        joint_entropy_helper = joint_entropy.ExactJointEntropy.empty(K, device=device, dtype=dtype)
        joint_entropy_helper.add_variables(log_probs_N_K_C[candidate_index : candidate_index + 1])
        joint_entropies_N = joint_entropy_helper.compute_batch(log_probs_N_K_C)
        dual_mi_N = entropies_N + entropies_N[candidate_index] - joint_entropies_N

        scores_N -= dual_mi_N

    return CandidateBatch(candidate_scores, candidate_indices)

In [None]:
get_sieve_bald_batch(np.repeat(ys_ws, 2, axis=0).log().double(), batch_size=8, dtype=torch.double)

Entropy:   0%|          | 0/8 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.08671183981604269, 0.11199240029961555, 0.13546126772230105, 0.15711844208409942, 0.17696392338501088, 0.1949977116250352], indices=[0, 1, 2, 3, 4, 6, 7, 5])

In [None]:
ys_ws.shape

torch.Size([4, 20, 4])

In [None]:
get_batch_bald_batch(np.repeat(ys_ws, 2, axis=0).log().double(), batch_size=8, num_samples=1000000, dtype=torch.double)

Conditional Entropy:   0%|          | 0/8 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/8 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.0869107051474467, 0.11275304532467878, 0.1372853331853925, 0.16062670985153638, 0.18288066309757767, 0.2041378763246513], indices=[2, 0, 1, 4, 3, 5, 6, 7])

## Real EPIG

Implement $I[Y_{acq} ; Y_{eval} \mid x_{acq} ; X_{eval},  D_{train}]$.



In [None]:
def get_joint_probs_N_C_C_old(pool_probs_N_K_C: torch.Tensor, single_eval_probs_K_C: torch.Tensor):
    K = single_eval_probs_K_C.shape[0]

    pool_log_probs_N_C_K = pool_probs_N_K_C.transpose(1, 2)
    joint_probs_N_C_C = pool_log_probs_N_C_K @ single_eval_probs_K_C / K
    return joint_probs_N_C_C


def get_real_naive_epig_scores_old(
    *, pool_log_probs_N_K_C: torch.Tensor, eval_log_probs_E_K_C: torch.Tensor, dtype=None, device=None
) -> torch.Tensor:
    """Implements naive EPIG: I[Y_acq; Y_eval | x_acq, X_eval]."""
    N, K, C = pool_log_probs_N_K_C.shape
    E, _, _ = eval_log_probs_E_K_C.shape
    assert (
        pool_log_probs_N_K_C.shape[1:] == pool_log_probs_N_K_C.shape[1:]
    ), "{pool_log_probs_N_K_C.shape[1:]} != {pool_log_probs_N_K_C.shape[1:]}"

    pool_probs_N_K_C = pool_log_probs_N_K_C.to(dtype=dtype, device=device).exp()
    eval_probs_E_K_C = eval_log_probs_E_K_C.to(dtype=dtype, device=device).exp()

    pool_probs_N_C = torch.mean(pool_probs_N_K_C, dim=1, keepdim=False)

    total_scores_N = torch.zeros((N,), dtype=dtype, device="cpu")
    for i_e in with_progress_bar(range(E), tqdm_args=dict(desc="Evaluation Set", leave=False)):
        single_eval_probs_K_C = eval_probs_E_K_C[i_e]

        joint_probs_N_C_C = get_joint_probs_N_C_C_old(pool_probs_N_K_C, single_eval_probs_K_C)

        single_eval_probs_C = torch.mean(single_eval_probs_K_C, dim=0, keepdim=False)

        nats_N_C_C = (
            -torch.log(single_eval_probs_C)[None, None, :]
            - torch.log(pool_probs_N_C)[:, :, None]
            + torch.log(joint_probs_N_C_C)
        )

        weighted_nats_N_C_C = nats_N_C_C * joint_probs_N_C_C
        weighted_nats_N_C_C[torch.isnan(weighted_nats_N_C_C)] = 0.0
        scores_N = weighted_nats_N_C_C.sum((1, 2), keepdim=False)

        total_scores_N += scores_N.to(device="cpu", non_blocking=True)

    total_scores_N /= E

    return total_scores_N

In [None]:
# exports


class BootstrapType(Enum):
    NO_BOOTSTRAP = 0
    SINGLE_BOOTSTRAP = 1
    PER_POINT_BOOTSTRAP = 2
    FAST_PER_POINT_BOOTSTRAP = 3

In [None]:
# @torch.no_grad()
# def logmatmulexp(log_A: torch.Tensor, log_B: torch.Tensor) -> torch.Tensor:
#     """Given matrix log_A of shape (batch...) ϴ×R and matrix log_B of shape R×I, calculates
#     (log_A.exp() @ log_B.exp()).log() and its backward in a numerically stable way."""
#     batch_shape = list(log_A.shape[:-2])
#     ϴ, R = log_A.shape[-2:]
#     I = log_B.shape[-1]
#     assert log_B.shape == (R, I)
#     log_A_expanded = log_A.unsqueeze(-1).expand(batch_shape + [ϴ, R, I])
#     log_B_expanded = log_B.unsqueeze(-3).expand((ϴ, R, I))
#     log_pairwise_products = log_A_expanded + log_B_expanded  # shape: (ϴ, R, I)
#     return torch.logsumexp(log_pairwise_products, dim=-2)


@torch.no_grad()
def logmatmulexp(log_A: torch.Tensor, log_B: torch.Tensor) -> torch.Tensor:
    """Given matrix log_A of shape (batch...) ϴ×R and matrix log_B of shape R×I, calculates
    (log_A.exp() @ log_B.exp()).log() and its backward in a numerically stable way."""
    max_A = torch.max(log_A, axis=-1, keepdim=True)[0]
    max_B = torch.max(log_B, axis=-2, keepdim=True)[0]
    C = torch.log((log_A - max_A).exp() @ (log_B - max_B).exp()) + max_A + max_B
    return C


@torch.no_grad()
def get_real_naive_epig_scores_stable(
    *,
    bootstrap_type=BootstrapType.NO_BOOTSTRAP,
    bootstrap_factor=1.0,
    pool_log_probs_N_K_C: torch.Tensor,
    eval_log_probs_E_K_C: torch.Tensor,
    dtype=None,
    device=None,
) -> torch.Tensor:
    """Implements naive EPIG: I[Y_acq; Y_eval | x_acq, X_eval]."""
    # I[Y_acq; Y_eval | x_acq, X_eval] = H[Y_acq | x_acq] + E_p(x_eval)[H[Y_eval | x_eval] - H[Y_acq, Y_eval | x_acq, x_eval]]
    N, K, C = pool_log_probs_N_K_C.shape
    E, _, _ = eval_log_probs_E_K_C.shape
    assert (
        pool_log_probs_N_K_C.shape[1:] == pool_log_probs_N_K_C.shape[1:]
    ), "{pool_log_probs_N_K_C.shape[1:]} != {pool_log_probs_N_K_C.shape[1:]}"

    pool_entropies_N = compute_entropy(pool_log_probs_N_K_C).to(device=device)

    total_joint_entropies_N = torch.zeros((N,), dtype=dtype, device=device)

    if bootstrap_type != BootstrapType.PER_POINT_BOOTSTRAP:
        eval_label_uncertainty = compute_entropy(eval_log_probs_E_K_C).mean(dim=0, keepdim=False)

        if bootstrap_type == BootstrapType.NO_BOOTSTRAP:
            eval_range = range(E)
        elif bootstrap_type == BootstrapType.SINGLE_BOOTSTRAP:
            num_eval_samples = int(E * bootstrap_factor)
            eval_range = torch.multinomial(torch.tensor(1.0).expand(E), num_samples=num_eval_samples, replacement=True)
        else:
            raise ValueError(f"Unknown bootstrap {bootstrap_type}")

        pool_log_probs_N_C_K = pool_log_probs_N_K_C.transpose(1, 2).contiguous().to(dtype=dtype, device=device)
        eval_log_probs_E_K_C = eval_log_probs_E_K_C.to(dtype=dtype, device=device)

        for i_e in with_progress_bar(eval_range, tqdm_args=dict(desc="Evaluation Set", leave=False)):
            single_eval_log_probs_K_C = eval_log_probs_E_K_C[i_e]

            joint_probs_N_C_C = logmatmulexp(pool_log_probs_N_C_K, single_eval_log_probs_K_C) - np.log(K)
            weighted_nats_N_C_C = joint_probs_N_C_C * -torch.exp(joint_probs_N_C_C)
            weighted_nats_N_C_C[torch.isnan(weighted_nats_N_C_C)] = 0.0
            joint_entropy_N = weighted_nats_N_C_C.sum((1, 2), keepdim=False)
            del weighted_nats_N_C_C

            total_joint_entropies_N += joint_entropy_N

        total_scores_N = pool_entropies_N - total_joint_entropies_N / E + eval_label_uncertainty
    #     elif bootstrap_type == BootstrapType.PER_POINT_BOOTSTRAP:
    #         eval_label_uncertainty_E = compute_entropy(eval_log_probs_E_K_C)

    #         total_scores_N = pool_entropies_N

    #         for i_n in with_progress_bar(range(N), tqdm_args=dict(desc="Pool Set", leave=False)):
    #             single_pool_probs_K_C = pool_probs_N_K_C[i_n]

    #             num_eval_samples = int(E * bootstrap_factor)
    #             eval_indices = torch.multinomial(
    #                 torch.tensor(1.0).expand(E), num_samples=num_eval_samples, replacement=True
    #             )
    #             # For debugging:
    #             # num_eval_samples = E
    #             # eval_indices = torch.tensor(list(range(E)))

    #             sampled_eval_probs_F_K_C = eval_probs_E_K_C[eval_indices]

    #             joint_probs_F_C_C = get_joint_probs_N_C_C(sampled_eval_probs_F_K_C, single_pool_probs_K_C)
    #             weighted_nats_F_C_C = joint_probs_F_C_C * -torch.log(joint_probs_F_C_C)
    #             avg_joint_entropy = weighted_nats_F_C_C.sum() / num_eval_samples
    #             del weighted_nats_F_C_C

    #             eval_label_uncertainty = eval_label_uncertainty_E[eval_indices].mean(dim=0, keepdim=False)
    #             total_scores_N[i_n] += eval_label_uncertainty - avg_joint_entropy

    return total_scores_N.to(device="cpu", non_blocking=True)

In [None]:
# exports


def get_joint_probs_N_C_C(pool_probs_N_K_C: torch.Tensor, single_eval_probs_K_C: torch.Tensor):
    K = single_eval_probs_K_C.shape[0]

    pool_log_probs_N_C_K = pool_probs_N_K_C.transpose(1, 2).contiguous()
    joint_probs_N_C_C = pool_log_probs_N_C_K @ single_eval_probs_K_C / K
    return joint_probs_N_C_C


def get_joint_probs_N_C_EC_transposed(pool_probs_N_C_K: torch.Tensor, eval_probs_E_K_C: torch.Tensor):
    N, C, K = pool_probs_N_C_K.shape
    E, K, C = eval_probs_E_K_C.shape
    # joint_probs_N_C_C = torch.empty(N, C, C, dtype=pool_probs_N_C_K.dtype, device=pool_probs_N_C_K.device)
    # for joint_probs_C_C, pool_probs_C_K in zip(joint_probs_N_C_C.split(4096), pool_probs_N_C_K.split(4096)):
    #     joint_probs = pool_probs_C_K @ single_eval_probs_K_C / K
    #     joint_probs_C_C.copy_(joint_probs, non_blocking=True)
    eval_probs_K_EC = eval_probs_E_K_C.transpose(0, 1).reshape(K, E * C)
    # joint_probs_N_C_EC = pool_probs_N_C_K.contiguous() @ eval_probs_K_EC.contiguous() / K
    joint_probs_N_C_EC = pool_probs_N_C_K @ eval_probs_K_EC / K
    return joint_probs_N_C_EC


def get_joint_probs_N_C_C_transposed(pool_probs_N_C_K: torch.Tensor, single_eval_probs_K_C: torch.Tensor):
    K = single_eval_probs_K_C.shape[0]
    joint_probs_N_C_C = pool_probs_N_C_K @ single_eval_probs_K_C / K
    return joint_probs_N_C_C


@torch.no_grad()
def get_real_naive_epig_scores(
    *,
    bootstrap_type=BootstrapType.NO_BOOTSTRAP,
    bootstrap_factor=1.0,
    pool_log_probs_N_K_C: torch.Tensor,
    eval_log_probs_E_K_C: torch.Tensor,
    dtype=None,
    device=None,
) -> torch.Tensor:
    """Implements naive EPIG: I[Y_acq; Y_eval | x_acq, X_eval]."""
    # I[Y_acq; Y_eval | x_acq, X_eval] = H[Y_acq | x_acq] + E_p(x_eval)[H[Y_eval | x_eval] - H[Y_acq, Y_eval | x_acq, x_eval]]
    N, K, C = pool_log_probs_N_K_C.shape
    E, _, _ = eval_log_probs_E_K_C.shape
    assert (
        pool_log_probs_N_K_C.shape[1:] == pool_log_probs_N_K_C.shape[1:]
    ), "{pool_log_probs_N_K_C.shape[1:]} != {pool_log_probs_N_K_C.shape[1:]}"

    pool_entropies_N = compute_entropy(pool_log_probs_N_K_C).to(device=device, non_blocking=True)

    total_joint_entropies_N = torch.zeros((N,), dtype=dtype, device=device)

    if bootstrap_type != BootstrapType.PER_POINT_BOOTSTRAP:
        pool_probs_N_C_K = (
            pool_log_probs_N_K_C.to(dtype=dtype, device=device, non_blocking=True).exp().transpose(1, 2).contiguous()
        )
        # eval_probs_E_K_C = eval_log_probs_E_K_C.to(device=device, non_blocking=True).exp()

        eval_label_uncertainty = compute_entropy(eval_log_probs_E_K_C).mean(dim=0, keepdim=False)

        if bootstrap_type == BootstrapType.NO_BOOTSTRAP:
            num_eval_samples = E
            eval_range = list(range(E))
        elif bootstrap_type == BootstrapType.SINGLE_BOOTSTRAP:
            num_eval_samples = int(E * bootstrap_factor)
            eval_range = torch.multinomial(torch.tensor(1.0).expand(E), num_samples=num_eval_samples, replacement=True)
        else:
            raise ValueError(f"Unknown bootstrap {bootstrap_type}")

        pbar = create_progress_bar(num_eval_samples, tqdm_args=dict(desc="Evaluation Set", leave=False))
        pbar.start()

        @toma.execute.batch(1024)
        def loop(batchsize: int):
            pbar.reset()

            nonlocal total_joint_entropies_N
            for chunked_eval_log_probs_e_K_C in eval_log_probs_E_K_C[eval_range].split(batchsize):
                chunked_eval_probs_e_K_C = chunked_eval_log_probs_e_K_C.to(
                    dtype=dtype, device=device, non_blocking=True
                ).exp()
                joint_probs_N_E_EC = get_joint_probs_N_C_EC_transposed(pool_probs_N_C_K, chunked_eval_probs_e_K_C)
                weighted_nats_N_C_EC = joint_probs_N_E_EC * -torch.log(joint_probs_N_E_EC)
                weighted_nats_N_C_EC[torch.isnan(weighted_nats_N_C_EC)] = 0.0
                joint_entropy_N = weighted_nats_N_C_EC.sum((1, 2), keepdim=False)
                del weighted_nats_N_C_EC

                total_joint_entropies_N += joint_entropy_N

                pbar.update(len(chunked_eval_probs_e_K_C))

        pbar.finish()

        total_scores_N = pool_entropies_N - total_joint_entropies_N / E + eval_label_uncertainty
    elif bootstrap_type == BootstrapType.PER_POINT_BOOTSTRAP:
        pool_probs_N_K_C = pool_log_probs_N_K_C.to(dtype=dtype, device=device).exp()
        eval_probs_E_C_K = eval_log_probs_E_K_C.to(dtype=dtype, device=device).exp().transpose(1, 2).contiguous()

        eval_label_uncertainty_E = compute_entropy(eval_log_probs_E_K_C)

        total_scores_N = pool_entropies_N

        for i_n in with_progress_bar(range(N), tqdm_args=dict(desc="Pool Set", leave=False)):
            single_pool_probs_K_C = pool_probs_N_K_C[i_n]

            num_eval_samples = int(E * bootstrap_factor)
            eval_indices = torch.multinomial(
                torch.tensor(1.0).expand(E), num_samples=num_eval_samples, replacement=True
            )
            # For debugging:
            # num_eval_samples = E
            # eval_indices = torch.tensor(list(range(E)))

            sampled_eval_probs_F_C_K = eval_probs_E_C_K[eval_indices]

            joint_probs_F_C_C = get_joint_probs_N_C_C_transposed(sampled_eval_probs_F_C_K, single_pool_probs_K_C)
            weighted_nats_F_C_C = joint_probs_F_C_C * -torch.log(joint_probs_F_C_C)
            weighted_nats_F_C_C[torch.isnan(weighted_nats_F_C_C)] = 0.0
            avg_joint_entropy = weighted_nats_F_C_C.sum() / num_eval_samples
            del weighted_nats_F_C_C

            eval_label_uncertainty = eval_label_uncertainty_E[eval_indices].mean(dim=0, keepdim=False)
            total_scores_N[i_n] += eval_label_uncertainty - avg_joint_entropy

    return total_scores_N.to(device="cpu", non_blocking=True)

In [None]:
pool_log_probs_N_K_C = torch.log_softmax(torch.randn(7, 13, 3) * 100, dim=2)
eval_log_probs_E_K_C = torch.log_softmax(torch.randn(11, 13, 3) * 100, dim=2)

In [None]:
for dtype in (torch.float32, torch.double):
    print(
        get_real_naive_epig_scores(
            pool_log_probs_N_K_C=pool_log_probs_N_K_C,
            eval_log_probs_E_K_C=eval_log_probs_E_K_C,
            device="cuda",
            dtype=dtype,
        ),
        get_real_naive_epig_scores_stable(
            pool_log_probs_N_K_C=pool_log_probs_N_K_C,
            eval_log_probs_E_K_C=eval_log_probs_E_K_C,
            device="cuda",
            dtype=dtype,
        ),
    )

Entropy:   0%|          | 0/7 [00:00<?, ?it/s]

Entropy:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluation Set:   0%|          | 0/11 [00:00<?, ?it/s]

Entropy:   0%|          | 0/7 [00:00<?, ?it/s]

Entropy:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluation Set:   0%|          | 0/11 [00:00<?, ?it/s]

tensor([0.2072, 0.1844, 0.1629, 0.1929, 0.1682, 0.2395, 0.1918],
       dtype=torch.float64) tensor([0.2072, 0.1844, 0.1629, 0.1929, 0.1682, 0.2395, 0.1918],
       dtype=torch.float64)


Entropy:   0%|          | 0/7 [00:00<?, ?it/s]

Entropy:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluation Set:   0%|          | 0/11 [00:00<?, ?it/s]

Entropy:   0%|          | 0/7 [00:00<?, ?it/s]

Entropy:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluation Set:   0%|          | 0/11 [00:00<?, ?it/s]

tensor([0.2072, 0.1844, 0.1629, 0.1929, 0.1682, 0.2395, 0.1918],
       dtype=torch.float64) tensor([0.2072, 0.1844, 0.1629, 0.1929, 0.1682, 0.2395, 0.1918],
       dtype=torch.float64)


In [None]:
get_real_naive_epig_scores(
    pool_log_probs_N_K_C=pool_log_probs_N_K_C, eval_log_probs_E_K_C=eval_log_probs_E_K_C, device="cuda"
), get_real_naive_epig_scores_old(
    pool_log_probs_N_K_C=pool_log_probs_N_K_C, eval_log_probs_E_K_C=eval_log_probs_E_K_C, device="cuda"
)

Entropy:   0%|          | 0/7 [00:00<?, ?it/s]

Entropy:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluation Set:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluation Set:   0%|          | 0/11 [00:00<?, ?it/s]

(tensor([0.2072, 0.1844, 0.1629, 0.1929, 0.1682, 0.2395, 0.1918],
        dtype=torch.float64),
 tensor([0.2072, 0.1844, 0.1629, 0.1929, 0.1682, 0.2395, 0.1918]))

In [None]:
get_real_naive_epig_scores(
    bootstrap_type=BootstrapType.PER_POINT_BOOTSTRAP,
    pool_log_probs_N_K_C=pool_log_probs_N_K_C,
    eval_log_probs_E_K_C=eval_log_probs_E_K_C,
    device="cuda",
)

Entropy:   0%|          | 0/7 [00:00<?, ?it/s]

Entropy:   0%|          | 0/11 [00:00<?, ?it/s]

Pool Set:   0%|          | 0/7 [00:00<?, ?it/s]

tensor([0.1770, 0.2326, 0.1526, 0.2323, 0.1525, 0.2118, 0.1741],
       dtype=torch.float64)

In [None]:
get_real_naive_epig_scores(
    bootstrap_type=BootstrapType.SINGLE_BOOTSTRAP,
    pool_log_probs_N_K_C=pool_log_probs_N_K_C,
    eval_log_probs_E_K_C=eval_log_probs_E_K_C,
    device="cuda",
)

Entropy:   0%|          | 0/7 [00:00<?, ?it/s]

Entropy:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluation Set:   0%|          | 0/11 [00:00<?, ?it/s]

tensor([0.1971, 0.1683, 0.1589, 0.1959, 0.1634, 0.2245, 0.1531],
       dtype=torch.float64)

In [None]:
# slow

num_samples = 60000

with torch.no_grad():
    X = torch.log_softmax(torch.randn(num_samples, 100, 10), dim=2)
    Y = torch.log_softmax(torch.randn(num_samples, 100, 10), dim=2)
    get_real_naive_epig_scores(
        pool_log_probs_N_K_C=X,
        eval_log_probs_E_K_C=Y,
        dtype=torch.double,
        device="cuda",
    )

Evaluation Set:   0%|          | 0/60000 [00:00<?, ?it/s]

KeyboardInterrupt: 

# slow

num_samples = 6000

with torch.no_grad():
    get_real_naive_epig_scores_old(
        pool_log_probs_N_K_C=X,
        eval_log_probs_E_K_C=Y,
        dtype=torch.float,
        device="cuda",
    )

In [None]:
# slow

with torch.no_grad():
    get_real_naive_epig_scores(
        bootstrap_type=BootstrapType.PER_POINT_BOOTSTRAP,
        bootstrap_factor=1,
        pool_log_probs_N_K_C=torch.log_softmax(torch.randn(num_samples, 100, 10), dim=2),
        eval_log_probs_E_K_C=torch.log_softmax(torch.randn(num_samples, 100, 10), dim=2),
        device="cuda",
    )

In [None]:
# slow

with torch.no_grad():
    get_real_naive_epig_scores(
        bootstrap_type=BootstrapType.SINGLE_BOOTSTRAP,
        bootstrap_factor=0.85,
        pool_log_probs_N_K_C=torch.log_softmax(torch.randn(num_samples, 100, 10), dim=2),
        eval_log_probs_E_K_C=torch.log_softmax(torch.randn(num_samples, 100, 10), dim=2),
        device="cuda",
    )

## DirichletBALD (Unclear/TODO)

This is enspired by Energy-Based Models and Dirichlet distributions. Instead of working with the probabilities after the Softmax layer, we use the logits directly and view them log concentrations of a Dirichlet distribution.

We can combine this with MC Dropout to get different Dirichlet samples by averaging the log concentrations. This leads to the exact same computation as before (geometric averaging of the probabilities). 

This is different than fitting a Dirichlet distribution. Taking the log concentrations instead of probabilities does not throw away "density" information from the model.

We could recover a mutual information term using the Dirichlet assumption then?

(In general, it is not entirely clear to me how to combine them. One path could be to use a conjugate prior distribution and taking the mean/mode of that. The conjugate prior of a Dirichlet distribution is the Boojum distribution, which is quite complex and does not provide an analytical solution for computing its mean or mode.



In [None]:
# def get_dirichlet_bald_scores(logits_N_K_C: torch.Tensor, *,
#     dtype=None,
#     device=None) -> torch.Tensor: