# How to build a specialized text classifier without days of human labeling

In this notebook, we will learn how to build a text classifier with AI and human feedback saving time and resources.

## Getting started

### Deploy the Argilla server¶

If you already have deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following [this guide](https://docs.argilla.io/latest/getting_started/quickstart/).

### Install dependencies

In [None]:
!pip install argilla
!pip install "distilabel[hf-inference-endpoints]"
!pip install setfit

## Configure the Argilla dataset

In [None]:
import argilla as rg
from datasets import load_dataset

# Connect to the Argilla server
client = rg.Argilla(api_url="<http://localhost:6900>", api_key="argilla.apikey")

In [None]:
# Configure the dataset
labels = ["positive", "neutral", "negative"]
settings = rg.Settings(
	guidelines = "We are dealing with input data about detailed customer reviews of 3 corresponding labels: positive, negative, and neutral.",
    fields=[
        rg.TextField(
            name="text",
            title="Text",
            description="Provide a concise response to the prompt",
        )
    ],
    questions=[
        rg.LabelQuestion(
            name="label",
            title="Emotion",
            description="Provide a single label for the emotion of the text",
            labels=labels,
        )
    ],
    mapping = {"labels": "label"}
)

# Create the dataset
dataset_name = "pc-component"
dataset = rg.Dataset(
    name=dataset_name,
    settings=settings,
).create()

# Load the records (in our case, they have suggestions, but it's not required)
hf_dataset = load_dataset("argilla/pc-components-reviews", split="train")
records = [
    rg.Record(
        fields = {"text": sample["text"]},
        suggestions = [rg.Suggestion("label", sample["labels"] if sample["labels"] in labels else "neutral")]
    ) for sample in hf_dataset
]
dataset.records.log(records)

## Active auto-labeling

In [None]:
import random
from collections import Counter, defaultdict
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.llms.huggingface import InferenceEndpointsLLM


example_records_dict = defaultdict(list)
counter = Counter()
max_samples_per_label = 16

# Helper function to get the example records
def get_example_records(num_samples_per_label):
    example_records = []    
    for label, records in example_records_dict.items():
        selected_records = records[:num_samples_per_label]
        example_records.extend(selected_records)
    random.shuffle(example_records)
    return example_records

# Simulate record annotation for few-shot
for record in dataset.records(with_responses=True, with_suggestions=True):
    label_value = record.suggestions["label"].value

    if counter[label_value] >= max_samples_per_label:
        continue
    counter[label_value] += 1

    record.responses.add(
        rg.Response(question_name="label", value=label_value, user_id=client.me)
    )
    example_records_dict[label_value].append(record.to_dict())

In [None]:
# Initialize the ArgillaLabeller
labeller = ArgillaLabeller(
    llm=InferenceEndpointsLLM(
        model_id="meta-llama/Llama-3.1-8B-Instruct",
        tokenizer_id="meta-llama/Llama-3.1-8B-Instruct",
    ),
    example_records=get_example_records(1),
)
labeller.load()

In [None]:
# Create the loop to start annotating and getting improved suggestions
while True:
    pending_records = list(
        dataset.records(
            query=rg.Query(filter=rg.Filter(("status", "==", "pending"))),
            limit=1,
        )
    )
    if not pending_records:
        sleep(5)
        continue

    results = next(
        labeller.process(
            [
                {
                    "record": record,
                    "fields": dataset.fields,
                    "question": dataset.questions[0],
                    "guidelines": dataset.guidelines,
                }
                for record in pending_records
            ]
        )
    )
    for record, suggestion in zip(pending_records, results):
        record.suggestions.add(rg.Suggestion(**suggestion["suggestion"]))

    dataset.records.log(pending_records)


## Train you small classifier

In [None]:
from datasets import Dataset
from setfit import sample_dataset

# Helper function to split the dataset
def sample_and_split(dataset, label_column, num_samples):
    train_dataset = sample_dataset(
        dataset, label_column=label_column, num_samples=num_samples
    )
    eval_dataset = dataset.filter(lambda x: x["id"] not in set(train_dataset["id"]))
    return train_dataset, eval_dataset

# Retrieve the data from Argilla
annotated_dataset = client.datasets(dataset_name).records.to_datasets()
annotated_dataset = Dataset.from_list([
    {"text": record["text"], "label": record["label.responses"][0], "id": i} 
    for i, record in enumerate(annotated_dataset) 
    if record.get("label.responses") is not None
])
train_dataset, eval_dataset = sample_and_split(
    annotated_dataset, "label", 8
)


In [None]:
from setfit import SetFitModel, Trainer

# Function to train our SetFit model
def train_model(model_name, train_dataset, eval_dataset):
    model = SetFitModel.from_pretrained(model_name)

    trainer = Trainer(
        model=model,
        train_dataset=train_dataset,
    )
    trainer.train()
    results = trainer.evaluate(eval_dataset)

    return model, results

# Train the classifier
model, results = train_model(
    model_name="TaylorAI/bge-micro-v2",
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
print(results)
