In [None]:
#colab
'''
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
%cd /content/drive/MyDrive/aml_final/aml_final/
! git pull
! pip install setfit
'''

import torch, gc
from datasets import load_dataset, Dataset, concatenate_datasets
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset

from train.reporter import Reporter
from train.metrics import comprehensive_metrics

In [None]:
dataset = load_dataset("dair-ai/emotion")
full_train_dataset = dataset["train"]
eval_dataset = dataset["validation"]
reporter = Reporter("model_comparison.csv")

num_classes = 6
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:

def create_random_subset(dataset: Dataset, num_samples: int = 16) -> Dataset:
    subset = dataset.shuffle().select(range(num_samples))
    if len(set(subset["label"])) < num_classes:
        print("Shit happens")
        return create_random_subset(dataset, num_samples)
    return subset


def create_train_sets(full_train_dataset: Dataset, num_samples: int = 16):
    full_train_dataset = full_train_dataset.shuffle()
    small_dataset = sample_dataset(full_train_dataset, label_column="label", num_samples=num_samples//num_classes)
    unbalanced_0 = full_train_dataset.filter(lambda e: e["label"]==0).select(range(num_samples))
    unbalanced_1 = full_train_dataset.filter(lambda e: e["label"]==1).select(range(num_samples))
    big_dataset = concatenate_datasets([small_dataset, unbalanced_0, unbalanced_1])
    unbalanced_0 = concatenate_datasets([small_dataset, unbalanced_0])
    unbalanced_1 = concatenate_datasets([small_dataset, unbalanced_1])
    random_subset = create_random_subset(big_dataset, num_samples=num_samples)
    return small_dataset, unbalanced_0, unbalanced_1, big_dataset, random_subset

In [None]:
def run_train(dataset: Dataset, reporter: Reporter, model_name: str):
    args = TrainingArguments(num_iterations=20)
    model_init = lambda: SetFitModel.from_pretrained(
        model_name,
        use_differentiable_head=True,
        head_params={"out_features": num_classes}
        ).to(device)
    trainer = Trainer(
        model_init=model_init,
        train_dataset=dataset,
        eval_dataset=eval_dataset,
        args=args,
        metric=comprehensive_metrics,
        column_mapping={"text": "text", "label": "label"},
    )
    trainer.train()
    reporter.report(trainer, dataset, model_name=model_name)
    gc.collect()

In [None]:
num_samples = 8
for seed in range(5):
    dataset = sample_dataset(full_train_dataset, num_samples=num_samples, seed=seed)
    for model_name in ["sentence-transformers/all-mpnet-base-v2", "WhereIsAI/UAE-Large-V1"]:
        run_train(dataset, reporter, model_name)

In [None]:
num_samples = 64
for seed in range(5):
    dataset = sample_dataset(full_train_dataset, num_samples=num_samples, seed=seed)
    for model_name in ["sentence-transformers/all-mpnet-base-v2", "WhereIsAI/UAE-Large-V1"]:
        run_train(dataset, reporter, model_name)