In [None]:
import datasets

In [None]:
ag_news_data = datasets.load_dataset('ag_news')

In [None]:
from transformers import AutoTokenizer

# Choose transformer model
TRANSFORMER_MODEL = "distilbert-base-uncased"

# Init tokenizer
tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL)

# Helper function to tokenize the input text
def tokenize(examples):
    return tokenizer(examples["text"], padding="max_length", max_length=64, truncation=True)

# Tokenize dataset
data_tokenized = ag_news_data.map(tokenize, batched=True, remove_columns=["text"])

In [None]:
# Set convenient output format
data_tokenized.set_format("torch")

In [None]:
from small_text.integrations.transformers import TransformersDataset
from small_text.base import LABEL_UNLABELED


# Create tuples from the tokenized training data
data = [
    # Need to add an extra dimension to indicate a batch size of 1 -> [None]
    (row["input_ids"][None], row["attention_mask"][None], LABEL_UNLABELED)
    for row in data_tokenized["train"]
]

# Create the dataset for small-text
dataset = TransformersDataset(data)

In [None]:
# Create test dataset
data_test = [
    (row["input_ids"][None], row["attention_mask"][None], int(row["label"]))
    for row in data_tokenized["test"]
]
dataset_test = TransformersDataset(data_test)

In [None]:
from small_text.integrations.transformers.classifiers.factories import TransformerBasedClassificationFactory
from small_text.integrations.transformers import TransformerModelArguments
from small_text.query_strategies import BreakingTies
from small_text.active_learner import PoolBasedActiveLearner


# Define our classifier
clf_factory = TransformerBasedClassificationFactory(
    TransformerModelArguments(TRANSFORMER_MODEL),
    num_classes=4,
    # If you have a cuda device, specify it here.
    # Otherwise, just remove the following line.
    # kwargs={"device": "cuda"}
)

# Define our query strategy
query_strategy = BreakingTies()

# Use the active learner with a pool containing all unlabeled data
active_learner = PoolBasedActiveLearner(clf_factory, query_strategy, dataset)

In [None]:
from small_text.initialization import random_initialization
import numpy as np
# Fix seed for reproducibility
np.random.seed(42)


# Number of samples in our queried batches
NUM_SAMPLES = 10

# Randomly draw an initial subset from the data pool
initial_indices = random_initialization(dataset, NUM_SAMPLES)

In [None]:
import rubrix as rb

rb.init(api_url="http://rubrix:80")

In [None]:
# Choose a name for the dataset
DATASET_NAME = "test_with_active_learning"

# Define labeling schema
labels = ag_news_data["train"].features["label"].names
settings = rb.TextClassificationSettings(label_schema=labels)

# Create dataset with a label schema
rb.configure_dataset(name=DATASET_NAME, settings=settings)

# Create records from the initial batch
records = [
    rb.TextClassificationRecord(
        text=ag_news_data["train"]["text"][idx],
        metadata={"batch_id": 0},
        id=idx,
    )
    for idx in initial_indices
]

# Log initial records to Rubrix
rb.log(records, DATASET_NAME)

In [None]:
from rubrix.listeners import listener
from sklearn.metrics import accuracy_score

# Define some helper variables
LABEL2INT = ag_news_data["train"].features["label"].str2int
ACCURACIES = []

# Set up the active learning loop with the listener decorator
@listener(
    dataset=DATASET_NAME,
    query="status:Validated AND metadata.batch_id:{batch_id}",
    condition=lambda search: search.total==NUM_SAMPLES,
    execution_interval_in_seconds=3,
    batch_id=0
)
def active_learning_loop(records, ctx):

    # 1. Update active learner
    print(f"Updating with batch_id {ctx.query_params['batch_id']} ...")
    y = np.array([LABEL2INT(rec.annotation) for rec in records])

    # initial update
    if ctx.query_params["batch_id"] == 0:
        indices = np.array([rec.id for rec in records])
        active_learner.initialize_data(indices, y)
    # update with the prior queried indices
    else:
        active_learner.update(y)
    print("Done!")

    # 2. Query active learner
    print("Querying new data points ...")
    queried_indices = active_learner.query(num_samples=NUM_SAMPLES)
    ctx.query_params["batch_id"] += 1
    new_records = [
        rb.TextClassificationRecord(
            text=ag_news_data["train"]["text"][idx],
            metadata={"batch_id": ctx.query_params["batch_id"]},
            id=idx,
        )
        for idx in queried_indices
    ]

    # 3. Log the batch to Rubrix
    rb.log(new_records, DATASET_NAME)

    # 4. Evaluate current classifier on the test set
    print("Evaluating current classifier ...")
    accuracy = accuracy_score(
        dataset_test.y,
        active_learner.classifier.predict(dataset_test),
    )
    ACCURACIES.append(accuracy)
    print("Done!")

    print("Waiting for annotations ...")

In [None]:
active_learning_loop.start()

In [None]:
import pandas as pd

pd.Series(ACCURACIES).plot(xlabel="Iteration", ylabel="Accuracy");

In [None]:
ACCURACIES

In [None]:
active_learning_loop.stop()