In [None]:
%run ./../data/load-tokenized-dataset.ipynb
%run ./../utils/_callbacks.ipynb
%run ./../doc2vec/_load-d2v-model.ipynb
%matplotlib inline

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import ConfusionMatrixDisplay
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.optimizers import SGD 

In [None]:
ITERS = 5

INFER = False
NORMED = True

RANDOM_SEED = 0
tf.random.set_seed(RANDOM_SEED)

In [None]:
corpus, labels = df[[proc_doc_col, label_col]].T.values
X = d2v_model.dv.vectors if not INFER else [d2v_model.infer_vector(doc) for doc in tqdm(corpus, disable=SILENT)]
y = labels.astype(int)

if NORMED:
    X = tf.math.l2_normalize(X, axis=1).numpy()

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.25, random_state=RANDOM_SEED)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=.3, random_state=RANDOM_SEED)

X_train, X_val, X_test = tf.constant(X_train), tf.constant(X_val), tf.constant(X_test)
y_train, y_val, y_test = tf.constant(y_train), tf.constant(y_val), tf.constant(y_test)

In [None]:
model = Sequential([
    Input(shape=(d2v_model.vector_size, ), dtype=tf.float32),
    Dense(d2v_model.vector_size, activation=tf.nn.relu),
    Dense(d2v_model.vector_size, activation=tf.nn.relu),
    Dense(1, activation= tf.nn.sigmoid)
])
model.summary()

In [None]:
model.compile(optimizer=SGD(1e-1), loss='binary_crossentropy', metrics=['binary_accuracy'])

test_hist = KerasEpochCallback(end_func=model.evaluate, end_args=(X_test, y_test), end_kwargs={'verbose': False})
train_hist = model.fit(X_train, y_train, validation_data=(X_val, y_val), 
                       epochs=ITERS, batch_size=8, shuffle=True, callbacks=[test_hist])

train_metrics = np.array(list(zip(train_hist.history['loss'], train_hist.history['binary_accuracy'])))
val_metrics = np.array(list(zip(train_hist.history['val_loss'], train_hist.history['val_binary_accuracy'])))
test_metrics = np.array(test_hist.end_results)

results = [{"Training Loss": train_loss, "Training Accuracy": train_acc,
            "Validation Loss": val_loss, "Validation Accuracy": val_acc,
            "Test Loss": test_loss, "Test Accuracy": test_acc} 
           for (train_loss, train_acc), (val_loss, val_acc), (test_loss, test_acc) 
           in zip(train_metrics, val_metrics, test_metrics)]

results_df = pd.DataFrame(results)
results_df.index += 1
results_df.style.highlight_min(subset=["Training Accuracy", "Validation Accuracy", "Test Accuracy"], color='lightcoral') \
                .highlight_max(subset=["Training Accuracy", "Validation Accuracy", "Test Accuracy"], color='lightgreen') \
                .highlight_min(subset=["Training Loss", "Validation Loss", "Test Loss"], color='lightgreen') \
                .highlight_max(subset=["Training Loss", "Validation Loss", "Test Loss"], color='lightcoral')

In [None]:
ax = plt.gca()
ax.plot(np.arange(0, ITERS), train_metrics.T[0], marker='o', c='#1f77b4', label="Training")
ax.plot(np.arange(0, ITERS), val_metrics.T[0], marker='o', c='#d62728', label="Validation")
ax.plot(np.arange(0, ITERS), test_metrics.T[0], marker='o', c='#2ca02c', label="Test")
ax.legend()
plt.xticks(np.arange(0, ITERS), np.arange(1, ITERS + 1))
plt.title("Loss")
plt.tight_layout()

In [None]:
ax = plt.gca()
ax.plot(np.arange(0, ITERS), train_metrics.T[1], marker='o', c='#1f77b4', label="Training")
ax.plot(np.arange(0, ITERS), val_metrics.T[1], marker='o', c='#d62728', label="Validation")
ax.plot(np.arange(0, ITERS), test_metrics.T[1], marker='o', c='#2ca02c', label="Test")
ax.legend()
plt.xticks(np.arange(0, ITERS), np.arange(1, ITERS + 1))
plt.title("Accuracy")
plt.tight_layout()

In [None]:
true_preds = model.predict(X_test).flatten()
y_preds = true_preds.round()
ConfusionMatrixDisplay.from_predictions(y_preds, y_test, normalize='true', colorbar=False,
                                        cmap=plt.cm.Blues, display_labels=("reliable", "unreliable"))