# Set up

## Set up packages

In [None]:
# Add current position to path in order to use project2Library
import sys
sys.path.append(".")

## Data & Model Settings

In [None]:
# Data Parameters

# Select if to use 200k or 20k dataset
small = input("Use small dataset? [yes/no]").lower() == "yes"

# We always replace numbers with @ as done in the original paper
replace_num = True

# Number of labels is always 5
num_labels = 5

## Set up Basic Folder Structure

In [None]:
from pathlib import Path

# Set up data path
data_dir = Path("./data")
data_dir.mkdir(parents=True, exist_ok=True)

# Set up model path
model_base_dir = Path("./TrainedModels")
if not model_base_dir.exists():
    raise Exception("You must first train the base model before creating the hierarchical model")

## Select Base Model and Get Corresponding Arguments

In [None]:
from project2Lib import load_embedded, download_data
from project2Lib import TokenCollator, TokenClassifier, SentenceCollator, SentenceClassifier
import torch

# Select if to use domain-specific or general purpose version of BERT
if input("Use BERT-model? [yes/no]").lower() == "yes":
    load_args = {
        "dataset_path": data_dir.joinpath("dataset_small_bert"),
        "embedding": "bert",
        "model_checkpoint": model_base_dir.joinpath("bert/microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext_nofreeze_200k/final_model/pytorch_model.bin"),
        "device": "cuda:0" if torch.cuda.is_available() else "cpu"
    }
    model_init = lambda trial: SentenceClassifier(num_labels=num_labels, sentence_size=768, trial=trial)
    collator = SentenceCollator()

else:
    load_args = {
        "dataset_path": data_dir.joinpath("dataset_small_w2v"),
        "embedding": "w2v",
        "model_checkpoint": model_base_dir.joinpath("w2v_200_lemmatization.bin")        
    }
    model_init = lambda trial: TokenClassifier(num_labels=num_labels, token_size=200, trial=trial)
    collator = TokenCollator()


# Pre-Process Data

## Set up final folder structure for model

In [None]:
# Set up final folder structure
model_dir_hp = model_base_dir.joinpath(load_args["embedding"]).joinpath("hp")
model_dir_hp.mkdir(parents=True, exist_ok=True)
model_dir_final = model_base_dir.joinpath(load_args["embedding"]).joinpath("final_model")
model_dir_final.mkdir(parents=True, exist_ok=True)

## Download & Load Data

In [None]:
from project2Lib import load_embedded, download_data

# Assert that model is already trained
if not load_args["model_checkpoint"].exists():
    raise Exception("You must first train the base model before using the embedding")

# We download the data if nessecery
download_data()

# We load the data as a Huggingface-dataset
encoded_dataset = load_embedded(data_dir=data_dir, fields=None, group_by_abstracs=True, **load_args)

# Run Training

## Set up Training

In [None]:
from transformers import TrainingArguments, Trainer
from project2Lib import compute_f1, bert_init, TokenCollator, TokenClassifier

# Set up default training arguments
args = TrainingArguments(
    model_dir_hp,
    evaluation_strategy = "epoch",
    save_strategy = "no",
    learning_rate=1e-4,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=5,
    weight_decay=0.01,
    load_best_model_at_end=False,
    metric_for_best_model = "f1",
    push_to_hub=False,
)

# Set up trainer
# Only train each model on a tenth of the data for performance reasons 
# during hyper parameter tuning
trainer = Trainer(
    model_init=model_init,
    args=args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["dev"],
    data_collator=collator,
    compute_metrics=compute_f1
)


## First Run Hype-Parameter Tuning

In [None]:
# Tune hyper parameters over 15 individual runs and select 
# the best performing combinations
def my_hp_space(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-5, 1e-3),
        "num_train_epochs": trial.suggest_int("num_train_epochs", 3, 7),
        "seed": trial.suggest_int("seed", 1, 40),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [16, 32, 64]),
        "per_device_eval_batch_size": trial.suggest_categorical("per_device_eval_batch_size", [16]),
        "hidden_size": trial.suggest_categorical("hidden_size", [64, 128, 256]),
        "dropout_p": trial.suggest_float("dropout_p", 0.1, 0.75),
        "num_layers": trial.suggest_int("num_layers", 1, 3),
        "sentence_size": trial.suggest_categorical("sentence_size", [64, 128, 256]),
    }

n_trials = 30
trainer.hyperparameter_search(
    n_trials=n_trials,
    direction="maximize",
    hp_space=my_hp_space
)

## Then Load Best Found Hyper-Parameters and Train Final Model

In [None]:
from project2Lib import load_best_bert
from transformers.trainer_utils import IntervalStrategy

# Update trainer for final run
_, top_args = load_best_bert(model_dir_hp)
top_args["evaluation_strategy"] = IntervalStrategy("epoch")
top_args["save_strategy"] = IntervalStrategy("epoch")
top_args["load_best_model_at_end"] = True
top_args["output_dir"] = model_dir_final
setattr(trainer, "args", TrainingArguments(**top_args))
setattr(trainer, "train_dataset", encoded_dataset["train"])
setattr(trainer, "eval_dataset", encoded_dataset["dev"])

# Run training on full dataset and save state
trainer.train()
trainer.save_model()
trainer.save_state()


## Test Final Model and Save Results

In [None]:
import json

# Evaluate on the test dataset
results = trainer.evaluate(encoded_dataset["test"])

# Save to results file
json.dump(
    results,
    open(model_dir_final.joinpath("results.json"), "w+"),
)
