In [None]:
# default_exp batchbald

In [None]:
#hide
import blackhc.project.script
from nbdev.showdoc import *

# BatchBALD Algorithm
> Greedy algorithm and score computation

First, we will implement two helper classes to compute conditional entropies $H[y_i|w]$ and entropies $H[y_i]$.

Then, we will implement BatchBALD and BALD.

In [None]:
# exports
from dataclasses import dataclass
from typing import List
import torch

from toma import toma

from batchbald_redux import joint_entropy

First, we are going to define a couple of sampled distributions to use for our testing our code.

We are going to use $K=20$ inference samples.

In [None]:
K=20

In [None]:
import numpy as np

def get_mixture_prob_dist(p1, p2, m):
    return (1. - m) * np.asarray(p1) + m * np.asarray(p2)


# p1 = [0.1, 0.2, 0.2, 0.5]
# p2 = [0.5, 0.2, 0.1, 0.2]
# y1_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

# p1 = [0.1, 0.6, 0.2, 0.1]
# p2 = [0.1, 0.4, 0.4, 0.1]
# y2_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

# p1 = [0.1, 0.1, 0.6, 0.2]
# p2 = [0.1, 0.1, 0.5, 0.3]
# y3_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

# p1 = [0.1, 0.1, 0.1, 0.7]
# p2 = [0.1, 0.5, 0.1, 0.3]
# y4_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.7, 0.1, 0.1, 0.1]
p2 = [0.3, 0.3, 0.2, 0.2]
y1_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.7, 0.1, 0.1]
p2 = [0.2, 0.3, 0.3, 0.2]
y2_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.7, 0.1]
p2 = [0.2, 0.2, 0.3, 0.3]
y3_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.1, 0.7]
p2 = [0.3, 0.2, 0.2, 0.3]
y4_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]


def nested_to_tensor(l):
    return torch.stack(list(map(torch.as_tensor, l)))


duplicates=2 

ys_ws = nested_to_tensor([y1_ws] * duplicates + [y2_ws] * duplicates + [y3_ws] * duplicates + [y4_ws] * duplicates)

In [None]:
ys_ws.shape

torch.Size([8, 20, 4])

## Computing conditional entropies and batch entropies


In [None]:
# exports

def compute_conditional_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape
    
    entropies_N = torch.empty(N, dtype=torch.double)
    @toma.execute.chunked(probs_N_K_C, 1024)
    def compute(probs_n_K_C, start:int, end: int):
        entropies_N[start:end].copy_(-torch.sum(probs_n_K_C*torch.log(probs_n_K_C), dim=(1,2))/K)
    
    return entropies_N


def compute_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape
    
    entropies_N = torch.empty(N, dtype=torch.double)
    @toma.execute.chunked(probs_N_K_C, 1024)
    def compute(probs_n_K_C, start:int, end: int):
        mean_probs_N_C = probs_N_K_C.mean(dim=1)        
        entropies_N[start:end].copy_(-torch.sum(mean_probs_N_C*torch.log(mean_probs_N_C), dim=1))
    
    return entropies_N
    

### Example usages

In [None]:
conditional_entropies = compute_conditional_entropy(ys_ws)

print(conditional_entropies)

assert np.allclose(conditional_entropies, [1.2069, 1.2069, 1.2069, 1.2069, 1.2069, 1.2069, 1.2069, 1.2069],
                   atol=0.01)

tensor([1.2069, 1.2069, 1.2069, 1.2069, 1.2069, 1.2069, 1.2069, 1.2069],
       dtype=torch.float64)


In [None]:
entropies = compute_entropy(ys_ws)

print(entropies)

assert np.allclose(entropies,
    [1.2376, 1.2376, 1.2376, 1.2376, 1.2376, 1.2376, 1.2376, 1.2376],
                   atol=0.01)


tensor([1.2376, 1.2376, 1.2376, 1.2376, 1.2376, 1.2376, 1.2376, 1.2376],
       dtype=torch.float64)


## BatchBALD

To compute BatchBALD exactly for a candidate batch, we'd have to compute $I[(y_b)_B;w] = H[(y_b)_B] - H[(y_b)_B|w]$.

As the $y_b$ are independent given $w$, we can simplify $H[(y_b)_B|w] = \sum_b H[y_b|w]$.

Furthermore, we use a greedy algorithm to build up the candidate batch, so $y_1,\dots,y_{B-1}$ will stay fixed as we determine $y_{B}$. We compute
$H[(y_b)_{B-1}, y_i] - H[y_i|w]$ for each pool element $y_i$ and add the highest scorer as $y_{B}$.

We don't utilize the last optimization here in order to compute the actual scores.


### In the paper

![BatchBALD algorithm in the paper](batchbald_algorithm.png)


### Implementation

In [None]:
# exports


@dataclass
class CandidateBatch:
    scores: List[float]
    indices: List[int]


def get_batchbald_batch(probs_N_K_C: torch.Tensor,
                        batch_size: int,
                        num_samples: int,
                        dtype=None,
                        device=None) -> CandidateBatch:
    N, K, C = probs_N_K_C.shape
    
    batch_size = min(batch_size, N)
    
    candidate_indices = []
    candidate_scores = []
    
    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    batch_joint_entropy = joint_entropy.DynamicJointEntropy(num_samples,
                                                            batch_size - 1,
                                                            K,
                                                            C,
                                                            dtype=dtype,
                                                            device=device)

    conditional_entropies_N = compute_conditional_entropy(probs_N_K_C)

    # We always keep these on the CPU.
    scores_N = torch.empty(N, dtype=torch.double)
    
    for i in range(batch_size):
        if i > 0:
            latest_index = candidate_indices[-1]
            batch_joint_entropy.add_variables(
                probs_N_K_C[latest_index:latest_index + 1])
        
        shared_conditinal_entropies = conditional_entropies_N[candidate_indices].sum()
        
        batch_joint_entropy.compute_batch(probs_N_K_C,
                                          output_entropies_B=scores_N)
        
        scores_N -= conditional_entropies_N + shared_conditinal_entropies
        scores_N[candidate_indices] = -float('inf')
        
        print(scores_N)

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

    return CandidateBatch(candidate_scores, candidate_indices)

### Example usages:

In [None]:
get_batchbald_batch(ys_ws.double(), 4, 1000, dtype=torch.double)

tensor([0.0307, 0.0307, 0.0307, 0.0307, 0.0307, 0.0307, 0.0307, 0.0307],
       dtype=torch.float64)
tensor([0.0596, 0.0596, 0.0596, 0.0596, 0.0596, 0.0596, 0.0596,   -inf],
       dtype=torch.float64)
tensor([0.0869, 0.0869, 0.0869, 0.0869, 0.0869, 0.0869,   -inf,   -inf],
       dtype=torch.float64)
tensor([0.1128, 0.1128, 0.1128, 0.1128, 0.1128,   -inf,   -inf,   -inf],
       dtype=torch.float64)


CandidateBatch(scores=[0.030715639666234917, 0.059619586271582925, 0.08691070514744759, 0.11275304532467789], indices=[7, 6, 5, 4])

## Encore: BALD

BALD is the same as BatchBALD, except that we pick candidates independently from each other.

In [None]:
def get_bald_batch(probs_N_K_C: torch.Tensor,
                   batch_size: int,
                   num_samples: int,
                   dtype=None,
                   device=None) -> CandidateBatch:
    N, K, C = probs_N_K_C.shape
    
    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    scores_N = compute_entropy(probs_N_K_C)
    scores_N -= compute_conditional_entropy(probs_N_K_C)
    
    candiate_scores, candidate_indices = torch.topk(scores_N, batch_size)
    
    return CandidateBatch(candiate_scores.tolist(), candidate_indices.tolist())

### Example usages:

In [None]:
get_bald_batch(ys_ws.double(), 4, 1000)

CandidateBatch(scores=[0.030715639666234917, 0.030715639666234917, 0.030715639666234917, 0.030715639666234917], indices=[2, 6, 3, 7])