# Training a MNIST Digits Classifier with Active Learning

This notebook provides you with a complete code example that uses active learning to train a neural network capable of classifying the MNIST digits.

## Training a Baseline Model

Load the MNIST digits ...

In [1]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

train_data = datasets.MNIST(root="data", train=True, download=True, 
                            transform=transforms.ToTensor())
test_data = datasets.MNIST(root="data", train=False, download=True, 
                           transform=transforms.ToTensor())

train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=1024, shuffle=False)

... define the classifier neural network ...

In [2]:
import deeplay as dl

backbone = dl.models.BackboneResnet18(in_channels=1, pool_output=True)
head = dl.MultiLayerPerceptron(512, [], 10)
classifier_template = dl.Sequential(backbone, head)

... define, train, and test the classifier with all data ...

In [None]:
import torchmetrics as tm

accuracy = tm.Accuracy(task="multiclass", num_classes=10)

classifier = dl.CategoricalClassifier(
    classifier_template.new(), optimizer=dl.Adam(lr=1e-3), 
    num_classes=10, metrics=[accuracy]
).build()

trainer = dl.Trainer(max_epochs=1)                                     ### trainer = dl.Trainer(max_epochs=30)
trainer.fit(classifier, train_dataloader)

full_results = trainer.test(classifier, test_dataloader)


## Implementing Multiple Active Learning Strategies

### Implementing a Common Configuration for All Samplings

Configure the general parameters ...

In [4]:
trials, budget_per_iteration, max_budget = 2, 2, 4                                  ### 5, 120, 1800
rounds = max_budget // budget_per_iteration - 1  # Number of rounds per trial.

... and implement a function to perform an active training loop.

In [5]:
def active_learning_loop(strategy, epochs):
    """Perform active learning loop."""
    trainer = dl.Trainer(max_epochs=epochs, enable_checkpointing=False,
                         enable_model_summary=False)
    trainer.fit(strategy)

    test_results = trainer.test(strategy, test_dataloader)
    accuracy = test_results[0]["testMulticlassAccuracy"]

    strategy.query_and_update(budget_per_iteration)
    strategy.reset_model()  # Reset the model to the initial state.
    return accuracy

### Uniform Random Sampling

In [6]:
import deeplay.activelearning as al
import numpy as np

uniform_acc = np.empty((trials, rounds))
for trial in range(trials):
    uniform_train_pool = al.ActiveLearningDataset(train_data)
    uniform_train_pool.annotate_random(budget_per_iteration)
    uniform_strategy = al.UniformStrategy(
        classifier_template.new(), train_pool=uniform_train_pool, 
        test=test_data, batch_size=128, test_metrics=[accuracy],
    ).build()

    for round in range(rounds):
        uniform_acc[trial, round] = \
            active_learning_loop(uniform_strategy, epochs=40)

Output()

/Users/giovannivolpe/Documents/GitHub/DeepLearningCrashCourse/py_env_dlcc/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/Users/giovannivolpe/Documents/GitHub/DeepLearningCrashCourse/py_env_dlcc/lib/python3.12/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
/Users/giovannivolpe/Documents/GitHub/DeepLearningCrashCourse/py_env_dlcc/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'trai

Output()

/Users/giovannivolpe/Documents/GitHub/DeepLearningCrashCourse/py_env_dlcc/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


TypeError: cannot pickle 'generator' object

### Uncertainty Sampling

In [None]:
uncertainty_acc = np.empty((trials, rounds))
for trial in range(trials):
    margin_train_pool = al.ActiveLearningDataset(train_data)
    margin_train_pool.annotate_random(budget_per_iteration)
    margin_strategy = al.UncertaintyStrategy(
        classifier_template.new(), train_pool=margin_train_pool,
        criterion=al.Margin(), batch_size=128, test_metrics=[accuracy],
    ).build()

    for round in range(rounds):
        uncertainty_acc[trial, round] = \
            active_learning_loop(margin_strategy, epochs=40)

### Adversarial Sampling

In [7]:
import torch

adversarial_acc = np.empty((trials, rounds))
for trial in range(trials):
    discriminator = dl.MultiLayerPerceptron(512, [512, 512], 1,
                                            out_activation=torch.nn.Sigmoid())
    discriminator.initialize(dl.initializers.Kaiming())

    adversarial_train_pool = al.ActiveLearningDataset(train_data)
    adversarial_train_pool.annotate_random(budget_per_iteration)
    adversarial_strategy = al.AdversarialStrategy(
        backbone=backbone.new(), classification_head=head.new(),
        discriminator_head=discriminator.new(),
        train_pool=adversarial_train_pool, criterion=al.Margin(), 
        batch_size=128, test_metrics=[accuracy],
    ).build()

    for round in range(rounds):
        adversarial_acc[trial, round] = \
            active_learning_loop(adversarial_strategy, epochs=5)

/Users/giovannivolpe/Documents/GitHub/DeepLearningCrashCourse/py_env_dlcc/lib/python3.12/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.


Output()

/Users/giovannivolpe/Documents/GitHub/DeepLearningCrashCourse/py_env_dlcc/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


Output()

/Users/giovannivolpe/Documents/GitHub/DeepLearningCrashCourse/py_env_dlcc/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


: 

## Comparing the Performance of the Active Learning Strategies

In [None]:
import matplotlib.pyplot as plt

x = np.arange(budget_per_iteration, max_budget, budget_per_iteration)

plt.plot(x, np.median(uniform_acc, 0), label="Uniform", linestyle="--")
plt.plot(x, np.median(uncertainty_acc, 0), label="Uncertainty", linestyle="-.")
plt.plot(x, np.median(adversarial_acc, 0), label="Adversarial", linestyle="-")
plt.axhline(full_results[0]["testMulticlassAccuracy_epoch"], 
            label="Full Test Accuracy", color="black", linestyle=":")
plt.xlabel("Number of Annotated Samples"), plt.ylabel("Test Accuracy")
plt.yticks([0.9, 0.95, full_results[0]["testMulticlassAccuracy_epoch"]])
plt.ylim(0.9, 1), plt.legend(), plt.plot();