### Note: SetFit requires transformers==4.40.2.
There's a specific branch of the original repo that hasn't been merged yet, but it's updated with the new version - make sure to use that instead.

In [1]:
from tqdm.notebook import tqdm
import pandas as pd

In [2]:
DEV = True
model_name = "avsolatorio/GIST-small-Embedding-v0"  # train when I've got a spare two hours

In [3]:
def import_labelled_data(path="data/labelled/data.json", group_relevant=True):
    data = pd.read_json(path, encoding="latin-1")
    data["relevance"] = data["class"].apply(
        lambda x: "relevant" if x != "irrelevant" else x
    )
    return data


data = import_labelled_data(path="../../data/labelled/data.json", group_relevant=False)

# drop null classes
data = data.dropna(subset=["class"])


if DEV:
    data = data.sample(5000)


# train test split
from sklearn.model_selection import train_test_split

train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
train_data, val_data = train_test_split(train_data, test_size=0.2, random_state=42)

data.head()

Unnamed: 0,url,text,class,relevance
1941,https://www.conservationevidence.com/individua...,A replicated and controlled experiment on two ...,Bird Conservation,relevant
9204,https://budget.finance.go.ug/sites/default/fil...,LG Draft Budget Estimates 2024/25 VOTE: 921 Ru...,irrelevant,irrelevant
2202,https://www.conservationevidence.com/individua...,"A replicated, randomized, controlled study in ...",Mediterranean Farmland,relevant
7726,https://www.dbtechnologies.com/docs/299/8672/I...,Professional passive speaker MANUALE D’USO – S...,irrelevant,irrelevant
4092,https://www.conservationevidence.com/individua...,Legally protect bat habitatsA study in 2015 of...,Bat Conservation,relevant


In [4]:
from chunking import chunk_dataset_and_explode


# roughly 4 characters per token
max_len = 1024

train_data = chunk_dataset_and_explode(train_data, max_len=max_len, overlap=int(max_len * 0.2))
test_data = chunk_dataset_and_explode(test_data, max_len=max_len, overlap=int(max_len * 0.2))
val_data = chunk_dataset_and_explode(val_data, max_len=max_len, overlap=int(max_len * 0.2))

In [5]:
from datasets import Dataset

train_dataset = Dataset.from_pandas(train_data, split="train")
test_dataset = Dataset.from_pandas(test_data, split="test")
val_dataset = Dataset.from_pandas(val_data, split="val")

train_dataset

Dataset({
    features: ['chunk_id', 'url', 'text', 'class', 'relevance'],
    num_rows: 102357
})

In [6]:
from setfit import sample_dataset, SetFitModel


train_dataset = sample_dataset(train_dataset, label_column='relevance',num_samples=5,seed=42)
val_dataset = val_dataset.select(range(10))

  df = df.apply(lambda x: x.sample(min(num_samples, len(x)), random_state=seed))


In [7]:
# same args as the huggingface TrainingArguments

model = SetFitModel.from_pretrained(model_name, labels=["relevant", "irrelevant"])

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/68.0k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


In [8]:
from setfit import Trainer, TrainingArguments

args = TrainingArguments(
    batch_size=16,
    num_epochs=5,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    seed=42,
)

In [9]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    metric="accuracy",
    column_mapping={"relevance": "label", "text": "text"},
)

Applying column mapping to the training dataset
Applying column mapping to the evaluation dataset


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

In [10]:
trainer.train()

***** Running training *****
  Num unique pairs = 60
  Batch size = 16
  Num epochs = 5
  Total optimization steps = 20


AttributeError: 'TrainingArguments' object has no attribute 'eval_strategy'

In [None]:
results = trainer.evaluate(test_dataset.shuffle().select(range(100)))
results

TypeError: Trainer.evaluate() got an unexpected keyword argument 'metric'

In [None]:
print(f'Accuracy: {results["accuracy"]}')

Accuracy: 0.94


In [None]:
model.save_pretrained(f'models/{model_name}')