In [None]:
%run ./../data/load-dataset.ipynb
%run ./../utils/_callbacks.ipynb
%run ./../transformers/_load-trf-data.ipynb
%matplotlib inline

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import ConfusionMatrixDisplay
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.regularizers import l1, l2, l1_l2

In [None]:
ITERS = 5

RANDOM_SEED = 0
tf.keras.utils.set_random_seed(RANDOM_SEED)

In [None]:
corpus, labels = df[[proc_doc_col, label_col]].T.values
X = trf_data
y = labels.astype(int)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.25, random_state=RANDOM_SEED)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=.3, random_state=RANDOM_SEED)

X_train, X_val, X_test = dict(input_word_ids=(_ := X_train.swapaxes(0, 1))[0], input_mask=_[1], input_type_ids=_[2]), \
                         dict(input_word_ids=(_ := X_val.swapaxes(0, 1))[0], input_mask=_[1], input_type_ids=_[2]), \
                         dict(input_word_ids=(_ := X_test.swapaxes(0, 1))[0], input_mask=_[1], input_type_ids=_[2])
y_train, y_val, y_test = tf.constant(y_train), tf.constant(y_val), tf.constant(y_test)

In [None]:
input_layer = dict(input_word_ids=Input(shape=(trf_data.shape[-1], ), dtype=tf.int32),
                   input_mask=Input(shape=(trf_data.shape[-1], ), dtype=tf.int32),
                   input_type_ids=Input(shape=(trf_data.shape[-1], ), dtype=tf.int32))
trf_output = trf_model(input_layer)['pooled_output']
output_layer = Dense(1, activation=tf.nn.sigmoid)(trf_output)

model = Model(input_layer, output_layer)
optimizer = RMSprop(1e-6)
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['binary_accuracy'])
model.summary()

In [None]:
test_hist = KerasEpochCallback(end_func=model.evaluate, end_args=(X_test, y_test), end_kwargs={'verbose': False})
train_hist = model.fit(X_train, y_train, validation_data=(X_val, y_val),
                       epochs=ITERS, batch_size=16, shuffle=True, callbacks=[test_hist])

train_metrics = np.array(list(zip(train_hist.history['loss'], train_hist.history['binary_accuracy'])))
val_metrics = np.array(list(zip(train_hist.history['val_loss'], train_hist.history['val_binary_accuracy'])))
test_metrics = np.array(test_hist.end_results)

results = [{"Training Loss": train_loss, "Training Accuracy": train_acc,
            "Validation Loss": val_loss, "Validation Accuracy": val_acc,
            "Test Loss": test_loss, "Test Accuracy": test_acc} 
           for (train_loss, train_acc), (val_loss, val_acc), (test_loss, test_acc) 
           in zip(train_metrics, val_metrics, test_metrics)]

results_df = pd.DataFrame(results)
results_df.index += 1
results_df.style.highlight_min(subset=["Training Accuracy", "Validation Accuracy", "Test Accuracy"], color='lightcoral') \
                .highlight_max(subset=["Training Accuracy", "Validation Accuracy", "Test Accuracy"], color='lightgreen') \
                .highlight_min(subset=["Training Loss", "Validation Loss", "Test Loss"], color='lightgreen') \
                .highlight_max(subset=["Training Loss", "Validation Loss", "Test Loss"], color='lightcoral')

In [None]:
ax = plt.gca()
ax.plot(np.arange(0, ITERS), train_metrics.T[0], c='#1f77b4', label="Training")
ax.plot(np.arange(0, ITERS), val_metrics.T[0], c='#d62728', label="Validation")
ax.plot(np.arange(0, ITERS), test_metrics.T[0], c='#2ca02c', label="Test")
ax.legend()
plt.xticks(np.arange(0, ITERS), np.arange(1, ITERS + 1))
plt.title("Loss")
plt.tight_layout()

In [None]:
ax = plt.gca()
ax.plot(np.arange(0, ITERS), train_metrics.T[1], c='#1f77b4', label="Training")
ax.plot(np.arange(0, ITERS), val_metrics.T[1], c='#d62728', label="Validation")
ax.plot(np.arange(0, ITERS), test_metrics.T[1], c='#2ca02c', label="Test")
ax.legend()
plt.xticks(np.arange(0, ITERS), np.arange(1, ITERS + 1))
plt.title("Accuracy")
plt.tight_layout()

In [None]:
true_preds = model.predict(X_test).flatten()
y_preds = true_preds.round()
ConfusionMatrixDisplay.from_predictions(y_preds, y_test, normalize='true', colorbar=False,
                                        cmap=plt.cm.Blues, display_labels=("reliable", "unreliable"))