In [None]:
#colab
'''
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
%cd /content/drive/MyDrive/aml_final/aml_final/
! git pull
! pip install setfit
'''
from datasets import load_dataset, Dataset
from setfit import SetFitModel, Trainer, TrainingArguments
import torch, gc

from data.dataset_config import DatasetConfig
from train.active_learning import ActiveTrainer, create_random_subset
from train.active_learning_config import ActiveLearningConfig
from train.reporter import Reporter
from train.metrics import camprehesive_metrics

In [None]:
samples_per_cycle = 12
dataset_name = "dair-ai/emotion"
dataset = load_dataset(dataset_name)
dataset_config = DatasetConfig(text_column="text", num_classes=6)


train_dataset = dataset["train"]
eval_dataset = dataset["validation"].select(range(500))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

final_reporter = Reporter(dataset_name + "_final.csv", label_column=dataset_config.label_column)
cycle_reporter = Reporter(dataset_name + "_cycle.csv", report_train_args=False, label_column=dataset_config.label_column)
def after_train_callback(trainer: Trainer, dataset: Dataset, run_id: int):
    cycle_reporter.report(trainer=trainer, dataset=dataset, run_id=run_id)

In [None]:
cycle_train_args = TrainingArguments(num_iterations=10, num_epochs=(1, 8))
final_train_args = TrainingArguments(num_iterations=20, num_epochs=(1, 16))

def run_train(args, initial_train_subset, active_learning_config, **kwargs):
    trainer = ActiveTrainer(
        full_train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        initial_train_subset=initial_train_subset,
        train_args=args,
        active_learning_config=active_learning_config, 
        dataset_config=dataset_config,
        metric=camprehesive_metrics,
        run_id=kwargs["run_id"],
        final_model_train_args=final_train_args
    )
    t = trainer.train()
    final_reporter.report(
        trainer=t, 
        dataset=trainer.train_subset, 
        active_learning_config=active_learning_config, 
        dataset_name=dataset_name, **kwargs
        )

In [None]:
for dataset_seed in range(5):
    
