In [None]:


# Let's start with a bunch of imports.

import random
from copy import deepcopy
from dataclasses import dataclass

import numpy as np
import torch
import torch.backends
import torch.utils.data as torchdata
from torch import optim
from torch.hub import load_state_dict_from_url
from torch.nn import CrossEntropyLoss
from torchvision.transforms import transforms
from tqdm.autonotebook import tqdm

from baal.active import get_heuristic, ActiveLearningDataset
from baal.active.active_loop import ActiveLearningLoop
from baal.bayesian.dropout import patch_module
from baal.modelwrapper import ModelWrapper

In [None]:
@dataclass
class ExperimentConfig:
    epoch: int = 5
    batch_size: int = 2
    initial_pool: int = 10
    query_size: int = 1
    lr: float = 9e-4
    heuristic: str = 'bald'
    iterations: int = 5
    training_duration: int = 5

In [None]:
from src.dataset.dataset import Dataset, collate_fn_padd
from src.dataset.utils import train_test_validation_split

def get_datasets(initial_pool):
    dataset = Dataset(path_to_npy_data="data/NPY/volumes/", path_to_npy_targets="data/NPY/labels/")

    train, test, valid = train_test_validation_split(dataset=dataset)
    # In a real application, you will want a validation set here.

    # Here we set `pool_specifics`, where we set the transform attribute for the pool.
    active_set = ActiveLearningDataset(train)

    # We start labeling randomly.
    active_set.label_randomly(initial_pool)
    return active_set, test

In [None]:
%cd ..

In [None]:
from src.models.conv_net import ConvNN

hyperparams = ExperimentConfig()
use_cuda = torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
random.seed(1337)
torch.manual_seed(1337)
if not use_cuda:
    print("warning, the experiments would take ages to run on cpu")

# Get datasets
active_set, test_set = get_datasets(hyperparams.initial_pool)

# Get our model.
heuristic = get_heuristic(hyperparams.heuristic)
criterion = CrossEntropyLoss()
model = ConvNN()

# change dropout layer to MCDropout
model = patch_module(model)

if use_cuda:
    model.cuda()
optimizer = optim.SGD(model.parameters(), lr=hyperparams.lr, momentum=0.9)

# Wraps the model into a usable API.
model = ModelWrapper(model, criterion, replicate_in_memory=False)

# for ActiveLearningLoop we use a smaller batchsize
# since we will stack predictions to perform MCDropout.
active_loop = ActiveLearningLoop(active_set,
                                 model.predict_on_dataset,
                                 heuristic,
                                 hyperparams.query_size,
                                 batch_size=1,
                                 iterations=hyperparams.iterations,
                                 use_cuda=use_cuda,
                                 verbose=False,
                                 collate_fn=collate_fn_padd)

# We will reset the weights at each active learning step so we make a copy.
init_weights = deepcopy(model.state_dict())


In [None]:


labelling_progress = active_set._labelled.copy().astype(np.uint16)
for epoch in tqdm(range(hyperparams.epoch)):
    # Load the initial weights.
    model.load_state_dict(init_weights)

    # Train the model on the currently labelled dataset.
    _ = model.train_on_dataset(active_set, optimizer=optimizer, batch_size=hyperparams.batch_size,
                               use_cuda=use_cuda, epoch=hyperparams.training_duration, collate_fn=collate_fn_padd)

    # Get test NLL!
    model.test_on_dataset(test_set, hyperparams.batch_size, use_cuda,
                          average_predictions=hyperparams.iterations, collate_fn=collate_fn_padd)
    metrics = model.metrics

    # We can now label the most uncertain samples according to our heuristic.
    should_continue = active_loop.step()
    # Keep track of progress
    labelling_progress += active_set._labelled.astype(np.uint16)
    if not should_continue:
        break

    test_loss = metrics['test_loss'].value
    logs = {
        "test_nll": test_loss,
        "epoch": epoch,
        "Next Training set size": len(active_set)
    }


In [None]:
model_weight = model.state_dict()
dataset = active_set.state_dict()
torch.save({'model': model_weight, 'dataset': dataset, 'labelling_progress': labelling_progress},
           'checkpoint.pth')
print(model.state_dict().keys(), dataset.keys(), labelling_progress)

In [None]:
model_weight = torch.load("checkpoint.pth")["model"]
active_set = torch.load("checkpoint.pth")["dataset"]
labelling_progress = torch.load("checkpoint.pth")["labelling_progress"]
print(len(active_set))

In [None]:
model = ConvNN()
model.load_state_dict(torch.load("checkpoint.pth")["model"])
model.cuda()

# modify our model to get features
from torch import nn
from torch.utils.data import DataLoader
torch.cuda.empty_cache()

# Make a feature extractor from our trained model.
class FeatureExtractor(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return torch.flatten(self.model(x), 1)


features = FeatureExtractor(model)
acc = []
for x, y in DataLoader(active_set._dataset, batch_size=2, collate_fn=collate_fn_padd):
    acc.append((features(x.cuda()).detach().cpu().numpy(), y.detach().cpu().numpy()))

xs, ys = zip(*acc)


In [None]:


from sklearn.manifold import TSNE

# Compute t-SNE on the extracted features.
tsne = TSNE(n_jobs=4)
transformed = tsne.fit_transform(np.vstack(xs))


In [None]:


labels = np.concatenate(ys)
labels.shape


In [None]:
from baal.utils.plot_utils import make_animation_from_data

# Create frames to animate the process.
frames = make_animation_from_data(transformed, labels, labelling_progress, ["CP", "NCP", "Normal"])


In [None]:
from IPython.display import HTML
import matplotlib.pyplot as plt
from matplotlib import animation


def plot_images(img_list):
    def init():
        img.set_data(img_list[0])
        return (img,)

    def animate(i):
        img.set_data(img_list[i])
        return (img,)

    fig = plt.Figure(figsize=(10, 10))
    ax = fig.gca()
    img = ax.imshow(img_list[0])
    anim = animation.FuncAnimation(fig, animate, init_func=init,
                                   frames=len(img_list), interval=60, blit=True)
    return anim


HTML(plot_images(frames).to_jshtml())