In [None]:
#colab
'''
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
%cd /content/drive/MyDrive/aml_final/aml_final/
! git pull
! pip install setfit
'''
from datasets import load_dataset
from setfit import SetFitModel, Trainer, TrainingArguments
import torch, gc, time

from data.dataset_config import DatasetConfig
from train.active_learning import create_random_subset
from train.reporter import Reporter
from train.metrics import camprehesive_metrics

In [None]:
num_samples = 100
inference_samples = 1000
dataset_config = DatasetConfig(text_column="text", num_classes=6)

dataset = load_dataset("dair-ai/emotion")
subset = create_random_subset(dataset["train"], dataset_config, num_samples)
inference_samples = create_random_subset(dataset["train"], dataset_config, inference_samples)[dataset_config.text_column]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
final_reporter = Reporter("time.csv", label_column=dataset_config.label_column)

cycle_args = TrainingArguments(num_iterations=10, num_epochs=(1, 8))
models = ["thenlper/gte-small", "sentence-transformers/all-mpnet-base-v2", "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", "WhereIsAI/UAE-Large-V1"]

In [None]:
for model_name in models:
    model = SetFitModel.from_pretrained(
        model_name, use_differentiable_head=True, head_params={"out_features": dataset_config.num_classes}
        ).to(device)
    trainer = Trainer(
        model=model,
        args=cycle_args,
        train_dataset=subset,
        eval_dataset=dataset["validation"],
        metric=camprehesive_metrics
    )
    start_time=time.time()
    trainer.train()
    train_time=time.time() - start_time
    start_time=time.time()
    probs = trainer.model.predict_proba(inference_samples)
    infernce_time=time.time() - start_time
    metrics = trainer.evaluate()
    final_reporter.report(dataset=subset, model=model_name, train_time=train_time, infernce_time=infernce_time, inference_samples=inference_samples, **metrics)
    gc.collect()
    