# Overview

The FastFit algorithm uses a pre-trained ST as the base model. It fine-tunes the base model by using in-batch contrastive loss to embed both the texts and their class names into a shared embeddding space such that texts and their respective class names have a low distance.

In [1]:
!pip install -q -U transformers==4.39.3
!pip install -q -U datasets==2.18.0
!pip install -q -U fast-fit==1.2.1

In [2]:
import os
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
login(token=user_secrets.get_secret("HUGGINGFACE_TOKEN"))

os.environ["WANDB_API_KEY"]=user_secrets.get_secret("WANDB_API_KEY")
os.environ["MODEL"]="sentence-transformers/paraphrase-mpnet-base-v2"
os.environ["DATASET"]="SetFit/amazon_massive_intent_en-US"
os.environ["FITMODEL"]="fastfit-mpnet-v2-amazon-mi"

Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


# Loading dataset

In [3]:
from datasets import Dataset, load_dataset

ds=load_dataset(os.getenv("DATASET"))

Downloading data: 100%|██████████| 1.14M/1.14M [00:00<00:00, 2.56MB/s]
Downloading data: 100%|██████████| 201k/201k [00:00<00:00, 748kB/s]
Downloading data: 100%|██████████| 294k/294k [00:00<00:00, 2.75MB/s]


Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [4]:
import pandas as pd

df=pd.DataFrame(ds["test"])

# Helper function to select random rows
def select_random_rows(group):
    return group.sample(n=10, random_state=42)


# find top classes with minimum 30 rows
label_counts=df["label_text"].value_counts()
label_counts=label_counts[label_counts>30]

# restruct df to top classes only
df=df[df["label_text"].isin(label_counts.index)].reset_index(drop=True)

assert set(df["label_text"].value_counts().index.to_list())==set(label_counts.index.to_list()), "Some labels were lost"

# select random row per unique value in label_text column
train_df=df.groupby("label_text", group_keys=False).apply(select_random_rows)

# create eval dataframe by dropping the train data and selecting random rows
eval_df=(df.drop(train_df.index).groupby("label_text",group_keys=False).apply(select_random_rows))

# create test dataframe by dropping both train and eval data
test_df=df.drop(train_df.index.to_list()+eval_df.index.to_list())

# reset the index
cols_to_keep=["text", "label_text"]
train_df=train_df[cols_to_keep].reset_index(drop=True)
eval_df=eval_df[cols_to_keep].reset_index(drop=True)
test_df=test_df[cols_to_keep].reset_index(drop=True)

# save the file
test_df.to_pickle("test_df.pkl")
train_df.to_pickle("train_df.pkl")
eval_df.to_pickle("eval_df.pkl")

train_ds=Dataset.from_pandas(train_df)
eval_ds=Dataset.from_pandas(eval_df)
test_ds=Dataset.from_pandas(test_df)

print(train_df.shape, eval_df.shape, test_df.shape)
train_df.head()

(350, 2) (350, 2) (1879, 2)


  train_df=df.groupby("label_text", group_keys=False).apply(select_random_rows)
  eval_df=(df.drop(train_df.index).groupby("label_text",group_keys=False).apply(select_random_rows))


Unnamed: 0,text,label_text
0,do i have any alarms set for six am tomorrow,alarm_query
1,what is the wake up time for my alarm i have s...,alarm_query
2,please tell me what alarms are on,alarm_query
3,please list all my alarms,alarm_query
4,what times do my alarms go off,alarm_query


# Training

In [5]:
from fastfit import FastFit, FastFitTrainer

# Load the base ST model and setup the trainer
trainer = FastFitTrainer(
    model_name_or_path=os.getenv("MODEL"),
    label_column_name="label_text",
    text_column_name="text",
    num_train_epochs=1,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    max_text_length=128,
    dataloader_drop_last=False,
    num_repeats=1,
    optim="adafactor",
    clf_loss_factor=0.1,
    fp16=True,
    train_dataset=train_ds,
    validation_dataset=eval_ds,
    test_dataset=test_ds,
)

model=trainer.train()

2024-05-25 05:21:41.900701: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-25 05:21:41.900797: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-25 05:21:42.041211: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


tokenizer_config.json:   0%|          | 0.00/1.19k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/594 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

Running tokenizer on dataset to infer max length for both query and document:   0%|          | 0/350 [00:00<?,…

Running tokenizer on dataset to infer max length for both query and document:   0%|          | 0/350 [00:00<?,…

Running tokenizer on dataset to infer max length for both query and document:   0%|          | 0/1879 [00:00<?…

Running tokenizer on dataset:   0%|          | 0/350 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/350 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/1879 [00:00<?, ? examples/s]

  metric = load_metric(self.data_args.metric_name, experiment_id=uuid.uuid4())
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
[34m[1mwandb[0m: Currently logged in as: [33murakiny[0m ([33mcausal_language_trainer[0m). Use [1m`wandb login --relogin`[0m to force relogin




Step,Training Loss


***** train metrics *****
  epoch                    =        1.0
  total_flos               =        0GF
  train_loss               =     2.7326
  train_runtime            = 0:00:21.80
  train_samples            =        350
  train_samples_per_second =     16.055
  train_steps_per_second   =      0.505


In [6]:
# Calculate metrics on test dataset
eval_metrics = trainer.evaluate()
print("Eval Accuracy: {:.2f}".format(eval_metrics["eval_accuracy"] * 100))

test_metrics = trainer.test()
print("Test Accuracy: {:.2f}".format(test_metrics["eval_accuracy"] * 100))

***** eval metrics *****
  epoch                   =        1.0
  eval_accuracy           =     0.7343
  eval_loss               =     4.5141
  eval_runtime            = 0:00:00.73
  eval_samples            =        350
  eval_samples_per_second =    475.297
  eval_steps_per_second   =      8.148
Eval Accuracy: 73.43
***** test metrics *****
  epoch                   =        1.0
  eval_accuracy           =     0.6525
  eval_loss               =     5.2223
  eval_runtime            = 0:00:03.79
  eval_samples_per_second =    495.108
  eval_steps_per_second   =      7.905
  test_samples            =       1879
Test Accuracy: 65.25


In [7]:
# trainer.push_to_hub("aisuko/"+os.getenv("FITMODEL"))

TypeError: FastFitTrainer.push_to_hub() takes 1 positional argument but 2 were given

In [8]:
model.save_pretrained("aisuko/"+os.getenv("FITMODEL"))

# Inference

In [9]:
from tqdm import tqdm
from sklearn.metrics import classification_report
from transformers import AutoTokenizer, pipeline

# Step 1: Load a pre-trained model from disk
model = FastFit.from_pretrained('aisuko/'+os.getenv("FITMODEL"))
tokenizer = AutoTokenizer.from_pretrained(os.getenv("MODEL"))
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, device="cuda")

# Step 2: Run predictions to calculate class level metrics
predictions = []
for row in tqdm(test_ds):
    predictions.append(classifier(row["text"])[0]["label"])

test_df["fastfit_predictions"] = predictions
print("FastFit Class Level Metrics:")
print(classification_report(test_df["label_text"], test_df["fastfit_predictions"]))

[ERROR|base.py:1052] 2024-05-25 05:23:03,252 >> The model 'FastFit' is not supported for text-classification. Supported models are ['AlbertForSequenceClassification', 'BartForSequenceClassification', 'BertForSequenceClassification', 'BigBirdForSequenceClassification', 'BigBirdPegasusForSequenceClassification', 'BioGptForSequenceClassification', 'BloomForSequenceClassification', 'CamembertForSequenceClassification', 'CanineForSequenceClassification', 'LlamaForSequenceClassification', 'ConvBertForSequenceClassification', 'CTRLForSequenceClassification', 'Data2VecTextForSequenceClassification', 'DebertaForSequenceClassification', 'DebertaV2ForSequenceClassification', 'DistilBertForSequenceClassification', 'ElectraForSequenceClassification', 'ErnieForSequenceClassification', 'ErnieMForSequenceClassification', 'EsmForSequenceClassification', 'FalconForSequenceClassification', 'FlaubertForSequenceClassification', 'FNetForSequenceClassification', 'FunnelForSequenceClassification', 'GemmaForSe

FastFit Class Level Metrics:
                          precision    recall  f1-score   support

             alarm_query       0.64      0.50      0.56        14
               alarm_set       0.58      0.86      0.69        21
       audio_volume_mute       0.85      0.92      0.88        12
          calendar_query       0.45      0.57      0.50       106
         calendar_remove       0.69      0.96      0.80        47
            calendar_set       0.80      0.43      0.56       189
          cooking_recipe       0.80      0.92      0.86        52
          datetime_query       0.81      0.71      0.76        68
             email_query       0.66      0.92      0.77        99
         email_sendemail       0.71      0.66      0.69        94
          general_quirky       0.32      0.10      0.15       149
              iot_coffee       0.83      0.94      0.88        16
     iot_hue_lightchange       0.80      0.25      0.38        16
        iot_hue_lightoff       0.67      0.96 




# Acknowledge

* https://medium.com/towards-artificial-intelligence/few-shot-nlp-intent-classification-d29bf85548aa
* https://github.com/IBM/fastfit