In [24]:
import numpy as np
import csv
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import sys
import statsmodels.api as sm
from pathlib import Path

from sklearn import metrics
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.linear_model import LinearRegression
from sklearn.neural_network import MLPRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.compose import make_column_selector as selector
from sklearn.compose import ColumnTransformer, make_column_transformer
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler, RobustScaler, OrdinalEncoder
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.base import BaseEstimator, TransformerMixin
from lime.lime_tabular import LimeTabularExplainer
import warnings
import os
import shutil
from libraries import create_explanations, summaryPlot, HeatMap_plot, Waterfall, Decision_plot
import dice_ml

def smape(y_true, y_pred):
    return 100/len(y_true) * np.sum(np.abs(y_pred - y_true) / (np.abs(y_true) + np.abs(y_pred)))

def predict_proba_wrapper(X):
    return model.predict(X)

if __name__ == "__main__":

    '''if(len(sys.argv)<2):
        print("ERROR! Usage: python scriptName.py fileCSV targetN modelloML\n")
              
        sys.exit(1)
    nome_script, mlModel = sys.argv'''
    mlModel = 'dt'

    if not sys.warnoptions:
        warnings.simplefilter("ignore")
        os.environ["PYTHONWARNINGS"] = "ignore"

    dataset = pd.read_csv('insurance.csv', sep=',')

    if not os.path.exists(mlModel+'/lime'):
        os.makedirs(mlModel+'/lime')

    if not os.path.exists(mlModel+'/shap'):
        os.makedirs(mlModel+'/shap')

    X = dataset.drop(columns=['charges'])
    y = dataset['charges']
    labels = X.columns.values

    categorical_features = ['sex', 'smoker', 'region']
    numeric_features = ['age', 'bmi', 'children']

    numeric_transformer = Pipeline(steps=[
                                      ('imputer', SimpleImputer(strategy='median')),
                                      ('scaler', StandardScaler())])

    categorical_transformer = Pipeline(steps=[
                                          ('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
                                          ('ordinal', OrdinalEncoder(handle_unknown='error'))])    

    preprocessor = ColumnTransformer(
                                 transformers=[
                                               ('num', numeric_transformer, numeric_features),
                                               ('cat', categorical_transformer, categorical_features)])
    model_reg = ['lr',
                'dt',
                'rf',
                'gbr']

    param_lr = [{'fit_intercept':[True,False], 'normalize':[True,False], 'copy_X':[True, False]}]

    param_dt = [{'max_depth': [5,10,20]}]

    param_rf = [{'bootstrap': [True, False],
                 'max_depth': [10, 20],
                 'max_features': ['auto', 'sqrt'],
                 'min_samples_leaf': [1, 2, 4],
                 'min_samples_split': [2],}]

    param_gbr = [{'learning_rate': [0.01,0.03],
                'subsample'    : [0.5, 0.2],
                'n_estimators' : [100,200],
                'max_depth'    : [4,8]}]

    models_regression = {
        'lr': {'name': 'Linear Regression',
               'estimator': LinearRegression(),
               'param': param_lr,
              },
        'dt': {'name': 'Decision Tree',
               'estimator': DecisionTreeRegressor(random_state=42),
               'param': param_dt,
              },
        'rf': {'name': 'Random Forest',
               'estimator': RandomForestRegressor(random_state=42),
               'param': param_rf,
              },

        'gbr': {'name': 'Gradient Boosting Regressor',
                'estimator': GradientBoostingRegressor(random_state=42),
                'param': param_gbr
                },
    }

    k = 10
    kf = KFold(n_splits=k, random_state=None)
    mod_grid = GridSearchCV(models_regression[mlModel]['estimator'], models_regression[mlModel]['param'], cv=5, return_train_score = False, scoring='neg_mean_squared_error', n_jobs = 8)

    mae = []
    mse = []
    rmse = []
    mape = []

    X_preprocessed = preprocessor.fit_transform(X)

    for train_index , test_index in kf.split(X):
        data_train , data_test = X.iloc[train_index,:],X.iloc[test_index,:]
        target_train , target_test = y[train_index] , y[test_index]

        data_train_lime = preprocessor.fit_transform(data_train)
        data_test_lime = preprocessor.transform(data_test)

        model_lime = Pipeline(steps=[('regressor', mod_grid)])
        model = Pipeline(steps=[('preprocessor', preprocessor),
                ('regressor', mod_grid)])

        _ = model_lime.fit(data_train_lime, target_train)
        _ = model.fit(data_train, target_train)

        feature_names = numeric_features + categorical_features
        target_pred = model_lime.predict(data_test_lime)
    
        mae.append(metrics.mean_absolute_error(target_test, target_pred))
        mse.append(metrics.mean_squared_error(target_test, target_pred))
        rmse.append(np.sqrt(metrics.mean_squared_error(target_test, target_pred)))
        mape.append(smape(target_test, target_pred))

        explainer = LimeTabularExplainer(data_train_lime,
                                         feature_names=feature_names,
                                         class_names=['charges'],
                                         mode='regression',
                                         discretize_continuous=True)
        
        random_numbers = np.random.randint(0, 70, size=5)
        explanation_instances = []
        for i in random_numbers:
            explanation_instances.append(data_test_lime[i])

    for idx, instance in enumerate(explanation_instances):
        exp = explainer.explain_instance(instance,
                                        model_lime.predict,
                                        num_features=5,) #5 most signficant
        


        # save Lime explanation results
        exp.save_to_file(f'{mlModel}/lime/lime_explanation_{idx}.html')

In [22]:
importance = []
    
if (mlModel=='lr'):
    importance = mod_grid.best_estimator_.coef_
    coefs = pd.DataFrame(mod_grid.best_estimator_.coef_, columns=["Coefficients"], index= labels)

elif (mlModel=='dt' or mlModel=='rf' or mlModel=='gbr'):
    importance = mod_grid.best_estimator_.feature_importances_
    coefs = pd.DataFrame(mod_grid.best_estimator_.feature_importances_, columns=["Coefficients"], index= labels)

else:
    c = [None] * len(labels)
    l = mod_grid.best_estimator_.coefs_[0]
    n_l = mod_grid.best_params_['hidden_layer_sizes'][0]
    for i in range(len(labels)):
        c[i] = l[i][n_l-1]
        importance = c
        coefs = pd.DataFrame(c,
                            columns=["Coefficients"],
                            index= labels)

# plot feature importance

indexes = np.arange(len(labels))
plt.bar([x for x in range(len(importance))], importance)
plt.xticks(indexes, labels, rotation = '48')
plt.savefig(mlModel + '/bar.png')
plt.clf()
plt.cla()
plt.close()

## SHAP

In [25]:
_ = summaryPlot(model, X, preprocessor, labels, mlModel+'/shap/', 'Dot_plot', 'dot')
_ = summaryPlot(model, X, preprocessor, labels, mlModel+'/shap/', 'Violin_plot', 'violin')
ordered_labels = summaryPlot(model, X, preprocessor, labels, mlModel+'/shap/', 'Bar_plot', 'bar')
HeatMap_plot(model, X, preprocessor, mlModel+'/shap/', 'HeatMap_plot', labels)
    
# Show some specific examples
Showed_examples = 5 
idx = np.random.randint(0, X.shape[0], Showed_examples)
for i,el in enumerate(idx):
    Decision_plot(model, X, preprocessor, mlModel+'/shap/', el, f'Decision_plot{i}', labels)
    Waterfall(model, X, preprocessor, mlModel+'/shap/', el, f'Waterfall_Plot_{i}', labels)

## DiCE

In [32]:
Ncount=30

Xdice = preprocessor.fit_transform(X)

constraints={}
    
Xdice = pd.DataFrame(Xdice, columns=labels)

desc=Xdice.describe()
for i in labels:
    constraints[i]=[desc[i]['min'],desc[i]['max']]
Xdice['output'] = y
desc=Xdice.describe()
interval = [desc['output']['min'],desc['output']['max']]

X_train, X_test = train_test_split(Xdice,test_size=0.2,random_state=42)

dice_train = dice_ml.Data(dataframe=X_train, continuous_features=numeric_features, outcome_name='output')
    
m = dice_ml.Model(model=mod_grid.best_estimator_,backend='sklearn', model_type='regressor',func=None)
exp = dice_ml.Dice(dice_train,m)

query_instance = X_test.drop(columns="output")
dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=Ncount, desired_range=interval, permitted_range=constraints)

data = []
for cf_example in dice_exp.cf_examples_list:
    data.append(cf_example.final_cfs_df)

df_combined = pd.concat(data, ignore_index=True)
for i in range(len(df_combined)):
    df_combined.iloc[i] = df_combined.iloc[i] - X_test.iloc[i//Ncount]
df_combined.to_csv(path_or_buf=f'{mlModel}/dice/counterfactuals.csv', index=False, sep=',')
df_combined.dtypes
df_filtered = df_combined[df_combined['output'] != 0]
count_per_column = df_filtered.apply(lambda x: (x != 0).sum())
diff_per_column = df_filtered.apply(lambda x: (abs(x)*abs(df_filtered['output'])).sum())


original_stdout = sys.stdout
with open(f'{mlModel}/dice/count.txt', 'w') as f:
    sys.stdout = f
    print('\n--------------------- Counterfactual absolute counts:-------------------------')
    print(diff_per_column)
    print('\n--------------------- Counterfactual relative counts:-------------------------')
    print(diff_per_column/count_per_column)
            
        
sys.stdout = original_stdout

  0%|          | 0/268 [00:00<?, ?it/s]


UserConfigValidationException: The category -2.412011094241716 does not occur in the training data for feature sex. Allowed categories are ['-1.7574736185132187', '-1.0405992403343882', '-0.9520152210628624', '0.5998455609901665', '-1.5081260087118864', '1.7727635939372228', '-1.0438801299370373', '-1.4925417830993029', '1.7333929187054333', '0.5801602233742723', '-1.3522837525860534', '1.050967881354419', '-0.1211299291919752', '1.4217084064537682', '-0.797813409738354', '0.2528914855100229', '0.4489246392683072', '-0.7289147280827231', '1.1592372382418392', '1.8268982723809328', '-0.49515134389397386', '0.8139236075630208', '0.20613880867227363', '-0.42051110543370634', '-2.0847423563774674', '-1.042239685135713', '-1.087351917172138', '1.123147452612699', '-1.7820802905330868', '-0.4287133294403292', '1.5562248801623817', '0.022408990923923334', '2.422379735261747', '-0.012040349903892454', '-0.7600831793078892', '-0.15803993722177792', '0.7983393819504377', '0.5276659897318865', '1.8810329508246442', '1.2494617023146903', '-0.48612889748668886', '-0.1039052587780673', '-1.0717676915595549', '0.4161157432418168', '-1.3719690902019481', '0.38330684721532526', '0.9541816380762698', '1.2675065951292603', '0.5178233209239388', '-1.0241947923211423', '0.3176890551623433', '1.498809312116022', '0.12083567900339627', '-0.4828480078840398', '-0.3220844173542334', '0.40135174002989527', '-1.5032046743079126', '1.7489771443180175', '1.4660004160895317', '-0.510735569506557', '0.689249802662355', '0.3504979511888337', '-0.06781547314892729', '0.47353131128817527', '-1.1496888196224708', '-0.6813418288443105', '1.4824048641027774', '0.8885638460232879', '-0.2892755213277424', '-0.7133305024701394', '-0.6337689296058986', '-0.08996147796680903', '-0.3926235438111892', '0.11263345499677396', '0.49813798330804454', '-0.47956711828139076', '-0.6009600335794075', '-1.1915201620562474', '0.5457108825464565', '0.564575997761688', '1.2314168095001203', '1.2133719166855503', '-0.6632969360297405', '0.058498776553063354', '0.4325201912550614', '-0.8847569842085556', '-1.102936142784721', '-0.12195015159263732', '-0.026804353115813424', '1.09443966858952', '-1.3211153013608872', '0.4399021928610222', '-1.710720941675469', '2.1385827846325993', '2.0139089797319323', '-2.1757870428509802', '-0.3081406365429745', '-0.38770220940721534', '0.3291721687716152', '-0.07601769715555018', '1.561966436967018', '-0.7912516305330559', '-0.6173644815926533', '-1.2899468501357205', '-0.9626781122714716', '0.5145424313212892', '-0.2728710733144966', '1.265866150327936', '0.6621824634405', '1.7120671362882147', '-0.07437725235422538', '-1.1423068180165108', '-0.0596132491423044', '-1.1505090420231336', '-0.8962400978178275', '2.310009266371014', '-0.4016459902184742', '-1.421182434241685', '0.08146500377160722', '-0.502533345499934', '0.202857919069624', '0.003543875708690629', '2.440424628076317', '0.9066087388378579', '-0.17608483003634792', '0.2996441623477733', '1.6644942370498028', '-0.2925564109303915', '2.548693984963738', '0.8754402876126917', '-0.23021950848005795', '-1.2571379541092293', '-0.9881050066920024', '1.686640241867684', '-0.9315096610463054', '-1.634440258413877', '-1.1652730452350546', '1.3511692799968122', '0.36690239920207945', '-0.8076560785463017', '-0.5845555855661618', '-0.13999504440720792', '0.7606091515199723', '0.06588077815902413', '-1.9133158746390513', '-1.3227557461622115', '0.004364098109353326', '0.7622495963212977', '-0.8798356498045817', '-1.5196091223211585', '-0.16788260602972505', '0.3111272759570452', '-0.9159254354337224', '-0.027624575516475538', '1.5152137601292677', '0.8098224955597098', '0.7770135995332181', '1.5201350945332417', '0.5637557753610265', '-0.6665778256323895', '0.43744152565903527', '0.49157620410274533', '0.050296552546440464', '-0.21463528286747488', '-1.0561834659469713', '1.5931348881921843', '-1.305531075748304', '-1.3990364294238031', '3.234399911917402', '-0.15229838041714194', '0.05521788695041431', '-0.9470938866588885', '0.4710706440861884', '0.019128101321274293', '-0.19905105725489178', '2.025392093341204', '-2.02814701073177', '0.48665486969877264', '-0.9938465634966384', '0.15364457502988782', '1.0165185405266033', '-0.22365772927475988', '0.7048340282749381', '0.969765863688854', '-0.043208801129059206', '2.089369440592862', '1.8449431651955028', '0.17497035744710687', '0.8426313915862002', '1.3839781760233039', '1.6398875650299336', '0.034712326933857375', '0.6129691194007627', '0.8918447356259375', '-0.6509936000198064', '-0.8535885329833889', '-0.19412972285091792', '-0.5550275791423198', '1.7186289154935128', '1.4372926320663513', '-2.1003265819900507', '0.34721706158618526', '1.8990778436392142', '-0.5263197951191406', '-1.1808572708476377', '-1.4703957782814217', '0.45548641847360527', '-0.9782623378840553', '1.055889215758393', '-0.3393090877681413', '0.34639683918552255', '0.184813026255054', '0.8606762844007702', '-1.2276099476853874', '0.6736655770497719', '-0.37129776139397014', '-0.5189377935131798', '1.0321027661391864', '0.37756529041068987', '-0.6452520432151705', '0.8344291675795777', '1.2035292478776025', '1.5307979857418508', '1.0886981117848833', '-1.445789106261553', '-0.735476507288021', '-0.6977462768575563', '-0.3745786509966192', '0.26847571112260593', '-0.5419040207317237', '-0.13671415480455887', '0.6178904538047365', '-0.4172302158310573', '-1.6196762552019561', '0.7442047035067277', '-1.8821474234138846', '-1.43020488064897', '0.09458856218220336', '1.3577310592021103', '0.9853500893014371', '-1.1685539348377036', '-0.9946667858973005', '-0.9003412098211386', '-0.6198251487946397', '2.801322484367719', '-1.6016313623873861', '0.09704922938419029', '-0.4369155534469521', '0.829507833175604', '1.5742697729769517', '0.33409350317558906', '-1.2226886132814136', '-0.10554570357939212', '-0.18346683164230812', '-1.385092648612545', '0.1438019062219401', '-0.27697218531780776', '-1.2243290580827384', '-0.3548933133807244', '0.10443123099015107', '-0.3860617646058911', '0.9246536316524279', '-0.17444438523502312', '0.399711295228571', '-0.5886566975694735', '0.6785869114537446', '0.2217230342848567', '1.141192345427269', '-0.40410665742046115', '-1.4613733318741366', '-1.683653602453614', '1.43319152006304', '1.1707203518511111', '1.6243033394173505', '-0.2613879597052247', '1.3396861663875403', '-0.5730724719568898', '-1.7279456120893768', '-2.1938319356655502', '1.301955935957076', '-0.8691727585959725', '-0.6501733776191443', '1.5923146657915217', '0.5818006681755965', '0.978788310096139', '-0.4648031150694698', '-0.3237248621555576', '-0.24006217728800566', '0.6949913594669904', '-0.8634312017913366', '1.105102559798129', '-1.2431941732979706', '0.990271423705411', '-0.09242214516879538', '2.530649092149168', '-1.0250150147218051', '0.408733741635856', '-1.5835864695728157', '1.069012774168989', '-0.5008929006986098', '0.33081261357293945', '0.2930823831424752', '0.1593861318345238', '1.110023894202103', '0.9230131868511036', '-0.744498953695306', '-0.8380043073708058', '-1.5655415767582457', '1.4479555232749617', '-1.134104594009888', '-1.4769575574867198', '-0.28435418692376857', '1.3757759520166803', '-0.45332000146019785', '-0.05879302674164229', '-0.4639828926688071', '-2.412011094241716', '-0.4106684366257592', '1.6087191138047674', '1.014878095725279', '-1.3366995269734703', '0.7360024795001043', '0.25207126310936134', '-1.096374363579423', '-1.414620655036387', '-0.7814089617251088', '1.4003826240365496', '-1.439227327056255', '-0.44839866705622405', '-0.20725328126151468', '0.8524740603941477', '-2.1782477100529665', '-0.806835856145639', '-1.9289001002516344', '-0.1908488332482689', '0.07654366936763336', '-1.404777986228439', '0.220902811884194', '-1.0094307891092213', '0.03881343893716853', '0.8590358395994458', '1.0009343149140202', '-0.3384888653674786', '1.8424824979935173', '0.2840599367351902', '-0.4467582222548992', '1.4495959680762858', '-0.43281444144364095', '1.5775506625796012', '2.0795267717849155', '0.23566681509611556', '2.602828663407448', '-2.110169250797998', '-1.7730578441258018', '0.845092058788187', '1.5808315521822498', '0.23730725989743978', '-0.12523104119528694', '0.5022390953113557', '1.0476869917517695', '1.1567765710398534', '-1.2079246100694927', '-1.2046437204668436', '2.630716225029966', '0.6269129002120215', '-1.0766890259635282', '-0.7321956176853721', '-0.0858603659634973', '-1.5237102343244695', '0.6424971258246046', '3.685522232281654', '-0.6042409231820566', '-2.2094161612781336', '-0.5353422415264256', '0.19055458305968995', '-1.2587783989105543', '-1.9600685514768013', '-1.2948681845396943', '-0.717431614473451', '-1.1259023700032649', '-1.457272219870825', '-0.6272071504006005', '-0.5574882463443068', '1.1871247998643568', '2.1697512358577655', '2.1877961286723355', '1.087057666983559', '2.3026272647650545', '0.6293735674140084', '0.8262269435729555', '-0.5681511375529166', '-1.2743626245231374', '-1.8641025305993146', '0.13724012701664204', '2.2238859143014755', '-0.7896111857317312', '1.032922988539849', '1.6136404482087412', '-1.2407335060959837', '-1.7418893929006352', '0.08802678297690529', '-1.663968264837719', '-1.1964414964602208', '0.548991772149105', '1.2191134734901856', '0.8705189532087179', '-0.699386721658881', '-0.24826440129462854', '0.2750374903279052', '3.0178611981425605', '2.4994806409240007', '-2.022405453927134', '0.166768133440484', '1.9253249604604064', '1.2346976991027687', '1.3126188271656853', '2.332155271188897', '0.960743417281569', '0.9082491836391833', '-0.14163548920853214', '0.5334075465365219', '0.5957444489868554', '-0.37047753899330743', '0.6113286745994384', '-1.995338114705279', '-1.679552490450302', '-1.1587112660297558', '-0.9290489938443185', '-1.1185203683973048', '-1.7328669464933502', '0.14872324062591397', '0.3619810647981056', '0.7827551563378546', '-1.8977316490264682', '1.982740528506766', '-1.6180358104006314', '-1.078329470764853', '1.8088533795663628', '-1.453991330268176', '-2.131495033215217', '1.0632712173643537', '-2.2717530637284664', '-0.04977058033435728', '-1.132464149208563', '1.5020902017186717', '1.3938208448312503', '-0.7756674049204728', '-1.4933620054999655', '-0.30239907973833857', '0.6457780154272542', '-1.6508447064271223', '-0.8437458641754417', '0.5096210969173165', '-0.6354093744072233', '2.007347200526634', '-0.5517466895396708', '0.465329087281553', '1.9351676292683542', '-1.7000580504668592', '-1.6918558264602361', '-0.753521400102591', '0.5965646713875169', '1.039484767745147', '2.3141103783743264', '-0.26630929410919857', '0.12821768060935704', '1.4684610832915186', '0.9074289612385206', '-0.830622305764845', '0.13067834781134396', '0.21926236708286978', '-0.3056799693409876', '-1.2120257220728043', '0.25699259751333403', '2.765232698738579', '-1.9789336666920332', '0.9410580796656736', '1.4840453089041017', '0.23894770469876403', '0.4243179672484391', '0.07162233496366009', '0.8762605100133544', '1.4118657376458215', '-0.9126445458310727', '0.28488015913585174', '-0.3565337581820492', '-1.648384039225136', '1.390539955228602', '-1.8353947465761347', '-0.6821620512449732', '0.9385974124636867', '0.7515867051126873', '-1.114419256393993', '2.7783562571491753']

## RESULTS

In [26]:
original_stdout = sys.stdout
with open('%s/res.txt' %(mlModel), 'w') as f:
    sys.stdout = f
    print('\n--------------------- Model errors and report:-------------------------')
    print('Mean Absolute Error:', np.mean(mae))
    print('Mean Squared Error:', np.mean(mse))
    print('Root Mean Squared Error:', np.mean(rmse))
    print('Mean Average Percentage Error:', np.mean(mape))
    print('\nFeature Scores: \n')
    print(coefs)
            
    print('\nBest Parameters used: ', mod_grid.best_params_)

        
sys.stdout = original_stdout
shutil.rmtree(os.getcwd() + "\__pycache__")
print('Results saved')

Results saved
