In [41]:
import pickle
import os
import sys #sys.exit()
import argparse
from pprint import pprint
import random
from copy import deepcopy
import csv
import datetime

import pandas as pd
import numpy as np
import torch
import torch.backends
from torch import optim
from torch.hub import load_state_dict_from_url
from torch.nn import CrossEntropyLoss
from torchvision import datasets
from torchvision.models import vgg16
from torchvision.transforms import transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from baal.active import get_heuristic, ActiveLearningDataset
from baal.active.active_loop import ActiveLearningLoop
from baal.bayesian.dropout import patch_module
from baal.modelwrapper import ModelWrapper
from baal.utils.metrics import Accuracy
from baal.active.heuristics import BALD

import aug_lib

from baal_extended.ExtendedActiveLearningDataset_2 import ExtendedActiveLearningDataset

In [42]:
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", default=3, type=int)
parser.add_argument("--batch_size", default=32, type=int)  # 32
parser.add_argument("--initial_pool", default=10, type=int)   # 1000, we will start training with only 1000(org)+1000(aug)=2000 labeled data samples out of the 50k (org) and
parser.add_argument("--query_size", default=1, type=int)    # request 100(org)+100(aug)=200 new samples to be labeled at every cycle
parser.add_argument("--lr", default=0.001)
parser.add_argument("--heuristic", default="bald", type=str)
parser.add_argument("--iterations", default=2, type=int)     # 20 sampling for MC-Dropout to kick paths with low weights for optimization
parser.add_argument("--shuffle_prop", default=0.05, type=float)
parser.add_argument("--learning_epoch", default=2, type=int) # 20
parser.add_argument("--augment", default=2, type=int)

_StoreAction(option_strings=['--augment'], dest='augment', nargs=None, const=None, default=2, type=<class 'int'>, choices=None, help=None, metavar=None)

In [43]:
def get_datasets(initial_pool, n_augmentations):
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(3 * [0.5], 3 * [0.5]),
        ]
    )
    aug_transform = transforms.Compose(
        [
            aug_lib.TrivialAugment(),
            transforms.ToTensor(),
            transforms.Normalize(3 * [0.5], 3 * [0.5]),
        ]
    )
    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(3 * [0.5], 3 * [0.5]),
        ]
    )
    # Note: We use the test set here as an example. You should make your own validation set.
    train_ds = datasets.CIFAR10(
        ".", train=True, transform=transform, target_transform=None, download=True
    )
    aug_train_ds = datasets.CIFAR10(
        ".", train=True, transform=aug_transform, target_transform=None, download=True
    )
    test_set = datasets.CIFAR10(
        ".", train=False, transform=test_transform, target_transform=None, download=True
    )

    #active_set = ActiveLearningDataset(train_ds, pool_specifics={"transform": test_transform})
    eald_set = ExtendedActiveLearningDataset(train_ds)
    eald_set.augment_n_times(n_augmentations, augmented_dataset=aug_train_ds)

    # We start labeling randomly.
    eald_set.label_randomly(initial_pool)
    return eald_set, test_set

In [44]:
def generate_pickle_file(dt_string, active_set, epoch, oracle_indices, uncertainty):    
    pickle_filename = dt_string + (
        f"_uncertainty_epoch={epoch}" f"_labelled={len(active_set)}.pkl"
    )
    dir_path = os.path.join(os.getcwd(), "uncertainties")
    isExist = os.path.exists("uncertainties")
    if not isExist:
        os.makedirs(dir_path)
    pickle_file_path = os.path.join(dir_path, pickle_filename)
    print("Saving file " + pickle_file_path)
    pickle.dump(
        {
            "oracle_indices": oracle_indices,
            "uncertainty": uncertainty,
            "labelled_map": active_set.labelled_map,
        },
        open(pickle_file_path, "wb")
    )
    return dir_path, pickle_file_path

In [45]:
def generate_excel_file(augment, dt_string, active_set, epoch, pickle_dir_path, df_lab_img): 
    excel_filename = dt_string + (
        f"_uncertainty_epoch={epoch}" f"_labelled={len(active_set)}.xlsx"
    )
    excel_path = os.path.join(pickle_dir_path, excel_filename)

    uncertainties_std = df_lab_img.transpose()
    if augment == 1:
        uncertainties_std.columns = ['original', 'aug1', 'std']
    if augment == 2:   
        uncertainties_std.columns = ['original', 'aug1', 'aug2', 'std']

    uncertainties_std.to_excel(excel_path)

In [46]:
args, unknown = parser.parse_known_args()
use_cuda = torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
random.seed(1337)
torch.manual_seed(1337)
if not use_cuda:
    print("warning, the experiments would take ages to run on cpu")

now = datetime.datetime.now()
dt_string = now.strftime("%d_%m_%Y_%Hx%M")
csv_filename = "uncertainties/metrics_cifarnet_" + dt_string + "_.csv"
with open(csv_filename, "w+", newline="") as out_file:
    csvwriter = csv.writer(out_file)
    csvwriter.writerow(
    (
        "epoch",
        "test_acc",
        "train_acc",
        "test_loss",
        "train_loss",
        "Next training size",
        "amount original images labelled",
        "amount augmented images labelled"
    )
    )

hyperparams = vars(args)

active_set, test_set = get_datasets(hyperparams["initial_pool"], hyperparams["augment"])

heuristic = get_heuristic(hyperparams["heuristic"], hyperparams["shuffle_prop"])
criterion = CrossEntropyLoss()
model = vgg16(num_classes=10)

# change dropout layer to MCDropout
model = patch_module(model)

if use_cuda:
    model.cuda()
else: 
    print("WARNING! NO CUDA IN USE!")
optimizer = optim.SGD(model.parameters(), lr=hyperparams["lr"], momentum=0.9)

# Wraps the model into a usable API.
model = ModelWrapper(model, criterion, replicate_in_memory=False)
model.add_metric(name='accuracy', initializer=lambda : Accuracy())

logs = {}
logs["epoch"] = 0

# for prediction we use a smaller batchsize
# since it is slower
active_loop = ActiveLearningLoop(
    active_set,
    model.predict_on_dataset,
    heuristic,
    hyperparams.get("query_size", 1),
    batch_size=10,
    iterations=hyperparams["iterations"],
    use_cuda=use_cuda,
)
# We will reset the weights at each active learning step.
init_weights = deepcopy(model.state_dict())

layout = {
    "Loss/Accuracy": {
        "Loss": ["Multiline", ["loss/train", "loss/test"]],
        "Accuracy": ["Multiline", ["accuracy/train", "accuracy/test"]],
    },
}

writer = SummaryWriter("vgg_mcdropout_cifar10_org+aug_3")    # baal-serhiy/experiments/vgg_mcdropout_cifar10_org+aug_3
writer.add_custom_scalars(layout)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
orig len50000
augmented n times0


In [48]:
for epoch in tqdm(range(args.epoch)):
    # if we are in the last round we want to train for longer epochs to get a more comparable result
    # if epoch == args.epoch:
    #     hyperparams["learning_epoch"] = 75
    # Load the initial weights.
    model.load_state_dict(init_weights)
    model.train_on_dataset(
        active_set,
        optimizer,
        hyperparams["batch_size"],
        hyperparams["learning_epoch"],
        use_cuda,
    )

    # Validation!
    model.test_on_dataset(test_set, hyperparams["batch_size"], use_cuda)
    metrics = model.metrics

    # get origin amount of labelled augmented/unaugmented images
    if(epoch == 0):
        with open(csv_filename, "a+", newline="") as out_file:
            csvwriter = csv.writer(out_file)
            csvwriter.writerow(
                (
                -1,
                0,
                0,
                0,
                0,
                active_set.n_labelled,
                active_set.n_unaugmented_images_labelled,
                active_set.n_augmented_images_labelled
                )
            )

    # replacement for step
    #pool = active_set._dataset  # len(active_set._dataset) 100000
    pool = active_set.pool # Returns a new Dataset made from unlabelled samples 
    print("1 pool length ="+str(len(pool))) # 149970

    # orig_s2 = int((len(pool)/3))
    # aug1_s1 = int(len(pool)/3)
    # aug1_s2 = int((len(pool)/3)*2)
    # aug2_s1 = int((len(pool)/3)*2)
    # aug2_s2 = int(len(pool))


    # original = uncertainty[0:orig_s2]
    # aug1 = uncertainty[aug1_s1:aug1_s2]
    # aug2 = uncertainty[aug2_s1:aug2_s2]

    # print("3 original length "+str(len(original)))
    # print("4 aug1 length "+str(len(aug1))) # 
    # print("5 aug2 length "+str(len(aug2))) # 49990

    #sys.exit()

    if len(pool) > 0:
        probs = model.predict_on_dataset(
            pool,
            batch_size=hyperparams["batch_size"],
            iterations=hyperparams["iterations"],
            use_cuda=use_cuda,
        )

        #if probs is not None and (isinstance(probs, types.GeneratorType) or len(probs) > 0):
        # -> "isinstance(...) needed when using predict_..._Generator"
        if probs is not None and len(probs) > 0:
            # 1. Get uncertainty
            uncertainty = active_loop.heuristic.get_uncertainties(probs)
            oracle_indices = np.argsort(uncertainty)
            active_set.labelled_map

            print("2 oracle_indices length "+str(len(oracle_indices))) # 149970

            pickle_dir_path, pickle_file_path = generate_pickle_file(dt_string, active_set, epoch, oracle_indices, uncertainty)

            mypickle = pd.read_pickle(pickle_file_path)

            uncertainty = mypickle['uncertainty']
            oracle_indices = mypickle['oracle_indices']
            labelled_map = mypickle['labelled_map']

            orig_s2 = int((len(pool)/3)-1)
            aug1_s1 = int(len(pool)/3)
            aug1_s2 = int((len(pool)/3)*2-1)
            aug2_s1 = int((len(pool)/3)*2)
            aug2_s2 = int(len(pool)-1)

            original = uncertainty[0:orig_s2]
            aug1 = uncertainty[aug1_s1:aug1_s2]
            aug2 = uncertainty[aug2_s1:aug2_s2]

            print("3 original length "+str(len(original)))
            print("4 aug1 length "+str(len(aug1)))
            print("5 aug2 length "+str(len(aug2)))

            if hyperparams["augment"] == 1:
                matrix = np.vstack([original, aug1])
            if hyperparams["augment"] == 2:   
                matrix = np.vstack([original, aug1, aug2])

            # 2. Calc standard deviation
            df_lab_img = pd.DataFrame(matrix)
            df_lab_img.std() # here
            df_lab_img = pd.DataFrame(np.vstack([matrix, df_lab_img.std()]))

            uncertainties_std = df_lab_img.transpose()
            if hyperparams["augment"] == 1:
                uncertainties_std.columns = ['original', 'aug1', 'std']
            if hyperparams["augment"] == 2:   
                uncertainties_std.columns = ['original', 'aug1', 'aug2', 'std']

            generate_excel_file(hyperparams["augment"], dt_string, active_set, epoch, pickle_dir_path, df_lab_img)
            
            # 3. Map std uncertainties to uncertainty array
            std_array = df_lab_img.std()
            for i in range(len(uncertainty)): # 150000
                uncertainty[i] = std_array[i % (len(pool)/3-1)]
            oracle_indices = np.argsort(uncertainty)
            print("6 oracle_indices length "+str(len(oracle_indices)))
            active_set.labelled_map
            # to_label -> indices sortiert von größter zu niedrigster uncertainty
            # uncertainty -> alle std uncertainties des pools
            to_label = heuristic.reorder_indices(uncertainty)
            print("7 to_label length "+str(len(to_label)))
            to_label = oracle_indices[np.array(to_label)] # len(to_label) = 150000
            print("8 to_label length "+str(len(to_label)))
            if len(to_label) > 0:
                active_set.label(to_label[: hyperparams.get("query_size", 1)])
            else: break
        else:
            break
    else: 
        break

  0%|          | 0/3 [00:00<?, ?it/s]

[7372-MainThread ] [baal.modelwrapper:train_on_dataset:83] 2022-12-18T14:26:29.002084Z [info     ] Starting training              dataset=33 epoch=2
[7372-MainThread ] [baal.modelwrapper:train_on_dataset:94] 2022-12-18T14:26:54.559674Z [info     ] Training complete              train_loss=2.2929248809814453
[7372-MainThread ] [baal.modelwrapper:test_on_dataset:123] 2022-12-18T14:26:54.565647Z [info     ] Starting evaluating            dataset=10000
[7372-MainThread ] [baal.modelwrapper:test_on_dataset:133] 2022-12-18T14:27:01.642135Z [info     ] Evaluation complete            test_loss=2.302551507949829
1 pool length =149967
[7372-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:232] 2022-12-18T14:27:01.771165Z [info     ] Start Predict                  dataset=149967

  0%|          | 0/4687 [00:00<?, ?it/s]
  0%|          | 1/4687 [00:11<15:08:12, 11.63s/it]
  0%|          | 8/4687 [00:11<1:23:35,  1.07s/it] 
  0%|          | 15/4687 [00:11<37:00,  2.10it/s] 
  0%|       

 33%|███▎      | 1/3 [02:07<04:14, 127.37s/it]

[7372-MainThread ] [baal.modelwrapper:train_on_dataset:83] 2022-12-18T14:28:36.374063Z [info     ] Starting training              dataset=36 epoch=2
[7372-MainThread ] [baal.modelwrapper:train_on_dataset:94] 2022-12-18T14:29:02.197153Z [info     ] Training complete              train_loss=2.2934036254882812
[7372-MainThread ] [baal.modelwrapper:test_on_dataset:123] 2022-12-18T14:29:02.203148Z [info     ] Starting evaluating            dataset=10000
[7372-MainThread ] [baal.modelwrapper:test_on_dataset:133] 2022-12-18T14:29:09.301525Z [info     ] Evaluation complete            test_loss=2.3022549152374268
1 pool length =149964
[7372-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:232] 2022-12-18T14:29:09.446527Z [info     ] Start Predict                  dataset=149964

  0%|          | 0/4687 [00:00<?, ?it/s]
  0%|          | 1/4687 [00:11<15:19:24, 11.77s/it]
  0%|          | 8/4687 [00:11<1:24:36,  1.08s/it] 
  0%|          | 15/4687 [00:11<37:27,  2.08it/s] 
  0%|      

 67%|██████▋   | 2/3 [04:14<02:07, 127.34s/it]

[7372-MainThread ] [baal.modelwrapper:train_on_dataset:83] 2022-12-18T14:30:43.696410Z [info     ] Starting training              dataset=39 epoch=2
[7372-MainThread ] [baal.modelwrapper:train_on_dataset:94] 2022-12-18T14:31:10.032914Z [info     ] Training complete              train_loss=2.291487216949463
[7372-MainThread ] [baal.modelwrapper:test_on_dataset:123] 2022-12-18T14:31:10.040917Z [info     ] Starting evaluating            dataset=10000
[7372-MainThread ] [baal.modelwrapper:test_on_dataset:133] 2022-12-18T14:31:17.090607Z [info     ] Evaluation complete            test_loss=2.302438259124756
1 pool length =149961
[7372-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:232] 2022-12-18T14:31:17.222653Z [info     ] Start Predict                  dataset=149961

  0%|          | 0/4687 [00:00<?, ?it/s]
  0%|          | 1/4687 [00:11<15:23:04, 11.82s/it]
  0%|          | 8/4687 [00:11<1:24:56,  1.09s/it] 
  0%|          | 15/4687 [00:12<37:36,  2.07it/s] 
  0%|        

100%|██████████| 3/3 [06:22<00:00, 127.49s/it]


In [None]:
len(aug1)

In [None]:
len(uncertainty)

In [None]:
len(to_label)

In [None]:
len(uncertainty)

149970

In [49]:
to_label # array[99999, 34555, ...]

array([ 67213, 109634,  79077, ..., 109828, 110968,  43832], dtype=int64)

In [50]:
pool = active_set.pool
len(pool)

149958

In [51]:
to_label = heuristic.reorder_indices(uncertainty)
to_label

array([122215,  72229,  22243, ...,  73562,  23576, 123548], dtype=int64)

In [52]:
to_label = oracle_indices[np.array(to_label)]
to_label

array([ 67213, 109634,  79077, ..., 109828, 110968,  43832], dtype=int64)

In [55]:
if (hyperparams["augment"] != 1) and (hyperparams["augment"] != 2):
  print("WARNING! Supporting only augmentation 1 and 2, for more write more code!")
  sys.exit()