In [5]:
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer
import pandas as pd
import datasets
import pyarrow as pa

# train pairs 4
train_dataframe = pd.read_csv("../data_ready/few_shot/train_pairs_random_4.csv").sample(frac=1)
eval_dataframe = pd.read_csv("../data_ready/unused_pairs_for_test_data.csv").sample(frac=1)

train_dataset = datasets.Dataset(pa.Table.from_pandas(train_dataframe))
test_dataset = datasets.Dataset(pa.Table.from_pandas(eval_dataframe))

# Load a SetFit model from Hub
model = SetFitModel.from_pretrained("facebook/bart-large-mnli")

# Create trainer
trainer = SetFitTrainer(
  model=model,
  train_dataset=train_dataset,
  eval_dataset=test_dataset,
  loss_class=CosineSimilarityLoss,
  metric="accuracy",
  batch_size=16,
  num_iterations=20, # The number of text pairs to generate for contrastive learning
  num_epochs=1, # The number of epochs to use for contrastive learning
  column_mapping={"text": "text", "class": "label"} # Map dataset columns to text/label expected by trainer
)

# Train and evaluate
trainer.train()
metrics = trainer.evaluate()

################ hugging face hub?? ##################
# Push model to the Hub
# trainer.push_to_hub("my-awesome-setfit-model")

# Download from Hub and run inference
# model = SetFitModel.from_pretrained("lewtun/my-awesome-setfit-model")
###############################################

# Run inference
# preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst ðŸ¤®"])

Downloading (â€¦)995f2/.gitattributes: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 445/445 [00:00<00:00, 51.4kB/s]
Downloading (â€¦)3fb03995f2/README.md: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 3.79k/3.79k [00:00<00:00, 216kB/s]
Downloading (â€¦)b03995f2/config.json: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1.15k/1.15k [00:00<00:00, 179kB/s]
Downloading (â€¦)fb03995f2/merges.txt: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 456k/456k [00:00<00:00, 1.09MB/s]
Downloading (â€¦)"model.safetensors";: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1.63G/1.63G [02:23<00:00, 11.4MB/s]
Downloading (â€¦)"pytorch_model.bin";: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1.63G/1.63G [02:22<00:00, 11.4MB/s]
Downloading (â€¦)995f2/tokenizer.json: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1.36M/1.36M [00:02<00:00, 514kB/s]
Downloading (â€¦)okenizer_config.json: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 26.0/26.0 [00:00<00:00, 6.79kB/s]
Downloading (â€¦)fb03995f2/vocab.json: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 899k/899k [00:01<00:00, 665kB/s]
No sen

: 

: 