In [None]:
%load_ext nb_black
%load_ext autoreload

%autoreload 2

In [None]:
import os
from pathlib import Path

from requests import get
import pandas as pd
import numpy as np

from sklearn.metrics import roc_auc_score, log_loss
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.utils import plot_model
from tensorflow.keras.callbacks import EarlyStopping

import logging

logging.basicConfig(level=logging.WARN)

In [None]:
from thc_net.explainable_model.input_utils import preproc_dataset
from thc_net.explainable_model.model import build_model
from thc_net.explainable_model.random_utils import setup_seed

import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow

%matplotlib inline

In [None]:
setup_seed()

In [None]:
def download(url, out, force=False, verify=True):
    out.parent.mkdir(parents=True, exist_ok=True)
    if force and out.exists():
        print(f"Removing file at {str(out)}")
        out.unlink()

    if out.exists():
        print("File already exists.")
        return
    print(f"Downloading {url} at {str(out)} ...")
    # open in binary mode
    with out.open(mode="wb") as file:
        # get request
        response = get(url, verify=verify)
        for chunk in response.iter_content(100000):
            # write to file
            file.write(chunk)


In [None]:
def plot_history(history):
    loss_list = [s for s in history.history.keys() if "loss" in s and "val" not in s]
    val_loss_list = [s for s in history.history.keys() if "loss" in s and "val" in s]
    acc_list = [s for s in history.history.keys() if "AUC" in s and "val" not in s]
    val_acc_list = [s for s in history.history.keys() if "AUC" in s and "val" in s]

    if len(loss_list) == 0:
        print("Loss is missing in history")
        return

    ## As loss always exists
    epochs = range(1, len(history.history[loss_list[0]]) + 1)

    ## Loss
    plt.figure(1)
    for l in loss_list:
        plt.plot(
            epochs,
            history.history[l],
            "b",
            label="Training loss ("
            + str(str(format(history.history[l][-1], ".5f")) + ")"),
        )
    for l in val_loss_list:
        plt.plot(
            epochs,
            history.history[l],
            "g",
            label="Validation loss ("
            + str(str(format(history.history[l][-1], ".5f")) + ")"),
        )

    plt.title("Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()

    plt.show()

In [None]:
dataset_name = "portoseguro"
filename = "train_bench.csv"
target = "target"
ids = []

In [None]:
dataset_name = "homesite-quote-conversion"
filename = "train_bench.csv"
target = "QuoteConversion_Flag"
ids = []

In [None]:
dataset_name = "bnp-cardif"
filename = "train_bench.csv"
target = "target"
ids = []

In [None]:
dataset_name = "cat-in-the-dat-ii"
filename = "train_bench.csv"
target = "target"
ids = []

In [None]:
dataset_name = "albert"
filename = "train_bench.csv"
target = "target"
ids = []

In [None]:
dataset_name = "census-income"
filename = "train_bench.csv"
target = "taxable income amount"
ids = []

In [None]:
dataset_name = "road-safety"
filename = "train_bench.csv"
target = "Sex_of_Driver_df_res"
ids = []

In [None]:
dataset_name = "open-payments"
filename = "train_bench.csv"
target = "status"
ids = []

In [None]:
dataset_name = "bank-marketing"
filename = "train_bench.csv"
target = "y"
ids = []

In [None]:
dataset_name = "give-me-some-credit"
filename = "train_bench.csv"
target = "SeriousDlqin2yrs"
ids = ["Unamed", "age"]

In [None]:
out = Path(os.getcwd()) / "data" / dataset_name / filename

In [None]:
train = pd.read_csv(out)
train.shape

In [None]:
if "Set" not in train.columns:
    print("Building tailored column")
    train_valid_index, test_index = next(
        StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=SEED).split(
            range(train[target].shape[0]), train[target].values
        )
    )
    train_index, valid_index = next(
        StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=SEED).split(
            train_valid_index, train[target].values[train_valid_index]
        )
    )
    train["Set"] = "train"
    train["Set"][valid_index] = "valid"
    train["Set"][test_index] = "test"
    # train.to_csv((out.parent / "train_bench.csv").as_posix(), index=False)

In [None]:
train_indices = train[train.Set == "train"].index
valid_indices = train[train.Set == "valid"].index
test_indices = train[train.Set == "test"].index

In [None]:
input_train, params = preproc_dataset(train.loc[train_indices], target, ids + ["Set"])
params

In [None]:
input_valid, _ = preproc_dataset(
    train.loc[valid_indices], target, ids + ["Set"], params
)
input_test, _ = preproc_dataset(train.loc[test_indices], target, ids + ["Set"], params)

In [None]:
target_encoder = LabelEncoder()

In [None]:
train[target] = target_encoder.fit_transform(train[target].values.reshape(-1))
y_train = train[target].values[train_indices]
y_valid = train[target].values[valid_indices]
y_test = train[target].values[test_indices]

In [None]:
model = build_model(params, lconv_dim=[256, 64, 32], lconv_num_dim=[128, 16],)

In [None]:
model.get_layer("output")._build_input_shape

In [None]:
model.summary()

In [None]:
# !pip install pydot graphviz

In [None]:
plot_model(
    model,
    # to_file="model.png",
    show_shapes=True,
    show_layer_names=True,
    rankdir="TB",
    expand_nested=False,
    dpi=96,
)

In [None]:
y_train.shape

In [None]:
counts = np.unique(y_train, return_counts=True)[1]
counts = counts.sum() / counts

In [None]:
class_weight = {
    0: counts[0],
    1: counts[1],
}
class_weight

In [None]:
%%time
history = model.fit(
    input_train,
    y_train.reshape(-1, 1),
    epochs=2000,
    batch_size=1024,
    validation_data=(input_valid, y_valid.reshape(-1, 1),),
    verbose=1,
    callbacks=[EarlyStopping(monitor="val_loss", patience=10, verbose=1)],
    class_weight=class_weight
)

In [None]:
plot_history(history)

In [None]:
model_auc = roc_auc_score(
    y_true=y_valid, y_score=model.predict(input_valid).reshape(-1),
)
model_auc

In [None]:
# porto seguro 0.6232119567809998
# homesite-quote-conversion 0.9523864163447144
# cat in dat II 0.7664026194097782
# albert 0.7204053242303463
# census-income 0.937909089697099
# open payments 0.9259053788101843
# give-me-some-credit 0.8579704956990437
# bank-marketing 0.7783336895486428

In [None]:
model_auc = roc_auc_score(y_true=y_test, y_score=model.predict(input_test).reshape(-1),)
model_auc

In [None]:
# porto seguro 0.6271530652897266
# homesite-quote-conversion 0.9551354942949817
# cat in dat II 0.759840080492485
# albert 0.7160934658045495
# census-income 0.9349738109864048
# open payments 0.9274364959503505
# give-me-some-credit 0.8559444719192335
# bank-marketing 0.804706871986707

In [None]:
from thc_net.explainable_model.model import predict, encode

In [None]:
probs, explanations = predict(model, input_test)

In [None]:
probs, encoded_output = encode(model, input_test)

In [None]:
y_test.shape

In [None]:
encoded_output.shape

In [None]:
probs[0], explanations[0].shape

In [None]:
import matplotlib.pyplot as plt

# plt.rcdefaults()
import numpy as np
import matplotlib.pyplot as plt


def explain_plot(importances, columns):
    selection = np.argsort(-np.absolute(importances))[:10]
    # indexes = np.argsort(importances)
    performance = importances[selection]
    #     print(performance.shape)
    y_pos = np.arange(performance.shape[0])

    plt.barh(y_pos, performance, align="center", alpha=0.5)
    plt.yticks(y_pos, columns[selection])
    # plt.xlabel('Usage')
    plt.title("Feature importance")

    plt.show()

In [None]:
all_cols = np.array(params["bool_cols"] + params["num_cols"] + params["cat_cols"])
all_cols

In [None]:
explain_plot(explanations[0], all_cols)
probs[0].item()

In [None]:
for i in range(20):
    explain_plot(explanations[i], all_cols)
    print(probs[i].item())

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
np.unique(y_test, return_counts=True)

In [None]:
confusion_matrix(
    y_true=y_test,
    y_pred=model.predict(input_test).reshape(-1) >= 0.5,
    labels=target_encoder.classes_,
)