#### Using Setfit to train an embedding model for classification with a tiny sample set

In [1]:
!pip install setfit==1.1.0 transformers==4.42.2 peft==0.10.0

Collecting setfit==1.1.0
  Downloading setfit-1.1.0-py3-none-any.whl.metadata (12 kB)
Collecting transformers==4.42.2
  Downloading transformers-4.42.2-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting peft==0.10.0
  Downloading peft-0.10.0-py3-none-any.whl.metadata (13 kB)
Collecting datasets>=2.15.0 (from setfit==1.1.0)
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting evaluate>=0.3.0 (from setfit==1.1.0)
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting tokenizers<0.20,>=0.19 (from transformers==4.42.2)
  Downloading tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets>=2.15.0->setfit==1.1.0)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets>=2.15.0->setfit==1.1.0)
  Downloading xxhash-3.5

#### Setfit with example dataset

In [4]:
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
from datasets import load_dataset

# Initializing a new SetFit model
model = SetFitModel.from_pretrained("Alibaba-NLP/gte-base-en-v1.5", trust_remote_code=True)

# Preparing the dataset
dataset = load_dataset("SetFit/sst2")
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
test_dataset = dataset["test"]

# Preparing the training arguments
args = TrainingArguments(
    batch_size=32,
    num_epochs=10,
)

# Preparing the trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
)
trainer.train()

# Evaluating
metrics = trainer.evaluate(test_dataset)
print(metrics)
# => {'accuracy': 0.8511806699615596}

# Performing inference
preds = model.predict([
    "It's a charming and often affecting journey.",
    "It's slow -- very, very slow.",
    "A sometimes tedious film.",
])
print(preds)
# => ["positive", "negative", "negative"]

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
Repo card metadata block was not found. Setting CardData to empty.


Map:   0%|          | 0/16 [00:00<?, ? examples/s]

***** Running training *****
  Num unique pairs = 144
  Batch size = 32
  Num epochs = 10


Step,Training Loss
1,0.2399
50,0.0509


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

***** Running evaluation *****


{'accuracy': 0.8978583196046128}
tensor([1, 0, 0])


#### Setfit with own dataset

In [2]:
import pandas as pd
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from datasets import Dataset
import torch

Load dataset

In [4]:
dataset = pd.read_csv('dataset.csv')
dataset.head()

Unnamed: 0,text,label,label_text
0,Approved Disability\nHello Everyone 👋\n\nI fin...,0,irrelevant
1,ARTICLE: Neurologists reveal 15 subtle migrain...,0,irrelevant
2,Anyone else get chills (or sweats) with their ...,1,relevant
3,What's your migraine remedy?\nI'm on my 3rd da...,1,relevant
4,LPT Request: Migraine relief tips\nSo I've bee...,1,relevant


In [5]:
train, test = train_test_split(dataset, test_size=0.33)
train.head(), test.head()

(                                                 text  label  label_text
 13  Lower back pain. What's stronger than ibuprofe...      1    relevant
 16  Anyone else get chills (or sweats) with their ...      1    relevant
 27  The other night, my wife (29F), son (13m) and ...      0  irrelevant
 29  Mom almost died from one & her dad died from o...      0  irrelevant
 25  \nI am very sorry. It’s extremely difficult an...      0  irrelevant,
                                                  text  label  label_text
 18  I had my first migraine since giving birth wit...      0  irrelevant
 10  Chronic UTI has made me ( 29 F) suicidal.\nSui...      1    relevant
 1   ARTICLE: Neurologists reveal 15 subtle migrain...      0  irrelevant
 26  Nausea has been a new symptom for me since las...      1    relevant
 15  Pain medication for herniated disc flare up?\n...      1    relevant)

In [6]:
train['label'].sum()/len(train)

0.5

In [7]:
test['label'].sum()/len(test)

0.6

Convert to dataset for training

In [8]:
train_dataset = Dataset.from_pandas(train.reset_index(drop=True))
test_dataset = Dataset.from_pandas(test.reset_index(drop=True))
train_dataset, test_dataset

(Dataset({
     features: ['text', 'label', 'label_text'],
     num_rows: 20
 }),
 Dataset({
     features: ['text', 'label', 'label_text'],
     num_rows: 10
 }))

In [9]:
torch.cuda.empty_cache()
# Initializing a new SetFit model
model = SetFitModel.from_pretrained("Alibaba-NLP/gte-base-en-v1.5", trust_remote_code=True)

# Preparing the training arguments
args = TrainingArguments(
    batch_size=10,
    num_epochs=10,
)

# Preparing the trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
)
trainer.train()

# Evaluating
metrics = trainer.evaluate(test_dataset)
print(metrics)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.35k [00:00<?, ?B/s]

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/72.3k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/54.0 [00:00<?, ?B/s]

configuration.py:   0%|          | 0.00/7.13k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/Alibaba-NLP/new-impl:
- configuration.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling.py:   0%|          | 0.00/59.0k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/Alibaba-NLP/new-impl:
- modeling.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/547M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.38k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/712k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/297 [00:00<?, ?B/s]

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


Map:   0%|          | 0/20 [00:00<?, ? examples/s]

***** Running training *****
  Num unique pairs = 220
  Batch size = 10
  Num epochs = 10
  return LooseVersion(v) >= LooseVersion(check)
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_init.py:202: PydanticDeprecatedSince20: The `copy` method is deprecated; use `model_copy` instead. See the docstring of `BaseModel.copy` for details about how to handle `include` and `exclude`. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  settings = self._wl.settings.copy()


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Step,Training Loss
1,0.3052
50,0.1441
100,0.0032
150,0.0018
200,0.0015


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

***** Running evaluation *****


Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

{'accuracy': 0.8}



Make predictions

In [11]:
preds = model.predict([
    "That’s amazing! So glad it worked out for you. I got tested years ago and it came back negative. I thought about going in a second time but my migraines come all different times of the day.",
    """Such a wholesome ending to your awful situation. I’ve never been outside of the states and am super naive when it comes to traveling but I know if I were to travel somehow far away from home I would probably have something similar like this happen to me knowing my luck haha. But it’s so cool that the dude was there to help nurse you back to health and to be a friend.

Did you guys add each other to Facebook for the occasional check in or happy birthday wish or something? You should totally send him a postcard or leave him a nice email to read letting him know you were thinking of him since retelling this story! I’m sure it would make his day 🙂‍↕️

Thanks for sharing 💚💚""",
    "I REALLY don't want to re-live it. Just know it involves ice cream induced dysentery whilst living in a shanty shack in Managua, Nicaragua. Actually almost died.",
    """If you wake up with migraines…
I wanted to share my success story for those who didn’t know this as I didn’t. I was waking up with the majority of my migraines. I would have to call out and stay home from work often because of this. I had a few doctors in the past tell me that this was abnormal, but didn’t really explain further. Well finally, a second neurologist I saw about 2 years ago immediately referred me to a sleep study. It came back that I had mild sleep apnea and low oxygen levels when I sleep. I started CPAP therapy a little over a month ago, and I have not woken up with a single migraine since then.

With that said, I just want to make sure if the majority of the migraines anyone has are upon waking, you might want to check with a doctor and see if you can get a sleep study to make sure that isn’t the cause :)"""
])
print(preds)

tensor([0, 0, 0, 1])
