In [1]:
from tqdm.notebook import tqdm
import pandas as pd

In [2]:
DEV = True
model_name = "avsolatorio/GIST-small-Embedding-v0"  # train when I've got a spare two hours

In [3]:
def import_labelled_data(path="data/labelled/data.json", group_relevant=True):
    data = pd.read_json(path, encoding="latin-1")
    data["relevance"] = data["class"].apply(
        lambda x: "relevant" if x != "irrelevant" else x
    )
    return data


data = import_labelled_data(path="../../data/labelled/data.json", group_relevant=False)

# drop null classes
data = data.dropna(subset=["class"])


if DEV:
    data = data.sample(5000)


# train test split
from sklearn.model_selection import train_test_split

train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
train_data, val_data = train_test_split(train_data, test_size=0.2, random_state=42)

data.head()

Unnamed: 0,url,text,class,relevance
9496,https://dce-uae.com/wp-content/uploads/2023/04...,FASHION ARENA PRAGUE OUTLET AMSTERDAM I GOTHEN...,irrelevant,irrelevant
3783,https://www.conservationevidence.com/individua...,This study is summarised as evidence for the f...,Shrubland and Heathland Conservation,relevant
4919,https://www.conservationevidence.com/individua...,Provide supplementary food to increase reprodu...,Terrestrial Mammal Conservation,relevant
1378,https://www.conservationevidence.com/individua...,"A replicated, randomised, controlled study fro...",Farmland Conservation,relevant
3811,https://www.conservationevidence.com/individua...,This study is summarised as evidence for the f...,Peatland Conservation,relevant


In [4]:
from chunking import chunk_dataset_and_explode


# roughly 4 characters per token
max_len = 2048

train_data = chunk_dataset_and_explode(train_data, max_len=max_len, overlap=int(max_len * 0.2))
test_data = chunk_dataset_and_explode(test_data, max_len=max_len, overlap=int(max_len * 0.2))
val_data = chunk_dataset_and_explode(val_data, max_len=max_len, overlap=int(max_len * 0.2))

In [5]:
from datasets import Dataset

train_dataset = Dataset.from_pandas(train_data, split="train")
test_dataset = Dataset.from_pandas(test_data, split="test")
val_dataset = Dataset.from_pandas(val_data, split="val")

train_dataset

Dataset({
    features: ['chunk_id', 'url', 'text', 'class', 'relevance'],
    num_rows: 46603
})

In [6]:
from fastfit import sample_dataset, FastFitTrainer


train_dataset = sample_dataset(train_dataset, label_column='relevance',num_samples_per_label=20,seed=42)
val_dataset = val_dataset.shuffle(seed=42).select(range(30))
test_dataset = test_dataset.shuffle(seed=42).select(range(1000))

In [7]:
# same args as the huggingface TrainingArguments

#! had to modify FastFitTrainer to at /fastfit/train.py, line 879, to add trust_remote_code=True to the loading of 'accuracy' metrics
#! don't know why it's not default, since accuracy is the default in fastfit

#* note that since SetFit uses evaluation_strategy as the argument name rather than eval_strategy
#* I had to change it in the FastFitTrainer call below
#* if using the latest transformers version (transformers>=4.41.0), use eval_strategy

#! another change in FastFitTrainer, also at line 879; commented out the fixed version above
#! since load_metric is deprecated in favour of evaluate.load()
#! using evaluate means we can use evaluate.combine(), which lets us calculate multiple metrics at once
#! just used an if/else to check if the metric_name is a list or not, then called load or combine accordingly

trainer = FastFitTrainer(
    model_name_or_path=model_name,
    train_dataset=train_dataset,
    validation_dataset=val_dataset,
    test_dataset=test_dataset,
    output_dir=f'models/{model_name}',
    overwrite_output_dir=True,
    label_column_name="relevance",
    text_column_name="text",
    num_train_epochs=20,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    max_text_length=2048,
    seed=42,
    num_repeats=2,
    evaluation_strategy="epoch",
    metric_name=['precision','accuracy']
)





Flattening the indices:   0%|          | 0/30 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/1000 [00:00<?, ? examples/s]

Running tokenizer on dataset to infer max length for both query and document:   0%|          | 0/40 [00:00<?, …

Running tokenizer on dataset to infer max length for both query and document:   0%|          | 0/30 [00:00<?, …

Running tokenizer on dataset to infer max length for both query and document:   0%|          | 0/1000 [00:00<?…

Running tokenizer on dataset:   0%|          | 0/40 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/30 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [8]:

#! another fastfit library modification
#! in /fastfit/train.py, line 971, change ignore_keys_for_eval from type set to a list
#! since it gets concatenated to a list later on

model = trainer.train()

	per_device_train_batch_size: 16 (from args) != 8 (from trainer_state.json)


  0%|          | 0/100 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 4.124362468719482, 'eval_precision': 1.0, 'eval_accuracy': 1.0, 'eval_runtime': 45.5148, 'eval_samples_per_second': 0.659, 'eval_steps_per_second': 0.044, 'epoch': 4.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.9607436656951904, 'eval_precision': 1.0, 'eval_accuracy': 1.0, 'eval_runtime': 20.8665, 'eval_samples_per_second': 1.438, 'eval_steps_per_second': 0.096, 'epoch': 5.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.925065517425537, 'eval_precision': 1.0, 'eval_accuracy': 1.0, 'eval_runtime': 14.9964, 'eval_samples_per_second': 2.0, 'eval_steps_per_second': 0.133, 'epoch': 6.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.86073899269104, 'eval_precision': 1.0, 'eval_accuracy': 1.0, 'eval_runtime': 15.1379, 'eval_samples_per_second': 1.982, 'eval_steps_per_second': 0.132, 'epoch': 7.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.850677967071533, 'eval_precision': 1.0, 'eval_accuracy': 1.0, 'eval_runtime': 15.0178, 'eval_samples_per_second': 1.998, 'eval_steps_per_second': 0.133, 'epoch': 8.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.8455774784088135, 'eval_precision': 1.0, 'eval_accuracy': 1.0, 'eval_runtime': 14.9971, 'eval_samples_per_second': 2.0, 'eval_steps_per_second': 0.133, 'epoch': 9.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.8384861946105957, 'eval_precision': 1.0, 'eval_accuracy': 1.0, 'eval_runtime': 14.9307, 'eval_samples_per_second': 2.009, 'eval_steps_per_second': 0.134, 'epoch': 10.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.8319671154022217, 'eval_precision': 1.0, 'eval_accuracy': 1.0, 'eval_runtime': 14.9831, 'eval_samples_per_second': 2.002, 'eval_steps_per_second': 0.133, 'epoch': 11.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.826968193054199, 'eval_precision': 1.0, 'eval_accuracy': 1.0, 'eval_runtime': 14.9447, 'eval_samples_per_second': 2.007, 'eval_steps_per_second': 0.134, 'epoch': 12.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.8231022357940674, 'eval_precision': 1.0, 'eval_accuracy': 1.0, 'eval_runtime': 14.8871, 'eval_samples_per_second': 2.015, 'eval_steps_per_second': 0.134, 'epoch': 13.0}


KeyboardInterrupt: 

In [None]:
results = trainer.evaluate()

  0%|          | 0/5 [00:00<?, ?it/s]

***** eval metrics *****
  epoch                   =        5.0
  eval_accuracy           =       0.95
  eval_loss               =     3.8564
  eval_runtime            = 0:00:04.39
  eval_samples            =         20
  eval_samples_per_second =      4.546
  eval_steps_per_second   =      1.137


In [None]:
print(f'Accuracy: {results["eval_accuracy"]}')

Accuracy: 0.95


In [None]:
model.save_pretrained(f'models/{model_name}')

In [None]:
trainer.test()

  0%|          | 0/25 [00:00<?, ?it/s]

***** test metrics *****
  epoch                   =        5.0
  eval_accuracy           =       0.88
  eval_loss               =     3.9806
  eval_runtime            = 0:00:22.43
  eval_samples_per_second =      4.457
  eval_steps_per_second   =      1.114
  test_samples            =        100


{'eval_loss': 3.9805710315704346,
 'eval_accuracy': 0.88,
 'eval_runtime': 22.4386,
 'eval_samples_per_second': 4.457,
 'eval_steps_per_second': 1.114,
 'epoch': 5.0,
 'test_samples': 100}

In [None]:
example = test_data.sample(1)
example

Unnamed: 0,chunk_id,url,text,class,relevance
11677,5044,https://www.conservationevidence.com/individua...,Release translocated/captive-bred mammals in f...,Terrestrial Mammal Conservation,relevant
