In [None]:
# hide

import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


# Acquisition Function: EPIG
> Greedy algorithm and score computation

First, we will implement two helper classes to compute conditional entropies $H[y_i|w]$ and entropies $H[y_i]$. 
Then, we will implement BatchBALD and BALD.

In [None]:
import math

import numpy as np
import torch
from blackhc.progress_bar import create_progress_bar
from toma import toma

from batchbald_redux.acquisition_functions.epig import * 
from batchbald_redux.joint_entropy import *

We are going to define a couple of sampled distributions to use for our testing our code.

$K=20$ means 20 inference samples.

In [None]:
K = 20

In [None]:
def get_mixture_prob_dist(p1, p2, m):
    return (1.0 - m) * np.asarray(p1) + m * np.asarray(p2)


p1 = [0.7, 0.1, 0.1, 0.1]
p2 = [0.3, 0.3, 0.2, 0.2]
y1_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.7, 0.1, 0.1]
p2 = [0.2, 0.3, 0.3, 0.2]
y2_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.7, 0.1]
p2 = [0.2, 0.2, 0.3, 0.3]
y3_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.1, 0.7]
p2 = [0.3, 0.2, 0.2, 0.3]
y4_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]


def nested_to_tensor(l):
    return torch.stack(list(map(torch.as_tensor, l)))


ys_ws = nested_to_tensor([y1_ws, y2_ws, y3_ws, y4_ws])

In [None]:
# hide

p = [0.25, 0.25, 0.25, 0.25]
yu_ws = [p for m in range(K)]
yus_ws = nested_to_tensor([yu_ws] * 4)

In [None]:
ys_ws.shape

torch.Size([4, 20, 4])

However, our neural networks usually use a `log_softmax` as final layer. To avoid having to call `.exp_()`, which is easy to miss and annoying to debug, we will instead use a version that uses `log_probs` instead of `probs`.

In [None]:
# hide

# Make sure everything is computed correctly.
assert np.allclose(compute_conditional_entropy(yus_ws.log()), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)
assert np.allclose(compute_entropy(yus_ws.log()), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

### Examples

In [None]:
conditional_entropies = compute_conditional_entropy(ys_ws.log())

print(conditional_entropies)

assert np.allclose(conditional_entropies, [1.2069, 1.2069, 1.2069, 1.2069], atol=0.01)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([1.2069, 1.2069, 1.2069, 1.2069], dtype=torch.float64)


In [None]:
entropies = compute_entropy(ys_ws.log())

print(entropies)

assert np.allclose(entropies, [1.2376, 1.2376, 1.2376, 1.2376], atol=0.01)

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([1.2376, 1.2376, 1.2376, 1.2376], dtype=torch.float64)


## EPIG-BALD

The computation for EPIG-BALD is simple. We need to keep track of two separate (Batch)BALD terms:

$$\mathrm{I}\left[(y)_{B} ; \omega \mid(x)_{B}, D_{T}\right]-\mathrm{I}\left[(y)_{B} ; \omega \mid(x)_{B}, D_{U} \cup D_{T}\right].$$


### Example

#### Pleasing example of the case when predictions match (full overlap)

In [None]:
get_batch_eval_bald_batch(
    ys_ws.log().double(), ys_ws.log().double(), batch_size=4, num_samples=1000, dtype=torch.double
)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.0, 0.0, 0.0, 0.0], indices=[0, 1, 2, 3])

In [None]:
get_batch_eval_bald_batch(
    ys_ws.log().double(), torch.zeros_like(ys_ws).double(), batch_size=4, num_samples=1000, dtype=torch.double
)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

BatchBALD:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.0869107051474467, 0.11275304532467878], indices=[1, 0, 2, 3])

## Additional EPIG-BALD variants

Instead of using BatchBALD, let's compute BALD directly and use either the top-k, TopRandom or Thomp

In [None]:
get_eval_bald_batch(ys_ws.log().double(), ys_ws.log().double(), batch_size=4, dtype=torch.double)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.0, 0.0, 0.0, 0.0], indices=[2, 3, 0, 1])

In [None]:
get_top_random_eval_bald_batch(
    ys_ws.log().double(), ys_ws.log().double(), batch_size=4, num_classes=10, dtype=torch.double
)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.0, 0.0, 0.0, 0.0], indices=[1, 0, 3, 2])

## EPIG

As part of an ablation (and to see how it performs), we can also compute the ICAL score.

In [None]:
get_eig_scores(ys_ws.log().double(), ys_ws.log().double())

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([0., 0., 0., 0.], dtype=torch.float64)

In [None]:
get_batch_eig_batch(ys_ws.log().double(), ys_ws.log().double(), batch_size=4, num_samples=1000, dtype=torch.double)

BatchBALD:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.0, 0.0, 0.0, 0.0], indices=[0, 1, 2, 3])

## Real EPIG

Implement $I[Y_{acq} ; Y_{eval} \mid x_{acq} ; X_{eval},  D_{train}]$.



In [None]:
def get_joint_probs_N_C_C_old(pool_probs_N_K_C: torch.Tensor, single_eval_probs_K_C: torch.Tensor):
    K = single_eval_probs_K_C.shape[0]

    pool_log_probs_N_C_K = pool_probs_N_K_C.transpose(1, 2)
    joint_probs_N_C_C = pool_log_probs_N_C_K @ single_eval_probs_K_C / K
    return joint_probs_N_C_C


def get_real_naive_epig_scores_old(
    *, pool_log_probs_N_K_C: torch.Tensor, eval_log_probs_E_K_C: torch.Tensor, dtype=None, device=None
) -> torch.Tensor:
    """Implements naive EPIG: I[Y_acq; Y_eval | x_acq, X_eval]."""
    N, K, C = pool_log_probs_N_K_C.shape
    E, _, _ = eval_log_probs_E_K_C.shape
    assert (
        pool_log_probs_N_K_C.shape[1:] == pool_log_probs_N_K_C.shape[1:]
    ), "{pool_log_probs_N_K_C.shape[1:]} != {pool_log_probs_N_K_C.shape[1:]}"

    pool_probs_N_K_C = pool_log_probs_N_K_C.to(dtype=dtype, device=device).exp()
    eval_probs_E_K_C = eval_log_probs_E_K_C.to(dtype=dtype, device=device).exp()

    pool_probs_N_C = torch.mean(pool_probs_N_K_C, dim=1, keepdim=False)

    total_scores_N = torch.zeros((N,), dtype=dtype, device="cpu")
    for i_e in with_progress_bar(range(E), tqdm_args=dict(desc="Evaluation Set", leave=False)):
        single_eval_probs_K_C = eval_probs_E_K_C[i_e]

        joint_probs_N_C_C = get_joint_probs_N_C_C_old(pool_probs_N_K_C, single_eval_probs_K_C)

        single_eval_probs_C = torch.mean(single_eval_probs_K_C, dim=0, keepdim=False)

        nats_N_C_C = (
            -torch.log(single_eval_probs_C)[None, None, :]
            - torch.log(pool_probs_N_C)[:, :, None]
            + torch.log(joint_probs_N_C_C)
        )

        weighted_nats_N_C_C = nats_N_C_C * joint_probs_N_C_C
        weighted_nats_N_C_C[torch.isnan(weighted_nats_N_C_C)] = 0.0
        scores_N = weighted_nats_N_C_C.sum((1, 2), keepdim=False)

        total_scores_N += scores_N.to(device="cpu", non_blocking=True)

    total_scores_N /= E

    return total_scores_N

In [None]:
# @torch.no_grad()
# def logmatmulexp(log_A: torch.Tensor, log_B: torch.Tensor) -> torch.Tensor:
#     """Given matrix log_A of shape (batch...) ϴ×R and matrix log_B of shape R×I, calculates
#     (log_A.exp() @ log_B.exp()).log() and its backward in a numerically stable way."""
#     batch_shape = list(log_A.shape[:-2])
#     ϴ, R = log_A.shape[-2:]
#     I = log_B.shape[-1]
#     assert log_B.shape == (R, I)
#     log_A_expanded = log_A.unsqueeze(-1).expand(batch_shape + [ϴ, R, I])
#     log_B_expanded = log_B.unsqueeze(-3).expand((ϴ, R, I))
#     log_pairwise_products = log_A_expanded + log_B_expanded  # shape: (ϴ, R, I)
#     return torch.logsumexp(log_pairwise_products, dim=-2)


@torch.no_grad()
def logmatmulexp(log_A: torch.Tensor, log_B: torch.Tensor) -> torch.Tensor:
    """Given matrix log_A of shape (batch...) ϴ×R and matrix log_B of shape R×I, calculates
    (log_A.exp() @ log_B.exp()).log() and its backward in a numerically stable way."""
    max_A = torch.max(log_A, axis=-1, keepdim=True)[0]
    max_B = torch.max(log_B, axis=-2, keepdim=True)[0]
    C = torch.log((log_A - max_A).exp() @ (log_B - max_B).exp()) + max_A + max_B
    return C


@torch.no_grad()
def get_real_naive_epig_scores_stable(
    *,
    bootstrap_type=BootstrapType.NO_BOOTSTRAP,
    bootstrap_factor=1.0,
    pool_log_probs_N_K_C: torch.Tensor,
    eval_log_probs_E_K_C: torch.Tensor,
    dtype=None,
    device=None,
) -> torch.Tensor:
    """Implements naive EPIG: I[Y_acq; Y_eval | x_acq, X_eval]."""
    # I[Y_acq; Y_eval | x_acq, X_eval] = H[Y_acq | x_acq] + E_p(x_eval)[H[Y_eval | x_eval] - H[Y_acq, Y_eval | x_acq, x_eval]]
    N, K, C = pool_log_probs_N_K_C.shape
    E, _, _ = eval_log_probs_E_K_C.shape
    assert (
        pool_log_probs_N_K_C.shape[1:] == pool_log_probs_N_K_C.shape[1:]
    ), "{pool_log_probs_N_K_C.shape[1:]} != {pool_log_probs_N_K_C.shape[1:]}"

    pool_entropies_N = compute_entropy(pool_log_probs_N_K_C).to(device=device)

    total_joint_entropies_N = torch.zeros((N,), dtype=dtype, device=device)

    if bootstrap_type != BootstrapType.PER_POINT_BOOTSTRAP:
        eval_label_uncertainty = compute_entropy(eval_log_probs_E_K_C).mean(dim=0, keepdim=False)

        if bootstrap_type == BootstrapType.NO_BOOTSTRAP:
            eval_range = range(E)
        elif bootstrap_type == BootstrapType.SINGLE_BOOTSTRAP:
            num_eval_samples = int(E * bootstrap_factor)
            eval_range = torch.multinomial(torch.tensor(1.0).expand(E), num_samples=num_eval_samples, replacement=True)
        else:
            raise ValueError(f"Unknown bootstrap {bootstrap_type}")

        pool_log_probs_N_C_K = pool_log_probs_N_K_C.transpose(1, 2).contiguous().to(dtype=dtype, device=device)
        eval_log_probs_E_K_C = eval_log_probs_E_K_C.to(dtype=dtype, device=device)

        for i_e in with_progress_bar(eval_range, tqdm_args=dict(desc="Evaluation Set", leave=False)):
            single_eval_log_probs_K_C = eval_log_probs_E_K_C[i_e]

            joint_probs_N_C_C = logmatmulexp(pool_log_probs_N_C_K, single_eval_log_probs_K_C) - np.log(K)
            weighted_nats_N_C_C = joint_probs_N_C_C * -torch.exp(joint_probs_N_C_C)
            weighted_nats_N_C_C[torch.isnan(weighted_nats_N_C_C)] = 0.0
            joint_entropy_N = weighted_nats_N_C_C.sum((1, 2), keepdim=False)
            del weighted_nats_N_C_C

            total_joint_entropies_N += joint_entropy_N

        total_scores_N = pool_entropies_N - total_joint_entropies_N / E + eval_label_uncertainty
    #     elif bootstrap_type == BootstrapType.PER_POINT_BOOTSTRAP:
    #         eval_label_uncertainty_E = compute_entropy(eval_log_probs_E_K_C)

    #         total_scores_N = pool_entropies_N

    #         for i_n in with_progress_bar(range(N), tqdm_args=dict(desc="Pool Set", leave=False)):
    #             single_pool_probs_K_C = pool_probs_N_K_C[i_n]

    #             num_eval_samples = int(E * bootstrap_factor)
    #             eval_indices = torch.multinomial(
    #                 torch.tensor(1.0).expand(E), num_samples=num_eval_samples, replacement=True
    #             )
    #             # For debugging:
    #             # num_eval_samples = E
    #             # eval_indices = torch.tensor(list(range(E)))

    #             sampled_eval_probs_F_K_C = eval_probs_E_K_C[eval_indices]

    #             joint_probs_F_C_C = get_joint_probs_N_C_C(sampled_eval_probs_F_K_C, single_pool_probs_K_C)
    #             weighted_nats_F_C_C = joint_probs_F_C_C * -torch.log(joint_probs_F_C_C)
    #             avg_joint_entropy = weighted_nats_F_C_C.sum() / num_eval_samples
    #             del weighted_nats_F_C_C

    #             eval_label_uncertainty = eval_label_uncertainty_E[eval_indices].mean(dim=0, keepdim=False)
    #             total_scores_N[i_n] += eval_label_uncertainty - avg_joint_entropy

    return total_scores_N.to(device="cpu", non_blocking=True)

In [None]:
pool_log_probs_N_K_C = torch.log_softmax(torch.randn(7, 13, 3) * 100, dim=2)
eval_log_probs_E_K_C = torch.log_softmax(torch.randn(11, 13, 3) * 100, dim=2)

In [None]:
for dtype in (torch.float32, torch.double):
    print(
        get_real_naive_epig_scores(
            pool_log_probs_N_K_C=pool_log_probs_N_K_C,
            eval_log_probs_E_K_C=eval_log_probs_E_K_C,
            device="cuda",
            dtype=dtype,
        ),
        get_real_naive_epig_scores_stable(
            pool_log_probs_N_K_C=pool_log_probs_N_K_C,
            eval_log_probs_E_K_C=eval_log_probs_E_K_C,
            device="cuda",
            dtype=dtype,
        ),
    )

Entropy:   0%|          | 0/7 [00:00<?, ?it/s]

Entropy:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluation Set:   0%|          | 0/11 [00:00<?, ?it/s]

Entropy:   0%|          | 0/7 [00:00<?, ?it/s]

Entropy:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluation Set:   0%|          | 0/11 [00:00<?, ?it/s]

tensor([0.2034, 0.2925, 0.2823, 0.2413, 0.1856, 0.1703, 0.1390],
       dtype=torch.float64) tensor([0.2034, 0.2925, 0.2823, 0.2413, 0.1856, 0.1703, 0.1390],
       dtype=torch.float64)


Entropy:   0%|          | 0/7 [00:00<?, ?it/s]

Entropy:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluation Set:   0%|          | 0/11 [00:00<?, ?it/s]

Entropy:   0%|          | 0/7 [00:00<?, ?it/s]

Entropy:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluation Set:   0%|          | 0/11 [00:00<?, ?it/s]

tensor([0.2034, 0.2925, 0.2823, 0.2413, 0.1856, 0.1703, 0.1390],
       dtype=torch.float64) tensor([0.2034, 0.2925, 0.2823, 0.2413, 0.1856, 0.1703, 0.1390],
       dtype=torch.float64)


In [None]:
get_real_naive_epig_scores(
    pool_log_probs_N_K_C=pool_log_probs_N_K_C, eval_log_probs_E_K_C=eval_log_probs_E_K_C, device="cuda"
), get_real_naive_epig_scores_old(
    pool_log_probs_N_K_C=pool_log_probs_N_K_C, eval_log_probs_E_K_C=eval_log_probs_E_K_C, device="cuda"
)

Entropy:   0%|          | 0/7 [00:00<?, ?it/s]

Entropy:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluation Set:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluation Set:   0%|          | 0/11 [00:00<?, ?it/s]

(tensor([0.2034, 0.2925, 0.2823, 0.2413, 0.1856, 0.1703, 0.1390],
        dtype=torch.float64),
 tensor([0.2034, 0.2925, 0.2823, 0.2413, 0.1856, 0.1703, 0.1390]))

In [None]:
get_real_naive_epig_scores(
    bootstrap_type=BootstrapType.PER_POINT_BOOTSTRAP,
    pool_log_probs_N_K_C=pool_log_probs_N_K_C,
    eval_log_probs_E_K_C=eval_log_probs_E_K_C,
    device="cuda",
)

Entropy:   0%|          | 0/7 [00:00<?, ?it/s]

Entropy:   0%|          | 0/11 [00:00<?, ?it/s]

Pool Set:   0%|          | 0/7 [00:00<?, ?it/s]

tensor([0.1770, 0.2326, 0.1526, 0.2323, 0.1525, 0.2118, 0.1741],
       dtype=torch.float64)

In [None]:
get_real_naive_epig_scores(
    bootstrap_type=BootstrapType.SINGLE_BOOTSTRAP,
    pool_log_probs_N_K_C=pool_log_probs_N_K_C,
    eval_log_probs_E_K_C=eval_log_probs_E_K_C,
    device="cuda",
)

Entropy:   0%|          | 0/7 [00:00<?, ?it/s]

Entropy:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluation Set:   0%|          | 0/11 [00:00<?, ?it/s]

tensor([0.1971, 0.1683, 0.1589, 0.1959, 0.1634, 0.2245, 0.1531],
       dtype=torch.float64)

In [None]:
# slow

num_samples = 60000

with torch.no_grad():
    X = torch.log_softmax(torch.randn(num_samples, 100, 10), dim=2)
    Y = torch.log_softmax(torch.randn(num_samples, 100, 10), dim=2)
    get_real_naive_epig_scores(
        pool_log_probs_N_K_C=X,
        eval_log_probs_E_K_C=Y,
        dtype=torch.double,
        device="cuda",
    )

Evaluation Set:   0%|          | 0/60000 [00:00<?, ?it/s]

KeyboardInterrupt: 

# slow

num_samples = 6000

with torch.no_grad():
    get_real_naive_epig_scores_old(
        pool_log_probs_N_K_C=X,
        eval_log_probs_E_K_C=Y,
        dtype=torch.float,
        device="cuda",
    )

In [None]:
# slow

with torch.no_grad():
    get_real_naive_epig_scores(
        bootstrap_type=BootstrapType.PER_POINT_BOOTSTRAP,
        bootstrap_factor=1,
        pool_log_probs_N_K_C=torch.log_softmax(torch.randn(num_samples, 100, 10), dim=2),
        eval_log_probs_E_K_C=torch.log_softmax(torch.randn(num_samples, 100, 10), dim=2),
        device="cuda",
    )

In [None]:
# slow

with torch.no_grad():
    get_real_naive_epig_scores(
        bootstrap_type=BootstrapType.SINGLE_BOOTSTRAP,
        bootstrap_factor=0.85,
        pool_log_probs_N_K_C=torch.log_softmax(torch.randn(num_samples, 100, 10), dim=2),
        eval_log_probs_E_K_C=torch.log_softmax(torch.randn(num_samples, 100, 10), dim=2),
        device="cuda",
    )