In [None]:
!pip install baal



In [None]:
from baal.active import FileDataset, ActiveLearningDataset

In [None]:
from datasets import load_dataset
datasets = load_dataset("glue", "sst2", cache_dir="/tmp")
raw_train_set = datasets['train']
raw_valid_set = datasets['validation']
al_epochs = 10

Reusing dataset glue (/tmp/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


In [None]:
from baal.active import active_huggingface_dataset,HuggingFaceDatasets
from transformers import BertTokenizer
pretrained_weights = 'bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_weights)
active_set = active_huggingface_dataset(raw_train_set, tokenizer)
valid_set = HuggingFaceDatasets(raw_valid_set, tokenizer)
# lets randomly label 100 samples, therefore len(active_set) should be 100
active_set.label_randomly(100)
assert len(active_set) == 100
print(len(active_set.pool))

67249


In [None]:
from copy import deepcopy
import torch
from transformers import BertForSequenceClassification
from baal.bayesian.dropout import patch_module

use_cuda = torch.cuda.is_available()

model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path=pretrained_weights)
model = patch_module(model)
if use_cuda:
    model.cuda()
init_weights = deepcopy(model.state_dict())

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [None]:
from baal.active import get_heuristic

heuristic = get_heuristic('certainty')



In [None]:
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import accuracy_score
import numpy as np 

def compute_metrics(p):
    label = p.label_ids
    preds = np.argmax(p.predictions, axis=-1)
    return {
        'accuracy': accuracy_score(label, preds),
        'f1': f1_score(label,preds,labels=np.unique(preds)),
        'precision_score' : precision_score(label,preds,labels=np.unique(preds)),
        'recall_score' : recall_score(label, preds,labels=np.unique(preds))
    }

In [None]:
from transformers import TrainingArguments
from baal import BaalTransformersTrainer
from baal.active.active_loop import ActiveLearningLoop

#Initialization for the huggingface trainer
training_args = TrainingArguments(
    output_dir='.',  # output directory
    num_train_epochs=1,  # total # of training epochs per AL step
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,  # batch size for evaluation
    weight_decay=0.01,  # strength of weight decay
    logging_dir='.',  # directory for storing logs
    )

# create the trainer through Baal Wrapper
baal_trainer = BaalTransformersTrainer(model=model,
                                       args=training_args,
                                       train_dataset=active_set,
                                       eval_dataset=valid_set,
                                       compute_metrics=compute_metrics,

                                       tokenizer=None)


active_loop = ActiveLearningLoop(active_set,
                                 baal_trainer.predict_on_dataset,
                                 heuristic, 20, iterations=3)

for epoch in range(al_epochs):
    baal_trainer.train()
    eval_metrics = baal_trainer.evaluate()
    should_continue = active_loop.step()
    # We reset the model weights to relearn from the new trainset.
    baal_trainer.load_state_dict(init_weights)
    if not should_continue:
          break
    active_logs = {"epoch": epoch,"labeled_data": active_set._labelled,"Next Training set size": len(active_set)}
    logs = {**eval_metrics, **active_logs}
    print(logs)

# at each Active step we add 1o samples to labelled data. At this point we should have 30 samples added
# to the labelled part of trainingset.
print(len(active_set))

Step,Training Loss


[73-MainThread   ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-05-12T17:58:39.220251Z [[32minfo     ] Start Predict                  dataset=67249


100%|██████████| 1051/1051 [12:05<00:00,  1.45it/s]


{'eval_loss': 0.6783511638641357, 'eval_accuracy': 0.5756880733944955, 'eval_f1': 0.6605504587155964, 'eval_precision_score': 0.5572755417956656, 'eval_recall_score': 0.8108108108108109, 'eval_runtime': 3.2659, 'eval_samples_per_second': 266.998, 'epoch': 0, 'eval_mem_cpu_alloc_delta': 0, 'eval_mem_gpu_alloc_delta': 0, 'eval_mem_cpu_peaked_delta': 0, 'eval_mem_gpu_peaked_delta': 479175680, 'labeled_data': array([False, False, False, ..., False, False, False]), 'Next Training set size': 120}


Step,Training Loss


[73-MainThread   ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-05-12T18:10:50.960538Z [[32minfo     ] Start Predict                  dataset=67229


100%|██████████| 1051/1051 [12:05<00:00,  1.45it/s]


{'eval_loss': 0.7095933556556702, 'eval_accuracy': 0.46674311926605505, 'eval_f1': 0.5730027548209367, 'eval_precision_score': 0.48372093023255813, 'eval_recall_score': 0.7027027027027027, 'eval_runtime': 3.2643, 'eval_samples_per_second': 267.132, 'epoch': 1, 'eval_mem_cpu_alloc_delta': 0, 'eval_mem_gpu_alloc_delta': -17408, 'eval_mem_cpu_peaked_delta': 0, 'eval_mem_gpu_peaked_delta': 277857792, 'labeled_data': array([False, False, False, ..., False, False, False]), 'Next Training set size': 140}


Step,Training Loss


[73-MainThread   ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-05-12T18:23:02.919319Z [[32minfo     ] Start Predict                  dataset=67209


100%|██████████| 1051/1051 [12:05<00:00,  1.45it/s]


{'eval_loss': 0.7086722254753113, 'eval_accuracy': 0.4690366972477064, 'eval_f1': 0.5724838411819021, 'eval_precision_score': 0.48513302034428796, 'eval_recall_score': 0.6981981981981982, 'eval_runtime': 3.2633, 'eval_samples_per_second': 267.212, 'epoch': 2, 'eval_mem_cpu_alloc_delta': 0, 'eval_mem_gpu_alloc_delta': -25600, 'eval_mem_cpu_peaked_delta': 0, 'eval_mem_gpu_peaked_delta': 277857792, 'labeled_data': array([False, False, False, ..., False, False, False]), 'Next Training set size': 160}


Step,Training Loss


[73-MainThread   ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-05-12T18:35:14.943013Z [[32minfo     ] Start Predict                  dataset=67189


100%|██████████| 1050/1050 [12:04<00:00,  1.45it/s]


{'eval_loss': 0.7139497995376587, 'eval_accuracy': 0.4461009174311927, 'eval_f1': 0.5621033544877607, 'eval_precision_score': 0.47040971168437024, 'eval_recall_score': 0.6981981981981982, 'eval_runtime': 3.2668, 'eval_samples_per_second': 266.924, 'epoch': 3, 'eval_mem_cpu_alloc_delta': 0, 'eval_mem_gpu_alloc_delta': 0, 'eval_mem_cpu_peaked_delta': 0, 'eval_mem_gpu_peaked_delta': 479184384, 'labeled_data': array([False, False, False, ..., False, False, False]), 'Next Training set size': 180}


Step,Training Loss


[73-MainThread   ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-05-12T18:47:27.188743Z [[32minfo     ] Start Predict                  dataset=67169


100%|██████████| 1050/1050 [12:04<00:00,  1.45it/s]


{'eval_loss': 0.7048604488372803, 'eval_accuracy': 0.48853211009174313, 'eval_f1': 0.5878003696857671, 'eval_precision_score': 0.49843260188087773, 'eval_recall_score': 0.7162162162162162, 'eval_runtime': 3.2672, 'eval_samples_per_second': 266.894, 'epoch': 4, 'eval_mem_cpu_alloc_delta': 0, 'eval_mem_gpu_alloc_delta': -9216, 'eval_mem_cpu_peaked_delta': 0, 'eval_mem_gpu_peaked_delta': 277857792, 'labeled_data': array([False, False, False, ..., False, False, False]), 'Next Training set size': 200}


Step,Training Loss


[73-MainThread   ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-05-12T18:59:39.537234Z [[32minfo     ] Start Predict                  dataset=67149


100%|██████████| 1050/1050 [12:04<00:00,  1.45it/s]


{'eval_loss': 0.706025242805481, 'eval_accuracy': 0.49770642201834864, 'eval_f1': 0.6032608695652175, 'eval_precision_score': 0.5045454545454545, 'eval_recall_score': 0.75, 'eval_runtime': 3.2658, 'eval_samples_per_second': 267.009, 'epoch': 5, 'eval_mem_cpu_alloc_delta': 0, 'eval_mem_gpu_alloc_delta': 0, 'eval_mem_cpu_peaked_delta': 0, 'eval_mem_gpu_peaked_delta': 277857792, 'labeled_data': array([False, False, False, ..., False, False, False]), 'Next Training set size': 220}


Step,Training Loss


[73-MainThread   ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-05-12T19:11:52.043133Z [[32minfo     ] Start Predict                  dataset=67129


100%|██████████| 1049/1049 [12:04<00:00,  1.45it/s]


{'eval_loss': 0.7072196006774902, 'eval_accuracy': 0.47706422018348627, 'eval_f1': 0.5738317757009345, 'eval_precision_score': 0.4904153354632588, 'eval_recall_score': 0.6914414414414415, 'eval_runtime': 3.2685, 'eval_samples_per_second': 266.791, 'epoch': 6, 'eval_mem_cpu_alloc_delta': 0, 'eval_mem_gpu_alloc_delta': -25600, 'eval_mem_cpu_peaked_delta': 0, 'eval_mem_gpu_peaked_delta': 277857792, 'labeled_data': array([False, False, False, ..., False, False, False]), 'Next Training set size': 240}


Step,Training Loss


[73-MainThread   ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-05-12T19:24:04.638878Z [[32minfo     ] Start Predict                  dataset=67109


100%|██████████| 1049/1049 [12:04<00:00,  1.45it/s]


{'eval_loss': 0.7065767049789429, 'eval_accuracy': 0.48509174311926606, 'eval_f1': 0.5838739573679333, 'eval_precision_score': 0.49606299212598426, 'eval_recall_score': 0.7094594594594594, 'eval_runtime': 3.2639, 'eval_samples_per_second': 267.168, 'epoch': 7, 'eval_mem_cpu_alloc_delta': 0, 'eval_mem_gpu_alloc_delta': 0, 'eval_mem_cpu_peaked_delta': 0, 'eval_mem_gpu_peaked_delta': 479175680, 'labeled_data': array([False, False, False, ..., False, False, False]), 'Next Training set size': 260}


Step,Training Loss


[73-MainThread   ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-05-12T19:36:17.435101Z [[32minfo     ] Start Predict                  dataset=67089


100%|██████████| 1049/1049 [12:03<00:00,  1.45it/s]


{'eval_loss': 0.7076452374458313, 'eval_accuracy': 0.4724770642201835, 'eval_f1': 0.5848375451263538, 'eval_precision_score': 0.4879518072289157, 'eval_recall_score': 0.7297297297297297, 'eval_runtime': 3.2654, 'eval_samples_per_second': 267.043, 'epoch': 8, 'eval_mem_cpu_alloc_delta': 0, 'eval_mem_gpu_alloc_delta': 0, 'eval_mem_cpu_peaked_delta': 0, 'eval_mem_gpu_peaked_delta': 277857792, 'labeled_data': array([False, False, False, ..., False, False, False]), 'Next Training set size': 280}


Step,Training Loss


[73-MainThread   ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-05-12T19:48:30.257114Z [[32minfo     ] Start Predict                  dataset=67069


100%|██████████| 1048/1048 [12:03<00:00,  1.45it/s]

{'eval_loss': 0.7080803513526917, 'eval_accuracy': 0.4701834862385321, 'eval_f1': 0.5769230769230769, 'eval_precision_score': 0.4861111111111111, 'eval_recall_score': 0.7094594594594594, 'eval_runtime': 3.2647, 'eval_samples_per_second': 267.102, 'epoch': 9, 'eval_mem_cpu_alloc_delta': 0, 'eval_mem_gpu_alloc_delta': 0, 'eval_mem_cpu_peaked_delta': 0, 'eval_mem_gpu_peaked_delta': 277857792, 'labeled_data': array([False, False, False, ..., False, False, False]), 'Next Training set size': 300}
300



