<a href="https://colab.research.google.com/github/AnzorGozalishvili/active_learning_playground/blob/main/notebooks/active_learning_experiments_on_sms_spam_classification_problem_using_baal_library.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers

In [2]:
!pip install -U datasets
!pip install -U baal

# Load `sms spam` dataset from huggingface datasets

In [1]:
import datasets

In [2]:
sms_spam_dataset = datasets.load_dataset('sms_spam', )

Reusing dataset sms_spam (/root/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c)


  0%|          | 0/1 [00:00<?, ?it/s]

In [3]:
sms_spam_dataset.shape

{'train': (5574, 2)}

## Let's split our dataset into Train/Test splits (since it only contains train set)

Let's have around 500 samples in test set since the overal size is only around 5500.

In [4]:
RANDOM_SEED = 42

In [5]:
splitted_sms_spam_dataset = sms_spam_dataset['train'].train_test_split(test_size=500, shuffle=True, seed=RANDOM_SEED)

Loading cached split indices for dataset at /root/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c/cache-49fa1f1338b1121b.arrow and /root/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c/cache-c1d02dfc7f23e5bf.arrow


In [6]:
splitted_sms_spam_dataset.shape

{'test': (500, 2), 'train': (5074, 2)}

In [7]:
train_ds, test_ds = splitted_sms_spam_dataset['train'], splitted_sms_spam_dataset['test']

In [8]:
train_ds.shape, test_ds.shape

((5074, 2), (500, 2))

# Load small pretrained Language Model from huggingface transformers library

In [9]:
import transformers

In [10]:
model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels=2)

Some weights of the model checkpoint at distilbert-base-cased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier

In [11]:
model

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
       

In [12]:
tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-cased")

In [13]:
tokenizer('sample sentence')

{'input_ids': [101, 6876, 5650, 102], 'attention_mask': [1, 1, 1, 1]}

In [14]:
tokenizer.decode(tokenizer('sample sentence')['input_ids'])

'[CLS] sample sentence [SEP]'

# Convert Huggingface Dataset into ActiveLearningDataset

In [15]:
from baal.active.dataset.nlp_datasets import active_huggingface_dataset

In [16]:
train_ds[0]

{'label': 0, 'sms': 'Well I might not come then...\n'}

In [17]:
active_set = active_huggingface_dataset(dataset=train_ds, tokenizer=tokenizer, target_key='label', input_key='sms')

# Wrap test set for evaluation purposes using Huggingface dataset wrapper defined in BaaL

In [18]:
from baal.active.dataset.nlp_datasets import HuggingFaceDatasets

In [19]:
eval_ds = HuggingFaceDatasets(dataset=test_ds, tokenizer=tokenizer, target_key='label', input_key='sms')

# Define Active Learning Experiment Configurations

In [20]:
from dataclasses import dataclass

In [21]:
@dataclass
class ExperimentConfig:
    epoch: int = 4500//128
    batch_size: int = 32
    initial_pool: int = 500
    query_size: int = 128
    lr: float = 0.001
    heuristic: str = 'bald'
    iterations: int = 40
    training_duration: int = 2

In [22]:
hyperparams = ExperimentConfig()

In [23]:
hyperparams

ExperimentConfig(epoch=35, batch_size=32, initial_pool=500, query_size=128, lr=0.001, heuristic='bald', iterations=40, training_duration=2)

In [24]:
active_set.can_label = False

In [25]:
active_set.label_randomly(hyperparams.initial_pool)

In [26]:
active_set.n_labelled

500

In [27]:
active_set[0]

{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]),
 'input_ids': tensor([  101,  1573,  1175,   112,   188,   170,  3170,  1115,  2502,  1114,
          1103,  3713, 13452,   119,  1135,   112,   188,  1175,  1177,  1152,
          1169,  5309,  1147,  2174, 26063, 22267,  1116,   119,  8790,  2227,
          9124,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     

In [28]:
import random
import torch

In [29]:
use_cuda = torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

if not use_cuda:
    print("warning, the experiments would take ages to run on cpu")

use_cuda

True

In [30]:
# Out Dataset Shapes
len(active_set), len(test_ds), len(eval_ds)

(500, 500, 500)

In [31]:
from baal.active import get_heuristic

In [32]:
# Get our model.
heuristic = get_heuristic(hyperparams.heuristic)

In [33]:
from baal.bayesian.dropout import patch_module
from copy import deepcopy

In [34]:
# change dropout layer to MCDropout
model = patch_module(model)

Checking if dropout layer was replaced by the BaaL implementation of Dropout

In [35]:
type(list(model.named_modules())[-1][1])

baal.bayesian.dropout.Dropout

In [36]:
if use_cuda:
    model.cuda()
init_weights = deepcopy(model.state_dict())

In [37]:
from transformers import TrainingArguments
from baal.transformers_trainer_wrapper import BaalTransformersTrainer
from baal.active.active_loop import ActiveLearningLoop

In [38]:
import numpy as np

In [39]:
from datasets import load_metric

metrics = {
    "accuracy": load_metric("accuracy"),
    "f1": load_metric("f1"),
    "precision": load_metric("precision"),
    "recall": load_metric("recall"),
}

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    metrics_results = {}
    for metric_name, metric_function in metrics.items():
      metrics_results.update(metric_function.compute(predictions=predictions, references=labels))

    return metrics_results

In [40]:
compute_metrics(([[0.9, 0.2], [0.1, 0.9]], [1, 1]))

{'accuracy': 0.5, 'f1': 0.6666666666666666, 'precision': 1.0, 'recall': 0.5}

In [41]:
#Initialization for the huggingface trainer
training_args = TrainingArguments(
    output_dir='.',  # output directory
    do_train=True,
    do_eval=True,
    evaluation_strategy="steps",
    num_train_epochs=hyperparams.epoch,  # total # of training epochs per AL step
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,  # batch size for evaluation
    weight_decay=0.01,  # strength of weight decay
    logging_dir='.',  # directory for storing logs
    )

In [42]:
# create the trainer through Baal Wrapper
baal_trainer = BaalTransformersTrainer(model=model,
                                       args=training_args,
                                       train_dataset=active_set,
                                       eval_dataset=eval_ds,
                                       compute_metrics=compute_metrics,
                                       tokenizer=None)

In [43]:
active_loop = ActiveLearningLoop(active_set,
                                 baal_trainer.predict_on_dataset,
                                 heuristic, 10, iterations=3)

In [48]:
for epoch in range(hyperparams.training_duration):
    baal_trainer.train()

    should_continue = active_loop.step()

    # We reset the model weights to relearn from the new train set.
    baal_trainer.load_state_dict(init_weights)
    baal_trainer.lr_scheduler = None
    if not should_continue:
        break

# at each Active step we add 10 samples to labelled data. At this point we should have 30 samples added
# to the labelled part of training set.
print(len(active_set))

***** Running training *****
  Num examples = 500
  Num Epochs = 35
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 1120


Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
500,0.0209,0.143336,0.978,0.916031,0.9375,0.895522
1000,0.0,0.19269,0.974,0.900763,0.921875,0.880597


***** Running Evaluation *****
  Num examples = 500
  Batch size = 64
Saving model checkpoint to ./checkpoint-500
Configuration saved in ./checkpoint-500/config.json
Model weights saved in ./checkpoint-500/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 500
  Batch size = 64
Saving model checkpoint to ./checkpoint-1000
Configuration saved in ./checkpoint-1000/config.json
Model weights saved in ./checkpoint-1000/pytorch_model.bin


Training completed. Do not forget to share your model on huggingface.co/models =)




[905-MainThread  ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:67] [2m2021-12-16T13:28:56.557407Z[0m [[32m[1minfo     [0m] [1mStart Predict                 [0m [36mdataset[0m=[35m4574[0m


100%|██████████| 72/72 [01:39<00:00,  1.39s/it]
***** Running training *****
  Num examples = 510
  Num Epochs = 35
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 1120


Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
500,0.0314,0.172448,0.978,0.916031,0.9375,0.895522
1000,0.0,0.218804,0.974,0.902256,0.909091,0.895522


***** Running Evaluation *****
  Num examples = 500
  Batch size = 64
Saving model checkpoint to ./checkpoint-500
Configuration saved in ./checkpoint-500/config.json
Model weights saved in ./checkpoint-500/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 500
  Batch size = 64
Saving model checkpoint to ./checkpoint-1000
Configuration saved in ./checkpoint-1000/config.json
Model weights saved in ./checkpoint-1000/pytorch_model.bin


Training completed. Do not forget to share your model on huggingface.co/models =)




[905-MainThread  ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:67] [2m2021-12-16T13:37:44.524479Z[0m [[32m[1minfo     [0m] [1mStart Predict                 [0m [36mdataset[0m=[35m4564[0m


100%|██████████| 72/72 [01:39<00:00,  1.38s/it]

520





In [51]:
set(active_set.labelled_map)

{0, 1, 2, 3}

In [65]:
labelling_progress = active_set.labelled_map.astype(np.uint16)

In [66]:
model_weight = model.state_dict()
dataset = active_set.state_dict()
torch.save({'model':model_weight, 'dataset':dataset, 'labelling_progress':labelling_progress},
           'checkpoint.pth')
print(model.state_dict().keys(), dataset.keys(), labelling_progress)

odict_keys(['distilbert.embeddings.word_embeddings.weight', 'distilbert.embeddings.position_embeddings.weight', 'distilbert.embeddings.LayerNorm.weight', 'distilbert.embeddings.LayerNorm.bias', 'distilbert.transformer.layer.0.attention.q_lin.weight', 'distilbert.transformer.layer.0.attention.q_lin.bias', 'distilbert.transformer.layer.0.attention.k_lin.weight', 'distilbert.transformer.layer.0.attention.k_lin.bias', 'distilbert.transformer.layer.0.attention.v_lin.weight', 'distilbert.transformer.layer.0.attention.v_lin.bias', 'distilbert.transformer.layer.0.attention.out_lin.weight', 'distilbert.transformer.layer.0.attention.out_lin.bias', 'distilbert.transformer.layer.0.sa_layer_norm.weight', 'distilbert.transformer.layer.0.sa_layer_norm.bias', 'distilbert.transformer.layer.0.ffn.lin1.weight', 'distilbert.transformer.layer.0.ffn.lin1.bias', 'distilbert.transformer.layer.0.ffn.lin2.weight', 'distilbert.transformer.layer.0.ffn.lin2.bias', 'distilbert.transformer.layer.0.output_layer_norm.

In [44]:
checkpoint = torch.load('checkpoint.pth')

In [45]:
checkpoint.keys()

dict_keys(['model', 'dataset', 'labelling_progress'])

In [46]:
active_set.load_state_dict(checkpoint['dataset'])

In [55]:
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [56]:
labelling_progress = checkpoint['labelling_progress']

## Visualization

Now that our active learning experiment is completed, we can visualize it!

## Get t-SNE features.
We will use MultiCoreTSNE to get a t-SNE representation of our dataset. This will allows us to visualize the progress.

In [58]:
# modify our model to get features
from torch import nn
from torch.utils.data import DataLoader


# Make a feature extractor from our trained model.
class FeatureExtractor(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, input_ids, attention_mask):
        out1 = self.model.distilbert(input_ids, attention_mask)
        out2 = self.model.pre_classifier(out1[0])
        embs = out2.mean(dim=1)
        return embs


In [59]:
features = FeatureExtractor(model)
acc = []
for x in DataLoader(active_set._dataset, batch_size=10):
    acc.append((features(x['input_ids'].cuda(), x['attention_mask'].cuda()).detach().cpu().numpy(), x['label'].detach().cpu().numpy()))
    
xs, ys = zip(*acc)

In [None]:
!pip install MulticoreTSNE

In [62]:
from MulticoreTSNE import MulticoreTSNE as TSNE

# Compute t-SNE on the extracted features.
tsne = TSNE(n_jobs=4)
transformed = tsne.fit_transform(np.vstack(xs))

In [63]:
labels = np.concatenate(ys)
labels.shape

(5074,)

To make the animation, BaaL has `baal.utils.plot_utils.make_animation_from_data` which takes a set of features, their labels
 and the array containing the progress we created earlier.

In [64]:
from baal.utils.plot_utils import make_animation_from_data

# Create frames to animate the process.
frames = make_animation_from_data(transformed, labels, labelling_progress, ["ham", "spam"])

In [65]:
from IPython.display import HTML
import matplotlib.pyplot as plt
from matplotlib import animation

def plot_images(img_list):
    def init():
        img.set_data(img_list[0])
        return (img,)

    def animate(i):
        img.set_data(img_list[i])
        return (img,)

    fig = plt.Figure(figsize=(10,10))
    ax = fig.gca()
    img = ax.imshow(img_list[0])
    anim = animation.FuncAnimation(fig, animate, init_func=init,
                                 frames=len(img_list), interval=60, blit=True)
    return anim

HTML(plot_images(frames).to_jshtml())

### Conclusion

And that's it! Using a couple lines of code, we were able to run our active learning experiment and plot
the progress on a t-SNE representation.
