In [None]:
import torch

import numpy as np
from main.models import ConvNet
from main.utils import save_experiment
from dataclasses import dataclass
from main.active_learning import run_active_learning
from main.prepare_data import create_repeated_MNIST_dataloaders, create_MNIST_dataloaders

%reload_ext autoreload
%autoreload 2

In [None]:
# set configurations for redundant MNIST experiment
@dataclass
class ActiveLearningConfigRedundant:
    subset_of_weights: str = 'last_layer'
    hessian_structure: str = 'kron'
    backend: str = 'AsdlGGN'
    temperature: float = 1.0
    max_training_samples: int = 100
    acquisition_batch_size: int = 5
    al_method: str = 'max_logdet_S'
    test_batch_size: int = 512
    num_classes: int = 10
    num_initial_samples: int = 40
    training_iterations: int = 4096 * 6
    scoring_batch_size: int = 64
    train_batch_size: int = 64
    extract_pool: int = 0
    num_repeats: int = 10
    samples_per_digit: int = 50

# set configurations
@dataclass
class ActiveLearningConfigMNIST:
    subset_of_weights: str = 'last_layer'
    hessian_structure: str = 'kron'
    backend: str = 'AsdlGGN'
    temperature: float = 1
    max_training_samples: int = 100
    acquisition_batch_size: int = 5
    al_method: str = 'max_logdet_S'
    test_batch_size: int = 512
    num_classes: int = 10
    num_initial_samples: int = 40
    training_iterations: int = 4096 * 6
    scoring_batch_size: int = 64
    train_batch_size: int = 64
    extract_pool: int = 55000  # number of samples to extract from the dataset (bit of a hack)


experiment_name = 'mnist_logdet_hardlabel'  # provide descriptive name for the experiment

config = ActiveLearningConfigMNIST()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

save_results = True
num_runs = 5

In [None]:
for i in range(num_runs):
    # load data
    train_loader, test_loader, pool_loader, active_learning_data = create_MNIST_dataloaders(config)
    
    # get results
    results = run_active_learning(
        train_loader=train_loader,
        test_loader=test_loader, 
        pool_loader=pool_loader,
        active_learning_data=active_learning_data,
        model_constructor=ConvNet, 
        config=config, 
        device=device
        )

    # save results and configuration
    if save_results:
        experiment_id = experiment_name + '_' + str(i + 1)
        save_experiment(config, results, experiment_id)