# Zero-Shot Classification and Few-Shot Fine-Tuning for News Articles"

## Introduction

This tutorial demonstrates a modern approach to text classification, combining the power of Large Language Models (LLMs) for zero-shot labeling with efficient fine-tuning using SetFit, and leveraging Argilla for human-in-the-loop validation and data management. We will classify news articles into categories: World, Sports, Business, and Sci/Tech.

**Workflow Overview:**

1.  **Zero-Shot Classification with LLM**: We'll use an LLM via the Together API to automatically label a sample of news articles without any prior training examples.
2.  **Human Review and Correction with Argilla**: We'll use Argilla to review and correct the LLM-generated labels, ensuring data quality and creating a gold-standard dataset.
3.  **Few-Shot Fine-tuning with SetFit**: We'll fine-tune a Sentence-transformers model using SetFit on the human-validated data from Argilla. SetFit is designed for efficient few-shot learning.
4.  **Evaluation and Comparison**: We'll evaluate the performance of the SetFit model and compare it to a traditional Logistic Regression model trained on a larger dataset.

Let's begin by installing the necessary libraries.

In [None]:
#| label: install-libraries
#| echo: false
!pip install openai datasets sentence-transformers argilla setfit -q

## 1. Setting up the Environment and API Keys

First, we import the required libraries and initialize Argilla and the OpenAI client (for Together API). Ensure you have an Argilla account and a Together API key. You'll need to set your Together API key as a Colab userdata secret named `TOGETHER_API_KEY`.

In [None]:
#| label: setup-environment
import json
import pandas as pd
from openai import OpenAI
from datasets import load_dataset, Dataset
from sentence_transformers import SentenceTransformer
import argilla as rg
from setfit import SetFitModel, Trainer, TrainingArguments
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.metrics import classification_report

# API key for Together API (use your own API key)
from google.colab import userdata
TOGETHER_API_KEY = userdata.get('TOGETHER_API_KEY')

# Initialize Argilla client with the new API structure
client = rg.Argilla(
    api_url="https://<your_space>.hf.space", # Replace with your Argilla API URL
    api_key="xxxxx-xxx", # Replace with your Argilla API Key
)

## 2. Zero-Shot Classification with LLM

In this section, we define the categories for news classification and set up an LLM-based classifier using the Together API. We use a system prompt to guide the LLM's classification task and ensure JSON formatted output for easy parsing.

In [None]:
#| label: define-categories-prompt
# Define news categories for classification
categories = ["World", "Sports", "Business", "Sci/Tech"]

# Define system prompt for the LLM - instructs the LLM for zero-shot classification
system_prompt = """
You are a sophisticated classification engine tasked with categorizing news articles.
Your primary function is to evaluate the core message of each article and assign it to one of the following categories:
"World" for global news covering politics and similar topics,
"Sports" for news related to sports,
"Business" for articles on business, economics, or finance,
and "Sci/Tech" for content focused on technology and science.

Upon analyzing a text input, you will provide an explanation for the category chosen.
Your output will adhere strictly to the JSON format, specifically:
{"prediction":"your selected prediction", "explanation":"your explanation"}.
It is imperative that your output is VALID JSON and contains no other elements.
"""

# Create an OpenAI-compatible client using Together API
llm_client = OpenAI(base_url="https://api.together.xyz/v1", api_key=TOGETHER_API_KEY)

# Function to classify text using the LLM
def classify(text):
    completion = llm_client.chat.completions.create(
        model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",  # Using an open-source LLM - maybe
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": f'Classify following text: {text}'}
        ],
        temperature=0.2,  # Lower temperature for more consistent outputs
    )
    json_response = completion.choices[0].message.content.strip()
    try:
        prediction = json.loads(json_response) # Parse JSON response
    except:
        # Fallback for incorrectly formatted JSON
        return {"prediction": None, "explanation": f"Error parsing JSON: {json_response}"}
    return prediction

# Example news article for testing
text_example = """
Stocks Rally on Lower Oil Prices. Stocks rallied in quiet trading Wednesday
as lower oil prices brought out buyers, countering a pair of government reports
that gave a mixed picture of the economy.
"""

# Test the classification function
result = classify(text_example)
print(f"Example classification:\nText: {text_example}\nResult: {result}")

Let's apply this classification function to a sample of news articles. We load a small sample from the `ag_news_unlabelled` dataset for demonstration.

In [None]:
#| label: classify-sample-data
# Load a sample of news articles for demonstration
data_train = pd.read_parquet('https://github.com/SDS-AAU/SDS-master/raw/master/M2/data/ag_news_unlabelled.pq')
dataset_news = Dataset.from_pandas(data_train.sample(20).reset_index(drop=True))  # Sample 20 articles

# Apply zero-shot classification to our sample
print("Classifying news articles with LLM...")
news_with_preds = []
for example in dataset_news:
    result = classify(example["text"])
    news_with_preds.append({
        "text": example["text"],
        "label": result["prediction"],
        "explanation": result["explanation"]
    })

Here are a few examples of the LLM's predictions.

In [None]:
#| label: display-sample-predictions
# Display sample predictions
print("\nSample of LLM predictions:")
for i, item in enumerate(news_with_preds[:3]):
    print(f"\nArticle {i+1}:")
    print(f"Text: {item['text'][:100]}...")
    print(f"Prediction: {item['label']}")
    print(f"Explanation: {item['explanation']}")

## 3. Human Review and Data Logging to Argilla with Embeddings

Now, we will use Argilla to create a dataset for human review of the LLM predictions. We also generate sentence embeddings for each news article to enable semantic search and similarity features in Argilla. We use `SentenceTransformer` to create these embeddings.

In [None]:
#| label: create-argilla-dataset
# Initialize sentence transformer model for embeddings
print("\nGenerating vector embeddings...")
model = SentenceTransformer("TaylorAI/bge-micro-v2")  # 384-dimensional embeddings

# Create a dataset in Argilla with vector settings properly included
print("\nCreating Argilla dataset...")

# Configure dataset settings with vector settings
settings = rg.Settings(
    guidelines="Classify news articles into one of the categories: World, Sports, Business, or Sci/Tech.",
    fields=[
        rg.TextField(
            name="text",
        ),
        rg.TextField(
            name="explanation",
            title="LLM Explanation",
        ),
    ],
    questions=[
        rg.LabelQuestion(
            name="label",
            title="Category",
            labels=categories
        ),
    ],
    # Vector settings for embeddings
    vectors=[
        rg.VectorField(
            name="sentence_embedding",
            title="Sentence Embedding",
            dimensions=384  # Using bge-micro embeddings which are 384-dimensional
        )
    ]
)

# Create dataset
dataset_name = "news"
try:
    # Check if dataset exists
    dataset = client.datasets.get(dataset_name)
    print(f"Dataset '{dataset_name}' already exists")
except:
    # Create new dataset
    dataset = rg.Dataset(
        name=dataset_name,
        settings=settings,
    )
    dataset.create()
    print(f"Created new dataset '{dataset_name}'")

We prepare the data records with sentence embeddings and log them to the Argilla dataset. The embeddings are stored as vector fields within Argilla, which can be used for advanced search and exploration within the Argilla platform.

In [None]:
#| label: log-records-argilla
# Prepare records with vectors using FLAT structure
print("Preparing records with embeddings before logging...")
records_with_vectors = []

for item in news_with_preds:
    if item["label"] is not None:
        # Generate embedding for this text
        embedding = model.encode(item["text"]).tolist()

        # Create record with FLAT structure (no nesting)
        records_with_vectors.append({
            "text": item["text"],
            "explanation": item["explanation"],
            "label": item["label"],
            "sentence_embedding": embedding  # Vector included directly
        })

# Log records WITH vectors to Argilla in a single operation
if records_with_vectors:
    dataset.records.log(records_with_vectors)
    print(f"Logged {len(records_with_vectors)} records with embeddings to Argilla")

After running the above cells, you can access the Argilla UI to review and annotate the LLM predictions. Correct any misclassifications and ensure the data is of high quality. Once annotation is complete in Argilla, we can load the hand-labeled data for SetFit fine-tuning.

## 4. Few-Shot Fine-tuning with SetFit

Now we will fine-tune a SetFit model using the data annotated in Argilla. SetFit is efficient for few-shot learning scenarios, making it ideal for leveraging our human-validated dataset.

First, we retrieve the annotated dataset from Argilla and load the AG News test dataset for evaluation.

In [None]:
#| label: load-annotated-data-test-data
# Retrieve the dataset from Argilla
retrieved_dataset = client.datasets(name="news", workspace="argilla")

# Load the handlabelled dataset from Argilla
train_ds = retrieved_dataset.records.to_datasets()

# Load the AG News test dataset
test_ds = load_dataset("ag_news", split="test")

We convert the Argilla dataset to a Pandas DataFrame to process the labels and remove records without human annotations.

In [None]:
#| label: prepare-training-data
# Convert to pandas for easier manipulation
train_ds_df = train_ds.to_pandas()
train_ds_df.dropna(subset="label.responses", inplace=True) # Remove records without human label

Let's inspect the label responses.

In [None]:
#| label: inspect-label-responses
# Extract label responses - output folded for brevity
train_ds_df['label.responses'].map(lambda t: t[0])

We also check the features of the test dataset to understand its structure.

In [None]:
#| label: check-test-dataset-features
# Check test dataset features
test_ds.features

For faster evaluation during this tutorial, we select a small subset of the test dataset.

In [None]:
#| label: subset-test-data
# Select a subset of the test dataset for SetFit evaluation
test_df_setfit = test_ds.shuffle(seed=42).select(range(50))

To prepare the labels for training, we create a mapping from label names to numerical indices, consistent with the AG News dataset.

In [None]:
#| label: create-label-mapping
# Create a mapping between label names and indices
mapping = dict(enumerate(test_ds.features['label'].names))
mapping = {v: k for k, v in mapping.items()}

Now, we map the human-annotated labels from Argilla to numerical indices and convert the processed DataFrame back to a Dataset format, ready for SetFit training.

In [None]:
#| label: map-labels-dataset-format
# Map label responses to label indices
train_ds_df['label'] = train_ds_df['label.responses'].map(lambda t: mapping[t[0]])

# Convert back to Dataset format
train_ds_prepared = Dataset.from_pandas(train_ds_df)

We load a pre-trained SetFit model and initialize the trainer with training arguments suitable for few-shot fine-tuning.

In [None]:
#| label: load-setfit-model
# Load a SetFit model from Hugging Face Hub
model = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-mpnet-base-v2",
    labels=['World', 'Sports', 'Business', 'Sci/Tech'] # Specify the labels for the classifier
)

We configure the training arguments, initialize the SetFit Trainer, and train the model using our prepared training dataset and evaluate on the test subset.

In [None]:
#| label: train-setfit-model
# SetFit training configuration
args = TrainingArguments(
    batch_size=16,
    num_epochs=3,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to="none", # Avoid logging to experiment trackers for simplicity in this tutorial
)

# Initialize and train the SetFit model
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds_prepared,
    eval_dataset=test_df_setfit,
    metric="accuracy", # Evaluate using accuracy
)

# Train the model
trainer.train()
metrics = trainer.evaluate()
print(metrics) # Print evaluation metrics

After training, we predict labels on the test subset using the fine-tuned SetFit model.

In [None]:
#| label: predict-setfit
# Predict using the trained SetFit model
predicted_labels = model.predict(test_ds['text'])

For evaluation, we create a reverse mapping to convert numerical labels back to their names.

In [None]:
#| label: create-reverse-mapping
# Create reverse mapping for label evaluation
mapping_reverse = {v: k for k, v in mapping.items()}

Finally, we evaluate the performance of the SetFit model using a classification report, showing precision, recall, F1-score, and support for each class.

In [None]:
#| label: evaluate-setfit
# Evaluate SetFit model performance
print(classification_report([mapping_reverse[x] for x in test_ds['label']], predicted_labels))

## 5. Comparison with Logistic Regression

To provide a baseline comparison, we train a traditional Logistic Regression model on the full AG News training dataset. This helps to contextualize the performance of our SetFit model, especially considering SetFit is trained on a much smaller, human-validated dataset.

In [None]:
#| label: train-logistic-regression
# Load AG News dataset for logistic regression comparison
dataset = load_dataset("ag_news", split={'train': 'train', 'test': 'test'})

# Training and test sets
train_texts = dataset['train']['text']
train_labels = dataset['train']['label']
test_texts = dataset['test']['text']
test_labels = dataset['test']['label']

# Create and train the logistic regression model
model_lg = make_pipeline(TfidfVectorizer(stop_words='english'), LogisticRegression(max_iter=1000))
model_lg.fit(train_texts, train_labels)

# Predict and evaluate the logistic regression model
predicted_labels = model_lg.predict(test_texts)
print(classification_report(test_labels, predicted_labels))

## Conclusion

This tutorial demonstrated a complete workflow for text classification, starting from zero-shot labeling with LLMs, incorporating human feedback with Argilla, and efficiently fine-tuning a model with SetFit. We showed how to integrate modern LLM techniques with human-in-the-loop processes to create high-quality labeled datasets and achieve good classification performance even with limited annotated data. Comparing SetFit with Logistic Regression highlights the effectiveness of few-shot learning approaches, especially when high-quality, human-validated data is available.

Further steps could involve:

-   Expanding the human-annotated dataset in Argilla for potentially better SetFit performance.
-   Experimenting with different LLMs for zero-shot classification and comparing their performance.
-   Exploring different Sentence-transformers models and SetFit training configurations.
-   Using Argilla's vector search capabilities to explore the dataset semantically.

This approach provides a robust and adaptable framework for text classification tasks, especially in scenarios where labeled data is scarce or expensive to obtain.