# SetFit for Multilabel Text Classification

In this notebook, we'll learn how to do few-shot text classification on a multilabel dataset with SetFit.

## Setup

If you're running this Notebook on Colab or some other cloud platform, you will need to install the `setfit` library. Uncomment the following cell and run it:

In [1]:
# %pip install setfit

In [24]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Then you need to install Git-LFS, which you can do by uncommenting and running following command:

In [3]:
# !apt install git-lfs

Finally, you may need to configue Git on your system by providing details about who you are:

In [None]:
# !git config --global user.email "you@example.com"
# !git config --global user.name "Your Name"

This notebook is designed to work with any multiclass [text classification dataset](https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads) and pretrained [Sentence Transformer](https://huggingface.co/models?library=sentence-transformers&sort=downloads) on the Hub. Change the values below to try a different dataset / model!

To be able to share your model with the community, there are a few more steps to follow.

First, you have to store your authentication token from the Hugging Face Hub (sign up [here](https://huggingface.co/join) if you haven't already!). To do so, execute the following cell and input an [access token](https://huggingface.co/docs/hub/security-tokens) associated with your account:

In [1]:
from scraper.scrape_utils import wayback_scraper, get_article
import numpy as np
import pandas as pd

In [2]:
df = pd.read_csv('train_data/v1/data.csv')

In [3]:
di = {True: 1, False: 0}

In [4]:
df = df.replace(di)

In [5]:
df

Unnamed: 0,url,t_mil,t_loc,t_milcas,t_civcas,t_isis_vic,date
0,https://www.cnn.com/2019/03/23/middleeast/isis...,1,1,0,0,0,03/23/19
1,https://www.cnn.com/2017/06/25/asia/philippine...,1,1,1,0,1,08/24/17
2,https://www.cnn.com/2017/06/06/middleeast/raqq...,1,1,0,0,1,06/06/17
3,https://fox2now.com/news/isis-in-the-crosshair...,1,1,0,0,0,08/16/16
4,https://www.cnn.com/2016/12/12/middleeast/palm...,1,1,1,0,1,09/12/16
5,https://www.fox13seattle.com/news/obama-were-i...,1,1,1,1,0,11/09/14
6,https://www.nbcnews.com/storyline/isis-terror/...,1,1,0,0,0,12/21/15
7,http://fingfx.thomsonreuters.com/gfx/rngs/PHIL...,1,1,1,0,1,05/25/17
8,https://www.defense.gov/News/News-Stories/Arti...,1,1,0,0,1,03/13/17
9,https://www.cnbc.com/2014/10/25/iraq-notches-v...,1,1,1,0,0,10/25/14


In [6]:
text = []
for url in df['url'].array:
    text.append(get_article(url))

In [7]:
df['text'] = text

In [8]:
df['text'].replace('', np.nan, inplace=True)

In [9]:
df.dropna(subset=['text'], inplace=True)

In [10]:
df

Unnamed: 0,url,t_mil,t_loc,t_milcas,t_civcas,t_isis_vic,date,text
0,https://www.cnn.com/2019/03/23/middleeast/isis...,1,1,0,0,0,03/23/19,Eastern Syria CNN —\n\nISIS has lost its final...
1,https://www.cnn.com/2017/06/25/asia/philippine...,1,1,1,0,1,08/24/17,"Iligan, Philippines CNN —\n\nDuring the rainy ..."
2,https://www.cnn.com/2017/06/06/middleeast/raqq...,1,1,0,0,1,06/06/17,Story highlights US-backed forces have been pu...
3,https://fox2now.com/news/isis-in-the-crosshair...,1,1,0,0,0,08/16/16,This is an archived article and the informatio...
4,https://www.cnn.com/2016/12/12/middleeast/palm...,1,1,1,0,1,09/12/16,Story highlights ISIS in control of ancient ci...
5,https://www.fox13seattle.com/news/obama-were-i...,1,1,1,1,0,11/09/14,article\n\n\n\n\n\n(CNN) -- The decision to in...
6,https://www.nbcnews.com/storyline/isis-terror/...,1,1,0,0,0,12/21/15,A group of five men crouch behind a berm. Peri...
8,https://www.defense.gov/News/News-Stories/Arti...,1,1,0,0,1,03/13/17,Iraqi security forces are battling Islamic Sta...
9,https://www.cnbc.com/2014/10/25/iraq-notches-v...,1,1,1,0,0,10/25/14,Iraqi government forces and Shi'ite militias s...
10,https://www.cnbc.com/2015/04/07/isis-in-damasc...,1,1,0,0,1,04/07/15,The arrival of ISIS fighters in Syria's capita...


In [11]:
from datasets import load_dataset, Dataset

model_id = "sentence-transformers/paraphrase-mpnet-base-v2"
#dataset = load_dataset("ethos", "multilabel")
dataset = Dataset.from_pandas(df)

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
dataset

Dataset({
    features: ['url', 't_mil', 't_loc', 't_milcas', 't_civcas', 't_isis_vic', 'date', 'text', '__index_level_0__'],
    num_rows: 19
})

## Loading and sampling the dataset

Most datasets on the Hub have many more labeled examples than those one encounters in few-shot settings. To simulate the effect of training on a limited number of examples, let's subsample the training set to have at least 8 labeled examples per feature.

Note that if your dataset has differently formatted labels, you may need to adapt this section.

In [13]:
import numpy as np

features = dataset.column_names
features.remove("text")
features.remove("url")
features.remove("date")
features.remove("__index_level_0__")
features

['t_mil', 't_loc', 't_milcas', 't_civcas', 't_isis_vic']

In [14]:
dataset[4]

{'url': 'https://www.cnn.com/2016/12/12/middleeast/palmyra-syria-isis-russia/index.html',
 't_mil': 1,
 't_loc': 1,
 't_milcas': 1,
 't_civcas': 0,
 't_isis_vic': 1,
 'date': '09/12/16',
 'text': 'Story highlights ISIS in control of ancient city, state media reports Militant group in 2015 blew up ancient treasures there\n\nCNN —\n\nISIS fighters were in fierce clashes with Syrian regime troops Sunday in the ancient city of Palmyra, where the militant group infamously blew up temples and monuments last year, a monitor said.\n\nSyrian news agency SANA reported that over 4,000 militants swarmed the city from “various directions,” despite having suffered heavy losses from bombardments by the Syrian air force. The Russian Defense Ministry had earlier reported that its aircraft had also taken part in the air campaign.\n\nThe UK-based Syrian Observatory for Human Rights (SOHR) confirmed that Palmyra had fallen to ISIS on Sunday after Syrian armed forces pulled out from the desert city, the or

In [15]:
#num_samples = 8
#samples = np.concatenate(
#    [np.random.choice(np.where(dataset["train"][f])[0], num_samples) for f in features]
#)

We encode the emotions in a single `'label'` feature. 

In [16]:
def encode_labels(record):
    return {"labels": [record[feature] for feature in features]}


dataset = dataset.map(encode_labels)

Map: 100%|███████| 19/19 [00:00<00:00, 2265.00 examples/s]


Next, we use the samples we selected as our training set, and the others as our test set (since the ethos dataset does not have a test split on the hub).

Here we have 64 total examples to train with since the `ethos` dataset has 8 classes.

In [17]:
train_dataset = dataset.select(range(14))
eval_dataset = dataset.select(range(14,19))

Okay, now we have the dataset, let's load and train a model!

## Fine-tuning the model

To train a SetFit model, the first thing to do is download a pretrained checkpoint from the Hub. We can do so by using the `from_pretrained()` method associated with the `SetFitModel` class.

**Note that the `multi_target_strategy` parameter here signals to both the model and the trainer to expect a multi-labelled dataset.**

In [18]:
from setfit import SetFitModel

model = SetFitModel.from_pretrained(model_id, multi_target_strategy="one-vs-rest")

2023-10-18 09:08:49.221990: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


Here, we've downloaded a pretrained Sentence Transformer from the Hub and added a logistic classification head to the create the SetFit model. As indicated in the message, we need to train this model on some labeled examples. We can do so by using the `SetFitTrainer` class as follows:

In [19]:
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitTrainer

trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    batch_size=1,
    loss_class=CosineSimilarityLoss,
    num_iterations=20,
    column_mapping={"text": "text", "labels": "label"},
)

The main arguments to notice in the trainer is the following:

* `loss_class`: The loss function to use for contrastive learning with the Sentence Transformer body
* `num_iterations`: The number of text pairs to generate for contrastive learning
* `column_mapping`: The `SetFitTrainer` expects the inputs to be found in a `text` and `label` column. This mapping automatically formats the training and evaluation datasets for us.

Now that we've created a trainer, we can train it!

In [20]:
trainer.train()

Applying column mapping to training dataset
Generating Training Pairs: 100%|█| 20/20 [00:00<00:00, 354
***** Running training *****
  Num examples = 1100
  Num epochs = 1
  Total optimization steps = 1100
  Total train batch size = 1
Epoch:   0%|                        | 0/1 [00:00<?, ?it/s]
Iteration:   0%|                 | 0/1100 [00:00<?, ?it/s][A
Iteration:   0%|         | 1/1100 [00:01<28:18,  1.55s/it][A
Iteration:   0%|         | 2/1100 [00:02<25:30,  1.39s/it][A
Iteration:   0%|         | 3/1100 [00:04<27:33,  1.51s/it][A
Iteration:   0%|         | 4/1100 [00:05<26:29,  1.45s/it][A
Iteration:   0%|         | 5/1100 [00:07<28:41,  1.57s/it][A
Iteration:   1%|         | 6/1100 [00:09<27:44,  1.52s/it][A
Iteration:   1%|         | 7/1100 [00:11<28:53,  1.59s/it][A
Iteration:   1%|         | 8/1100 [00:12<28:06,  1.54s/it][A
Iteration:   1%|         | 9/1100 [00:14<28:39,  1.58s/it][A
Iteration:   1%|        | 10/1100 [00:15<27:53,  1.54s/it][A
Iteration:   1%|        |

The final step is to compute the model's performance using the `evaluate()` method. The default metric measures 'subset accuracy', which measures the fraction of samples where we predict all 8 labels correctly.

In [21]:
metrics = trainer.evaluate()
metrics

Applying column mapping to evaluation dataset
***** Running evaluation *****


{'accuracy': 0.4}

And once the model is trained, you can push it to the Hub:

In [22]:
model._save_pretrained("models/CONTACT_setfit_v1")

In [29]:
trainer.model.push_to_hub("CONTACT_setfit_v1")

model_head.pkl:   0%|         | 0.00/33.4k [00:00<?, ?B/s]

Upload 2 LFS files:   0%|           | 0/2 [00:00<?, ?it/s][A[A
pytorch_model.bin:   0%|       | 0.00/438M [00:00<?, ?B/s][A
model_head.pkl:  25%|▏| 8.19k/33.4k [00:00<00:00, 29.1kB/s[A
model_head.pkl: 100%|█| 33.4k/33.4k [00:00<00:00, 70.1kB/s[A

pytorch_model.bin:   0%| | 696k/438M [00:00<03:46, 1.93MB/[A
pytorch_model.bin:   0%| | 1.03M/438M [00:00<05:44, 1.27MB[A

Upload 2 LFS files:  50%|█▌ | 1/2 [00:01<00:01,  1.10s/it][A[A
pytorch_model.bin:   0%| | 1.54M/438M [00:01<05:11, 1.40MB[A
pytorch_model.bin:   0%| | 1.93M/438M [00:01<05:54, 1.23MB[A
pytorch_model.bin:   1%| | 2.31M/438M [00:01<06:20, 1.14MB[A
pytorch_model.bin:   1%| | 2.70M/438M [00:02<06:37, 1.09MB[A
pytorch_model.bin:   1%| | 3.08M/438M [00:02<06:45, 1.07MB[A
pytorch_model.bin:   1%| | 3.47M/438M [00:03<06:51, 1.06MB[A
pytorch_model.bin:   1%| | 3.86M/438M [00:03<06:58, 1.04MB[A
pytorch_model.bin:   1%| | 4.24M/438M [00:03<07:30, 962kB/[A
py

'https://huggingface.co/PaulKMandal/CONTACT_setfit_v1/tree/main/'

You can now share this model with all your friends, family, favorite pets: they can all load it with the identifier `your-username/the-name-you-picked` so for instance:

In [30]:
from setfit import SetFitModel

model = SetFitModel.from_pretrained("PaulKMandal/CONTACT_setfit_v1")

Downloading (…)lve/main/config.json: 100%|█| 667/667 [00:0
Downloading (…)a2f3b/.gitattributes: 100%|█| 1.52k/1.52k [
Downloading (…)_Pooling/config.json: 100%|█| 190/190 [00:0
Downloading (…)dd602a2f3b/README.md: 100%|█| 1.55k/1.55k [
Downloading (…)602a2f3b/config.json: 100%|█| 667/667 [00:0
Downloading (…)ce_transformers.json: 100%|█| 122/122 [00:0
Downloading model_head.pkl: 100%|█| 33.4k/33.4k [00:00<00:
Downloading pytorch_model.bin: 100%|█| 438M/438M [00:23<00
Downloading (…)nce_bert_config.json: 100%|█| 53.0/53.0 [00
Downloading (…)cial_tokens_map.json: 100%|█| 280/280 [00:0
Downloading (…)a2f3b/tokenizer.json: 100%|█| 712k/712k [00
Downloading (…)okenizer_config.json: 100%|█| 1.40k/1.40k [
Downloading (…)dd602a2f3b/vocab.txt: 100%|█| 232k/232k [00
Downloading (…)02a2f3b/modules.json: 100%|█| 229/229 [00:0
Downloading model_head.pkl: 100%|█| 33.4k/33.4k [00:00<00:


Run inference. As is usual in toxicity models, it tends to think any mention of topics such as race or gender are negative.

In [34]:
preds = model(eval_dataset['text'])
preds

tensor([[1, 1, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 0, 0, 0]])

In [35]:
# Show predicted labels, requires you to have stored the 'features' somewhere
[[f for f, p in zip(features, ps) if p] for ps in preds]

[['t_mil', 't_loc'],
 ['t_mil', 't_loc'],
 ['t_mil', 't_loc'],
 ['t_mil', 't_loc'],
 ['t_mil', 't_loc']]