# Simplified decision trees generation for rule-based inference, with CART algorithm

Decision tree based methods are a very powerful machine learning technique, particularly for tabular data.  
They can generate decision trees for you, based on your data, in an iterative fashion.  
For tabular data (meaning not images, not time series,not language... Only plain regular excel-like tables with columns, every line being independant from each other and containing several predictor variables), decision tree based machine learning algorithms are in fact performing at state of the art level (with algorithms such as XGBoost, LGBM).  
This scikit-learn page is interesting, to understand tree-based models : https://scikit-learn.org/stable/modules/tree.html

However, the most performing decision tree algorithms generate very complex trees that can be very, very deep, and you can have multiple trees based on residual errors of former ones. It is impossible for a human to understand such trees in a global way.  

In some cases, it is useful to have some human-interpretable models, in order to :
- Check if you can trust the algorithm
- Be able to justify your choices to regulators (typically financial regulators in credit risk problems)
- Discover business rules from the data and use them to better approach your customers  
- Decrease maintenance efforts by using simpler (but less performing), close to business ruled-based models that will be more robust to new kind of data than more complex (though more performing) machine learning models

There are many techniques for human-interpretability / explainability of machine learning models, including LIME or SHAPLEY : those techniques are interesting because they allow you to maintain your model complexity (which means keep your level of performance) and still explain what's happening.
But even with LIME or SHAPLEY it can still sometimes be difficult to explain some decisions, and you still have model maintenance / retraining issue.

In this notebook, we'll see a much more simple technique, based on DecisionTreeClassifier from scikit-learn (CART algorithm) + some additionnal code we wrote, in order to generate simple, human-understandable rules from data. To generate those rules though, we'll have to sacrifice performance (more precisely this means that our simple rule-based model will only be able to predict a subsample of instances, not all of them).
We want those rules to be simple (not more than 2 criterias), precise (as in good precision score). Exhaustivity and recall will not at all be our priority.




Thanks to this original notebook from which I took data preparation, description and DecisionTree calling code : https://www.kaggle.com/mariosfish/default-of-credit-card-clients-lr-dt-nn

I just added my simple tree-based rules generation code (see § Simplified decision trees)


## Attribute Information

Below there are the description of the attributes that will be used in our model for better understanding of the data:

- `LIMIT_BAL`: Amount of the given credit (NT dollar). It includes both the individual consumer credit and his/her family (supplementary) credit.
- `SEX`: Gender (1 = male; 2 = female).
- `EDUCATION`: Education (1 = graduate school; 2 = university; 3 = high school; 4 = others).
- `MARRIAGE`: Marital status (1 = married; 2 = single; 3 = others).
- `AGE`: Age (year).
- `PAY_1`: the repayment status in September, 2005. The measurement scale for the repayment status is: -1 = pay duly; 1 = payment delay for one month; 2 = payment delay for two months; . . .; 8 = payment delay for eight months; 9 = payment delay for nine months and above.
- `PAY_2`: the repayment status in August, 2005. The measurement scale for the repayment status is: -1 = pay duly; 1 = payment delay for one month; 2 = payment delay for two months; . . .; 8 = payment delay for eight months; 9 = payment delay for nine months and above.
- `PAY_3`: the repayment status in July, 2005. The measurement scale for the repayment status is: -1 = pay duly; 1 = payment delay for one month; 2 = payment delay for two months; . . .; 8 = payment delay for eight months; 9 = payment delay for nine months and above.
- `PAY_4`: the repayment status in June, 2005. The measurement scale for the repayment status is: -1 = pay duly; 1 = payment delay for one month; 2 = payment delay for two months; . . .; 8 = payment delay for eight months; 9 = payment delay for nine months and above.
- `PAY_5`: the repayment status in May, 2005. The measurement scale for the repayment status is: -1 = pay duly; 1 = payment delay for one month; 2 = payment delay for two months; . . .; 8 = payment delay for eight months; 9 = payment delay for nine months and above.
- `PAY_6`: the repayment status in April, 2005. The measurement scale for the repayment status is: -1 = pay duly; 1 = payment delay for one month; 2 = payment delay for two months; . . .; 8 = payment delay for eight months; 9 = payment delay for nine months and above.
- `BILL_AMT1`: Amount of bill statement (NT dollar). Amount of bill statement in September, 2005.
- `BILL_AMT2`: Amount of bill statement (NT dollar). Amount of bill statement in August, 2005.
- `BILL_AMT3`: Amount of bill statement (NT dollar). Amount of bill statement in July, 2005.
- `BILL_AMT4`: Amount of bill statement (NT dollar). Amount of bill statement in June, 2005.
- `BILL_AMT5`: Amount of bill statement (NT dollar). Amount of bill statement in May, 2005.
- `BILL_AMT6`: Amount of bill statement (NT dollar). Amount of bill statement in April, 2005.
- `PAY_AMT1`: Amount of previous payment (NT dollar). Amount paid in September, 2005.
- `PAY_AMT2`: Amount of previous payment (NT dollar). Amount paid in August, 2005.
- `PAY_AMT3`: Amount of previous payment (NT dollar). Amount paid in July, 2005.
- `PAY_AMT4`: Amount of previous payment (NT dollar). Amount paid in June, 2005.
- `PAY_AMT5`: Amount of previous payment (NT dollar). Amount paid in May, 2005.
- `PAY_AMT6`: Amount of previous payment (NT dollar). Amount paid in June, 2005.
- `dpnm`: Default payment next month.(Yes = 1, No = 0)

## Models

We will create 2 models :
- Decision tree  (classic one)
- Simplified rules (that's my custom code that generates those rules) 

## Import libraries/packages 

In [None]:
### General libraries ###
import pandas as pd
from pandas.api.types import CategoricalDtype
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import seaborn as sns
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import graphviz 
from graphviz import Source
from IPython.display import SVG

##################################

### ML Models ###
from sklearn.linear_model import LogisticRegression
from sklearn import tree
from sklearn.tree.export import export_text
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler

##################################

### Metrics ###
from sklearn import metrics
from sklearn.metrics import f1_score,confusion_matrix, mean_squared_error, mean_absolute_error, classification_report, roc_auc_score, roc_curve, precision_score, recall_score

## Part 1: Load and clean the data

In this section we will load the data from the csv file and check for any "impurities", such as null values or duplicate rows. If any of these will appear, we will remove them from the data set. We will also plot the correlations of the class column with all the other columns.

In [None]:
# Load the data.
data=pd.read_csv('../input/default of credit card clients.csv')

# Information
data.info()

Since the `ID` column is for indexing purposes only, we remove it from the data set.

In [None]:
# Drop "ID" column.
data=data.drop(['ID'], axis=1)

Now we check for duplicate rows. If any, we remove them from the data set, since they provide only reduntant information.

In [None]:
# Check for duplicate rows.
print(f"There are {data.duplicated().sum()} duplicate rows in the data set.")

# Remove duplicate rows.
data=data.drop_duplicates()
print("The duplicate rows were removed.")

We also check for null values.

In [None]:
# Check for null values.
print(f"There are {data.isna().any().sum()} cells with null values in the data set.")

Below is the plot of the correlation matrix for the data set.

In [None]:
plt.figure(figsize=(20,20))
sns.heatmap(data.corr(),annot=True, cmap='rainbow',linewidth=0.5, fmt='.2f');

## Part 2: Pre-processing

In this part we prepare our data for our models. This means that we choose the columns that will be our independed variables and which column the class that we want to predict. Once we are done with that, we split our data into train and test sets and perfom a standardization upon them.

In [None]:
# Distinguish attribute columns and class column.
X=data[data.columns[:-1]]
y=data['dpnm']

In [None]:
# Split to train and test sets. 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=25)

## Part 3: Modeling

In this section we build and try 3 models:

 - Decision tree (classical)
 - Simplified decision trees : that's my custom code for generating simple tree-based rules


## Decision tree

In [None]:
# Initialize a decision tree estimator.
tr = tree.DecisionTreeClassifier(max_depth=3, criterion='gini', random_state=25)

# Train the estimator.
tr.fit(X_train, y_train)

In [None]:
# Plot the tree.
dot_data = tree.export_graphviz(tr, out_file=None, feature_names=X.columns, filled=True, rounded=True, special_characters=True)  
graph = graphviz.Source(dot_data)  
graph 

In [None]:
# Make predictions.
tr_pred=tr.predict(X_test)

# CV score
#tr_cv=cross_val_score(tr, X_train, y_train, cv=10).mean()

## Metrics for Decision tree

In [None]:
# Accuracy: 1 is perfect prediction.
print('Accuracy: %.3f' % tr.score(X_test, y_test))

# Precision
print('Precision: %.3f' % precision_score(y_test, tr_pred))

# Recall
print('Recall: %.3f' % recall_score(y_test, tr_pred))

# f1 score: best value at 1 (perfect precision and recall) and worst at 0.
print('F1 score: %.3f' % f1_score(y_test, tr_pred))

=> Note the precision of 0.671 with this technique : meaning that when the model predicts something, it's right 67% of the time.  
=> That means it's wrong 33% of the time : is it acceptable ? You need to defined acceptable threshold with the business.

In [None]:
# Plot confusion matrix for Decision tree.
tr_matrix = confusion_matrix(y_test,tr_pred)
sns.set(font_scale=1.3)
plt.subplots(figsize=(8, 8))
sns.heatmap(tr_matrix,annot=True, cbar=False, cmap='twilight',linewidth=0.5,fmt="d")
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix for Decision tree');

In [None]:
# Predict propabilities for the test data.
tr_probs = tr.predict_proba(X_test)

# Keep Probabilities of the positive class only.
tr_probs = tr_probs[:, 1]

# Compute the AUC Score.
auc_tr = roc_auc_score(y_test, tr_probs)
print('AUC: %.2f' % auc_tr)

# Simplified decision trees

## Define parameters

Now we'll use random forest to generate multiple, simple trees, each one of which will give us a set of rules to predict payment default.  

But we'll constrain random forest with particular hyper parameters :

- A very short maximum depth to keep it simple
- No bootstrap in order to search the whole dataset for each tree

Then we'll have some code to explore all resulted trees, and retain only the folling ones :
- A maximum gini impurity coefficient (https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity) of tree leaves : the lowest it is, the lowest risk you'll have of an incorrect classification. Here you can try several MAX_GINI_TODISPLAY parameters and see the lowest possible you can get while still having results.
- A minimum percentage of samples concerned by the tree : the idea is that you want rules that can apply to a minimum percentage of instances, in order to be useful

In [None]:
!pip install pydotplus

In [None]:
from io import StringIO
import pydotplus
from IPython.display import Image
from sklearn.ensemble import RandomForestClassifier

In [None]:
MAX_DEPTH = 2
MAX_GINI_TODISPLAY = 0.2
MIN_PERCENT_SAMPLES_TODISPLAY = 0.02
min_nb_samples_todisplay = int(MIN_PERCENT_SAMPLES_TODISPLAY * X_train.shape[0])

## Train the random forest

In [None]:
model_simple = RandomForestClassifier(random_state=42, bootstrap=False, max_depth=MAX_DEPTH, max_features=3, n_estimators=30000, min_impurity_decrease=0)

model_simple.fit(X_train, y_train)

## Display only trees that respect simplicity constraints we defined in our parameters

Note that this code is probably not optimal and could be improved in order to directly print out only rules path that lead to the low gini leave, in each tree

In [None]:
for tree_todisplay in model_simple.estimators_:
    # Get gini indices of last leaves
    # node[0] and node[1] are gini indices of children (-1 if last leaf: this is what we want to get gini indices of last leaves)
    
    end_node_gini_indices_and_nb_samples = [ [node[4], node[5]] for node in tree_todisplay.tree_.__getstate__()['nodes'] if ((node[0] == -1) and (node[1] == -1))] 
    
    display_current_tree = 0
    for (end_node_gini_indice, end_node_gini_nb_samples) in end_node_gini_indices_and_nb_samples:        
        if ((end_node_gini_indice < MAX_GINI_TODISPLAY) and (end_node_gini_nb_samples > min_nb_samples_todisplay)):            
            display_current_tree = 1
        
    if (display_current_tree == 1):
        dot_data = StringIO()

        tree.export_graphviz(
            tree_todisplay,
            out_file=dot_data,
            feature_names=X.columns,

            class_names = sorted(np.unique(y_train.astype(str))),

            max_depth=3,
            filled=True,
        )
        g = pydotplus.graph_from_dot_data(
            dot_data.getvalue()
        )    
        g.set_size('"5,5!"')

        display(Image(g.create_png()))
        

## Visually inspect resulting trees and extract interesting rules

Now, we can visually inspect generated trees above : we'll see that each tree has at least one leave with a gini coefficient that is not higher than the maximum one we defined.  
Our rules are all the set of rules that go from the top of a tree to its low-gini leave  

For example, on the first tree you have this rule :
- IF  BILL_AMT1 <= 54416 AND PAY_AMT2 > 4553 :  THEN class 0 (no payment default) is most probable one with gini impurity of 0.184

And on next trees, you have :
- IF  BILL_AMT1 <= 54416 AND PAY_AMT1 > 4553 :  THEN class 0 is most probable one with gini impurity of 0.186
- IF  BILL_AMT2 <= 4959.5 AND PAY_AMT1 > 4053.5 :  THEN class 0 is most probable one with gini impurity of 0.189
- IF  BILL_AMT3 <= 49144.5 AND PAY_AMT1 > 4522.0 : THEN class 0 is mot probable one with gini impurity of 0.184
- IF  BILL_AMT1 <= 54416 AND PAY_AMT5 > 4317.5 : THEN class 0 is most probable one with gini impurity of 0.199

You will see that many trees are redundant between themselves, with parameters changing only slightly (for example, for the first rule, you also have a tree with PAY_AMT2 > 4692.5 instead of 4553 : and indeed, why 4692.5 and why not 4553 ? 
That's where decision trees sometimes provide over complex solutions where you could have less trees with same performance.  
You can talk to the business and come up with a figure that makes sense and then check how y is distributed in the test set.  



In [None]:
df_test = pd.concat([pd.DataFrame(X_test, columns=X.columns).reset_index(drop=True), y_test.reset_index(drop=True)], axis=1)

Distribution of y on test set for first rule

In [None]:
df_criteria = df_test[(df_test['BILL_AMT1'] <= 54416) & (df_test['PAY_AMT2'] > 4553)]
df_criteria['dpnm'].value_counts()

Precision of our prediction in test set, for first rule

In [None]:
df_criteria['dpnm'].value_counts()[0] / df_criteria.shape[0]

Precision of our prediction in test set if we change rule to the following, more human readable rule that could be determined with business people :  
IF BILL_AMT1 <= 55000 AND PAY_AMT2 > 4500 : THEN class 0 is most probable one with gini impurity of 0.186

In [None]:
df_criteria = df_test[(df_test['BILL_AMT1'] <= 55000) & (df_test['PAY_AMT2'] > 4500)]
df_criteria['dpnm'].value_counts()[0] / df_criteria.shape[0]

=> We see that it's mostly equivalent (even slightly better)

Now let's build a simple model with all the rules we have determined so far :

IF  (BILL_AMT1 <= 55000 AND PAY_AMT2 > 4500 )  
   OR (BILL_AMT1 <= 54416 AND PAY_AMT1 > 4553)  
   OR (BILL_AMT2 <= 4959.5 AND PAY_AMT1 > 4053.5)  
   OR (BILL_AMT3 <= 49144.5 AND PAY_AMT1 > 4522.0)  
   OR (BILL_AMT1 <= 54416 AND PAY_AMT5 > 4317.5)  
   
THEN we predicted class 0  (= no payment default)  

In [None]:
df_criteria = df_test[  ((df_test['BILL_AMT1'] <= 55000) & (df_test['PAY_AMT2'] > 4500))
                       | ((df_test['BILL_AMT1'] <= 54416) & (df_test['PAY_AMT1'] > 4553))
                       | ((df_test['BILL_AMT2'] <= 4959.5) & (df_test['PAY_AMT1'] > 4053.5))
                       | ((df_test['BILL_AMT3'] <= 49144.5) & (df_test['PAY_AMT2'] > 4522.0))                     
                       | ((df_test['BILL_AMT1'] <= 54416) & (df_test['PAY_AMT5'] > 4317.5))                                           
                    ]

df_criteria['dpnm'].value_counts()[0] / df_criteria.shape[0]

Percentage of values in the set we predicted :

In [None]:
df_criteria['dpnm'].value_counts()[0] / df_test.shape[0]

Not bad ! With a set of 5 simple rules, we can predict about **20%** of test set data with **88%** precision  
So we can predict 20% of decisions without any machine learning deployment.

# Conclusion

This was just a simple demonstration of how you can extract simple rules from data, using machine learning for modelisation, but without any machine learning for deployment.
In some cases, those techniques can be very useful to increase human explainability, and it will be much easier to justify those rules to regulators. It will also save ML model maintenance time.  

Note that other powerful techniques can be used for interpretability such as LIME and SHAPLEY.  
You could still train a machine learning model on remaining 80% non-predicted decisions, with an hybrid approach.