In [None]:
from Dataset.dataset import FlchainSub1, PM
import numpy as np
from sksurv.ensemble import RandomSurvivalForest
import pickle
from model import SurvCounterfactual
import time
import pandas as pd
import matplotlib.pyplot as plt
from pyswarms.utils.plotters import (plot_cost_history, plot_contour, plot_surface)
from matplotlib.ticker import FormatStrFormatter
import shap
import os

# Load the Data

In [None]:
exp_name = 'PM_Example'
os.makedirs(f'Results/{exp_name}')

In [None]:
ds = PM('Dataset/PM.csv')

In [None]:
(x_train, ye_train, y_train, e_train,
 x_val, ye_val, y_val, e_val,
 x_test, ye_test, y_test, e_test) = ds.get_train_val_test_from_splits(test_id=0, val_id=1)

print(x_train.shape, x_val.shape, x_test.shape)

In [None]:
def get_event_time_data(subset):
    T = subset[2]
    E = subset[3]
    dtypes = np.dtype('bool,float')
    ET = np.array([(bool(e), t) for e, t in zip(E, T)], dtype=dtypes)
    
    return ET

def get_idx_data(subset):
    return subset[0]

def get_features_data(subset):
    return subset[1]

def get_time_data(subset):
    return subset[2]

def get_event_data(subset):
    return subset[3]

def plot_pca_time(x, pca_mdl, labels, col_names=None, suffix='a', label='Data', size=30):
    if pca_mdl is None:
        x_pca = x
    else:
        x_pca = pca_mdl.transform(x)
    
    if col_names is None:
        col_names = ['PC0', 'PC1', 'PC2']
    
    if x_pca.shape[1] == 2:
        k = 1
    else:
        k=3
        
    fig, ax = plt.subplots(1,k, figsize=(7*k, 5))
    
    plt.title("Data distribution (PCA)")
    if k == 1:
        im0= ax.scatter(x_pca[:, 0], x_pca[:, 1], c=labels, alpha=1, cmap='viridis', s=size, label=label)
        ax.set_xlabel(col_names[0])
        ax.set_ylabel(col_names[1])
        fig.colorbar(im0, ax=ax, orientation='vertical')
    else:
        im0= ax[0].scatter(x_pca[:, 0], x_pca[:, 1], c=labels, alpha=1, cmap='viridis', s=size, label=label)
        ax[0].set_xlabel(col_names[0])
        ax[0].set_ylabel(col_names[1])
        fig.colorbar(im0, ax=ax[0], orientation='vertical')

        im1= ax[1].scatter(x_pca[:, 0], x_pca[:, 2], c=labels, alpha=1, cmap='viridis', s=size, label=label)
        ax[1].set_xlabel(col_names[0])
        ax[1].set_ylabel(col_names[2])
        fig.colorbar(im1, ax=ax[1], orientation='vertical')


        im2=ax[2].scatter(x_pca[:, 1], x_pca[:, 2], c=labels, alpha=1, cmap='viridis', s=size, label=label)
        ax[2].set_xlabel(col_names[1])
        ax[2].set_ylabel(col_names[2])
        fig.colorbar(im2, ax=ax[2], orientation='vertical')
    
    plt.tight_layout()
    plt.savefig(f'Results/{exp_name}/scatter_time_{suffix}.pdf', format='pdf', bbox_inches='tight')
    
    
def plot_pca_time_3D(x, pca_mdl, labels, col_names=None, suffix='a', label='Data', size=30):
    if pca_mdl is None:
        x_pca = x
    else:
        x_pca = pca_mdl.transform(x)
        
    if col_names is None:
        col_names = ['PC0', 'PC1', 'PC2']
    
    if x_pca.shape[1] <3:
        print('Data is less than 3D')
        return
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(projection='3d')
    ax.scatter(x_pca[:, 0], x_pca[:, 1], x_pca[:, 2], c=labels, alpha=1, cmap='viridis', s=size, label=label)
    ax.view_init(elev=30., azim=-60)
    #ax.view_init(elev=10., azim=30)
    ax.set_xlabel(col_names[0])
    ax.set_ylabel(col_names[1])
    ax.set_zlabel(col_names[2])
    plt.savefig(f'Results/{exp_name}/scatter_time_3D_{suffix}.pdf', format='pdf', bbox_inches='tight')
    
def plot_pca_patterns(x, labels, pca_mdl, col_names=None, x_origs=None, x_cfacts=None, suffix='a', size=30, csize=60, alpha=0.5):
    if pca_mdl is None:
        x_pca = x
        if x_origs is not None:
            x_pca_origs = x_origs
        if x_cfacts is not None:
            x_pca_cfacts = x_cfacts
    else:
        x_pca = pca_mdl.transform(x)
        if x_origs is not None:
            x_pca_origs = pca_mdl.transform(x_origs)
        if x_cfacts is not None:
            x_pca_cfacts = pca_mdl.transform(x_cfacts)
        
    if col_names is None:
        col_names = ['PC0', 'PC1', 'PC2']
        
    
    if x_pca.shape[1] == 2:
        k = 1
    else:
        k=3
    fig, ax = plt.subplots(1,k, figsize=(7*k, 5))
    plt.title("Data distribution (PCA)")
    for p in set(labels):
        if k == 1:
            ax.scatter(x_pca[labels==p, 0], x_pca[labels==p, 1], c=f'C{p}', alpha=alpha, s=size, label=f"Pattern {p}")
            ax.set_xlabel(col_names[0])
            ax.set_ylabel(col_names[1])
        else:
            ax[0].scatter(x_pca[labels==p, 0], x_pca[labels==p, 1], c=f'C{p}', alpha=alpha, s=size, label=f"Pattern {p}")
            ax[0].set_xlabel(col_names[0])
            ax[0].set_ylabel(col_names[1])

            ax[1].scatter(x_pca[labels==p, 0], x_pca[labels==p, 2], c=f'C{p}', alpha=alpha, s=size, label=f"Pattern {p}")
            ax[1].set_xlabel(col_names[0])
            ax[1].set_ylabel(col_names[2])

            ax[2].scatter(x_pca[labels==p, 1], x_pca[labels==p, 2], c=f'C{p}', alpha=alpha, s=size, label=f"Pattern {p}")
            ax[2].set_xlabel(col_names[1])
            ax[2].set_ylabel(col_names[2])
        
    if x_origs is not None:
        if k == 1:
            ax.scatter(x_pca_origs[:, 0], x_pca_origs[:, 1], c='k', alpha=1, s=csize, marker='*', label="Originals")
        else:
            ax[0].scatter(x_pca_origs[:, 0], x_pca_origs[:, 1], c='k', alpha=1, s=csize, marker='*', label="Originals")
            ax[1].scatter(x_pca_origs[:, 0], x_pca_origs[:, 2], c='k', alpha=1, s=csize, marker='*', label="Originals")
            ax[2].scatter(x_pca_origs[:, 1], x_pca_origs[:, 2], c='k', alpha=1, s=csize, marker='*', label="Originals")
    if x_cfacts is not None:
        if k == 1:
            ax.scatter(x_pca_cfacts[:, 0], x_pca_cfacts[:, 1], c='k', alpha=1, s=csize, marker='^', label="Originals")
        else:
            ax[0].scatter(x_pca_cfacts[:, 0], x_pca_cfacts[:, 1], c='k', alpha=1, s=csize, marker='^', label="Originals")
            ax[1].scatter(x_pca_cfacts[:, 0], x_pca_cfacts[:, 2], c='k', alpha=1, s=csize, marker='^', label="Originals")
            ax[2].scatter(x_pca_cfacts[:, 1], x_pca_cfacts[:, 2], c='k', alpha=1, s=csize, marker='^', label="Counterfacts")
    plt.tight_layout()
    plt.legend()
    plt.savefig(f'Results/{exp_name}/scatter_patterns_{suffix}.pdf', format='pdf', bbox_inches='tight')

def plot_pca_patterns_3D(x, labels, pca_mdl, col_names=None, x_origs=None, x_cfacts=None, x_cfacts1=None, suffix='a', size=30, csize=60, alpha=0.5, ax=None):
    if pca_mdl is None:
        x_pca = x
        if x_origs is not None:
            x_pca_origs = x_origs
        if x_cfacts is not None:
            x_pca_cfacts = x_cfacts
        if x_cfacts1 is not None:
            x_pca_cfacts1 = x_cfacts1
    else:
        x_pca = pca_mdl.transform(x)
        if x_origs is not None:
            x_pca_origs = pca_mdl.transform(x_origs)
        if x_cfacts is not None:
            x_pca_cfacts = pca_mdl.transform(x_cfacts)
        if x_cfacts1 is not None:
            x_pca_cfacts1 = pca_mdl.transform(x_cfacts1)
        
    if col_names is None:
        col_names = ['PC0', 'PC1', 'PC2']
    
    if x_pca.shape[1] <3:
        print('Data is less than 3D')
        return
    
    if ax==None:
        fig = plt.figure(figsize=(5, 5))
        ax = fig.add_subplot(projection='3d')
    for p in set(explainer.labels_train):
        ax.scatter(X_pca_train[explainer.labels_train==p, 0], X_pca_train[explainer.labels_train==p, 1], X_pca_train[explainer.labels_train==p, 2], c=f'C{p}', alpha=alpha, s=size, label=f"Pattern {p}")
    
    if x_origs is not None:
        ax.scatter(x_pca_origs[:, 0], x_pca_origs[:, 1], x_pca_origs[:, 2], c='k', alpha=1, s=csize, marker='x', label="Original Points")
    if x_cfacts is not None:
        ax.scatter(x_pca_cfacts[:, 0], x_pca_cfacts[:, 1], x_pca_cfacts[:, 2], marker='^', facecolors='limegreen', edgecolors='w', alpha=1, s=csize*2, label="Counterfactuals w/o AE")
    if x_cfacts1 is not None:
        ax.scatter(x_pca_cfacts1[:, 0], x_pca_cfacts1[:, 1], x_pca_cfacts1[:, 2], marker='s', facecolors='dodgerblue', edgecolors='w', alpha=1, s=csize, label="Counterfactuals w AE")
    
    ax.view_init(elev=30., azim=-60)
    ax.set_xlabel(col_names[0])
    ax.set_ylabel(col_names[1])
    ax.set_zlabel(col_names[2])
    leg = plt.legend(loc=(1.15, 0.3))
    for lh in leg.legendHandles: 
        lh.set_alpha(1)
    plt.locator_params(axis='both', nbins=4)
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
    plt.savefig(f'Results/{exp_name}/scatter_patterns_3D_{suffix}.pdf', format='pdf', bbox_inches='tight')
    return ax


def find_counterfactuals(x, targets, explainer, 
                         feature_names,
                         feature_types=None,
                         ohe_features=None,
                         mask=None, 
                         norm=1, 
                         anomaly_model=None,
                         anomaly_threshold=0,
                         n_particles=500, n_iterations=100000, patience=1000, 
                         loss_distance_weight=1,
                         loss_anomaly_weight=1, 
                         loss_target_weight=1e5, 
                         loss_mutual_exclusions_weight=1e5):
    x_cfacts = []
    p_preds = []
    hists = []
    for (x_sample, p_target) in zip(x, targets):
        x_cfact = explainer.explain(x=x_sample,
                                    target_pattern=p_target, 
                                    features_names_list=feature_names,
                                    feature_types = feature_types,
                                    ohe_features=ohe_features,
                                    mask= mask,
                                    norm=norm,
                                    anomaly_model=anomaly_model,
                                    anomaly_threshold=anomaly_threshold,
                                    n_particles=n_particles,
                                    n_iterations=n_iterations,
                                    patience=patience,
                                    loss_distance_weight=loss_distance_weight,
                                    loss_anomaly_weight=loss_anomaly_weight,
                                    loss_target_weight=loss_target_weight, 
                                    loss_mutual_exclusions_weight=loss_mutual_exclusions_weight
                                   )
        p_pred = explainer.predict(x_cfact[np.newaxis])

        p_preds.append(p_pred)
        x_cfacts.append(x_cfact)
        hists.append(explainer.optimizer.history)
    
    return np.array(x_cfacts), np.array(p_preds), hists

In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components=3)
pca.fit(x_train)
X_pca_train = pca.transform(x_train)
print('PCA explained variance: {:.2f} %'.format(pca.explained_variance_ratio_.sum()*100))

In [None]:
plot_pca_time(x_train, pca_mdl=pca, col_names=None, suffix='a', labels=y_train)

In [None]:
plot_pca_time_3D(x_train, pca_mdl=pca, col_names=None, suffix='a', labels=y_train)

# Train Survival Model (Random Survival Forest)

In [None]:
# special for RSF
dt = np.dtype('bool,float')
y_train_surv = np.array([(bool(e), y) for e, y in zip(e_train, y_train)], dtype=dt)
y_val_surv = np.array([(bool(e), y) for e, y in zip(e_val, y_val)], dtype=dt)
y_test_surv = np.array([(bool(e), y) for e, y in zip(e_test, y_test)], dtype=dt)
print(y_train_surv.shape)

# train RSF
rsf = RandomSurvivalForest(n_estimators=15,
                           min_samples_split=20,
                           min_samples_leaf=10,
                           max_features="sqrt",
                           oob_score=True,
                           n_jobs=-1,
                           random_state=20)
rsf.fit(x_train, y_train_surv)

cindex_train = rsf.score(x_train, y_train_surv)
cindex_oob = rsf.oob_score_
cindex_val = rsf.score(x_val, y_val_surv)
cindex_test = rsf.score(x_test, y_test_surv)

print('Train cindex {:.2f}'.format(cindex_train*100))
print('Val cindex {:.2f}'.format(cindex_val*100))
print('Test cindex {:.2f}'.format(cindex_test*100))
print('oob cindex {:.2f}'.format(cindex_oob*100))

# Prepare the Data and the Survival Curves for SurvCounterFactual

In [None]:
surv_train = rsf.predict_survival_function(x_train, return_array=True)
surv_val = rsf.predict_survival_function(x_val, return_array=True)
surv_test = rsf.predict_survival_function(x_test, return_array=True)

event_times=rsf.event_times_

#Prepare Data for Explanation
xte_data = (x_train, y_train, e_train,
            x_val, y_val, e_val,
            x_test, y_test, e_test)

#Prepare the Survival Curves for Explanation
survival_curves = (surv_train, surv_val, surv_test)



# Explain

In [None]:
explainer = SurvCounterfactual(prefix_name=exp_name, max_depth=5)
explainer.fit(xte_data=xte_data, survival_curves=survival_curves, event_times=event_times, survival_mdl=rsf)

In [None]:
plot_pca_patterns(x_train, explainer.labels_train, pca_mdl=pca, suffix='patterns', col_names=None)

In [None]:
plot_pca_patterns_3D(x_train, explainer.labels_train, pca_mdl=pca, suffix='patterns', col_names=None, alpha=0.01, size=200, csize=50)

In [None]:
feature_types = ['float', 'float', 'float', 'bool', 'bool', 'bool', 'bool', 'bool', 'bool', 'bool']

ohe_features = [[3, 4, 5], [6, 7, 8, 9]]

mask = [1]*x_train.shape[1]


# Without Autoencoder

In [None]:
y_source = 7
y_target = 4
n_samples = 2#x_train[explainer.labels_train==y_source].shape[0]
x_sources = x_train[explainer.labels_train==y_source][:n_samples].copy()
y_targets = np.array([y_target]*n_samples)

In [None]:
%%time
x_cfacts, p_preds, hists =find_counterfactuals(x=x_sources,
                                        targets=y_targets, 
                                        explainer=explainer, 
                                        feature_names=ds.feature_names,
                                        feature_types=feature_types,
                                        ohe_features=ohe_features,
                                        mask=mask, 
                                        norm=1, 
                                        anomaly_model=None, 
                                        n_particles=100, n_iterations=100000, patience=200, 
                                        loss_distance_weight=1,
                                        loss_anomaly_weight=1, 
                                        loss_target_weight=1e2, 
                                        loss_mutual_exclusions_weight=0)

In [None]:
for i,fs in enumerate(ohe_features):
    fig, ax = plt.subplots(1, 3, figsize=(9, 2))
    ax[0].bar(np.array(ds.feature_names)[fs], x_sources[:,fs].sum(axis=0))
    ax[0].set_xticklabels(np.array(ds.feature_names)[fs], rotation=45)
    ax[0].set_title(f'Source Pattern {y_source}')
    ax[0].yaxis.set_major_formatter(FormatStrFormatter('%.d'))
    ax[1].bar(np.array(ds.feature_names)[fs], x_cfacts[:,fs].sum(axis=0))
    ax[1].set_xticklabels(np.array(ds.feature_names)[fs], rotation=45)
    ax[1].set_title(f'CounterFactual Target Pattern {y_target}')
    ax[1].yaxis.set_major_formatter(FormatStrFormatter('%.d'))
    ax[2].bar(np.array(ds.feature_names)[fs], x_train[explainer.labels_train==y_target][:,fs].sum(axis=0))
    ax[2].set_xticklabels(np.array(ds.feature_names)[fs], rotation=45)
    ax[2].set_title(f'Data Target Pattern {y_target}')
    ax[2].yaxis.set_major_formatter(FormatStrFormatter('%.d'))
    plt.savefig(f'Results/{exp_name}/OHE_source_target_cfacts_distributions_{i}_no_AE.pdf', format='pdf', bbox_inches='tight')

In [None]:
plot_pca_patterns(x_train, explainer.labels_train, pca_mdl=pca, col_names=None, suffix='no_AE', x_origs=x_sources, x_cfacts=x_cfacts,size=50, csize=50, alpha=0.5)

In [None]:
ax=plot_pca_patterns_3D(x_train, explainer.labels_train, pca_mdl=pca, col_names=None, suffix='no_AE', x_origs=x_sources, x_cfacts=x_cfacts,size=200, csize=100, alpha=0.01)

# Autoencoder

In [None]:
from autoencoder import Autoencoder, AutoencoderDataset, AutoencoderLearner
import torch
from torch.nn import MSELoss

train_loader = AutoencoderDataset(x_train)
val_loader = AutoencoderDataset(x_val)

autoencoder = Autoencoder(n_features=x_train.shape[-1],
                         hidden_layers_size=[16, 16],
                         latent_size=4,
                         activation="relu",
                         last_activation="sigmoid")

# optimizer = torch.optim.Adam(autoencoder.parameters(), lr=1e-3)
# loss_function = MSELoss()

# train_loss_list, valid_loss_list = AutoencoderLearner.run_training(autoencoder, optimizer, loss_function, 
#                                                                    train_loader, val_loader, epochs=1000,
#                                                                    early_stopping=True,
#                                                                    early_stopping_patience=50,
#                                                                    early_stopping_delta=1e-5,
#                                                                   )

# autoencoder.save_weights('Flchain_autoencoder.mdl')

# plt.figure(figsize=(8, 5))
# plt.plot(train_loss_list, label="Train Loss")
# plt.plot(valid_loss_list, label="Test Loss")
# plt.legend()
# plt.tight_layout()
# plt.show()

In [None]:
autoencoder.load_weights('PM_New_autoencoder.mdl')

# Autoencoder error distribution

In [None]:
x_train_error = autoencoder.anomaly_score_multi(x_train)
x_val_error = autoencoder.anomaly_score_multi(x_val)
x_test_error = autoencoder.anomaly_score_multi(x_test)

In [None]:
plt.boxplot([x_train_error, x_val_error, x_test_error]);

In [None]:
anomaly_threshold = np.quantile(x_test_error, 0.75) + 1.5*(np.quantile(x_test_error, 0.75) - np.quantile(x_test_error, 0.25))

# With Autoencoder

In [None]:
x_cfacts1, p_preds1, hist1 =find_counterfactuals(x=x_sources,
                                        targets=y_targets, 
                                        explainer=explainer, 
                                        feature_names=ds.feature_names,
                                        feature_types=feature_types,
                                        ohe_features=ohe_features,
                                        mask=mask, 
                                        norm=1, 
                                        anomaly_model=autoencoder,
                                        anomaly_threshold=anomaly_threshold,
                                        n_particles=100, n_iterations=100000, patience=200, 
                                        loss_distance_weight=1,
                                        loss_anomaly_weight=1e2, 
                                        loss_target_weight=1e2, 
                                        loss_mutual_exclusions_weight=1e2)

In [None]:
plot_pca_patterns(x_train, explainer.labels_train, pca_mdl=pca, col_names=None, x_origs=x_sources, x_cfacts=x_cfacts1, suffix='w_AE', size=50, csize=50, alpha=0.1)

In [None]:
plot_pca_patterns_3D(x_train, explainer.labels_train, pca_mdl=pca, col_names=None, x_origs=x_sources, x_cfacts=x_cfacts, x_cfacts1=x_cfacts1, suffix='w_wo_AE', size=200, csize=70, alpha=0.01)

In [None]:
def f1(x):
    if len(x)>1:
        return x[1]
    else:
        return x[0]
feature_names = [f1(s.split('_')) for s in ds.feature_names]

In [None]:
for i, fs in enumerate(ohe_features):
    fig, ax = plt.subplots(1, 4, figsize=(12, 2))
    ax[0].bar(np.array(feature_names)[fs], x_sources[:,fs].sum(axis=0))
    ax[0].set_xticklabels(np.array(feature_names)[fs], rotation=45)
    ax[0].set_title(f'Source Pattern {y_source}')
    ax[1].bar(np.array(feature_names)[fs], x_train[explainer.labels_train==y_target][:,fs].sum(axis=0))
    ax[1].set_xticklabels(np.array(feature_names)[fs], rotation=45)
    ax[1].set_title(f'Target Pattern {y_target}')
    ax[2].bar(np.array(feature_names)[fs], x_cfacts[:,fs].sum(axis=0))
    ax[2].set_xticklabels(np.array(feature_names)[fs], rotation=45)
    ax[2].set_title(f'CounterFactuals w/o AE')
    ax[2].yaxis.set_major_formatter(FormatStrFormatter('%.d'))
    ax[3].bar(np.array(feature_names)[fs], x_cfacts1[:,fs].sum(axis=0))
    ax[3].set_xticklabels(np.array(feature_names)[fs], rotation=45)
    ax[3].set_title(f'CounterFactuals w AE')
    
    plt.savefig(f'Results/{exp_name}/OHE_source_target_cfacts_distributions_{i}_w_wo_AE.pdf', format='pdf', bbox_inches='tight')

    