In [None]:
#Colab
'''
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
%cd /content/drive/MyDrive/aml_final
#with open("github_token.txt", "r") as f:
#  token = f.read()
#! git clone https://{token}@github.com/Tryner/aml_final.git #clone repo
%cd aml_final/
! git pull
! pip install setfit
'''

In [None]:

import torch
from datasets import load_dataset, Dataset
from setfit import SetFitModel, Trainer, TrainingArguments

from train.active_learning import ActiveTrainer
from train.active_learning_config import ActiveLearningConfig
from data.dataset_config import DatasetConfig
from data.load_datasets import select_dataset, load
from train.reporter import Reporter
from train.metrics import comprehensive_metrics

In [None]:
dataset_choice = select_dataset()
dataset_name, dataset = load(dataset_choice)
active_learning_config = ActiveLearningConfig(samples_per_cycle=4, unlabeled_samples=40, balancing_factor=0.5, model_name="thenlper/gte-small")
dataset_config = DatasetConfig()
train_args = TrainingArguments(num_iterations=20)

In [None]:
final_reporter = Reporter(dataset_name+"_final.csv", label_column=dataset_config.label_column)
cycle_reporter = Reporter(dataset_name+"_cycle.csv", report_train_args=False, label_column=dataset_config.label_column)

def after_train_callback(trainer: Trainer, dataset: Dataset, run_id: int):
    cycle_reporter.report(trainer=trainer, dataset=dataset, run_id=run_id)

In [None]:
for run_id in range(3):
    trainer = ActiveTrainer(
        full_train_dataset=dataset["train"],
        eval_dataset=dataset["validation"],
        train_args=train_args,
        active_learning_config=active_learning_config, 
        dataset_config=dataset_config,
        after_train_callback=after_train_callback,
        metric=comprehensive_metrics,
        run_id=run_id
    )
    t = trainer.train()
    final_reporter.report(
        trainer=t, 
        dataset=trainer.train_subset, 
        active_learning_config=active_learning_config, 
        dataset_name=dataset_name, run_id=run_id
        )
