# Set up

## Set up packages

In [None]:
# Add current position to path in order to use project2Library
import sys
sys.path.append(".")

## Data & Model Settings

In [None]:
# Data Parameters

# Select if to use 200k or 20k dataset
small = input("Use small dataset? [yes/no]").lower() == "yes"

# We always replace numbers with @ as done in the original paper
replace_num = True

# BERT Parameters

# Select if to use domain-specific or general purpose version of BERT
if input("Use pubmed-BERT? [yes/no]").lower() == "yes":
    model_checkpoint = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
else:
    model_checkpoint = "bert-base-uncased"

# Select if to fine-tune BERT
freeze_bert = input("Fine-tune bert? [yes/no]").lower() == "no"

# Number of labels is always 5
num_labels = 5

## Set up Folder Structure

In [None]:
from pathlib import Path

# Set up data path
data_dir = Path("./data")
data_dir.mkdir(parents=True, exist_ok=True)

# Define one folder for full experiment
base_path = f"./TrainedModels/bert/{model_checkpoint}_" \
            f"{'freeze' if freeze_bert else 'nofreeze'}_{'20k' if small else '200k'}"

# Use one sub-folder for hyper paramerter tuning
model_dir_hp = Path(base_path).joinpath("hp")
model_dir_hp.mkdir(parents=True, exist_ok=True)

# Use another sub-folder for the final model
model_dir_final = Path(base_path).joinpath("final_model")
model_dir_final.mkdir(parents=True, exist_ok=True)

# Pre-Process Data

## Download & Load Data

In [None]:
from project2Lib import download_data, load_data

# We download the data if nessecery
download_data(data_dir=data_dir, small=small, replace_num=replace_num)

# We load the data as a Huggingface-dataset 
dataset = load_data(data_dir=data_dir)

## Tokenize Data

In [None]:
from transformers import AutoTokenizer

# We use the tokenizer corresponding to the BERT-model
# This also automatically pre-process the data by lower-casing the sentences
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True, model_max_length=512)

# Encode the dataset by applying the tokenizer
encoded_dataset = dataset.map(
    lambda x: tokenizer(x["sentence"], truncation=True), 
    batched=True
)

# Run Training

## Set up Training

In [None]:
from transformers import TrainingArguments, Trainer
from project2Lib import compute_f1, bert_init

# Set up default training arguments
args = TrainingArguments(
    model_dir_hp,
    evaluation_strategy = "epoch",
    save_strategy = "no",
    learning_rate=1e-3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
    load_best_model_at_end=False,
    metric_for_best_model = "f1",
    push_to_hub=False,
)

# Set up trainer
# Only train each model on a tenth of the data for performance reasons 
# during hyper parameter tuning
trainer = Trainer(
    model_init=lambda: bert_init(model_checkpoint, freeze_bert, num_labels),
    args=args,
    train_dataset=encoded_dataset["train"].shard(index=1, num_shards=10),
    eval_dataset=encoded_dataset["dev"],
    tokenizer=tokenizer,
    compute_metrics=compute_f1
)


## First Run Hype-Parameter Tuning

In [None]:
# Tune hyper parameters over 15 individual runs and select 
# the best performing combinations
def my_hp_space(trial, small):

    l_lr, u_lr = 1e-5, 1e-2 if small else 1e-6, 1e-4
    l_epochs, u_epochs = 3, 8 if small else 1, 1

    return {
        "learning_rate": trial.suggest_float("learning_rate", l_lr, u_lr, log=True),
        "num_train_epochs": trial.suggest_int("num_train_epochs", l_epochs, u_epochs),
        "seed": trial.suggest_int("seed", 1, 40),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [8, 16, 32]),
    }


# Only do hyper parameter search for small models
n_trials = 15
trainer.hyperparameter_search(
    n_trials=n_trials,
    direction="maximize",
    hp_space= lambda trial: my_hp_space(trial, small)
)

## Then Load Best Found Hyper-Parameters and Train Final Model

In [None]:
from project2Lib import load_best_bert
from transformers.trainer_utils import IntervalStrategy

# Update trainer for final run
_, top_args = load_best_bert(model_dir_hp)
top_args["evaluation_strategy"] = IntervalStrategy("epoch")
top_args["save_strategy"] = IntervalStrategy("epoch")
top_args["load_best_model_at_end"] = True
top_args["output_dir"] = model_dir_final
top_args["save_total_limit"] = 1

# Only one epoch for the big models
setattr(trainer, "args", TrainingArguments(**top_args))
setattr(trainer, "train_dataset", encoded_dataset["train"])

# Run training on full dataset and save state
trainer.train()
trainer.save_model()
trainer.save_state()


## Test Final Model and Save Results

In [None]:
import json

# Evaluate on the test dataset
results = trainer.evaluate(encoded_dataset["test"])

# Save to results file
json.dump(
    results,
    open(model_dir_final.joinpath("results.json"), "w+"),
)
