In [None]:
#Colab
'''
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
%cd /content/drive/MyDrive/aml_final
#with open("github_token.txt", "r") as f:
#  token = f.read()
#! git clone https://{token}@github.com/Tryner/aml_final.git #clone repo
%cd aml_final/
! git pull
! pip install setfit
'''

In [None]:
import torch
from datasets import load_dataset, Dataset
from setfit import SetFitModel, Trainer, TrainingArguments

from train.active_learning import ActiveTrainer
from train.active_learning_config import ActiveLearningConfig
from data.dataset_config import DatasetConfig
from train.reporter import Reporter
from train.metrics import camprehesive_metrics

In [None]:
dataset = load_dataset("sst2")
active_learning_config = ActiveLearningConfig(samples_per_cycle=2, unlabeled_samples=20, balancing_factor=0.5) # speed up training, not advisable
dataset_config = DatasetConfig()
train_args = TrainingArguments(num_epochs=1) #speed up training, not advisable

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

final_reporter = Reporter("example_final_report.csv", label_column=dataset_config.label_column)
cycle_reporter = Reporter("example_cycle_report.csv", report_train_args=False, label_column=dataset_config.label_column)

def model_init():
    return SetFitModel.from_pretrained(active_learning_config.model_name, use_differentiable_head=True, head_params={"out_features": dataset_config.num_classes}).to(device)
def after_train_callback(trainer: Trainer, dataset: Dataset):
    cycle_reporter.report(trainer=trainer, dataset=dataset)


In [None]:
for run_id in range(2):
    trainer = ActiveTrainer(
    model_init=model_init, 
    full_train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    train_args=train_args,
    active_learning_config=active_learning_config, 
    dataset_config=dataset_config,
    after_train_callback=after_train_callback,
    metric=camprehesive_metrics
    )
    t = trainer.train()
    final_reporter.report(
        trainer=t, 
        dataset=trainer.train_subset, 
        active_learning_config=active_learning_config, 
        dataset_name="sst2", run_id=run_id #kwars, so you can put anything here
        )
