In [None]:
# default_exp active_learning

In [None]:
#hide
import blackhc.project.script
from nbdev.showdoc import *

Appended /home/blackhc/PycharmProjects/blackhc.batchbald/src to paths
Switched to directory /home/blackhc/PycharmProjects/blackhc.batchbald
%load_ext autoreload
%autoreload 2


# Active learning
> Everything needed for active Learning

## Active Learning Data

For active learning, we need to split the available training data between a training set and a pool set of (unlabelled) data, which we score using our model and acquisition function and add to the training set peu a peu.

In [None]:
# exports
from typing import Dict, List
import numpy as np
import torch.utils.data as data
import torch
import collections


In [None]:
# exports

class ActiveLearningData:
    """Splits `dataset` into an active dataset and an available dataset."""
    active_dataset: data.Dataset
    available_dataset: data.Dataset

    def __init__(self, dataset: data.Dataset):
        super().__init__()
        self.dataset = dataset
        self.active_mask = np.full((len(dataset),), False)
        self.available_mask = np.full((len(dataset),), True)

        self.active_dataset = data.Subset(self.dataset, None)
        self.available_dataset = data.Subset(self.dataset, None)

        self._update_indices()

    def _update_indices(self):
        self.active_dataset.indices = np.nonzero(self.active_mask)[0]
        self.available_dataset.indices = np.nonzero(self.available_mask)[0]

    def get_dataset_indices(self, available_indices: List[int]) -> List[int]:
        indices = self.available_dataset.indices[available_indices]
        return indices

    def acquire(self, available_indices):
        indices = self.get_dataset_indices(available_indices)

        self.active_mask[indices] = True
        self.available_mask[indices] = False
        self._update_indices()

    def make_unavailable(self, available_indices):
        indices = self.get_dataset_indices(available_indices)

        self.available_mask[indices] = False
        self._update_indices()

    def get_random_available_indices(self, size) -> torch.LongTensor:
        assert 0 <= size <= len(self.available_dataset)
        available_indices = torch.randperm(len(self.available_dataset))[:size]
        return available_indices

    def extract_dataset(self, size) -> data.Dataset:
        """Extract a dataset randomly from the available dataset and make those indices unavailable.
        
        Useful for extracting a validation set."""
        return self.extract_dataset_from_indices(self.get_random_available_indices(size))

    def extract_dataset_from_indices(self, available_indices) -> data.Dataset:
        """Extract a dataset from the available dataset and make those indices unavailable.
        
        Useful for extracting a validation set."""
        dataset_indices = self.get_dataset_indices(available_indices)

        self.make_unavailable(available_indices)
        return data.Subset(self.dataset, dataset_indices)


In [None]:
show_doc(ActiveLearningData.get_dataset_indices)
show_doc(ActiveLearningData.acquire)
show_doc(ActiveLearningData.make_unavailable)
show_doc(ActiveLearningData.get_random_available_indices)
show_doc(ActiveLearningData.extract_dataset)
show_doc(ActiveLearningData.extract_dataset_from_indices)

<h4 id="ActiveLearningData.get_dataset_indices" class="doc_header"><code>ActiveLearningData.get_dataset_indices</code><a href="__main__.py#L23" class="source_link" style="float:right">[source]</a></h4>

> <code>ActiveLearningData.get_dataset_indices</code>(**`available_indices`**:`List`\[`int`\])



<h4 id="ActiveLearningData.acquire" class="doc_header"><code>ActiveLearningData.acquire</code><a href="__main__.py#L27" class="source_link" style="float:right">[source]</a></h4>

> <code>ActiveLearningData.acquire</code>(**`available_indices`**)



<h4 id="ActiveLearningData.make_unavailable" class="doc_header"><code>ActiveLearningData.make_unavailable</code><a href="__main__.py#L34" class="source_link" style="float:right">[source]</a></h4>

> <code>ActiveLearningData.make_unavailable</code>(**`available_indices`**)



<h4 id="ActiveLearningData.get_random_available_indices" class="doc_header"><code>ActiveLearningData.get_random_available_indices</code><a href="__main__.py#L40" class="source_link" style="float:right">[source]</a></h4>

> <code>ActiveLearningData.get_random_available_indices</code>(**`size`**)



<h4 id="ActiveLearningData.extract_dataset" class="doc_header"><code>ActiveLearningData.extract_dataset</code><a href="__main__.py#L45" class="source_link" style="float:right">[source]</a></h4>

> <code>ActiveLearningData.extract_dataset</code>(**`size`**)

Extract a dataset randomly from the available dataset and make those indices unavailable.

Useful for extracting a validation set.

<h4 id="ActiveLearningData.extract_dataset_from_indices" class="doc_header"><code>ActiveLearningData.extract_dataset_from_indices</code><a href="__main__.py#L51" class="source_link" style="float:right">[source]</a></h4>

> <code>ActiveLearningData.extract_dataset_from_indices</code>(**`available_indices`**)

Extract a dataset from the available dataset and make those indices unavailable.

Useful for extracting a validation set.

## Additional helpers

In [None]:
# exports

def get_balanced_sample_indices(target_classes: List, num_classes, n_per_digit=2) -> List[int]:
    permed_indices = torch.randperm(len(target_classes))

    if n_per_digit == 0:
        return []
    
    num_samples_by_class = collections.defaultdict(int)
    initial_samples = []
    
    for i in range(len(permed_indices)):
        permed_index = int(permed_indices[i])
        index, target = permed_index, int(target_classes[permed_index])

        num_target_samples = num_samples_by_class[target]
        if num_target_samples == n_per_digit:
            continue
            
        initial_samples.append(index)        
        num_samples_by_class[target] += 1

        if len(initial_samples) == num_classes * n_per_digit:
            break

    return initial_samples

def get_subset_base_indices(dataset: data.Subset, indices: List[int]):
    return [int(dataset.indices[index]) for index in indices]


def get_base_indices(dataset: data.Dataset, indices: List[int]):
    if isinstance(dataset, data.Subset):
        return get_base_indices(dataset.dataset, get_subset_base_indices(dataset, indices))
    return indices


class RandomFixedLengthSampler(data.Sampler):
    """
    Sometimes, you really want to do more with little data without increasing the number of epochs.

    This sampler takes a `dataset` and draws `target_length` samples from it (with repetition).
    """

    def __init__(self, dataset: data.Dataset, target_length):
        super().__init__(dataset)
        self.dataset = dataset
        self.target_length = target_length

    def __iter__(self):
        # Ensure that we don't lose data by accident.
        if self.target_length < len(self.dataset):
            return iter(range(len(self.dataset)))

        return iter((torch.randperm(self.target_length) % len(self.dataset)).tolist())

    def __len__(self):
        return self.target_length
