# Set up

## Set up packages

In [None]:
# Add current position to path in order to use project2Library
import sys
sys.path.append(".")

## Data Settings

In [None]:
# Data Parameters
small = True
replace_num=True

# BERT Parameters
num_labels = 5

## Set up Folder Structure

In [None]:
from pathlib import Path

# Set up data path
data_dir = Path("./data")
data_dir.mkdir(parents=True, exist_ok=True)

# Set up model path
model_base_dir = Path("./TrainedModels")
if not model_base_dir.exists():
    raise Exception("You must first train the base model before creating the hierarchical model")

# Define one folder for full experiment
model_dir_final = model_base_dir.joinpath("KD")
model_dir_final.mkdir(parents=True, exist_ok=True)

# Pre-Process Data

## Download & Load Data

In [None]:
from project2Lib import download_data, load_data, load_data_as_dataframe

# We download the data if nessecery
download_data(data_dir=data_dir, small=small, replace_num=replace_num)

# We load the data as a dataframe
dataset = load_data_as_dataframe(data_dir=data_dir)

## Set up TF-IDF Embedding

In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer

# Fit the TF-IDF-model
tfidf_vectorizer = TfidfVectorizer(use_idf=True, max_features=50000, ngram_range=(1,2))
X = tfidf_vectorizer.fit(dataset[0]["sentence"].values)

# Create the embedding
X = {
    k: tfidf_vectorizer.transform(dataset[i]["sentence"].values) \
    for i, k in enumerate(["train", "dev", "test"])
}
for k in X:
    X[k].sort_indices()

# Load the labels
Y = {
    k: dataset[i]["label"].to_numpy() \
    for i, k in enumerate(["train", "dev", "test"])
}


## Load Data for Getting Teacher Predictions

In [None]:
from project2Lib import load_embedded, download_data
import torch

# Load the embedded data from the BERT model
load_args = {
    "dataset_path": data_dir.joinpath("dataset_small_bert"),
    "embedding": "bert",
    "model_checkpoint": model_base_dir.joinpath("bert/microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext_nofreeze_200k/final_model/pytorch_model.bin"),
    "device": "cuda:0" if torch.cuda.is_available() else "cpu"
}

# Assert that model is already trained
if not load_args["model_checkpoint"].exists():
    raise Exception("You must first train the base model before using the embedding")

# We download the data if nessecery
download_data()

# We load the data as a Huggingface-dataset
encoded_dataset = load_embedded(data_dir=data_dir, fields=None, group_by_abstracs=True, **load_args)

## Get Teacher Predictions

In [None]:
from project2Lib import SentenceClassifier, SentenceCollator

# Load teach
teacher_path = model_base_dir.joinpath(load_args["embedding"]).joinpath("final_model/pytorch_model.bin")

# Assert that techer model is already trained
if not load_args["model_checkpoint"].exists():
    raise Exception("You must first train the teacher model before doing KD")

# Load pre-trained teacher
model = SentenceClassifier(num_labels, 768, None)
model.load_state_dict(torch.load(teacher_path, map_location=load_args["device"]))

# Add logits to the dataset
logit_dataset = encoded_dataset.map(
    lambda batch: {"logits": model(**SentenceCollator()([batch])).logits},
    batched = False
)

# Convert to numpy array
teacher_logits = {
    k: np.array([x for s in v["logits"] for x in s])  for k, v in logit_dataset.items()
}


# Run Training

## Define KD-opbjective for Hyperparameter Optimization

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from project2Lib import add_KD

from sklearn.metrics import f1_score, accuracy_score
import numpy as np

def objective(trial, base_model, num_labels, model_dir, X, teacher_logits, Y):
    
    # Get parameters for trial
    alpha = trial.suggest_float("alpha", 0, 1)
    T = trial.suggest_float("T", 0, 1)

    # Print status
    print("New trial: ", alpha, T)    

    # Get KD-model with given parameters
    model = add_KD(alpha, T, base_model, num_labels)

    # Set up model
    file_path = model_dir.joinpath("./model.h5")
    checkpoint = ModelCheckpoint(file_path, monitor="val_accuracy", verbose=0, save_best_only=True, mode='max')   
    early = EarlyStopping(monitor="val_accuracy", patience=10)
    redonplat = ReduceLROnPlateau(monitor="val_accuracy", mode="max", patience=3, verbose=0)
    callbacks_list = [checkpoint, early, redonplat]

    # Run model
    model.fit([X["train"], teacher_logits["train"]], Y["train"], 
              validation_data=([X["dev"], teacher_logits["dev"]], Y["dev"]), 
              epochs=100, batch_size=1024, verbose=0, callbacks=callbacks_list)
    
    # Load best state
    model.load_weights(file_path)

    # Compute F1
    preds = model.predict([X["dev"], teacher_logits["dev"]])
    preds = np.argmax(preds[:, :num_labels], axis=-1)
    score = f1_score(Y["dev"], preds, average="weighted")

    print("Finished trial: ", alpha, T, score)

    return score

## Define Logistic Regression Model

In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

def get_logistic_regression(num_labels = 5, numb_features=50000):
    lr = Sequential()
    lr.add(Dense(num_labels,input_dim = numb_features))
    lr.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['acc'])
    return lr

## Run Hyperparameter optimization

In [None]:
import optuna

# HP-space for KD
def my_hp_space(trial):
    return {
        "alpha": trial.suggest_float("alpha", 0, 1),
        "T": trial.suggest_float("T", 0.5, 5)
    }

# Optimize the hyperparameters for KD
study = optuna.create_study() 
best_trial = study.optimize(
    lambda trial: objective(trial, get_logistic_regression, num_labels, model_dir_final, X, teacher_logits, Y), 
    n_trials=30
)

## Run Final Model

In [None]:
# Get parameters for trial
alpha = best_trial.params["alpha"]
T = best_trial.params["T"]

# Print status
print("Training final model with parameters: ", alpha, T)    

# Get KD-model with given parameters
model = add_KD(alpha, T, get_logistic_regression, num_labels)

# Set up model
file_path = model_dir_final.joinpath("./model.h5")
checkpoint = ModelCheckpoint(file_path, monitor="val_accuracy", verbose=0, save_best_only=True, mode='max')   
early = EarlyStopping(monitor="val_accuracy", patience=10)
redonplat = ReduceLROnPlateau(monitor="val_accuracy", mode="max", patience=3, verbose=0)
callbacks_list = [checkpoint, early, redonplat]

# Run model
model.fit([X["train"], teacher_logits["train"]], Y["train"], 
          validation_data=([X["dev"], teacher_logits["dev"]], Y["dev"]), 
          epochs=100, batch_size=1024, verbose=0, callbacks=callbacks_list)

# Load best state
model.load_weights(file_path)

# Compute F1
preds = model.predict([X["test"], teacher_logits["test"]])
preds = np.argmax(preds[:, :num_labels], axis=-1)
f1 = f1_score(Y["dev"], preds, average="weighted")
acc = accuracy_score(Y["dev"], preds)

# Print results
print("Finished model with f1 = ", f1, " and accuracy = ", acc)