In [1]:
import torch

import numpy as np
from main.models import ConvNet
from main.utils import save_experiment
from dataclasses import dataclass
from main.active_learning import run_active_learning
from main.prepare_data import create_repeated_MNIST_dataloaders, create_MNIST_dataloaders

%reload_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
# set configurations for redundant MNIST experiment
@dataclass
class ActiveLearningConfigRedundant:
    subset_of_weights: str = 'last_layer'
    hessian_structure: str = 'kron'
    backend: str = 'AsdlGGN'
    temperature: float = 1.0
    max_training_samples: int = 100
    acquisition_batch_size: int = 5
    al_method: str = 'bald'
    test_batch_size: int = 512
    num_classes: int = 10
    num_initial_samples: int = 40
    training_iterations: int = 4096 * 6
    scoring_batch_size: int = 64
    train_batch_size: int = 64
    extract_pool: int = 0
    num_repeats: int = 10
    samples_per_digit: int = 50

# set configurations
@dataclass
class ActiveLearningConfigMNIST:
    subset_of_weights: str = 'last_layer'
    hessian_structure: str = 'kron'
    backend: str = 'AsdlGGN'
    temperature: float = 1
    max_training_samples: int = 500
    acquisition_batch_size: int = 100
    al_method: str = 'random'
    test_batch_size: int =512
    num_classes: int = 10
    num_initial_samples: int = 50
    training_iterations: int = 4096 * 6
    scoring_batch_size: int = 64
    train_batch_size: int = 64
    extract_pool: int = 55000  # number of samples to extract from the dataset (bit of a hack)


experiment_name = 'random_batch100'  # provide descriptive name for the experiment

config = ActiveLearningConfigMNIST()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

save_results = True
num_runs = 1

Using device: cpu


In [12]:
for i in range(num_runs):
    # load data
    train_loader, test_loader, pool_loader, active_learning_data = create_MNIST_dataloaders(config)
    
    # get results
    results = run_active_learning(
        train_loader=train_loader,
        test_loader=test_loader, 
        pool_loader=pool_loader,
        active_learning_data=active_learning_data,
        model_constructor=ConvNet, 
        config=config, 
        device=device
        )

    # save results and configuration
    if save_results:
        experiment_id = experiment_name + '_' + str(i + 1)
        save_experiment(config, results, experiment_id)


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A2024-07-31 21:06:00,706 - INFO - Training set size: 50, Test set accuracy: 72.86, Test set loss: -0.0099


Dataset indices:  [38649  1233 38751 44978 38122 12116 34369 37413 44653 33994 38486 16998
 29566 35250 54139 38590 53804  7607  7426  3871 18561 46286 49109 21232
 25977 49842 38781  2156  3739 42128 31330 38157  8059 45137 44118 45070
 50384 57634  1434 57407 52557 26042 17920 19958 47786  2165 32123 39908
 41481  6235 37053  3523 23196 13003 32230 33308 27364 21878 32449 31902
 56215 24003 17606 31746 16109  8564 16475 57768 31850  8083 21661 35140
 16812 48084 13482 47923 52246 47516 51399 25023 11063 16714 36700 27872
 10883 34324 10776  8489 43243 55488  6918 41930   346 35881 23003 39028
 41756 30693 53418 34980]
Scores:  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Labels:  tensor([2, 3, 1, 2, 4, 7, 0, 5, 2, 4, 5, 6, 4, 2, 9, 


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A2024-07-31 21:06:34,473 - INFO - Training set size: 150, Test set accuracy: 85.07, Test set loss: -0.0313


Dataset indices:  [59662 53673 46858 20569  6256   293 40996 28268  1160  9023 25919 55051
 18453 13486 20295 28378 18068 27204 23793 15304 10573 52816 27610 49047
 17757 17182 49460 56171 17540 57380 42828 58901 24816 22266 37846 27718
 37084 26779 48449 30585 45179 35856 56026  5733 17244 32377 15846  7241
 28772 37792 34229  5652    88  3277  6458 41435 17663 49602  9889 14123
  1846 47105 21046  9266 11637  2913 45812 57182 36336 15265  3982 16441
  6490 31303 41954 51238 29449 38209 31660  5973 55197 16217 25507  8517
 43690  3755   559 11205 40133 18805 34224 46684 50513 30897 33630 59044
 18775 15327 56121 33881]
Scores:  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Labels:  tensor([5, 4, 7, 5, 4, 0, 3, 2, 4, 7, 0, 3, 3, 2, 7, 


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A2024-07-31 21:07:04,571 - INFO - Training set size: 250, Test set accuracy: 89.63, Test set loss: -0.0312


Dataset indices:  [ 4388 37293  5345 50370 24748 51560  4296 34600 50128 55107 29011 52617
 43389 44058 13749 59578 41431 16453 15082 59628 41467 56539 34127 38404
 41251 25359 47229 40889  4617 58994 16783 53210 11068 27451 25983 40979
 30424 44140  7906 43025 24612   291 49696  2198 56445  1746   780 20465
 28334  8261 40751 30632 25367 22183 54576 36242 12127  3912 50304 30805
  9845 16347  7001 41175 59129 52881 27483 20705 53135  4329 46389 32534
 57172 18781 23628 27321  1628 57011 46689 20718 59232  5257 38953  7930
 59359 23035  8381 30974 17968 58035 24807 56619 13890  7055  1361 16400
 32295 28719 52167 49347]
Scores:  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Labels:  tensor([8, 3, 2, 7, 3, 7, 7, 7, 9, 8, 4, 0, 1, 3, 5, 


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A2024-07-31 21:07:38,685 - INFO - Training set size: 350, Test set accuracy: 91.03, Test set loss: -0.0364


Dataset indices:  [23934 28043 12961 48889 59956 41414 15964 33635 55770  8731  5024 23494
 21884 59608 19877 46897 16047 37735 40514 44923 23423  3794 42998 32386
  9739 11066 21298 18940 59098  9782 18474 47865 14669 35576 50721 25521
 38168 42570 39715 40416 39814 33790 22144  7194  3931 13745 10956 17528
 15861 35652 19220 30577 18361 57166 14902 35080 32400 37338 29727 24471
  1823 42309 53641 28314 12590 45096 50189 46003 33538 12602 26277 46321
 13786 56329 38487  4518 28574 29768 35156  2270 17525  1368 45732 22178
 30810 25469 49934 32319 35661 29998  2089 36772 50924 44602 44586 11018
 21992 20279 11362 31781]
Scores:  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Labels:  tensor([2, 4, 7, 7, 5, 7, 1, 7, 6, 5, 9, 2, 1, 4, 8, 


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A2024-07-31 21:08:09,912 - INFO - Training set size: 450, Test set accuracy: 92.22, Test set loss: -0.0353


Dataset indices:  [ 9395 44166 28197 59438 48812 22578  4655 17117 12048 44405 27581 13734
 25828 46041 58147 40752 15967 28685  8937 27493 44676  3581 39140 12822
 45989 15303 14137 11699 10954 54798 25626  9322 41236 22602 11924  4632
  8604 51321 39566 44281 52570 18037 37629 59587  6808 58293  6754 42930
 19743 18597  3644   633 17228 24691  1037 32041 53324  6074 20528 31069
 10771 44734 57279 52126  1941 11573 28858 42896  5219 17446 41582  1906
  3128 36154 13789 53974 50939 32905 27062 25934  9029 14319 55859 25466
 43458 19666 32681 35418 19818 56077 16511  9542  7220  6349  1600 38674
 28358 17977 33513 38078]
Scores:  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Labels:  tensor([7, 2, 9, 8, 5, 3, 0, 7, 7, 4, 9, 0, 9, 0, 1, 


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A2024-07-31 21:08:37,191 - INFO - Training set size: 550, Test set accuracy: 93.30, Test set loss: -0.0328
Training Set Size: 550it [02:56,  2.83it/s]

Experiment saved in experiments\random_batch100_1



