In [1]:
from autogluon.tabular import TabularDataset, TabularPredictor

In [2]:
from sklearn.model_selection import train_test_split

data = TabularDataset('data.csv')
train_data, test_data = train_test_split(data, random_state=1)

train_data.head()

label = 'cnt'
train_data[label].describe()

count    13060.000000
mean      1139.233844
std       1083.850522
min          0.000000
25%        257.000000
50%        837.000000
75%       1664.250000
max       7860.000000
Name: cnt, dtype: float64

In [15]:
predictor = TabularPredictor(label=label).fit(train_data)

count    13931.000000
mean      1140.577274
std       1081.964648
min          0.000000
25%        259.000000
50%        842.000000
75%       1663.000000
max       7860.000000
Name: cnt, dtype: float64

In [3]:
predictor = TabularPredictor.load("./autogloun_model/")

In [4]:
y_pred = predictor.predict(test_data.drop(columns=[label]))
y_pred.head()

14999     727.184204
5504     1056.075317
10259    1249.485840
15150    1785.752197
345       201.123642
Name: cnt, dtype: float32

In [5]:
predictor.evaluate(test_data, silent=True)

{'root_mean_squared_error': -831.1946741681302,
 'mean_squared_error': -690884.5863654641,
 'mean_absolute_error': -585.3470244839151,
 'r2': 0.4172016926865758,
 'pearsonr': 0.6484387217849894,
 'median_absolute_error': -422.31573486328125}

In [6]:
predictor.leaderboard(test_data)

Unnamed: 0,model,score_test,score_val,eval_metric,pred_time_test,pred_time_val,fit_time,pred_time_test_marginal,pred_time_val_marginal,fit_time_marginal,stack_level,can_infer,fit_order
0,WeightedEnsemble_L2,-831.194674,-851.193578,root_mean_squared_error,0.848012,0.24898,41.171174,0.006,0.003001,0.307,2,True,12
1,ExtraTreesMSE,-834.894921,-889.042982,root_mean_squared_error,0.582,0.202999,4.401082,0.582,0.202999,4.401082,1,True,7
2,RandomForestMSE,-849.867244,-903.845828,root_mean_squared_error,0.627,0.186997,6.382169,0.627,0.186997,6.382169,1,True,5
3,LightGBMLarge,-854.687917,-864.160039,root_mean_squared_error,0.034,0.010999,1.474072,0.034,0.010999,1.474072,1,True,11
4,XGBoost,-862.286051,-862.676594,root_mean_squared_error,0.029998,0.005999,0.842525,0.029998,0.005999,0.842525,1,True,9
5,NeuralNetFastAI,-865.317246,-862.00727,root_mean_squared_error,0.138012,0.023,15.598369,0.138012,0.023,15.598369,1,True,8
6,LightGBMXT,-867.184388,-862.525423,root_mean_squared_error,0.390996,0.013001,1.533518,0.390996,0.013001,1.533518,1,True,3
7,CatBoost,-867.430746,-862.014261,root_mean_squared_error,0.052001,0.002981,1.606015,0.052001,0.002981,1.606015,1,True,6
8,LightGBM,-867.474643,-864.203128,root_mean_squared_error,0.023,0.006999,0.965998,0.023,0.006999,0.965998,1,True,4
9,NeuralNetTorch,-900.914608,-888.359756,root_mean_squared_error,0.04,0.011,18.416183,0.04,0.011,18.416183,1,True,10


In [27]:
import shap
import matplotlib.pyplot as plt
import pandas as pd

def wrapped_model(x):
    column_names = [f'column_{i}' for i in range(12)]
    x = pd.DataFrame(x)
    x.columns = column_names
    preds = predictor.predict(x).to_numpy()
        
    return preds

test_data = TabularDataset('Test.csv')

to_be_explained = pd.DataFrame(test_data).drop('label', axis=1).to_numpy()[0]
explainer = shap.KernelExplainer(wrapped_model, pd.read_csv('./Train.csv').drop('label', axis=1).sample(n=100))
shap_values = explainer.shap_values(to_be_explained)
relevance = abs(shap_values.ravel())


norm_relevance = ((relevance - min(relevance)) / (max(relevance) - min(relevance)))

print(relevance)
plt.imshow(norm_relevance.reshape((2, 6)))
plt.colorbar()

{'root_mean_squared_error': -864.7775674115269,
 'mean_squared_error': -747840.241098198,
 'mean_absolute_error': -608.0744023789903,
 'r2': 0.3791673724300947,
 'pearsonr': 0.6166151071994901,
 'median_absolute_error': -434.9698486328125}

In [28]:
import lime.lime_tabular
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


def wrapped_net(x):
    column_names = [f'column_{i}' for i in range(12)]
    x = pd.DataFrame(x)
    x.columns = column_names
    preds = predictor.predict_proba(x).to_numpy()
        
    return preds

background = pd.read_csv('./Train.csv').drop('label', axis=1).sample(n=100).to_numpy()
explainer = lime.lime_tabular.LimeTabularExplainer(
    background,
    feature_names=[str(i) for i in range(12)],
    verbose=True,
    mode='classification',
)

test = pd.read_csv('./Train.csv').drop('label', axis=1).to_numpy()[0]

exp = explainer.explain_instance(test, wrapped_net, num_features=12)
# exp.save_to_file('lime_explanationall.html')
# relevance = abs(np.asarray([float(i) for i in exp.domain_mapper.feature_values]))
relevance = abs(np.asarray([j for i, j in sorted(exp.local_exp[1], key=lambda i: i[0])]))
# relevance = exp.local_exp[1]
norm_relevance = ((relevance - min(relevance)) / (max(relevance) - min(relevance)))


print(relevance)
plt.imshow(norm_relevance.reshape((2, 6)))
plt.colorbar()

Unnamed: 0,model,score_test,score_val,eval_metric,pred_time_test,pred_time_val,fit_time,pred_time_test_marginal,pred_time_val_marginal,fit_time_marginal,stack_level,can_infer,fit_order
0,WeightedEnsemble_L2,-864.777567,-851.193578,root_mean_squared_error,1.10807,0.24898,41.171174,0.005999,0.003001,0.307,2,True,12
1,NeuralNetFastAI,-871.452567,-862.00727,root_mean_squared_error,0.148617,0.023,15.598369,0.148617,0.023,15.598369,1,True,8
2,CatBoost,-874.392598,-862.014261,root_mean_squared_error,0.059017,0.002981,1.606015,0.059017,0.002981,1.606015,1,True,6
3,XGBoost,-874.580263,-862.676594,root_mean_squared_error,0.030994,0.005999,0.842525,0.030994,0.005999,0.842525,1,True,9
4,LightGBMXT,-874.625472,-862.525423,root_mean_squared_error,0.054903,0.013001,1.533518,0.054903,0.013001,1.533518,1,True,3
5,LightGBM,-877.855435,-864.203128,root_mean_squared_error,0.014,0.006999,0.965998,0.014,0.006999,0.965998,1,True,4
6,LightGBMLarge,-880.022788,-864.160039,root_mean_squared_error,0.023,0.010999,1.474072,0.023,0.010999,1.474072,1,True,11
7,NeuralNetTorch,-907.496117,-888.359756,root_mean_squared_error,0.037001,0.011,18.416183,0.037001,0.011,18.416183,1,True,10
8,ExtraTreesMSE,-911.525525,-889.042982,root_mean_squared_error,0.826442,0.202999,4.401082,0.826442,0.202999,4.401082,1,True,7
9,RandomForestMSE,-928.204409,-903.845828,root_mean_squared_error,0.835002,0.186997,6.382169,0.835002,0.186997,6.382169,1,True,5
