# Zero-shot with LLM, SetFit and Argilla

This tutorial covers text classification using zero-shot LLM labeling through the Together API. We then use traditional machine learning techniques with TF-IDF vectorization and logistic regression - for comparaison. The idea here is to have the LLM prelabel our dataset. We then use Argilla to check if the model provided good labels. There are variations to that approach - including Bart Models (multi-label-zero-shot).

**Using OpenAI-style API via the Together API:**

1. **Model Configuration**: A system prompt is defined, outlining the task of categorizing news articles into predefined categories: "World", "Sports", "Business", and "Sci/Tech". The output is required to adhere strictly to the JSON format.

2. **Text Classification**: The provided text is classified into one of the predefined categories using an open source LLM model.

3. **Evaluation**: The classified data is evaluated by predicting categories for a sample of news articles and comparing the predictions with ground truth annotations.

**Traditional Machine Learning Approach:**

We compare the results from the 0-shot/few-shot approach with a traditional ML/NLP approach, where we will be using a very large training dataset.

**Comparison with Traditional Machine Learning Approach:**

Both approaches aim to classify news articles into predefined categories, but they differ in their underlying methodologies. The LLM based approach leverages a state-of-the-art language model for text classification, while the traditional machine learning approach relies on TF-IDF vectorization and logistic regression (but requires much more data for training).

In [3]:
!pip install openai -q

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/226.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m225.3/226.7 kB[0m [31m7.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.7/226.7 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/75.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.6/75.6 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.8/77.8 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.3/58.3 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import json

In [None]:
from openai import OpenAI

In [1]:
from google.colab import userdata
TOGETHER_API_KEY = userdata.get('TOGETHER_API_KEY')

In [124]:
system_prompt = """

You are a sophisticated classification engine tasked with categorizing news articles.
Your primary function is to evaluate the core message of each article and assign it to one of the following categories: "World" for global news covering politics and similar topics,
"Sports" for news related to sports, "Business" for articles on business, economics, or finance,
and "Sci/Tech" for content focused on technology and science. Upon analyzing a text input, you will provide an explanation for the category chosen.
 Your output will adhere strictly to the JSON format, specifically:
 {"prediction":"your selected prediction", "explanation":"your explanation"}.
 It is imperative that your output is VALID JSON and contains no other elements. Output it as string not markdown or code.

"""

In [125]:
text = """
Stocks Rally on Lower Oil Prices Stocks rallied in quiet trading Wednesday
as lower oil prices brought out buyers, countering a pair of government reports
that gave a mixed picture of the economy.
"""

In [126]:
# Point to the local server
client = OpenAI(base_url="https://api.together.xyz/v1", api_key=TOGETHER_API_KEY)

completion = client.chat.completions.create(
  model="NousResearch/Nous-Hermes-2-Mistral-7B-DPO", # this field is currently unused
  messages=[
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": f'Classify following text: {text}'}
  ],
  temperature=0.2,
)

print(completion.choices[0].message)

ChatCompletionMessage(content='{"prediction": "Business", "explanation": "The text discusses stocks, oil prices, and the economy, which are all related to business and finance."}', role='assistant', function_call=None, tool_calls=None)


In [127]:
json.loads(completion.choices[0].message.content.strip())

{'prediction': 'Business',
 'explanation': 'The text discusses stocks, oil prices, and the economy, which are all related to business and finance.'}

In [128]:
def classify(text):
  completion = client.chat.completions.create(
  model="NousResearch/Nous-Hermes-2-Mistral-7B-DPO", # this field is currently unused
  messages=[
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": f'Classify following text: {text}'}
  ],
  temperature=0.2,
)
  json_response = completion.choices[0].message.content.strip()
  try:
        prediction = json.loads(json_response)
  except:
        # for some examples, json is not correctly formatted
        return {"prediction": None, "explanation": f"Wrong JSON format: {json_response}" }
  return prediction

In [74]:
classify(text)

{'prediction': 'business',
 'explanation': 'The text discusses stock market performance, oil prices, and government reports on the economy, which are all related to business, economics, and finance.'}

In [62]:
!pip install argilla setfit datasets -qqq

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/417.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━[0m [32m286.7/417.2 kB[0m [31m8.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m417.2/417.2 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/75.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.9/75.9 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m536.7/536.7 kB[0m [31m34.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.9/75.9 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m74.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━

In [122]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.metrics import classification_report

In [78]:
import pandas as pd
from datasets import load_dataset, Dataset
from sentence_transformers import SentenceTransformer
import argilla as rg

# Load the data
data_train = pd.read_parquet('https://github.com/SDS-AAU/SDS-master/raw/master/M2/data/ag_news_unlabelled.pq')

# Convert to Hugging Face dataset
dataset_news = Dataset.from_pandas(data_train.sample(30).reset_index(drop=True))

In [None]:
# let's predict over the test set to eval our zero-shot classifier
train_ds_with_preds = dataset_news.map(lambda example: classify(example["text"]))

pd.set_option('display.max_colwidth', None)
train_ds_with_preds.to_pandas().head(15)

In [135]:
# Define and apply the encoder
encoder = SentenceTransformer("intfloat/multilingual-e5-base", device='cuda')
encoded_dataset = train_ds_with_preds.map(lambda batch: {"vectors": encoder.encode(batch["text"], convert_to_tensor=True)}, batched=True)


Map:   0%|          | 0/30 [00:00<?, ? examples/s]

In [136]:
# Turn vectors into a dictionary
encoded_dataset = encoded_dataset.map(
    lambda r: {"vectors": {"multilingual-e5-base": r["vectors"]}}
)

Map:   0%|          | 0/30 [00:00<?, ? examples/s]

In [96]:
# Initialize Argilla
rg.init(api_url="https://rjuro-unistra.hf.space", api_key="owner.apikey", workspace="admin")


In [137]:
encoded_dataset.to_pandas().prediction.unique()

array(['Sports', 'World', 'Sci/Tech', 'Business'], dtype=object)

In [None]:
encoded_dataset[0]

In [138]:
records = []

#labels = ['world', 'sports', 'business', 'tech/sci']

for example in encoded_dataset:
    # create a record with ground-truth annotations and gpt-3 predictions
    record = rg.TextClassificationRecord(
        inputs={"text": example["text"], "explanation": example["explanation"]},
        #annotation=labels[example["label"]],
        prediction=[(example["prediction"].lower(), 1.0)],
        vectors= example["vectors"]
    )
    records.append(record)

# create a dataset in Argilla
rg.log(records, "news-llm-embeddings")

Output()

BulkResponse(dataset='news-llm-embeddings', processed=30, failed=0)

In [154]:
from setfit import SetFitModel, SetFitTrainer
from sentence_transformers import SentenceTransformer

# Initialize SetFitModel
model = SetFitModel.from_pretrained("intfloat/multilingual-e5-base").to('cuda')


model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


In [139]:
# Load the handlabelled dataset from Argilla
train_ds = rg.load("news-llm-embeddings").prepare_for_training()
test_ds = load_dataset("ag_news", split="test")



In [140]:
train_ds.features

{'id': Value(dtype='string', id=None),
 'text': Value(dtype='string', id=None),
 'label': ClassLabel(names=['business', 'sci/tech', 'sports', 'world'], id=None)}

In [141]:
test_ds.features

{'text': Value(dtype='string', id=None),
 'label': ClassLabel(names=['World', 'Sports', 'Business', 'Sci/Tech'], id=None)}

In [142]:
# Example label mappings for demonstration (replace these with your actual mappings)
label_mapping_a_to_b = {0: 2, 1: 3, 2:1, 3:0}

# Function to apply label mapping
def apply_label_mapping(example, label_mapping):
    example['label'] = label_mapping[example['label']]
    return example

# Apply the mapping to align dataset_b labels with dataset_a
train_ds = train_ds.map(lambda x: apply_label_mapping(x, label_mapping_a_to_b))

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

In [155]:
# Create SetFitTrainer and train
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    batch_size=16,
    num_iterations=5,
    num_epochs = 3 # Adjust as needed
)

trainer.train()
metrics = trainer.evaluate()
print(metrics)

  trainer = SetFitTrainer(


Map:   0%|          | 0/20 [00:00<?, ? examples/s]

***** Running training *****
  Num unique pairs = 200
  Batch size = 16
  Num epochs = 3
  Total optimization steps = 39


Step,Training Loss


***** Running evaluation *****


{'accuracy': 0.854078947368421}


In [156]:
# Predict and evaluate
predicted_labels = model.predict(test_ds['text'])

In [157]:
print(classification_report(test_ds['label'], predicted_labels))

              precision    recall  f1-score   support

           0       0.91      0.83      0.87      1900
           1       0.96      0.93      0.95      1900
           2       0.79      0.83      0.81      1900
           3       0.77      0.82      0.79      1900

    accuracy                           0.85      7600
   macro avg       0.86      0.85      0.86      7600
weighted avg       0.86      0.85      0.86      7600



In [146]:
# Load AG News dataset for logistic regression
dataset = load_dataset("ag_news", split={'train': 'train', 'test': 'test'})

# Training and test sets
train_texts = dataset['train']['text']
train_labels = dataset['train']['label']
test_texts = dataset['test']['text']
test_labels = dataset['test']['label']

# Create and train the logistic regression model
model_lg = make_pipeline(TfidfVectorizer(stop_words='english'), LogisticRegression(max_iter=1000))
model_lg.fit(train_texts, train_labels)

# Predict and evaluate
predicted_labels = model_lg.predict(test_texts)
print(classification_report(test_labels, predicted_labels))

              precision    recall  f1-score   support

           0       0.93      0.91      0.92      1900
           1       0.96      0.98      0.97      1900
           2       0.89      0.88      0.88      1900
           3       0.89      0.90      0.89      1900

    accuracy                           0.92      7600
   macro avg       0.92      0.92      0.92      7600
weighted avg       0.92      0.92      0.92      7600

