# This notebook trains RNN based models 

# Imports and Setup

In [None]:
%load_ext autoreload
%autoreload 2

from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))


In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
from sklearn.utils.class_weight import compute_class_weight

from src.embeddings import get_embedding_config
from src.data_processing import preprocess_raw_datasets, PreprocessingOptions, encode_one_hot_labels
from src.data_loading import load_embeddings, load_labels
from src.tf_models.rnn_models import rnn_model_factory, load_trained_model
from src.tf_models.utils import compile_tf_model
from src.tf_datasets import create_tf_datasets

from src.constants import PATH_TF_MODELS
from src.plots import plot_tf_history


# Data Loading
Load previously computed embeddings

In [None]:
y_train, y_dev, y_test = load_labels()
balanced_class_weight = {k: v for k, v in enumerate(compute_class_weight("balanced", classes=np.unique(y_train), y=y_train))}
uniform_class_weight = {k: v for k, v in enumerate([1.,1.,1.,1.,1.])}

# need one hot labels
y_train = encode_one_hot_labels(y_train)
y_dev = encode_one_hot_labels(y_dev)
y_test = encode_one_hot_labels(y_test)

print("Balanced class weights", balanced_class_weight)


# Config

In [None]:
# relevant settings to identify the desired embedding
PREPROCESSING_OPTIONS = PreprocessingOptions(remove_stop_words=False, lemmatisation=False)
EMBEDDING = "word2vec" # "word2vec" only
EMBEDDING_VERSION = "cbow" # or Skip_N-gram
VECTOR_SIZE = 25
MAX_WORDS = 50


In [None]:
x_embeddings_train, x_embeddings_dev, x_embeddings_test = load_embeddings(PREPROCESSING_OPTIONS, EMBEDDING_VERSION, VECTOR_SIZE, MAX_WORDS, embedding_type=EMBEDDING)


## Create tf Datasets

In [None]:
train_dataset, dev_dataset, test_dataset = create_tf_datasets(x_embeddings_train, y_train,
                                                              x_embeddings_dev, y_dev,
                                                              x_embeddings_test, y_test)

del x_embeddings_train
del x_embeddings_dev
del x_embeddings_test


## Train RNN models

### Set the filepath where they will be saved

In [None]:
MODEL_NAME = "large_bidirectional_lstm" # could also be: rnn, lstm, bidirectional_lstm
CLASS_WEIGHT = "balanced"

checkpoint_filepath = PATH_TF_MODELS + MODEL_NAME + "_" + CLASS_WEIGHT + "_" + get_embedding_config(PREPROCESSING_OPTIONS, EMBEDDING_VERSION, VECTOR_SIZE, MAX_WORDS)


In [None]:
if CLASS_WEIGHT == "balanced":
    class_weight = balanced_class_weight
else:
    class_weight = uniform_class_weight
    
class_weight


### Start training

In [None]:
early_stop = tf.keras.callbacks.EarlyStopping(monitor='macro_f1_score', patience=2, mode="max")
checkpoint = tf.keras.callbacks.ModelCheckpoint(checkpoint_filepath, monitor="macro_f1_score", save_best_only=True, mode="max")

model = rnn_model_factory(MODEL_NAME)
history = model.fit(x=train_dataset, epochs=5, validation_data=dev_dataset, validation_steps=30, callbacks=[early_stop, checkpoint], class_weight=class_weight)


In [None]:
plot_tf_history(history)


# Evaluate trained model

In [None]:
print(checkpoint_filepath)
model = load_trained_model(checkpoint_filepath)
compile_tf_model(model)


In [None]:
model.evaluate(test_dataset)
