In [None]:
# hide

import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/bald-ical/src to paths
Switched to directory /home/blackhc/PycharmProjects/bald-ical
%load_ext autoreload
%autoreload 2


# Acquisition Function: CoreSet
> Greedy algorithm and score computation

First, we will implement two helper classes to compute conditional entropies $H[y_i|w]$ and entropies $H[y_i]$. 
Then, we will implement BatchBALD and BALD.

In [None]:
import math

import numpy as np
import torch
from blackhc.progress_bar import create_progress_bar
from toma import toma

from batchbald_redux.acquisition_functions.coreset import * 
from batchbald_redux.joint_entropy import *

We are going to define a couple of sampled distributions to use for our testing our code.

$K=20$ means 20 inference samples.

In [None]:
K = 20

In [None]:
def get_mixture_prob_dist(p1, p2, m):
    return (1.0 - m) * np.asarray(p1) + m * np.asarray(p2)


p1 = [0.7, 0.1, 0.1, 0.1]
p2 = [0.3, 0.3, 0.2, 0.2]
y1_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.7, 0.1, 0.1]
p2 = [0.2, 0.3, 0.3, 0.2]
y2_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.7, 0.1]
p2 = [0.2, 0.2, 0.3, 0.3]
y3_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.1, 0.7]
p2 = [0.3, 0.2, 0.2, 0.3]
y4_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]


def nested_to_tensor(l):
    return torch.stack(list(map(torch.as_tensor, l)))


ys_ws = nested_to_tensor([y1_ws, y2_ws, y3_ws, y4_ws])

  return torch.stack(list(map(torch.as_tensor, l)))


In [None]:
# hide

p = [0.25, 0.25, 0.25, 0.25]
yu_ws = [p for m in range(K)]
yus_ws = nested_to_tensor([yu_ws] * 4)

In [None]:
ys_ws.shape

torch.Size([4, 20, 4])

However, our neural networks usually use a `log_softmax` as final layer. To avoid having to call `.exp_()`, which is easy to miss and annoying to debug, we will instead use a version that uses `log_probs` instead of `probs`.

In [None]:
# hide

# Make sure everything is computed correctly.
assert np.allclose(compute_conditional_entropy(yus_ws.log()), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)
assert np.allclose(compute_entropy(yus_ws.log()), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

### Examples

In [None]:
conditional_entropies = compute_conditional_entropy(ys_ws.log())

print(conditional_entropies)

assert np.allclose(conditional_entropies, [1.2069, 1.2069, 1.2069, 1.2069], atol=0.01)

Conditional Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([1.2069, 1.2069, 1.2069, 1.2069], dtype=torch.float64)


In [None]:
entropies = compute_entropy(ys_ws.log())

print(entropies)

assert np.allclose(entropies, [1.2376, 1.2376, 1.2376, 1.2376], atol=0.01)

Entropy:   0%|          | 0/4 [00:00<?, ?it/s]

tensor([1.2376, 1.2376, 1.2376, 1.2376], dtype=torch.float64)


## Information Gain

(Following the new notation.)

Instead of computing $I[Y;\omega|x]$, we use our knowledge of the labels and compute: $$I[y;\omega|x]= H[y|x] - \mathbb{E}_{p(\omega|y,x)} H[y|x,\omega].$$

In [None]:
get_coreset_bald_scores(ys_ws.log().double(), torch.tensor([0, 1, 2, 3])), [
    get_coreset_bald_scores(ys_ws.log().double(), torch.tensor([i, i, i, i])) for i in range(3)
]

(tensor([0.0300, 0.0300, 0.0300, 0.0300], dtype=torch.float64),
 [tensor([0.0300, 0.0207, 0.0207, 0.0474], dtype=torch.float64),
  tensor([0.0474, 0.0300, 0.0207, 0.0207], dtype=torch.float64),
  tensor([0.0207, 0.0474, 0.0300, 0.0207], dtype=torch.float64)])

## Batch Information Gain

(Following the new notation.)

The batch version of this acquisition function can be computed more easily:

$$ \operatorname{I}[(y)_B;\omega|(x)_B] = \operatorname{H}[(y)_B|(x)_B] - \mathbb{E}_{p(\omega|(y)_B, (x)_B)} \operatorname{H}[(y)_B|(x)_B, \omega], $$

where $p(\omega|(y)_B, (x)_B) = \frac{ p((y)_B| (x)_B, \omega) p(\omega) }{ p((y)_B| (x)_B) }$ as usual, and we make use of the independence of the $(y)_B$ given $\omega$.

We can make this efficient for computing scores in parallel by using:
$$p((y)_B|(x)_B, \omega) = p(y_B|x_B, \omega) \; p((y)_{B-1}|(x)_{B-1}, \omega).$$

### Do we have sub-modularity?

Unclear.

In [None]:
def get_batch_coreset_bald_batch_simpler(
    log_probs_N_K_C: torch.Tensor, labels_N: torch.Tensor, *, batch_size: int, dtype=None, device=None
) -> CandidateBatch:
    # We want to compute (note this does not follow the notation from below):
    # CoreSetBALD = H[y_1, ..., y_n ] - E_p(w) p(y_1, ..., y_n | w) / p(y_1, ..., y_n) H[y_1, ..., y_n | w]
    # H[y_1, ..., y_n | w] = H[y_1, ..., y_{n-1} | w] + H[y_n | w] because y_i _||_ y_j | w
    N, K, C = log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    labels_N_1_1 = labels_N[:, None, None]
    log_probs_N_K = (
        joint_entropy.gather_expand(log_probs_N_K_C, dim=2, index=labels_N_1_1)
        .squeeze(2)
        .to(dtype=dtype, device=device)
    )

    # p((y)_{B-1}|(x)_{B-1}, \omega)
    log_probs_conditional_joint_batch_K = torch.zeros_like(log_probs_N_K[0], dtype=dtype, device=device)

    for i in with_progress_bar(range(batch_size), tqdm_args=dict(desc="BatchCoreSetBALD", leave=False)):
        # p((y)_B|(x)_B, \omega) = p(y_B|x_B, \omega) * p((y)_{B-1}|(x)_{B-1}, \omega)
        log_prob_conditional_joint_N_K = log_probs_N_K + log_probs_conditional_joint_batch_K[None, :]

        # Marginalize over w (but using sum not mean):
        # p((y)_B|(x)_B) = E_p(\omega) p((y)_B|(x)_B, \omega)
        log_prob_joint_N_1 = log_prob_conditional_joint_N_K.logsumexp(dim=1, keepdim=True) - np.log(K)

        # \frac{ p((y)_B| (x)_B, \omega) }{ p((y)_B| (x)_B) }
        log_ratio_N_K = log_prob_conditional_joint_N_K - log_prob_joint_N_1
        conditional_entropy_joint_N = -torch.mean(log_ratio_N_K.exp() * log_prob_conditional_joint_N_K, dim=1)
        entropy_joint_N = -log_prob_joint_N_1.squeeze(1)
        scores_N = entropy_joint_N - conditional_entropy_joint_N

        # Select candidate
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

        # Update log_probs_conditional_joint_batch_K
        log_probs_conditional_joint_batch_K = log_prob_conditional_joint_N_K[candidate_index]

    return CandidateBatch(candidate_scores, candidate_indices)

In [None]:
get_coreset_bald_scores(ys_ws.log().double(), torch.tensor([0, 1, 2, 3])).numpy(), [
    get_coreset_bald_scores(ys_ws.log().double(), torch.tensor([i, i, i, i])) for i in range(3)
]

(array([0.03002132, 0.03002132, 0.03002132, 0.03002132]),
 [tensor([0.0300, 0.0207, 0.0207, 0.0474], dtype=torch.float64),
  tensor([0.0474, 0.0300, 0.0207, 0.0207], dtype=torch.float64),
  tensor([0.0207, 0.0474, 0.0300, 0.0207], dtype=torch.float64)])

In [None]:
ys_ws.shape

torch.Size([4, 20, 4])

In [None]:
get_batch_coreset_bald_batch(ys_ws.log().double(), torch.tensor([0, 1, 2, 3]), batch_size=4, dtype=torch.double)

BatchCoreSetBALD:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030021323375763576, 0.10871562110954991, 0.2168431672489275, 0.3375447132429139], indices=[0, 1, 2, 3])

In [None]:
get_batch_coreset_bald_batch_simpler(ys_ws.log().double(), torch.tensor([0, 1, 2, 3]), batch_size=4, dtype=torch.double)

BatchCoreSetBALD:   0%|          | 0/4 [00:00<?, ?it/s]

CandidateBatch(scores=[0.030021323375763687, 0.10871562110954991, 0.2168431672489275, 0.3375447132429139], indices=[0, 1, 2, 3])

## CoreSet-PIG & Coreset-PIG-BALD

Combining EIG with CoreSets to use $I[y_{eval}; y_{batch} | x_{eval}; x_{batch}, D_{train}]$.

This is really easy to compute as $H[y_{batch} | x_{batch}, D_{train}] - H[y_{batch} | y_{eval}, x_{eval}; x_{batch}, D_{train}]$.

In [None]:
get_coreset_eig_scores(
    training_log_probs_N_K_C=ys_ws.log().double(),
    eval_log_probs_N_K_C=ys_ws.log().double(),
    labels_N=torch.tensor([0, 1, 2, 3]),
    dtype=torch.double,
)

tensor([0., 0., 0., 0.], dtype=torch.float64)

In [None]:
get_coreset_eig_bald_scores(
    training_log_probs_N_K_C=ys_ws.log().double(),
    eval_log_probs_N_K_C=ys_ws.log().double(),
    labels_N=torch.tensor([0, 1, 2, 3]),
    dtype=torch.double,
)

tensor([0., 0., 0., 0.], dtype=torch.float64)