# Small text active learning example

taken from https://rubrix.readthedocs.io/en/stable/tutorials/active_learning_with_small_text.html

In [1]:
import datasets

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Configs
DATASET = "bergr7/weakly_supervised_ag_news"
TRANSFORMER_MODEL = "distilbert-base-uncased"
LABELS = datasets.load_dataset('ag_news')["train"].features["label"].names
NUM_SAMPLES = 5

Using custom data configuration default
Found cached dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)
100%|██████████| 2/2 [00:00<00:00, 33.81it/s]


In [6]:
from app.model.train_model import load_data

ag_news_data = load_data(split=False)

Using custom data configuration bergr7--weakly_supervised_ag_news-6f78f309523478bd
Found cached dataset csv (/root/.cache/huggingface/datasets/bergr7___csv/bergr7--weakly_supervised_ag_news-6f78f309523478bd/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a)
100%|██████████| 3/3 [00:00<00:00, 410.59it/s]


In [7]:
ag_news_data

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 37340
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 24000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})

In [8]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL)

# Helper function to tokenize the input text
def tokenize(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

# Tokenize dataset
data_tokenized = ag_news_data.map(tokenize, batched=True, remove_columns=["text"])
# Set convenient output format
data_tokenized.set_format("torch")

Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 18.9kB/s]
Downloading: 100%|██████████| 483/483 [00:00<00:00, 381kB/s]
Downloading: 100%|██████████| 232k/232k [00:00<00:00, 429kB/s]  
Downloading: 100%|██████████| 466k/466k [00:00<00:00, 759kB/s]  
100%|██████████| 38/38 [00:10<00:00,  3.71ba/s]
100%|██████████| 24/24 [00:07<00:00,  3.05ba/s]
100%|██████████| 8/8 [00:02<00:00,  3.12ba/s]


In [10]:
from small_text.integrations.transformers import TransformersDataset
from small_text.base import LABEL_UNLABELED


# Create tuples from the tokenized training data
data = [
    # Need to add an extra dimension to indicate a batch size of 1 -> [None]
    (row["input_ids"][None], row["attention_mask"][None], LABEL_UNLABELED)
    for row in data_tokenized["train"]
]

# Create the dataset for small-text
dataset = TransformersDataset(data)



In [11]:
len(dataset.data)

37340

In [12]:
# Create validation dataset
data_test = [
    (row["input_ids"][None], row["attention_mask"][None], int(row["label"]))
    for row in data_tokenized["validation"]
]
dataset_test = TransformersDataset(data_test)

In [13]:
from small_text.integrations.transformers.classifiers.factories import TransformerBasedClassificationFactory
from small_text.integrations.transformers import TransformerModelArguments
from small_text.query_strategies import LeastConfidence
from small_text.active_learner import PoolBasedActiveLearner


# Define our classifier
clf_factory = TransformerBasedClassificationFactory(
    TransformerModelArguments(TRANSFORMER_MODEL),
    num_classes=4,
    # If you have a cuda device, specify it here.
    # Otherwise, just remove the following line.
    kwargs={"device": "cuda"}
)

# Define our query strategy
query_strategy = LeastConfidence()

# Use the active learner with a pool containing all unlabeled data
active_learner = PoolBasedActiveLearner(clf_factory, query_strategy, dataset)

In [15]:
from small_text.initialization import random_initialization
import numpy as np

np.random.seed(42)


# Number of samples in our queried batches
NUM_SAMPLES = 5

# Randomly draw an initial subset from the data pool
initial_indices = random_initialization(dataset, NUM_SAMPLES)

In [16]:
initial_indices

array([30992,  8317, 17797, 23232, 12463])

In [18]:
import rubrix as rb
import os

RUBRIX_URL = os.getenv("RUBRIX_API_URL", "http://localhost:6900")

rb.init(api_url=RUBRIX_URL)

In [19]:
# Choose a name for the dataset
DATASET_NAME = "test_with_active_learning_test"

# Define labeling schema
settings = rb.TextClassificationSettings(label_schema=LABELS)

# Create dataset with a label schema
rb.configure_dataset(name=DATASET_NAME, settings=settings)

# Create records from the initial batch
records = [
    rb.TextClassificationRecord(
        text=ag_news_data["train"]["text"][idx],
        metadata={"batch_id": 0},
        id=idx,
    )
    for idx in initial_indices
]

# Log initial records to Rubrix
rb.log(records, DATASET_NAME)

100%|██████████| 5/5 [00:01<00:00,  4.31it/s]

5 records logged to http://rubrix:80/datasets/rubrix/test_with_active_learning_test





BulkResponse(dataset='test_with_active_learning_test', processed=5, failed=0)

In [20]:
records[0]

TextClassificationRecord(text='Baseball and its fans recover from 1994 strike Ten years after the World Series was canceled and fans left in droves, Major League Baseball will tell you it has never been healthier.', inputs={'text': 'Baseball and its fans recover from 1994 strike Ten years after the World Series was canceled and fans left in droves, Major League Baseball will tell you it has never been healthier.'}, prediction=None, prediction_agent=None, annotation=None, annotation_agent=None, multi_label=False, explanation=None, id=30992, metadata={'batch_id': 0}, status='Default', event_timestamp=None, metrics=None, search_keywords=None)

In [22]:
from rubrix.listeners import listener
from sklearn.metrics import accuracy_score

# Define some helper variables
# LABEL2INT = ag_news_data["train"].features["label"].str2int
LABEL2INT = dict(zip(LABELS, range(4)))
ACCURACIES = []

# Set up the active learning loop with the listener decorator
@listener(
    dataset=DATASET_NAME,
    query="status:Validated AND metadata.batch_id:{batch_id}",
    condition=lambda search: search.total==NUM_SAMPLES,
    execution_interval_in_seconds=3,
    batch_id=0
)
def active_learning_loop(records, ctx):

    # 1. Update active learner
    print(f"Updating with batch_id {ctx.query_params['batch_id']} ...")
    print('Please go to rubrix to label the data...')
    y = np.array([LABEL2INT[rec.annotation] for rec in records])
    
    print(f"{NUM_SAMPLES} records have been labeled updating active learner...")
    # initial update
    if ctx.query_params["batch_id"] == 0:
        indices = np.array([rec.id for rec in records])
        active_learner.initialize_data(indices, y)
    # update with the prior queried indices
    else:
        active_learner.update(y)
    print("Done!")
    

    # 2. Query active learner
    print("Querying new data points ...")
    queried_indices = active_learner.query(num_samples=NUM_SAMPLES)
    ctx.query_params["batch_id"] += 1
    new_records = [
        rb.TextClassificationRecord(
            text=ag_news_data["train"]["text"][idx],
            metadata={"batch_id": ctx.query_params["batch_id"]},
            id=idx,
        )
        for idx in queried_indices
    ]

    # 3. Log the batch to Rubrix
    rb.log(new_records, DATASET_NAME)

    # 4. Evaluate current classifier on the test set
    print("Evaluating current classifier ...")
    accuracy = accuracy_score(
        dataset_test.y,
        active_learner.classifier.predict(dataset_test),
    )
    ACCURACIES.append(accuracy)
    print("Done!")

    print("Waiting for annotations ...")

In [None]:
active_learning_loop.start()

In [24]:
active_learning_loop.stop()

In [None]:
import pandas as pd

pd.Series(ACCURACIES).plot(xlabel="Iteration", ylabel="Accuracy");

In [21]:
ACCURACIES

[0.24695833333333334, 0.24845833333333334]

In [22]:
active_learner.classifier

<small_text.integrations.transformers.classifiers.classification.TransformerBasedClassification at 0x7f7a9c5a0880>

In [47]:
rb.load(DATASET_NAME).to_datasets()[0]

{'text': 'Musharraf ally elected as new Pakistan PM ISLAMABAD - Pakistan #39;s Parliament elected former Finance Minister Shaukat Aziz as Prime Minister yesterday amid an opposition boycott of the vote.',
 'inputs': {'text': 'Musharraf ally elected as new Pakistan PM ISLAMABAD - Pakistan #39;s Parliament elected former Finance Minister Shaukat Aziz as Prime Minister yesterday amid an opposition boycott of the vote.'},
 'prediction': None,
 'prediction_agent': None,
 'annotation': 'World',
 'annotation_agent': 'rubrix',
 'multi_label': False,
 'explanation': None,
 'id': '1813',
 'metadata': {'batch_id': 1},
 'status': 'Validated',
 'event_timestamp': None,
 'metrics': {'text_length': 192}}