# Zero-shot Text Classification

Your class names are likely already good descriptors of the text that you're looking to classify. With 🤗 SetFit, you can use these class names with strong pretrained Sentence Transformer models to get a strong baseline model without any training samples.

This guide will show you how to perform zero-shot text classification.

In [None]:
!pip -q install transformers datasets evaluate setfit

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/480.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━[0m [32m337.9/480.6 kB[0m [31m10.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.4/75.4 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━

## Testing dataset

We'll use the [dair-ai/emotion](https://huggingface.co/datasets/dair-ai/emotion) dataset to test the performance of our zero-shot model.

In [None]:
from datasets import load_dataset

test_dataset = load_dataset("dair-ai/emotion", "split", split="test")

This dataset stores the class names within the dataset `Features`, so we'll extract the classes like so:

In [None]:
classes = test_dataset.features["label"].names
# => ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']

Otherwise, we could manually set the list of classes.

## Synthetic dataset

Then, we can use [get_templated_dataset()](https://huggingface.co/docs/setfit/main/en/reference/utility#setfit.get_templated_dataset) to synthetically generate a dummy dataset given these class names.

In [None]:
from setfit import get_templated_dataset

train_dataset = get_templated_dataset()

In [None]:
print(train_dataset)
# => Dataset({
#     features: ['text', 'label'],
#     num_rows: 48
# })
print(train_dataset[0])
# {'text': 'This sentence is sadness', 'label': 0}

## Training

We can use this dataset to train a SetFit model just like normal:

In [None]:
from setfit import SetFitModel, Trainer, TrainingArguments

model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5")

args = TrainingArguments(
    batch_size=32,
    num_epochs=1,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)
trainer.train()

```
***** Running training *****
  Num examples = 60
  Num epochs = 1
  Total optimization steps = 60
  Total train batch size = 32
{'embedding_loss': 0.2628, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.02}                                                                                 
{'embedding_loss': 0.0222, 'learning_rate': 3.7037037037037037e-06, 'epoch': 0.83}                                                                                 
{'train_runtime': 15.4717, 'train_samples_per_second': 124.098, 'train_steps_per_second': 3.878, 'epoch': 1.0}                                                     
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:09<00:00,  6.35it/s]
```

Once trained, we can evaluate the model:

In [None]:
metrics = trainer.evaluate()
print(metrics)

```
***** Running evaluation *****
{'accuracy': 0.591}
```

And run predictions:

In [None]:
preds = model.predict([
    "i am just feeling cranky and blue",
    "i feel incredibly lucky just to be able to talk to her",
    "you're pissing me off right now",
    "i definitely have thalassophobia, don't get me near water like that",
    "i did not see that coming at all",
])
print([classes[idx] for idx in preds])

In [None]:
['sadness', 'joy', 'anger', 'fear', 'surprise']

These predictions all look right!

## Baseline

To show that the zero-shot performance of SetFit works well, we'll compare it against a zero-shot classification model from `transformers`.

In [None]:
from transformers import pipeline
from datasets import load_dataset
import evaluate

# Prepare the testing dataset
test_dataset = load_dataset("dair-ai/emotion", "split", split="test")
classes = test_dataset.features["label"].names

# Set up the zero-shot classification pipeline from transformers
# Uses 'facebook/bart-large-mnli' by default
pipe = pipeline("zero-shot-classification", device=0)
zeroshot_preds = pipe(test_dataset["text"], batch_size=16, candidate_labels=classes)
preds = [classes.index(pred["labels"][0]) for pred in zeroshot_preds]

# Compute the accuracy
metric = evaluate.load("accuracy")
transformers_accuracy = metric.compute(predictions=preds, references=test_dataset["label"])
print(transformers_accuracy)

In [None]:
{'accuracy': 0.3765}

With its 59.1% accuracy, the 0-shot SetFit heavily outperforms the recommended zero-shot model by `transformers`.

## Prediction latency

Beyond getting higher accuracies, SetFit is much faster too. Let's compute the latency of SetFit with `BAAI/bge-small-en-v1.5` versus the latency of `transformers` with `facebook/bart-large-mnli`. Both tests were performed on a GPU.

In [None]:
import time

start_t = time.time()
pipe(test_dataset["text"], batch_size=32, candidate_labels=classes)
delta_t = time.time() - start_t
print(f"`transformers` with `facebook/bart-large-mnli` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")

```
`transformers` with `facebook/bart-large-mnli` latency: 31.1765ms per sentence
```

In [None]:
import time

start_t = time.time()
model.predict(test_dataset["text"])
delta_t = time.time() - start_t
print(f"SetFit with `BAAI/bge-small-en-v1.5` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")

```
SetFit with `BAAI/bge-small-en-v1.5` latency: 0.4600ms per sentence
```

So, SetFit with `BAAI/bge-small-en-v1.5` is 67x faster than `transformers` with `facebook/bart-large-mnli`, alongside being more accurate:

![zero_shot_transformers_vs_setfit](https://github.com/huggingface/setfit/assets/37621491/33f574d9-c51b-4e02-8d98-6e04e18427ef)