# 03-01 : Zero-Shot Classification 

Experiment with the `get_templated_dataset` function from the `setfit` module and test the zero-shot classification capabilities of the `setfit` model. 

## References

- [SetFit: Zero-shot Text Classification](https://huggingface.co/docs/setfit/en/how_to/zero_shot)
- [SetFit: get_templated_dataset](https://huggingface.co/docs/setfit/v1.0.3/en/reference/utility#setfit.get_templated_dataset)
- [Suggestions for Data Annotation with SetFit in Zero-shot Text Classification](https://github.com/huggingface/cookbook/blob/c2e1869b9608a7fa52278be5a587bcbf530383b5/notebooks/en/labelling_feedback_setfit.ipynb)

In [1]:
import pandas as pd
import numpy as np

from sklearn.preprocessing import MultiLabelBinarizer
from setfit import get_templated_dataset
from setfit import SetFitModel, Trainer, TrainingArguments

2024-05-20 19:01:05.921408: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-20 19:01:05.921434: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-20 19:01:05.922326: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-20 19:01:05.927067: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
data_path = '../../data'
input_path = f'{data_path}/input/labelled_tweets/csv_labels'
train_input_file = f'{input_path}/train.csv'
test_input_file = f'{input_path}/test.csv'
val_input_file = f'{input_path}/val.csv'

## 1. Load Data

In [3]:
df_train = pd.read_csv(train_input_file)
df_val = pd.read_csv(val_input_file)
df_test = pd.read_csv(test_input_file)

## 2. Preprocessing

### 2.1. Labels to List

In [4]:
df_train['labels_list'] = df_train['labels'].str.split(' ')
df_test['labels_list'] = df_test['labels'].str.split(' ')
df_val['labels_list'] = df_val['labels'].str.split(' ')

### 2.2. Multi-label Binarization

In [5]:
# get the list of label values
labels = pd.concat([df_train.labels_list, 
                    df_val.labels_list, 
                    df_test.labels_list])

# initialize MultiLabelBinarizer
labels_lookup = MultiLabelBinarizer()

# learn the vocabulary
labels_lookup = labels_lookup.fit(labels)

# show the vocabulary
vocab = labels_lookup.classes_
print(f'Vocabulary size: {len(vocab)}')
print(f'Vocabulary: {vocab}')


Vocabulary size: 12
Vocabulary: ['conspiracy' 'country' 'ineffective' 'ingredients' 'mandatory' 'none'
 'pharma' 'political' 'religious' 'rushed' 'side-effect' 'unnecessary']


In [6]:
# update the data frame with a `labels_encoded` column
df_train['labels_encoded'] = labels_lookup.transform(df_train.labels_list).tolist()
df_val['labels_encoded'] = labels_lookup.transform(df_val.labels_list).tolist()
df_test['labels_encoded'] = labels_lookup.transform(df_test.labels_list).tolist()

In [7]:
# add the one-hot encoded labels as columns to the data frames
df_train = df_train.join(pd.DataFrame(labels_lookup.transform(df_train.labels_list), 
                                     columns=labels_lookup.classes_, 
                                     index=df_train.index))

df_val = df_val.join(pd.DataFrame(labels_lookup.transform(df_val.labels_list),
                                    columns=labels_lookup.classes_,
                                    index=df_val.index))

df_test = df_test.join(pd.DataFrame(labels_lookup.transform(df_test.labels_list),
                                    columns=labels_lookup.classes_,
                                    index=df_test.index))

## 3. Test get_templated_dataset

In [8]:
candidate_labels = vocab.tolist()
template='This vaccine concern is about {}'

In [9]:
test_dataset = get_templated_dataset(
    candidate_labels=candidate_labels,
    sample_size=8,
    template=template,
    multi_label=True)

test_dataset.to_dict()

{'text': ['This vaccine concern is about conspiracy',
  'This vaccine concern is about conspiracy',
  'This vaccine concern is about conspiracy',
  'This vaccine concern is about conspiracy',
  'This vaccine concern is about conspiracy',
  'This vaccine concern is about conspiracy',
  'This vaccine concern is about conspiracy',
  'This vaccine concern is about conspiracy',
  'This vaccine concern is about country',
  'This vaccine concern is about country',
  'This vaccine concern is about country',
  'This vaccine concern is about country',
  'This vaccine concern is about country',
  'This vaccine concern is about country',
  'This vaccine concern is about country',
  'This vaccine concern is about country',
  'This vaccine concern is about ineffective',
  'This vaccine concern is about ineffective',
  'This vaccine concern is about ineffective',
  'This vaccine concern is about ineffective',
  'This vaccine concern is about ineffective',
  'This vaccine concern is about ineffective'

## 4. Model Training

In [11]:
#model_name = 'BAAI/bge-small-en-v1.5'
model_name = 'sentence-transformers/paraphrase-mpnet-base-v2'

In [12]:
def train_model(candidate_labels, template, multi_label=False, model_name="all-MiniLM-L6-v2"):
    # build a training dataset to train the zero-shot classifier
    train_dataset = get_templated_dataset(
        candidate_labels=candidate_labels,
        sample_size=8,
        template=template,
        multi_label=multi_label
    )

    # train a model using the training dataset we just built
    if multi_label:
        model = SetFitModel.from_pretrained(
            model_name,
            multi_target_strategy="one-vs-rest"
        )
    else:
        model = SetFitModel.from_pretrained(
            model_name
        )

    trainer = Trainer(
        model=model,
        train_dataset=train_dataset
    )
    trainer.train()
    return model

### 4.1. Train a multi-label classifier

In [13]:
# get the classification model
model = train_model(
    candidate_labels=candidate_labels, 
    template=template, 
    multi_label=True,
    model_name=model_name)

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
  if not hasattr(tensorboard, "__version__") or LooseVersion(
  ) < LooseVersion("1.15"):


Map:   0%|          | 0/96 [00:00<?, ? examples/s]

***** Running training *****
  Num unique pairs = 8448
  Batch size = 16
  Num epochs = 1
  Total optimization steps = 528


Step,Training Loss


## 5. Make Predictions

In [14]:
data = df_test[:5]

In [15]:
def get_predictions(texts, model, labels):
    probas = model.predict_proba(texts, as_numpy=True)
    return probas.tolist()
    # for pred in probas:
    #     yield [{"label": label, "score": score} for label, score in zip(labels, pred)]

In [16]:
y_pred = list(get_predictions(data.text.values, model, candidate_labels))
y_pred[:5]

[[0.058900424129974226,
  0.006439022990380861,
  0.005081788219557832,
  0.012280360696443077,
  0.007821252784898222,
  0.006176522814892339,
  0.08039143365121408,
  0.01338885843353766,
  0.005539749772679079,
  0.006989365918959436,
  0.009635772295716006,
  0.010520266009283172],
 [0.024651297876294116,
  0.009563608843233254,
  0.044754771974295886,
  0.006569937650931341,
  0.01788917737873444,
  0.007851320421729029,
  0.0062969367190919626,
  0.009027307022265549,
  0.008291270815444028,
  0.008526647686099297,
  0.007935537992653821,
  0.016064096455197113],
 [0.004783939401488553,
  0.007850681968734891,
  0.04473901070064493,
  0.021152890471366106,
  0.00610784862872318,
  0.007830827931005004,
  0.013756584165896271,
  0.005118114310428941,
  0.010434421524870487,
  0.009960489671893582,
  0.056581654897407845,
  0.006821455108851918],
 [0.006787426250942632,
  0.007546201426519945,
  0.014644516460107167,
  0.005360994440365425,
  0.01020591377211537,
  0.00502485220637