In [None]:
# default_exp batchbald

In [None]:
#hide
import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/blackhc.batchbald/src to paths
Switched to directory /home/blackhc/PycharmProjects/blackhc.batchbald
%load_ext autoreload
%autoreload 2


# BatchBALD Algorithm
> and compute BatchBALD scores

First, we will implement two helper classes to compute conditional entropies $H[y_i|w]$ and entropies $H[y_i]$.

Then, we will implement BatchBALD and then BALD as encore.

In [None]:
# exports
from dataclasses import dataclass
from typing import List
import torch

from toma import toma

from batchbald_redux import joint_entropy

## Computing conditional entropies and batch entropies


In [None]:
# exports

def compute_conditional_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape
    
    entropies_N = torch.empty(N, dtype=torch.double)
    @toma.chunked(probs_N_K_C, 1024)
    def compute(probs_n_K_C, start:int, end: int):
        entropies_N[start:end].copy_(-torch.sum(probs_n_K_C*torch.log(probs_n_K_C), dim=(1,2))/K)
    
    return entropies_N


def compute_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape
    
    entropies_N = torch.empty(N, dtype=torch.double)
    @toma.chunked(probs_N_K_C, 1024)
    def compute(probs_n_K_C, start:int, end: int):
        mean_probs_N_C = probs_N_K_C.mean(dim=1)        
        entropies_N[start:end].copy_(-torch.sum(mean_probs_N_C*torch.log(mean_probs_N_C), dim=1))
    
    return entropies_N
    

## Example usages

## BatchBALD

To compute BatchBALD exactly for a candidate batch, we'd have to compute $I[(y_b)_B;w] = H[(y_b)_B] - H[(y_b)_B|w]$.

As the $y_b$ are independent given $w$, we can simplify $H[(y_b)_B|w] = \sum_b H[y_b|w]$.

Furthermore, we use a greedy algorithm to build up the candidate batch, so $y_1,\dots,y_{B-1}$ will stay fixed as we determine $y_{B}$. We compute
$H[(y_b)_{B-1}, y_i] - H[y_i|w]$ for each pool element $y_i$ and add the highest scorer as $y_{B}$.

We don't utilize the last optimization here in order to compute the actual scores.


### In the paper

![BatchBALD algorithm in the paper](batchbald_algorithm.png)


### Implementation

In [None]:
# exports


@dataclass
class CandidateBatch:
    scores: List[float]
    indices: List[int]


def get_batchbald_batch(probs_N_K_C: torch.Tensor,
                        batch_size: int,
                        num_samples: int,
                        dtype=None,
                        device=None) -> CandidateBatch:
    N, K, C = probs_N_K_C.shape
    
    batch_size = min(batch_size, N)
    
    candidate_indices = []
    candidate_scores = []
    
    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    batch_joint_entropy = joint_entropy.DynamicJointEntropy(num_samples,
                                                            batch_size - 1,
                                                            K,
                                                            C,
                                                            dtype=dtype,
                                                            device=device)

    conditional_entropies_N = compute_conditional_entropy(probs_N_K_C)

    # We always keep these on the CPU.
    scores_N = torch.empty(N, dtype=torch.double)

    for i in range(batch_size):
        if i > 0:
            latest_index = candidate_indices[-1]
            batch_joint_entropy.add_variables(
                probs_N_K_C[latest_index:latest_index + 1])
            shared_conditinal_entropies = conditional_entropies_N[candidate_indices].sum()
        else:
            shared_conditinal_entropies = 0.
            
        batch_joint_entropy.compute_batch(probs_N_K_C,
                                          output_entropies_B=scores_N)

        scores_N -= conditional_entropies_N + shared_conditinal_entropies
        scores_N[candidate_indices] = -float('inf')

        candidate_score, candidate_index = scores_N.max()

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_scores.item())

    return CandidateBatch(candidate_scores, candidate_indices)

### Example usages:

## Encore: BALD

BALD is the same as BatchBALD, except that we pick candidates independently from each other.

In [None]:
def get_bald_batch(probs_N_K_C: torch.Tensor,
                   batch_size: int,
                   num_samples: int,
                   dtype=None,
                   device=None) -> CandidateBatch:
    N, K, C = probs_N_K_C.shape
    
    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    scores_N = compute_entropy(probs_N_K_C)
    scores_N -= compute_conditional_entropy(probs_N_K_C)
    
    candiate_scores, candidate_indices = torch.topk(scores_N, batch_size)
    
    return CandidateBatch(candiate_scores.tolist(), candidate_indices.tolist())

### Example usages: