<a href="https://www.kaggle.com/drjohnwagner/heart-disease-prediction-with-xgboost?scriptVersionId=85257453" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
import json
import random
import numpy as np
import pandas as pd
from igraph import Graph
import igraph
from pprint import pprint
#
import plotly.io as pio
# import plotly.express as px
import plotly.graph_objects as go
# from plotly.subplots import make_subplots
#
from xgboost import XGBClassifier
# from xgboost import plot_importance
from xgboost import plot_tree as xgb_plot_tree
from xgboost import XGBModel
from xgboost import Booster
#
from sklearn.compose import ColumnTransformer
from sklearn.feature_selection import SelectKBest, chi2
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import make_scorer
from sklearn.metrics import plot_confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import MinMaxScaler
from sklearn.utils import shuffle

from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)

import matplotlib.pyplot as plt
%matplotlib inline

# https://matplotlib.org/stable/users/prev_whats_new/dflt_style_changes.html
SMALL_SIZE = 20
MEDIUM_SIZE = 24
LARGE_SIZE = 28

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=LARGE_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=LARGE_SIZE)  # fontsize of the figure title

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk("/kaggle/input"):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# Colorblindness friendly colours...
# It is important to make our work
# as accessible as possible...
COLORMAP = ["#005ab5", "#DC3220"]
# Labels for plotting...
LABELS = {
    "Sex": "Sex",
    "Age": "Age",
    "MaxHR": "Max HR",
    "OldPeak": "Old Peak",
    "STSlope": "ST Slope",
    "RestingBP": "Rest. BP",
    "FastingBS": "Fast. BS",
    "RestingECG": "Rest. ECG",
    "Cholesterol": "Cholesterol",
    "HeartDisease": "Heart Disease",
    "ChestPainType": "Chest Pain",
    "ExerciseAngina": "Ex. Angina",
}
# Random seed for determinism...
SEED = 395147

# # Template settings for plotly...
# layout_axis = dict(
#     mirror=True,
#     ticks="outside",
#     showline=True,
#     title_standoff = 5,
#     showgrid = True,
# )
# pio.templates["DrJohnWagner"] = go.layout.Template(
#     layout_xaxis = layout_axis,
#     layout_yaxis = layout_axis,
#     layout_title_font_size = 18,
#     layout_font_size = 16,
# )
# pio.templates.default = "simple_white+DrJohnWagner"
# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/heart-failure-prediction/heart.csv


## Now with new and improved `plot_importance()` and `plot_tree()`!

## Load the Data

In [2]:
# Loading the data from the csv file...
df = pd.read_csv("/kaggle/input/heart-failure-prediction/heart.csv")
df.head()

Unnamed: 0,Age,Sex,ChestPainType,RestingBP,Cholesterol,FastingBS,RestingECG,MaxHR,ExerciseAngina,Oldpeak,ST_Slope,HeartDisease
0,40,M,ATA,140,289,0,Normal,172,N,0.0,Up,0
1,49,F,NAP,160,180,0,Normal,156,N,1.0,Flat,1
2,37,M,ATA,130,283,0,ST,98,N,0.0,Up,0
3,48,F,ASY,138,214,0,Normal,108,Y,1.5,Flat,1
4,54,M,NAP,150,195,0,Normal,122,N,0.0,Up,0


## Fix the Column Names
Columns `ST_Slope` and `Oldpeak` do not use the same naming convention as the other columns.

We also convert `HeartDisease` to a categorical variable (`int64` to `int8`)...

In [3]:
# Fix the egregious column naming error...
df = df.rename(columns = {"ST_Slope": "STSlope", "Oldpeak": "OldPeak"})

# Always test these things...
assert len(df["STSlope"]) > 0, "Ruh roh! ST_Slope is still terribly mistaken!"
assert len(df["OldPeak"]) > 0, "Ruh roh! Oldpeak is still terribly mistaken!"

# Convert target to categorical
target = pd.Categorical(df["HeartDisease"])
df["HeartDisease"] = target.codes

print("Datatypes")
print("---------")
print(df.dtypes)

Datatypes
---------
Age                 int64
Sex                object
ChestPainType      object
RestingBP           int64
Cholesterol         int64
FastingBS           int64
RestingECG         object
MaxHR               int64
ExerciseAngina     object
OldPeak           float64
STSlope            object
HeartDisease         int8
dtype: object


## Grab the Column Names

In [4]:
# Break the columns into two groupings...
categorical_columns = [column for column in df.columns if df[column].dtypes == np.object]
numerical_columns   = [column for column in df.columns if df[column].dtypes != np.object]

if "HeartDisease" in numerical_columns:
    numerical_columns.remove("HeartDisease")

assert "HeartDisease" not in numerical_columns, "Ruh roh! HeartDisease is still in numerical_columns!"

print("Categorical Columns: ", categorical_columns)
print("  Numerical Columns: ", numerical_columns)

Categorical Columns:  ['Sex', 'ChestPainType', 'RestingECG', 'ExerciseAngina', 'STSlope']
  Numerical Columns:  ['Age', 'RestingBP', 'Cholesterol', 'FastingBS', 'MaxHR', 'OldPeak']


## Randomly Redistribute Missing Values
The columns `RestingBP` and `Cholesterol` have records with value 0. Hypothesising that these
represent missing values, we set them to random values drawn from a normal distribution fit to
the rest of the values in each of those columns.

In [5]:
def set_column_value_to_normal_distribution(df, column, value):
    # Compute the column's mean and standard deviation
    # after removing rows whose column matches value...
    mean_value = df[df[column] != value][column].mean()
    std_value  = df[df[column] != value][column].std()
    # Create a random number generator...
    rng = np.random.default_rng(SEED)
    # Now set the column of those rows to a
    # random sample from a normal distribution...
    df[column] = df[column].apply(
        lambda x : rng.normal(mean_value, std_value) if x == value else x
    )
    return df

df = set_column_value_to_normal_distribution(df, "RestingBP"  , 0)
df = set_column_value_to_normal_distribution(df, "Cholesterol", 0)

# Always test...
assert len(df[df["RestingBP"  ] == 0]) == 0, "Ruh roh! One or more patients has crashed again!"
assert len(df[df["Cholesterol"] == 0]) == 0, "Ruh roh! One or more patients has crashed again!"

## Compare Numeric Columns

In [6]:
df[df["HeartDisease"] == 0].describe(include=[np.number])

Unnamed: 0,Age,RestingBP,Cholesterol,FastingBS,MaxHR,OldPeak,HeartDisease
count,410.0,410.0,410.0,410.0,410.0,410.0,410.0
mean,50.55122,130.180488,239.142666,0.107317,148.15122,0.408049,0.0
std,9.444915,16.499585,55.466868,0.309894,23.288067,0.699709,0.0
min,28.0,80.0,85.0,0.0,69.0,-1.1,0.0
25%,43.0,120.0,202.25,0.0,134.0,0.0,0.0
50%,51.0,130.0,232.5,0.0,150.0,0.0,0.0
75%,57.0,140.0,270.0,0.0,165.0,0.6,0.0
max,76.0,190.0,564.0,1.0,202.0,4.2,0.0


In [7]:
df[df["HeartDisease"] == 1].describe(include=[np.number])

Unnamed: 0,Age,RestingBP,Cholesterol,FastingBS,MaxHR,OldPeak,HeartDisease
count,508.0,508.0,508.0,508.0,508.0,508.0,508.0
mean,55.899606,134.424968,245.989923,0.334646,127.655512,1.274213,1.0
std,8.727056,18.918337,64.151139,0.472332,23.386923,1.151872,0.0
min,31.0,92.0,52.747549,0.0,60.0,-2.6,1.0
25%,51.0,120.0,208.0,0.0,112.0,0.0,1.0
50%,57.0,132.0,242.000602,0.0,126.0,1.2,1.0
75%,62.0,145.0,283.43555,1.0,144.25,2.0,1.0
max,77.0,200.0,603.0,1.0,195.0,6.2,1.0


## Build the model...

Run this cell is you want to perform a very limited grid search:

In [8]:
PERFORM_GRID_SEARCH = True

Run this cell if you **do not** want to perform grid search:

In [9]:
PERFORM_GRID_SEARCH = False

Change `PERFORM_HUGE_GRID_SEARCH` to `True` is you want to do an extensive grid search.

Be aware: this takes many hours on a four-core CPU!

In [10]:
PERFORM_HUGE_GRID_SEARCH = False

In [11]:
# Split the dataset into training and test...
xt, xv, yt, yv = train_test_split(
    df.drop("HeartDisease", axis = 1),
    df["HeartDisease"],
    test_size = 0.2,
    random_state = 42,
    shuffle = True,
    stratify = df["HeartDisease"]
)

# Define the data preparation, feature
# selection and classification pipeline
pipeline = Pipeline(steps = [
    ("transform", ColumnTransformer(
            transformers = [
                ("cat", OrdinalEncoder(), categorical_columns),
                ("num", MinMaxScaler(), numerical_columns)
            ]
        )
    ),
    ("features", SelectKBest()),
    ("classifier", XGBClassifier(
            objective = "binary:logistic", eval_metric = "auc", use_label_encoder = False
        )
    )
])

if PERFORM_GRID_SEARCH:
    # Define our search space for grid search...
    # Short search over gamma as a quick example...
    search_space = [{
        "classifier__n_estimators": [60],
        "classifier__learning_rate": [0.1],
        "classifier__max_depth": [4],
        "classifier__colsample_bytree": [0.2],
        "classifier__gamma": [i / 10.0 for i in range(3, 7)],
        "features__score_func": [chi2],
        "features__k": [10],
    }]
    if PERFORM_HUGE_GRID_SEARCH:
        # Define our search space for grid search...
        # This is a real search but takes hours...
        search_space = [{
            "classifier__n_estimators": [i*10 for i in range(1, 10)],
            "classifier__learning_rate": [0.01, 0.05, 0.1, 0.2],
            "classifier__max_depth": range(1, 10, 1),
            "classifier__colsample_bytree": [i/20.0 for i in range(7)],
            "classifier__gamma": [i / 10.0 for i in range(3, 7)],
            "features__score_func": [chi2],
            "features__k": [10],
        }]
    # Define grid search...
    grid = GridSearchCV(
        pipeline,
        param_grid = search_space,
        # Define cross validation...
        cv = KFold(n_splits = 10, random_state = 917, shuffle = True),
        # Define AUC and accuracy as score...
        scoring = {
            "AUC": "roc_auc",
            "Accuracy": make_scorer(accuracy_score)
        },
        refit = "AUC",
        verbose = 1,
        n_jobs = -1
    )
    # Fit grid search
    grid_model = grid.fit(xt, yt)
    yp = grid_model.predict(xv)
    #
    print(f"Best AUC Score: {grid_model.best_score_}")
    print(f"Accuracy: {accuracy_score(yv, yp)}")
    print("Confusion Matrix: ", confusion_matrix(yv, yp))
    print("Best Parameters: ", grid_model.best_params_)

In [12]:
# Use the new best parameters if they were computed
# else use the previously computed best parameters.
# These produced an AUC score of 0.9244531360448315
# and an accuracy of 0.8858695652173914.
parameters = {
    "classifier__colsample_bytree": 0.2,
    "classifier__gamma": 0.4,
    "classifier__learning_rate": 0.1,
    "classifier__max_depth": 4,
    "classifier__n_estimators": 60,
    "features__k": 10,
    "features__score_func": chi2
}
if PERFORM_GRID_SEARCH:
    parameters = grid_model.best_params_

pipeline.set_params(**parameters)

model = pipeline.fit(xt, yt)
yp = model.predict(xv)

print(f"Accuracy: {accuracy_score(yv, yp)}")
print("Confusion Matrix: ", confusion_matrix(yv, yp))
print("Prediction: ", yp)

Accuracy: 0.8858695652173914
Confusion Matrix:  [[73  9]
 [12 90]]
Prediction:  [1 0 1 0 0 0 0 1 0 1 1 0 0 0 0 1 0 1 0 0 0 1 0 0 1 1 0 1 1 0 0 1 0 0 1 1 1
 1 0 1 1 0 1 0 1 1 1 0 1 0 1 0 0 0 1 0 1 0 0 0 0 1 0 1 0 1 1 1 0 0 0 1 1 1
 0 1 1 1 1 1 1 1 1 0 1 0 1 1 0 1 0 1 0 0 1 1 1 1 0 1 1 0 1 0 1 1 1 0 0 1 1
 0 0 1 1 0 1 0 1 1 0 1 1 0 0 1 0 0 1 0 0 0 0 0 1 1 1 0 0 0 1 1 0 1 1 1 1 0
 1 1 1 0 1 0 1 1 0 0 1 1 1 1 0 0 1 0 1 1 0 1 1 0 1 1 1 0 0 0 0 0 1 1 1 0]


## A Better `plot_importance()`...

The `xgboost` functions `plot_importance()` and `plot_tree()` produce hopelessly unimpressive looking visualisations using the unimpressive looking `matplotlib` defaults so I decided to put together my own. Here's my `plot_importance`...

In [13]:
def plot_importance(booster, width = 0.5, xrange = None, yrange = None,
                    title = 'Feature Importance', xlabel = 'F Score',
                    ylabel = 'Features', fmap = '', max_features = None,
                    importance_type = 'weight', show_grid = True, show_values = True):
    """Plot importance based on fitted trees using plotly.
    Parameters
    ----------
    booster : Booster, XGBModel or dict
        Booster or XGBModel instance, or dict taken by Booster.get_fscore()
    show_grid : bool, Turn the axes grids on or off.  Default is True (On).
    importance_type : str, default "weight"
        How the importance is calculated: either "weight", "gain", or "cover"
        * "weight" is the number of times a feature appears in a tree
        * "gain" is the average gain of splits which use the feature
        * "cover" is the average coverage of splits which use the feature
          where coverage is defined as the number of samples affected by the split
    max_features : int, default None
        Maximum number of top features displayed on plot. If None, all features will be displayed.
    width : float, default 0.5
        Bar width
    xlimits : tuple, default None
        Tuple passed to axes.xlim()
    ylimits : tuple, default None
        Tuple passed to axes.ylim()
    title : str, default "Feature importance"
        Axes title. To disable, pass None.
    xlabel : str, default "F score"
        X axis title label. To disable, pass None.
    ylabel : str, default "Features"
        Y axis title label. To disable, pass None.
    fmap: str or os.PathLike (optional)
        The name of feature map file.
    show_values : bool, default True
        Show values on plot. To disable, pass False.
    Returns
    -------
    fig : plotly Figure object
    """
    try:
        import plotly.graph_objects as go
    except ImportError as e:
        raise ImportError('You must install plotly to plot importance') from e

    if isinstance(booster, XGBModel):
        importance = booster.get_booster().get_score(
            importance_type = importance_type, fmap = fmap)
    elif isinstance(booster, Booster):
        importance = booster.get_score(importance_type = importance_type, fmap = fmap)
    elif isinstance(booster, dict):
        importance = booster
    else:
        raise ValueError('tree must be Booster, XGBModel or dict instance')

    if not importance:
        raise ValueError(
            'Booster.get_score() results in empty.  ' +
            'This maybe caused by having all trees as decision dumps.')

    tuples = [(k, importance[k]) for k in importance]
    if max_features is not None:
        # pylint: disable=invalid-unary-operand-type
        tuples = sorted(tuples, key=lambda x: x[1])[-max_features:]
    else:
        tuples = sorted(tuples, key=lambda x: x[1])
    labels, values = zip(*tuples)

    text = [xlabel + ": " + str(value) for value in values]
    fig = go.Figure(go.Bar(
            y = [label.upper() for label in labels],
            x = values,
            orientation = 'h',
            width = width,
            hovertext = text if show_values else [],
            text = text if show_values else [],
            textposition = 'auto',
        ))
    if xrange is not None:
        if not isinstance(xrange, tuple) or len(xrange) != 2:
            raise ValueError('xrange must be a tuple of 2 elements')
        fig.update_xaxes(range = xrange)
    if yrange is not None:
        if not isinstance(yrange, tuple) or len(yrange) != 2:
            raise ValueError('yrange must be a tuple of 2 elements')
        fig.update_yaxes(range = yrange)
    if title is not None:
        fig.update_layout(title = {"text": title, "x": 0.5, "xanchor":  "center"})
    if xlabel is not None:
        fig.update_xaxes(title_text = xlabel)
    if ylabel is not None:
        fig.update_yaxes(title_text = ylabel)
    fig.update_xaxes(showgrid = show_grid)
    fig.update_yaxes(showgrid = False)
    return fig

## A Better `plot_tree()`...
And here's my version of `plot_tree()`...

In [14]:
def get_leaf(index, line):
    return (int(index), float(line.split("=")[1]))

def get_node(index, line):
    node, line = line.split(" ")
    feature, value = node.strip("[]").split("<")
    y, n, m = line.split(",")
    y, n, m = [y.split("=")[1], n.split("=")[1], m.split("=")[1]]
    return(int(index), feature.upper(), float(value), int(y), int(n), int(m))

def get_node_labels(nodes, features, precision):
    return [
        features.get(feature, feature) + " < " + str(round(value, precision)) for _, feature, value in nodes
    ]

def get_leaf_labels(leaves, precision):
    return [
        str(round(value, precision)) for _, value in leaves
    ]

def get_unique_edges(index, y, n, m):
    if y == n and n == m:
        assert False, "Should never have Yes = No = Missing..."
        # return[(index, y, "Yes/No/Missing")]
    if y == n:
        assert False, "Should never have Yes = No..."
        # return[(index, y, "Yes/No"), (index, m, "Missing")]
    if n == m:
        return[(index, y, "Yes"), (index, n, "No/Missing")]
    if y == m:
        return[(index, y, "Yes/Missing"), (index, n, "No")]
    return [(index, y, "Yes"), (index, n, "No"), (index, m, "Missing")]

def get_min_max_delta(v, indent):
    min_v, max_v = np.min(v), np.max(v)
    delta_v = max_v - min_v
    min_v = min_v - indent*delta_v
    max_v = max_v + indent*delta_v
    return min_v, max_v, max_v - min_v

def get_features(features):
    return { f"F{f}": features[f] for f in range(len(features))}

def get_graph(dump):
    lines = [line.strip().split(":") for line in dump.splitlines()]
    indexes, lines =[
        [i for i, j in lines ],
        [j for i, j in lines ]
    ]
    leaves = [get_leaf(i, line)
        for i, line in zip(indexes, lines)
        if line.startswith("leaf")
    ]
    nodes = [get_node(i, line)
        for i, line in zip(indexes, lines)
        if not line.startswith("leaf")
    ]
    edges = [
        get_unique_edges(index, y, n, m)
        for index, _, _, y, n, m in nodes
    ]
    edges = [item for edge in edges for item in edge]
    nodes = [(index, feature, float(value)) for index, feature, value, _, _, _ in nodes]
    graph = Graph([(i, j) for i, j, _ in edges])
    xy = graph.layout_reingold_tilford(root = 0)
    return leaves, nodes, edges, graph, xy

def apply_default(parameter, default):
    if parameter:
        return { **default, **parameter }
    return default

# GET RID OF h WE DON'T NEED IT RIGHT?
def get_edge_annotation(edge, xy, w, h, labels = {}, colors = {}, arrow = {}, label = {}, font = {}):
    i, j, text = edge
    xi, yi = xy[i]
    xj, yj = xy[j]
    xm, ym = (xi + xj)/2.0, (yi + yj)/2.0
    font['color'] = colors[text]
    arrow['arrowcolor'] = colors[text]
    arrow['arrowwidth'] = 1
    return [dict(
            x  = xj - w / 2.0, y  = yj, xref  = "x", yref  = "y",
            ax = xj - w / 2.0 - 0.05, ay = yj, axref = "x", ayref = "y",
            font = font, **arrow,
            hovertext = text,
        ), dict(
            x  = xm, y  = ym, xref  = "x", yref  = "y",
            ax = xm, ay = ym, axref = "x", ayref = "y",
            showarrow = False, **label,
            font = font, text = labels.get(text, text),
        )
    ]

def get_edge_shapes(edges, xy, w, colors = {}, line = {}):
    shapes = []
    for edge in edges:
        i, j, text = edge
        xi, yi = xy[i]
        xj, yj = xy[j]
        xi, xj = xi + w/2.0, xj - w/2.0
        dx, dy = np.maximum(xj - xi, 0.2), 0.02*np.sign(yj - yi)
#         line['color'] = colors[text]
        line = { **line, 'color': colors[text]}
        shapes.append(dict(
            type = 'path', layer = 'below', line = line, #{ **line },
            path = f"M{xi} {yi + dy}, C {xi + dx} {yi + dy}, {xj - dx} {yj}, {xj - 0.02} {yj}",
        ))
    return shapes

def get_node_shapes(nodes, xy, w, h, shape = {}, line = {}):
    x = [xy[i][0] for i, _, _ in nodes]
    y = [xy[i][1] for i, _, _ in nodes]
    w, h = w / 2.0, h / 2.0
    shapes = []
    for i in range(len(x)):
        shapes.append(dict(
             x0 = x[i] - w, y0 = y[i] - h, x1 = x[i] + w, y1 = y[i] + h,
            layer = "below", line = line, **shape,
        ))
    return shapes

def get_leaf_shapes(leaves, xy, w, h, shape = {}, line = {}):
    x = [xy[i][0] for i, _ in leaves]
    y = [xy[i][1] for i, _ in leaves]
    w, h, rx, ry = w / 2.0, h / 2.0, 0.04, w*0.04/h
    shapes = []
    if shape.get("type", "rect") == "rounded":
        shape.pop('type', 'No Key Found')
        for i in range(len(x)):
            xi, yi, xj, yj = x[i] - w, y[i] - h, x[i] + w, y[i] + h
            rounded_bl = f" M {xi+rx}, {yi} Q {xi}, {yi} {xi}, {yi+ry}"
            rounded_tl = f" L {xi}, {yj-ry} Q {xi}, {yj} {xi+rx}, {yj}"
            rounded_tr = f" L {xj-rx}, {yj} Q {xj}, {yj} {xj}, {yj-ry}"
            rounded_br = f" L {xj}, {yi+ry} Q {xj}, {yi} {xj-rx}, {yi}Z"
            shapes.append(dict(
                xref = "x", yref = "y", type = "path",
                path = rounded_bl + rounded_tl + rounded_tr + rounded_br,
                layer = "below", line = line, **shape,
             ))
    else:
        for i in range(len(x)):
            shapes.append(dict(
                x0 = x[i] - w, y0 = y[i] - h, x1 = x[i] + w, y1 = y[i] + h,
                layer = "below", line = line, **shape,
             ))
    return shapes

def get_nodes_scatter_plot(nodes, xy, features, precision, font = {}):
    x = [xy[i][0] for i, _, _ in nodes]
    y = [xy[i][1] for i, _, _ in nodes]
    return go.Scatter(x = x, y = y, mode = 'text', textfont = font,
        text = get_node_labels(nodes, features, precision),
    )

def get_leaves_scatter_plot(leaves, xy, precision, font = {}):
    x = [xy[i][0] for i, _ in leaves]
    y = [xy[i][1] for i, _ in leaves]
    return go.Scatter(x = x, y = y, mode = 'text', textfont = font,
        text = get_leaf_labels(leaves, precision),
    )

# Non-grayscale defaults are from the vibrant colormap here:
# https://personal.sron.nl/~pault/
def plot_tree(booster, tree, features = {}, width = None, height = None,
    precision = 4, scale = 0.7, font = None, grayscale = False,
    node_shape = {}, node_line = {}, node_font = {},
    leaf_shape = {}, leaf_line = {}, leaf_font = {},
    edge_labels = {}, edge_colors = {}, edge_arrow = {},
    edge_line = {}, edge_label = {}, edge_font = {}):
    #
    DEFAULT_FONT = {
        # 'family': "Courier New, monospace",
        'family': "Verdana, Helvetica, Verdana, Calibri, Garamond, Cambria, Arial",
        'size': 16,
        'color': "#000000"
    }
    DEFAULT_NODE_SHAPE = {
        'type': "rect",
        'fillcolor': "#CBCBCB" if grayscale else "rgba(0,153,136,0.8)", #"#009988", #"#C2CA95",
        'opacity': 1.0,
    }
    DEFAULT_NODE_LINE = {
        'color': "#666666" if grayscale else "rgb(238,119,51)", #"olive",
        'width': 1,
        'dash': "solid", # ['solid', 'dot', 'dash', 'longdash', 'dashdot', 'longdashdot']
    }
    DEFAULT_LEAF_SHAPE = {
        'type': "rounded",
        'fillcolor': "#EDEDED" if grayscale else "rgba(238,119,51,0.8)", #"#EE7733", #"#E8D8C0", #"#B5B97A",
        'opacity': 1.0,
    }
    DEFAULT_LEAF_LINE = {
        'color': "#777777" if grayscale else "rgb(0,153,136)", #"brown", #"olive",
        'width': 1,
        'dash': "solid", # ['solid', 'dot', 'dash', 'longdash', 'dashdot', 'longdashdot']
    }
    DEFAULT_EDGE_LABELS = {
        'Yes': "Yes",
        'No': "No",
        'Missing': "Missing",
        'Yes/Missing': "Yes/Missing",
        'No/Missing': "No/Missing"
    }
    # COLORMAP = ["#005ab5", "#DC3220"]
    DEFAULT_EDGE_COLORS = {
        'Yes': "#222222" if grayscale else "#005AB5",
        'No': "#777777" if grayscale else "#CC3311", #"#DC3220",
        'Missing': "#AAAAAA",
        'Yes/Missing':  "#222222" if grayscale else "#0077BB", #"#005AB5",
        'No/Missing': "#777777" if grayscale else "#DC3220",
    }
    DEFAULT_EDGE_ARROW = {
        'arrowhead': 3, # Integer between or equal to 0 and 8
        'arrowsize': 1.5, # Relative to arrowwidth
        'arrowwidth': 1,
    }
    DEFAULT_EDGE_LINE = {
        'width': 1.5,
        'dash': "solid", # ['solid', 'dot', 'dash', 'longdash', 'dashdot', 'longdashdot']
    }
    DEFAULT_EDGE_LABEL = {
        'align': "center",
        'bgcolor': "#FFFFFF",
        'bordercolor': "rgba(0,0,0,0)",
        'borderpad': 1,
        'borderwidth': 1,
        'opacity': 1.0,
        'textangle': 0,
        'valign': "middle",
        'visible': True,
    }
    font = apply_default(font, DEFAULT_FONT)
    node_shape = apply_default(node_shape, DEFAULT_NODE_SHAPE)
    node_line = apply_default(node_line, DEFAULT_NODE_LINE)
    node_font = { **font, **node_font}
    leaf_shape = apply_default(leaf_shape, DEFAULT_LEAF_SHAPE)
    leaf_line  = apply_default(leaf_line,  DEFAULT_LEAF_LINE)
    leaf_font = { **font, **leaf_font }
    edge_labels = apply_default(edge_labels, DEFAULT_EDGE_LABELS)
    edge_colors = apply_default(edge_colors, DEFAULT_EDGE_COLORS)
    edge_arrow  = apply_default(edge_arrow,  DEFAULT_EDGE_ARROW)
    edge_line = apply_default(edge_line, DEFAULT_EDGE_LINE)
    edge_label = apply_default(edge_label, DEFAULT_EDGE_LABEL)
    edge_font   = { **font, **{ 'size': font.get('size', 16) - 2 }, **edge_font }
    #
    if isinstance(features, list):
        features = get_features(features)
    #
    dump = booster.get_dump()[tree]
    leaves, nodes, edges, graph, xy = get_graph(dump)
    #
    xy = [[y, x] for x, y in xy]
    #
    _, layers, _ = graph.bfs(0)
    # Fix the root node's y position...
    # it's sometimes wrong...
    if len(layers) > 2:
        y = np.sum([xy[i][1] for i in range(layers[1], layers[2])])
        xy[0][1] = y/(layers[2] - layers[1])
    # Could instead do a dfs and then work from deepest layers
    # to the root setting all parents' y values to the mean of
    # their children. Here's the dfs() description:
    # def dfs(self, vid, mode=OUT):
    # Conducts a depth first search (DFS) on the graph.
    # Parameters	vid	the root vertex ID
    # mode	either "in" or "out" or "all", ignored for undirected graphs.
    # Returns	a tuple with the following items:
    # The vertex IDs visited (in order)
    # The parent of every vertex in the DFS
    #
    layers = [layers[i] - layers[i-1] for i in range(1, len(layers))]
    tree_depth, tree_width = len(layers), np.max(layers)
    #
    print(scale)
    # KLUDGE: We need to get font metrics and do this right...
    w, h = scale*font.get('size', 14)*np.max([len(label)
        for label in get_node_labels(nodes, features, precision) + get_leaf_labels(leaves, precision)
    ]), 3*font.get('size', 14)
    #
    if width is None:
        width = w*tree_depth + 70*(tree_depth + 1) #// 4 # 5*height // 2
    if height is None:
        height = h*tree_width + int(2.5*h*np.sqrt(2.0 + (len(nodes) + len(leaves))/tree_depth))
    #
    x = [xy[i][0] for i, _, _ in nodes] + [xy[i][0] for i, _ in leaves]
    y = [xy[i][1] for i, _, _ in nodes] + [xy[i][1] for i, _ in leaves]
    #
    min_x, max_x, delta_x = get_min_max_delta(x, 0.1)
    min_y, max_y, delta_y = get_min_max_delta(y, 0.1)
    pixels_x, pixels_y = width/delta_x, height/delta_y
    # KLUDGE: We need to get font metrics and do this right...
    w, h = w/pixels_x, h/pixels_y
    #
    nodes_scatter_plot = get_nodes_scatter_plot(nodes, xy, features, precision, node_font)
    leaves_scatter_plot = get_leaves_scatter_plot(leaves, xy, precision, leaf_font)
    #
    layout = go.Layout(
        font = font,
        showlegend = False,
        autosize = False,
        width = width,
        height = height,
        plot_bgcolor = "#FFFFFF",
        xaxis = dict(visible = False),
        yaxis = dict(visible = False),
        shapes = get_edge_shapes(
            edges, xy, w, colors = edge_colors, line = edge_line
        ) + get_node_shapes(
            nodes, xy, w, h, node_shape, node_line
        ) + get_leaf_shapes(
            leaves, xy, w, h, leaf_shape, leaf_line
        ),
        xaxis_range = [min_x, max_x],
        yaxis_range = [min_y, max_y],
    )
    #
    fig = go.Figure([nodes_scatter_plot, leaves_scatter_plot], layout)
    #
    for edge in edges:
        arrow, label = get_edge_annotation(edge, xy, w, h/3.0,
            labels = edge_labels, colors = edge_colors,
            arrow = edge_arrow, label = edge_label, font = edge_font,
        )
        fig.add_annotation(arrow)
        fig.add_annotation(label)
    return fig


## Plotting the Importance and First Five Trees
And here I demonstrate them on the `xgboost` model above...

In [15]:
classifier = pipeline["classifier"]
trees = [tree for tree in classifier.get_booster()]
# print(pio.templates["presentation"])

# fig, ax = plt.subplots(figsize = (10, 10))
# plot_importance(classifier, ax = ax, height = 0.7, title = "Feature Importance", show_values = False)
# plt.show()

fig = plot_importance(classifier, title = "Feature Importance")
fig.update_layout(template = "presentation")
fig.update_layout(width = 700)
fig.show()

# print("Plotting the first five trees:")
# for tree in range(np.minimum(3, len(trees))):
#     print("Tree: ", tree)
#     fig, ax = plt.subplots(figsize = (25, 25))
#     # https://stackoverflow.com/questions/37340474/xgb-plot-tree-font-size-python
# #     plot_tree(classifier, num_trees = tree, rankdir='LR', ax = ax, yes_color = COLORMAP[0], no_color = COLORMAP[1])
#     xgb_plot_tree(classifier, num_trees = tree, rankdir='LR', ax = ax,
# #         yes_color = COLORMAP[0], no_color = COLORMAP[1],
# #         condition_node_params = {
# #             'shape': 'box',
# #             'style': 'filled,rounded',
# #             'fillcolor': '#78bceb'
# #         },
# #         leaf_node_params = {
# #             'shape': 'box',
# #             'style': 'filled',
# #             'fillcolor': '#e48038'
# #         },
#         # graph_attrs = { "fontsize": SMALL_SIZE}
#     )
#     plt.show()

booster = classifier.get_booster()
print("Plotting the first five trees:")
for tree in range(0, np.minimum(5, len(trees))):
    precision = 4
    scale = 0.75
    title = f"Tree {tree}"
#     width, height = 1200, 800
    features = [LABELS[column] for column in df.drop("HeartDisease", axis = 1).columns]
    edge_labels = { 'Yes/Missing': "Yes" }
    fig = plot_tree(booster, tree, features = features, grayscale = tree % 2 == 0, edge_labels = edge_labels)
    fig.update_layout(
        margin = { 'l': 10, 'r': 10, 't': 50, 'b': 10 },
        title = { 'text': title, 'x': 0.5, 'xanchor': "center" },
    )
    fig.show()

#     fig = plot_tree(booster, tree# , features, # width = None, height = None, #height,
#         # precision = precision, scale = scale, font = font,
#         # node_shape = node_shape, node_line = node_line, # node_font = font,
#         #leaf_shape = leaf_shape, leaf_line = leaf_line, # leaf_font = font,
#         # edge_labels = edge_labels, edge_colors = edge_colors,
#         # edge_line = {}, edge_font = edge_font
#     )
#     fig.update_layout(
#         margin = { 'l': 10, 'r': 10, 't': 50, 'b': 10 },
#         title = { 'text': title, 'x': 0.5, 'xanchor': "center" },
#     )
#     fig.show()

Plotting the first five trees:
0.7


0.7


0.7


0.7


0.7
