In [None]:
#colab
'''
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
%cd /content/drive/MyDrive/aml_final/aml_final/
! git pull
! pip install setfit
'''

In [None]:
import torch
from datasets import load_dataset, Dataset, concatenate_datasets
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset

from train.reporter import Reporter
from train.metrics import camprehesive_metrics

In [None]:
dataset = load_dataset("sst2")
full_train_dataset = dataset["train"]
eval_dataset = dataset["validation"]

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 2
model_name = "BAAI/bge-base-en-v1.5"

def model_init():
    return SetFitModel.from_pretrained(
        model_name, 
        use_differentiable_head=True, 
        head_params={"out_features": num_classes}
        ).to(device)

def create_random_subset(dataset: Dataset, num_samples: int = 16) -> Dataset:
    subset = dataset.shuffle().select(range(num_samples))
    if len(set(subset["label"])) < num_classes:
        print("Shit happens")
        return create_random_subset(dataset, num_samples)
    return subset


def create_train_sets(full_train_dataset: Dataset, num_samples: int = 16):
    full_train_dataset = full_train_dataset.shuffle()
    small_dataset = sample_dataset(full_train_dataset, label_column="label", num_samples=num_samples//num_classes)
    unbalanced_0 = full_train_dataset.filter(lambda e: e["label"]==0).select(range(num_samples))
    unbalanced_1 = full_train_dataset.filter(lambda e: e["label"]==1).select(range(num_samples))
    big_dataset = concatenate_datasets([small_dataset, unbalanced_0, unbalanced_1])
    unbalanced_0 = concatenate_datasets([small_dataset, unbalanced_0])
    unbalanced_1 = concatenate_datasets([small_dataset, unbalanced_1])
    random_subset = create_random_subset(big_dataset, num_samples=num_samples)
    return small_dataset, unbalanced_0, unbalanced_1, big_dataset, random_subset


In [None]:
reporter = Reporter("colin_setfit_eval.csv")

In [None]:
num_samples = 16
for i in range(5):
    small_dataset, unbalanced_0, unbalanced_1, big_dataset, random_subset = create_train_sets(full_train_dataset)
    datasets = [small_dataset, unbalanced_0, unbalanced_1, big_dataset, random_subset]
    for dataset in datasets:
        for iterations in range(10, 31, 10):
            args = TrainingArguments(num_iterations=iterations)
            trainer = Trainer(
                model_init=model_init,
                train_dataset=dataset,
                eval_dataset=eval_dataset,
                args=args,
                metric=camprehesive_metrics,
                column_mapping={"sentence": "text", "label": "label"},
            )
            trainer.train()
            reporter.report(trainer, dataset, model_name=model_name, i=i, iterations=iterations)