In [None]:
%pip install small-text[transformers]==1.3.0
%pip install setfit>=0.5.0
%pip install datasets
%pip install matplotlib

In [None]:
import logging

import datasets

import matplotlib.pyplot as plt

plt.rc('figure', titlesize=22)
plt.rc('axes', titlesize=22, labelsize=20, linewidth=1.2)
plt.rc('xtick', labelsize=14)
plt.rc('ytick', labelsize=14)
plt.rc('legend', fontsize=16)
plt.rc('lines', linewidth=2)

datasets.logging.set_verbosity_error()

# disables the progress bar for notebooks: https://github.com/huggingface/datasets/issues/2651
datasets.logging.get_verbosity = lambda: logging.NOTSET

for logger_name in ['setfit.modeling', 'setfit.trainer']:
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.ERROR)

In [None]:
import datasets
import logging
import numpy as np
from small_text import TextDataset
from transformers import AutoTokenizer
from small_text.integrations.transformers.classifiers.setfit import SetFitModelArguments
from small_text.integrations.transformers.classifiers.factories import SetFitClassificationFactory
from small_text import (
    PoolBasedActiveLearner,
    PredictionEntropy,
    TransformerBasedClassificationFactory,
    TransformerModelArguments,
    random_initialization_balanced,
    TransformersDataset
)

import gc
import torch
from sklearn.metrics import accuracy_score



In [None]:
datasets.logging.get_verbosity = lambda: logging.NOTSET

dataset1 = datasets.load_dataset("SetFit/toxic_conversations_50k")
dataset2 = datasets.load_dataset("SetFit/tweet_eval_stance_abortion")
dataset3 = datasets.load_dataset("SetFit/catalonia_independence_es")

raw_dataset = dataset1
raw_dataset['train']=raw_dataset['train'].shuffle(seed=42).select(range(10 * 10))
raw_dataset['test']=raw_dataset['test'].shuffle(seed=42).select(range(10 * 2))

num_classes = np.unique(raw_dataset['train']['label']).shape[0]

print('First 10 training samples:\n')
for i in range(10):
    print(raw_dataset['train']['label'][i], ' ', raw_dataset['train']['text'][i])

In [None]:
transformer_model_name = 'bert-base-uncased'

tokenizer = AutoTokenizer.from_pretrained(
    transformer_model_name
)

In [None]:
target_labels = np.arange(num_classes)

train = TransformersDataset.from_arrays(raw_dataset['train']['text'],
                                        raw_dataset['train']['label'],
                                        tokenizer,
                                        max_length=60,
                                        target_labels=target_labels)

test = TransformersDataset.from_arrays(raw_dataset['test']['text'], 
                                       raw_dataset['test']['label'],
                                       tokenizer,
                                       max_length=60,
                                       target_labels=target_labels)

In [None]:
# simulates an initial labeling to warm-start the active learning process
def initialize_active_learner(active_learner, y_train):

    indices_initial = random_initialization_balanced(y_train, n_samples=20)
    active_learner.initialize_data(indices_initial, y_train[indices_initial])

    return indices_initial


transformer_model = TransformerModelArguments(transformer_model_name)
clf_factory = TransformerBasedClassificationFactory(transformer_model, 
                                                    num_classes, 
                                                    kwargs=dict({'device': 'cuda', 
                                                                 'mini_batch_size': 32,
                                                                 'class_weight': 'balanced'
                                                                }))
query_strategy = PredictionEntropy()

active_learner = PoolBasedActiveLearner(clf_factory, query_strategy, train)
indices_labeled = initialize_active_learner(active_learner, train.y)


In [None]:
from scipy.stats import entropy
from small_text.utils.data import list_length
import pandas as pd
num_queries = 5

def get_entropy_labeled(row, model,indices_labeled):
  p = model.predict_proba(row)
  entr = np.apply_along_axis(lambda x: entropy(x), 1, p) 
  query2 = pd.DataFrame(entr,indices_labeled)
  query3 = np.array(query2.nlargest(1,0).index.values, dtype=np.int64)
 
  return query3
  

def evaluate(active_learner, train, test,indices_labeled):

    y_pred = active_learner.classifier.predict(train)

    y_prob = get_entropy_labeled(train,active_learner.classifier,indices_labeled)

  
    y_pred_test = active_learner.classifier.predict(test)
    
    test_acc = accuracy_score(y_pred_test, test.y)

    print('Train accuracy: {:.2f}'.format(accuracy_score(y_pred, train.y)))
    print('Test accuracy: {:.2f}'.format(test_acc))
    
    return test_acc, y_prob


results = []
eval, mix_query = evaluate(active_learner, train[indices_labeled], test,indices_labeled)
results.append(eval)

    
for i in range(num_queries):
    # ...where each iteration consists of labelling 5 samples
    indices_queried = active_learner.query(num_samples=5)

    # Simulate user interaction here. Replace this for real-world usage.
    y = train.y[indices_queried]
    
    
    yy=train.y[mix_query]


    # Return the labels for the current query to the active learner.
    active_learner.update(y)

    #update the examples for which the model is uncertain about
    active_learner.update_label_at(mix_query[0],yy[0],retrain=True)

    indices_labeled = np.concatenate([indices_queried, indices_labeled])
    
    print('---------------')
    print(f'Iteration #{i} ({len(indices_labeled)} samples)')
    eval, mix_query = evaluate(active_learner, train[indices_labeled], test,indices_labeled)
    results.append(eval)

Train accuracy: 0.98
Test accuracy: 0.90
---------------
Iteration #0 (50 samples)
Train accuracy: 0.98
Test accuracy: 0.95
---------------
Iteration #1 (55 samples)
Train accuracy: 0.93
Test accuracy: 0.95
---------------
Iteration #2 (60 samples)
Train accuracy: 0.90
Test accuracy: 0.95
---------------
Iteration #3 (65 samples)
Train accuracy: 0.86
Test accuracy: 0.85
---------------
Iteration #4 (70 samples)
Train accuracy: 0.97
Test accuracy: 0.95
