# Active learning loop

# Imports

In [1]:
import numpy as np
import os

# torch
import torch

# HF
from datasets import load_dataset, Features, Value, ClassLabel, Version
from transformers import AutoTokenizer

# small text
from small_text import random_initialization
from small_text.active_learner import PoolBasedActiveLearner
from small_text.base import LABEL_UNLABELED
from small_text.integrations.transformers import TransformersDataset, TransformerModelArguments
from small_text.integrations.transformers.classifiers.factories import TransformerBasedClassificationFactory
from small_text.query_strategies import BreakingTies

# rubrix
import rubrix as rb
from rubrix.listeners import listener

# metrics
from sklearn.metrics import accuracy_score

  from .autonotebook import tqdm as notebook_tqdm


# Constants

In [138]:
DATASET_AG_NEWS = "bergr7/weakly_supervised_ag_news"
TRANSFORMER_MODEL = "distilbert-base-uncased"
TOKENIZER = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL)
LABELS = load_dataset('ag_news')["train"].features["label"].names
NUM_SAMPLES_ITER = 5
NUM_CLASSES = len(LABELS)
DATASET_NAME = "active_learning_loop"

100%|██████████| 2/2 [00:00<00:00, 453.54it/s]


# Load data

In [117]:
# labeled data - train.csv contains weak labels
labeled_data_files = {
    "train": "train.csv",
    "validation": "validation.csv", 
    "test": "test.csv"
}

# unlabeled data - not covered by LFs
unlabeled_data_files = {"unlabeled": "unlabeled_train.csv"}

# Define schema
labeled_features = Features(
    {
        "text": Value("string"),
        "label": ClassLabel(
            num_classes=4,
            names=['World', 'Sports', 'Business', 'Sci/Tech']
        )
    }
)
unlabeled_features = Features({"text": Value("string")})

# load data
labeled_dataset = load_dataset(
    DATASET_AG_NEWS,
    data_files=labeled_data_files,
    features=labeled_features
)

unlabeled_dataset = load_dataset(
    DATASET_AG_NEWS,
    data_files=unlabeled_data_files,
    features=unlabeled_features
)

100%|██████████| 3/3 [00:00<00:00, 965.76it/s]
100%|██████████| 1/1 [00:00<00:00, 926.71it/s]


In [60]:
labeled_dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 37340
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 24000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})

In [61]:
unlabeled_dataset

DatasetDict({
    unlabeled: Dataset({
        features: ['text'],
        num_rows: 58660
    })
})

# Training data


In [118]:
# text
weak_train_text = [row["text"] for row in labeled_dataset["train"]][:100]
unlabeled_train_text = [row["text"] for row in unlabeled_dataset["unlabeled"]][:200]
# labels
weak_train_labels = [row["label"] for row in labeled_dataset["train"]][:100]
unlabeled_no_labels = [LABEL_UNLABELED for _ in unlabeled_train_text][:200]
# dataset
train_text = np.array(weak_train_text + unlabeled_train_text)
train_labels = np.array(weak_train_labels + unlabeled_no_labels)

# Preprocessing the dataset

Wrap the dataset in a specific data class by small-text -> `TransformersDataset` since we are gonna use `distilbert-base-uncased`

`TransformersDataset` should contain the tokenized input text.


In [119]:
# create dataset
dataset = TransformersDataset.from_arrays(
    train_text,
    train_labels,
    tokenizer=TOKENIZER,
    target_labels=[0, 1, 2, 3]
)



In [64]:
# visualize an input - weak label
print(f"input_ids: {dataset.data[37339][0]}")
print(f"token_type_ids: {dataset.data[37339][1]}")
print(f"label: {dataset.data[37339][2]}")

# visualize an input - unlabeled
print(f"input_ids: {dataset.data[50000][0]}")
print(f"token_type_ids: {dataset.data[50000][1]}")
print(f"label: {dataset.data[50000][2]}")

IndexError: list index out of range

In [120]:
# validation dataset
val_text = [row["text"] for row in labeled_dataset["validation"]][:20]
val_labels = [row["label"] for row in labeled_dataset["validation"]][:20]

In [121]:
val_dataset = TransformersDataset.from_arrays(
    val_text,
    val_labels,
    tokenizer=TOKENIZER,
    target_labels=[0, 1, 2, 3]
)

>Notes

- cleaning??
- test set???

# Set up active learner

- Component 1 -> classifier
- Component 2 -> query strategy (sampling)

In [122]:
# hyper-params
params = dict(
        {
            "lr": 3e-5,
            "num_epochs": 1,
            "mini_batch_size": 32,
            "model_selection": True,
            "device": "cuda"
        
    }
)

# classifier
model_factory = TransformerBasedClassificationFactory(
    TransformerModelArguments(TRANSFORMER_MODEL),
    num_classes=NUM_CLASSES,
    kwargs=params
)

In [123]:
# define query strategy
query_strategy = BreakingTies()

In [124]:
def initialize_active_learner(active_learner, train_labels, weak_train_labels):

    indices_initial = [i for i in range(len(weak_train_labels))]
    active_learner.initialize_data(indices_initial, train_labels[indices_initial])

    return indices_initial

In [129]:
active_learner = PoolBasedActiveLearner(
    model_factory,
    query_strategy,
    dataset,
    reuse_model=True
)

In [130]:
# this takes quite a bit of time... (1hr!!) - we initilize the active learner with the weak labels
indices_labeled = initialize_active_learner(active_learner, train_labels, weak_train_labels)

Since most query strategies, including ours, require a trained model, we randomly draw a subset from the data pool to initialize our AL system. After obtaining the labels for this batch of instances, the active learner will use them to create the first classifier.

# Active Learning loop with Rubrix

Configure a Rubrix dataset and set up active learning loop

In [126]:
# init rubrix

RUBRIX_URL = os.getenv("RUBRIX_API_URL", "http://localhost:6900")

rb.init(api_url=RUBRIX_URL)

In [127]:
settings = rb.TextClassificationSettings(label_schema=LABELS)
# configure dataset
rb.configure_dataset(name=DATASET_NAME, settings=settings)

In [131]:
# first batch querying with active learner
batch_1_indices = active_learner.query(num_samples=NUM_SAMPLES_ITER)

records = [
    rb.TextClassificationRecord(
        text=train_text[idx],
        metadata={"batch_id": 0},
        id=idx,
    )
    for idx in batch_1_indices
]

# Log initial records to Rubrix
rb.log(records, DATASET_NAME)

100%|██████████| 5/5 [00:01<00:00,  4.96it/s]

5 records logged to http://rubrix:80/datasets/rubrix/bg_test_3





BulkResponse(dataset='bg_test_3', processed=5, failed=0)

In [132]:
LABEL2INT = labeled_dataset["train"].features["label"].str2int

In [133]:
ACCURACIES = []

In [134]:
@listener(dataset=DATASET_NAME,
    query="status:Validated AND metadata.batch_id:{batch_id}",
    condition=lambda search: search.total==NUM_SAMPLES_ITER,
    execution_interval_in_seconds=3,
    batch_id=0
)
def active_learning_loop(records, ctx):
    
    # 1. Update active learner
    print(f"Updating with batch_id {ctx.query_params['batch_id']} ...")
    y = np.array([LABEL2INT(rec.annotation) for rec in records])

    # update with the prior queried indices
    active_learner.update(y)
    print("Done!")

    # 2. Query active learner
    print("Querying new data points ...")
    queried_indices = active_learner.query(num_samples=NUM_SAMPLES_ITER)
    new_batch = ctx.query_params["batch_id"] + 1
    new_records = [
        rb.TextClassificationRecord(
            text=train_text[idx],
            metadata={"batch_id": new_batch},
            id=idx,
        )
        for idx in queried_indices
    ]

    # 3. Log the batch to Rubrix
    rb.log(new_records, DATASET_NAME)

    # 4. Evaluate current classifier on the test set
    print("Evaluating current classifier ...")
    accuracy = accuracy_score(
        val_dataset.y,
        active_learner.classifier.predict(val_dataset),
    )

    ACCURACIES.append(accuracy)
    ctx.query_params["batch_id"] = new_batch
    print("Done!")

    print("Waiting for annotations ...")

# Start active learning loop

In [135]:
active_learning_loop.start()

Updating with batch_id 0 ...
Done!
Querying new data points ...


100%|██████████| 5/5 [00:00<00:00,  6.40it/s]


5 records logged to http://rubrix:80/datasets/rubrix/bg_test_3
Evaluating current classifier ...
Done!
Waiting for annotations ...
Updating with batch_id 1 ...
Done!
Querying new data points ...


100%|██████████| 5/5 [00:00<00:00, 60.16it/s]


5 records logged to http://rubrix:80/datasets/rubrix/bg_test_3
Evaluating current classifier ...
Done!
Waiting for annotations ...
Updating with batch_id 2 ...
Done!
Querying new data points ...


100%|██████████| 5/5 [00:00<00:00,  7.38it/s]


5 records logged to http://rubrix:80/datasets/rubrix/bg_test_3
Evaluating current classifier ...
Done!
Waiting for annotations ...


In [137]:
active_learning_loop.stop()