In [None]:
%load_ext nb_black
%load_ext autoreload

%autoreload 2

In [None]:
import os
from pathlib import Path

from requests import get
import pandas as pd
import numpy as np

from sklearn.metrics import roc_auc_score, log_loss
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.utils import plot_model
from tensorflow.keras.callbacks import EarlyStopping

import logging

logging.basicConfig(level=logging.WARN)

In [None]:
from thc_net.explainable_model.input_utils import preproc_dataset
from thc_net.explainable_model.model import build_model
from thc_net.explainable_model.random_utils import setup_seed

import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow

%matplotlib inline

In [None]:
def download(url, out, force=False, verify=True):
    out.parent.mkdir(parents=True, exist_ok=True)
    if force and out.exists():
        print(f"Removing file at {str(out)}")
        out.unlink()

    if out.exists():
        print("File already exists.")
        return
    print(f"Downloading {url} at {str(out)} ...")
    # open in binary mode
    with out.open(mode="wb") as file:
        # get request
        response = get(url, verify=verify)
        for chunk in response.iter_content(100000):
            # write to file
            file.write(chunk)


In [None]:
def plot_history(history):
    loss_list = [s for s in history.history.keys() if "loss" in s and "val" not in s]
    val_loss_list = [s for s in history.history.keys() if "loss" in s and "val" in s]
    acc_list = [s for s in history.history.keys() if "AUC" in s and "val" not in s]
    val_acc_list = [s for s in history.history.keys() if "AUC" in s and "val" in s]

    if len(loss_list) == 0:
        print("Loss is missing in history")
        return

    ## As loss always exists
    epochs = range(1, len(history.history[loss_list[0]]) + 1)

    ## Loss
    plt.figure(1)
    for l in loss_list:
        plt.plot(
            epochs,
            history.history[l],
            "b",
            label="Training loss ("
            + str(str(format(history.history[l][-1], ".5f")) + ")"),
        )
    for l in val_loss_list:
        plt.plot(
            epochs,
            history.history[l],
            "g",
            label="Validation loss ("
            + str(str(format(history.history[l][-1], ".5f")) + ")"),
        )

    plt.title("Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()

    plt.show()

In [None]:
dataset_name = "bank-marketing"
filename = "train_bench.csv"
target = "y"
ids = []

In [None]:
dataset_name = "give-me-some-credit"
filename = "train_bench.csv"
target = "SeriousDlqin2yrs"
ids = ["Unamed", "age"]

In [None]:
out = Path(os.getcwd()) / "data" / dataset_name / filename

In [None]:
train = pd.read_csv(out)
train.shape

In [None]:
if "Set" not in train.columns:
    print("Building tailored column")
    train_valid_index, test_index = next(
        StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=SEED).split(
            range(train[target].shape[0]), train[target].values
        )
    )
    train_index, valid_index = next(
        StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=SEED).split(
            train_valid_index, train[target].values[train_valid_index]
        )
    )
    train["Set"] = "train"
    train["Set"][valid_index] = "valid"
    train["Set"][test_index] = "test"
    # train.to_csv((out.parent / "train_bench.csv").as_posix(), index=False)

In [None]:
train_indices = train[train.Set == "train"].index
valid_indices = train[train.Set == "valid"].index
test_indices = train[train.Set == "test"].index

In [None]:
input_train, params = preproc_dataset(train.loc[train_indices], target, ids + ["Set"])
params

In [None]:
input_valid, _ = preproc_dataset(
    train.loc[valid_indices], target, ids + ["Set"], params
)
input_test, _ = preproc_dataset(train.loc[test_indices], target, ids + ["Set"], params)

In [None]:
target_encoder = LabelEncoder()

In [None]:
train[target] = target_encoder.fit_transform(train[target].values.reshape(-1))
y_train = train[target].values[train_indices]
y_valid = train[target].values[valid_indices]
y_test = train[target].values[test_indices]

In [None]:
model = build_model(params, lconv_dim=[64, 32], lconv_num_dim=[16],)

In [None]:
model.get_layer("output")._build_input_shape

In [None]:
model.summary()

In [None]:
#!pip install pydot graphviz

In [None]:
plot_model(
    model,
    # to_file="model.png",
    show_shapes=True,
    show_layer_names=True,
    rankdir="TB",
    expand_nested=False,
    dpi=96,
)

In [None]:
y_train.shape

In [None]:
%%time
history = model.fit(
    input_train,
    y_train.reshape(-1, 1),
    epochs=2000,
    batch_size=1024,
    validation_data=(input_valid, y_valid.reshape(-1, 1),),
    verbose=1,
    callbacks=[EarlyStopping(monitor="val_loss", patience=5, verbose=1)],
)

In [None]:
plot_history(history)

In [None]:
model_auc = roc_auc_score(
    y_true=y_valid, y_score=model.predict(input_valid).reshape(-1),
)
model_auc

In [None]:
# BM : 0.7847761386793823
# Census : 0.9461137700867462
# give me some credit : 0.8584216818313924

In [None]:
model_auc = roc_auc_score(y_true=y_test, y_score=model.predict(input_test).reshape(-1),)
model_auc

In [None]:
from thc_net.explainable_model.model import predict

In [None]:
probs, explanations, sig_agg_explanations = predict(model, input_valid)

In [None]:
probs[0], explanations[0], sig_agg_explanations[0]

In [None]:
import matplotlib.pyplot as plt

# plt.rcdefaults()
import numpy as np
import matplotlib.pyplot as plt


def explain_plot(importances, columns):
    #     selection = np.argsort(np.absolute(importances))[:10]
    indexes = np.argsort(importances)
    performance = importances[indexes]
    #     print(performance.shape)
    y_pos = np.arange(performance.shape[0])

    plt.barh(y_pos, performance, align="center", alpha=0.5)
    plt.yticks(y_pos, columns[indexes])
    # plt.xlabel('Usage')
    plt.title("Feature importance")

    plt.show()

In [None]:
all_cols = np.array(params["bool_cols"] + params["num_cols"] + params["cat_cols"])
all_cols

In [None]:
explain_plot(sig_agg_explanations[0], all_cols)
probs[0].item()

In [None]:
explain_plot(explanations[0], all_cols)
probs[0].item()

In [None]:
for i in range(100):
    # explain_plot(explanations[i], all_cols)
    explain_plot(sig_agg_explanations[i], all_cols)
    print(probs[i].item())

In [None]:
#  BM : 0.8091600443913225
# Census : 0.9467201048401863
# give me some credit : 0.8599316528022821

In [None]:
# New version V3 => number are fillna, and activation is tanh instead of mish


In [None]:
# NEW VERSION
# Bank marketing
# valid 0.7974101623084582 test 0.8133980360868731     conv_dim=[],    lconv_dim=[128, 64, 32],    lconv_num_dim=[64, 32, 16], patience 20
# RL
# valid 0.9334586431074957 test 0.9331843177543191     conv_dim=[],    lconv_dim=[128, 64, 32],    lconv_num_dim=[64, 32, 16], patience 20

In [None]:
# Census example
# valid 0.9282381974389771 test 0.9262939626480025 conv_dim=[64], lconv_dim=[128, 64, 32] patience 50

# RL
# valid 0.9363136991351992 test 0.9431532242454923 conv_dim=[64], lconv_dim=[128, 64, 32] patience 50

# Open payments
# valid 0.9395366568006073 test 0.9370193221838594 conv_dim=[64], lconv_dim=[128, 64, 32] patience 50

# give-me-some-credit
# valid  test  conv_dim=[64], lconv_dim=[128, 64, 32] patience 50