### Note: SetFit requires transformers==4.40.2.

This conflicts with FastFit, which requires transformers>=4.41.0.

In [1]:
from tqdm.notebook import tqdm
import pandas as pd

In [2]:
DEV = True
model_name = "avsolatorio/GIST-small-Embedding-v0"  # train when I've got a spare two hours

In [3]:
def import_labelled_data(path="data/labelled/data.json", group_relevant=True):
    data = pd.read_json(path, encoding="latin-1")
    data["relevance"] = data["class"].apply(
        lambda x: "relevant" if x != "irrelevant" else x
    )
    return data


data = import_labelled_data(path="../../data/labelled/data.json", group_relevant=False)

# drop null classes
data = data.dropna(subset=["class"])


if DEV:
    data = data.sample(5000)


# train test split
from sklearn.model_selection import train_test_split

train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
train_data, val_data = train_test_split(train_data, test_size=0.2, random_state=42)

data.head()

Unnamed: 0,url,text,class,relevance
4256,https://www.conservationevidence.com/individua...,"A replicated, site comparison study (year unsp...",Subtidal Benthic Invertebrate Conservation,relevant
5408,https://www.conservationevidence.com/individua...,"A replicated, controlled study in 2003–2004 in...",Marine Fish Conservation,relevant
4979,https://www.conservationevidence.com/individua...,A before-and-after study in 1997 in a mixed ha...,Terrestrial Mammal Conservation,relevant
5505,https://www.conservationevidence.com/individua...,A review in 2015 of electrotrawling activity i...,Marine Fish Conservation,relevant
7980,https://www.ddbst.com/files/files/ddbsp/2024/D...,"Dortmund Data Bank Retrieval, Display, Plot, a...",irrelevant,irrelevant


In [4]:
from chunking import chunk_dataset_and_explode


# roughly 4 characters per token
max_len = 1024

train_data = chunk_dataset_and_explode(train_data, max_len=max_len, overlap=int(max_len * 0.2))
test_data = chunk_dataset_and_explode(test_data, max_len=max_len, overlap=int(max_len * 0.2))
val_data = chunk_dataset_and_explode(val_data, max_len=max_len, overlap=int(max_len * 0.2))

In [5]:
from datasets import Dataset

train_dataset = Dataset.from_pandas(train_data, split="train")
test_dataset = Dataset.from_pandas(test_data, split="test")
val_dataset = Dataset.from_pandas(val_data, split="val")

train_dataset

Dataset({
    features: ['chunk_id', 'url', 'text', 'class', 'relevance'],
    num_rows: 110029
})

In [6]:
from setfit import sample_dataset, SetFitModel


train_dataset = sample_dataset(train_dataset, label_column='relevance',num_samples=5,seed=42)
val_dataset = val_dataset.select(range(10))

  df = df.apply(lambda x: x.sample(min(num_samples, len(x)), random_state=seed))


In [7]:
# same args as the huggingface TrainingArguments

model = SetFitModel.from_pretrained(model_name, labels=["relevant", "irrelevant"])

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


In [8]:
from setfit import Trainer, TrainingArguments

args = TrainingArguments(
    batch_size=16,
    num_epochs=5,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    seed=42,
)

In [9]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    metric="accuracy",
    column_mapping={"relevance": "label", "text": "text"},
)

Applying column mapping to the training dataset
Applying column mapping to the evaluation dataset


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

In [10]:
trainer.train()

***** Running training *****
  Num unique pairs = 60
  Batch size = 16
  Num epochs = 5
  Total optimization steps = 20


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


AttributeError: 'TrainingArguments' object has no attribute 'eval_strategy'

In [None]:
results = trainer.evaluate(test_dataset.shuffle().select(range(100)))
results

TypeError: Trainer.evaluate() got an unexpected keyword argument 'metric'

In [None]:
print(f'Accuracy: {results["accuracy"]}')

Accuracy: 0.94


In [None]:
model.save_pretrained(f'models/{model_name}')