In [5]:
import json
import random
import numpy as np
import pandas as pd
from pprint import pprint
#
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots
#
from xgboost import XGBClassifier
from xgboost import XGBModel
from xgboost import Booster
#
from sklearn.compose import ColumnTransformer
from sklearn.feature_selection import SelectKBest, chi2
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import make_scorer
from sklearn.metrics import plot_confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import MinMaxScaler
from sklearn.utils import shuffle

# import PDSUtilities
from PDSUtilities.xgboost import plot_importance
from PDSUtilities.xgboost import plot_tree
from PDSUtilities.plotly import ColorblindSafeColormaps
import PDSUtilities.plotly.templates

# from PDSUtilities.pandas import plot_histograms
# print("Using PDSUtilities version ", PDSUtilities.__version__)

# from plotly.offline import init_notebook_mode, iplot
# init_notebook_mode(connected = True)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# # Colorblindness friendly colours...
# # It is important to make our work
# # as accessible as possible...
# COLORMAP = ["#005ab5", "#DC3220"]
# Labels for plotting...
LABELS = {
    "Sex": "Sex",
    "Age": "Age",
    "MaxHR": "Max HR",
    "OldPeak": "Old Peak",
    "STSlope": "ST Slope",
    "RestingBP": "Rest. BP",
    "FastingBS": "Fast. BS",
    "RestingECG": "Rest. ECG",
    "Cholesterol": "Cholesterol",
    "HeartDisease": "Heart Disease",
    "ChestPainType": "Chest Pain",
    "ExerciseAngina": "Ex. Angina",
}
# Random seed for determinism...
SEED = 395147


'simple_white+DrJohnWagner'

In [2]:
# Loading the data from the csv file...
df = pd.read_csv("./data/heart.csv")
df.head()


Unnamed: 0,Age,Sex,ChestPainType,RestingBP,Cholesterol,FastingBS,RestingECG,MaxHR,ExerciseAngina,Oldpeak,ST_Slope,HeartDisease
0,40,M,ATA,140,289,0,Normal,172,N,0.0,Up,0
1,49,F,NAP,160,180,0,Normal,156,N,1.0,Flat,1
2,37,M,ATA,130,283,0,ST,98,N,0.0,Up,0
3,48,F,ASY,138,214,0,Normal,108,Y,1.5,Flat,1
4,54,M,NAP,150,195,0,Normal,122,N,0.0,Up,0


In [3]:
# Fix the egregious column naming error...
df = df.rename(columns = {"ST_Slope": "STSlope", "Oldpeak": "OldPeak"})

# Always test these things...
assert len(df["STSlope"]) > 0, "Ruh roh! ST_Slope is still terribly mistaken!"
assert len(df["OldPeak"]) > 0, "Ruh roh! Oldpeak is still terribly mistaken!"

# Convert target to categorical
target = pd.Categorical(df["HeartDisease"])
df["HeartDisease"] = target.codes

print("Datatypes")
print("---------")
print(df.dtypes)


Datatypes
---------
Age                 int64
Sex                object
ChestPainType      object
RestingBP           int64
Cholesterol         int64
FastingBS           int64
RestingECG         object
MaxHR               int64
ExerciseAngina     object
OldPeak           float64
STSlope            object
HeartDisease         int8
dtype: object


In [4]:
# Break the columns into two groupings...
categorical_columns = [column for column in df.columns if df[column].dtypes == object]
numerical_columns   = [column for column in df.columns if df[column].dtypes != object]

if "HeartDisease" in numerical_columns:
    numerical_columns.remove("HeartDisease")

assert "HeartDisease" not in numerical_columns, "Ruh roh! HeartDisease is still in numerical_columns!"

print("Categorical Columns: ", categorical_columns)
print("  Numerical Columns: ", numerical_columns)


Categorical Columns:  ['Sex', 'ChestPainType', 'RestingECG', 'ExerciseAngina', 'STSlope']
  Numerical Columns:  ['Age', 'RestingBP', 'Cholesterol', 'FastingBS', 'MaxHR', 'OldPeak']


In [5]:
def set_column_value_to_normal_distribution(df, column, value):
    # Compute the column's mean and standard deviation
    # after removing rows whose column matches value...
    mean_value = df[df[column] != value][column].mean()
    std_value  = df[df[column] != value][column].std()
    # Create a random number generator...
    rng = np.random.default_rng(SEED)
    # Now set the column of those rows to a
    # random sample from a normal distribution...
    df[column] = df[column].apply(
        lambda x : rng.normal(mean_value, std_value) if x == value else x
    )
    return df

df = set_column_value_to_normal_distribution(df, "RestingBP"  , 0)
df = set_column_value_to_normal_distribution(df, "Cholesterol", 0)

# Always test...
assert len(df[df["RestingBP"  ] == 0]) == 0, "Ruh roh! One or more patients has crashed again!"
assert len(df[df["Cholesterol"] == 0]) == 0, "Ruh roh! One or more patients has crashed again!"


In [4]:
# Copyright 2022 by Contributors

import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.graph_objects as go
from pandas.api.types import is_integer_dtype
from PDSUtilities.pandas import get_numerical_columns
from PDSUtilities.plotly import get_colors
from PDSUtilities.plotly import apply_default
from PDSUtilities.plotly import get_font
from PDSUtilities.plotly import update_title
from PDSUtilities.plotly import update_width_and_height

def get_line(df, target, colors):
    line = dict(
        color = colors[0],
        showscale = False,
    )
    if target is not None:
        values = df[target]
        if df[target].dtypes == 'O':
            values = df[target].astype('category').cat.codes
        line['color'] = values
        line['colorscale'] = [
            colors[index] for index in range(len(np.unique(values)))
        ]
    return line

def get_dimension(df, column, labels):
    dimension = dict(
        values = df[column],
        label = labels.get(column, column),
        name = column,
    )
    if df[column].dtypes == 'O':
        categories = df[column].astype('category').cat
        dimension['values'] = categories.codes
        dimension['tickvals'] = np.sort(np.unique(categories.codes))
        dimension['ticktext'] = categories.categories
    elif is_integer_dtype(df[column]) and len(df[column].unique()) <= 8:
        dimension['tickvals'] = np.sort(df[column].unique())
        dimension['ticktext'] = np.sort(df[column].unique())
    return dimension

# TODO: #8 add template and misc args, comments and update README.md for plot_parallel functions...
def plot_parallel_coordinates(df, target = None, columns = None, labels = {},
    width = None, height = None, title = None, colors = 0,
    font = {}, tick_font = {}, axis_font = {}, title_font = {}):
    #
    default_font = get_font()
    font = apply_default(default_font, font)
    #
    colors = get_colors(colors)
    columns = get_numerical_columns(df, columns)
    if target is not None and target not in columns:
        columns = [target] + columns
    #
    if target is not None and target not in columns:
        columns = [target] + columns
    fig = go.Figure(go.Parcoords(
        dimensions = list([
            get_dimension(df, column, labels)
            for column in columns
        ]),
        line = get_line(df, target, colors),
        labelfont = apply_default(font, axis_font),
        tickfont = apply_default(font, tick_font),
        # This eliminates the range! Set color to background!
        rangefont = { 'size': 1, 'color': "#FFFFFF" }
    ))
    fig = update_width_and_height(fig, width, height)
    fig = update_title(fig, title, title_font, font)
    # if template is not None:
    #     fig.update_layout(template = template)
    fig.update_layout(font = font)
    return fig


'simple_white+DrJohnWagner'

In [7]:
fig = plot_parallel_coordinates(df.iloc[::4, :], target = "ChestPainType",
    # columns = df.columns,
    title = "Heart Disease Dataset Numerical Columns", colors = 1, font = { 'size': 14 })
fig.show()
# fig = plot_histograms(df, target = "ChestPainType", bins = 25, colors = "Vibrant", title = "Dataset Histograms", template = "presentation")
# fig.show()
# fig = plot_histograms(df, target = "ChestPainType", bins = {"Age": 10, "MaxHR": 20, "Cholesterol": 25}, colors = "Vibrant", barmode = "group", title = "Dataset Histograms", template = "presentation")
# fig.show()
# fig = plot_histograms(df, cols = 3, target = "HeartDisease", colors = -1, barmode = "overlay", title = "Heart Disease Dataset Grouped by Chest Pain Type", template = "simple_white")
# fig.show()


In [21]:
# Copyright 2022 by Contributors

import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.graph_objects as go
from PDSUtilities.pandas import get_categorical_columns
from PDSUtilities.plotly import get_colors
from PDSUtilities.plotly import apply_default
from PDSUtilities.plotly import get_font
from PDSUtilities.plotly import update_title
from PDSUtilities.plotly import update_width_and_height

def get_line(df, target, colors):
    line = dict(
        color = colors[0],
        showscale = False,
    )
    if target is not None:
        values = df[target]
        if df[target].dtypes == 'O':
            values = df[target].astype('category').cat.codes
        line['color'] = values
        line['colorscale'] = [
            colors[index % len(colors)] for index in range(len(np.unique(values)))
        ]
    return line

def plot_parallel_categories(df, target = None, columns = None, labels = {},
    width = None, height = None, title = None, colors = 0,
    font = {}, tick_font = {}, axis_font = {}, title_font = {}):
    font = apply_default(get_font(), font)
    #
    colors = get_colors(colors)
    columns = get_categorical_columns(df, columns)
    if target is not None and target not in columns:
        columns = [target] + columns
    #
    if target is not None:
        if target in columns:
            columns.remove(target)
        columns = [target] + columns
    fig = go.Figure(go.Parcats(
        dimensions = list([
            dict(
                values = df[column],
                label = labels.get(column, column),
                categoryorder = "category ascending",
            ) for column in columns
        ]),
        line = get_line(df, target, colors),
        labelfont = apply_default(font, axis_font),
        tickfont = apply_default(font, tick_font),
    ))
    fig = update_width_and_height(fig, width, height)
    fig = update_title(fig, title, title_font, font)
    # if template is not None:
    #     fig.update_layout(template = template)
    fig.update_layout(font = font)
    return fig

In [22]:
fig = plot_parallel_categories(df.iloc[::4, :], target = "ChestPainType", #columns = df.columns,
    title = "Heart Disease Dataset Categorical Columns", colors = 1, font = { 'size': 12 },
)
fig.show()


In [50]:
# Copyright 2022 by Contributors

import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from PDSUtilities.pandas import get_numerical_columns
from PDSUtilities.plotly import apply_default
from PDSUtilities.plotly import get_font
from PDSUtilities.plotly import get_marker
from PDSUtilities.plotly import get_colors
from PDSUtilities.plotly import update_layout
from PDSUtilities.plotly import update_title
from PDSUtilities.plotly import update_width_and_height

def get_labels(columns, labels):
    if isinstance(labels, list):
        message = "Length of labels list must match length of columns list..."
        assert len(columns) == len(labels), message
        labels = { f"{columns[c]}": labels[c] for c in range(len(columns))}
    return labels

def get_correlation_label(correlations, columns, labels, row, col, precision, align = "middle"):
    BR = "<br />"
    col_label = labels.get(columns[col], columns[col])
    row_label = labels.get(columns[row], columns[row])
    cor_label = f"<b>{np.round(correlations.iloc[row, col], precision)}</b>"
    if align == "top":
        return "<span>" + cor_label + BR + col_label + BR + row_label + "</span>"
    if align == "bottom":
        return "<span>" + col_label + BR + row_label + BR + cor_label + "</span>"
    return "<span>" + col_label + BR + cor_label + BR + row_label + "</span>"

def get_center(values):
    return (0.5*(min(values) + max(values)))

def plot_correlation_scatter(df, target = None, columns = None, labels = {},
    width = None, height = None, title = None, precision = 4,
    template = None, colors = 0, marker = {},
    font = {}, axis_font = {}, tick_font = {}, legend_font = {}, hover_font = {},
    label_font = {}, title_font = {}):
    #
    font = apply_default(get_font(), font)
    axis_font = apply_default(font, axis_font)
    tick_font = apply_default(font, tick_font)
    label_font = apply_default(font, label_font)
    legend_font = apply_default(font, legend_font)
    #
    marker = apply_default(get_marker(), marker)
    #
    colors = get_colors(colors)
    columns = get_numerical_columns(df, columns, target)
    labels = get_labels(columns, labels)
    rows = [columns[c] for c in range(len(columns))]
    cols = [columns[c] for c in range(len(columns))]
    correlations = df[columns].corr()
    #
    values = [] if target is None else [value for value in df[target].unique()]
    values = [] if target is None else df[target].unique()
    #
    fig = make_subplots(rows = len(rows), cols = len(cols),
        horizontal_spacing = 0.1/len(cols),
        vertical_spacing = 0.1/len(rows),
        shared_xaxes = True,
        shared_yaxes = True,
        # print_grid = True,
    )
    for r in range(len(rows)):
        for c in range(r):
            text = get_correlation_label(
                correlations, columns, labels, r, c, precision, "middle"
            )
            for value in values:
                selection = df[target] == value
                fig.append_trace(
                    go.Scatter(
                        x = df[selection][cols[c]],
                        y = df[selection][rows[r]],
                        mode = 'markers',
                        marker = get_marker(marker, color = colors[value]),
                        name = labels.get(target, target) + " = " + str(value),
                        legendgroup = target + " = " + str(value),
                        hoverlabel = dict(font = hover_font),
                        showlegend = r == 1 and c == 0,
                    ),
                    r + 1, c + 1
                )
            if target is None:
                fig.append_trace(
                    go.Scatter(
                        x = df[cols[c]],
                        y = df[rows[r]],
                        mode = 'markers',
                        marker = get_marker(marker, color = colors[0]),
                        name = rows[r] + "/" + cols[c],
                        hoverlabel = dict(font = hover_font),
                        showlegend = False,
                    ),
                    r + 1, c + 1
                )
            # Used to center correlation text in
            # the plot as plotly annotations...
            fig.append_trace(
                go.Scatter(
                    x = [get_center(df[cols[c]])],
                    y = [get_center(df[rows[r]])],
                    mode = 'text',
                    text = text,
                    textfont = label_font,
                    showlegend = False,
                    hoverinfo = "skip",
                ),
                c + 1, r + 1
            )
    # Point axes in upper plots to the axes
    # in the corresponding lower plots...
    for r in range(len(rows)):
        for c in range(r):
            # (r, c) corresponds to who we are pointing at...
            x, y = (len(rows) - 1)*len(cols) + c, r*len(cols)
            # So (c, r) is who we are...
            fig.update_xaxes(matches = f"x{x+1}", row = c + 1, col = r + 1)
            fig.update_yaxes(matches = f"y{y+1}", row = c + 1, col = r + 1)
    for r in range(len(rows)):
        fig.update_yaxes(
            title_text = labels.get(rows[r], rows[r]), row = r + 1, col = 1
        )
    for c in range(len(cols)):
        fig.update_xaxes(
            title_text = labels.get(cols[c], cols[c]), row = len(rows), col = c + 1
        )
    for r in range(len(rows)):
        for c in range(r, len(cols)):
            fig.update_xaxes(showgrid = False, row = r + 1, col = c + 1)
            fig.update_yaxes(showgrid = False, row = r + 1, col = c + 1)
        for c in range(1, len(cols)):
            fig.update_yaxes(ticks = "", row = r + 1, col = c + 1)
    for r in range(len(rows) - 1):
        for c in range(len(cols)):
            fig.update_xaxes(ticks = "", row = r + 1, col = c + 1)
    fig.update_xaxes(title_font= axis_font, tickfont = tick_font, linecolor = "black")
    fig.update_yaxes(title_font= axis_font, tickfont = tick_font, linecolor = "black")
    fig.update_xaxes(linewidth = 0.5, mirror = True, zeroline = False)
    fig.update_yaxes(linewidth = 0.5, mirror = True, zeroline = False)
    #
    fig.update_layout(legend_font = legend_font)
    if target is not None:
        fig.update_layout(legend_itemsizing = 'constant')
        fig.update_layout(legend = dict(
            orientation = 'h', yanchor = 'top', xanchor = 'center', y = 1.07, x = 0.5
        ))
    #
    fig = update_width_and_height(fig, width, height)
    fig = update_title(fig, title, title_font, font)
    fig = update_layout(fig, font = font, template = template)
    return fig

In [51]:
columns = numerical_columns
if "FastingBS" in columns:
    columns.remove("FastingBS")
fig = plot_correlation_scatter(df, target = "HeartDisease", columns = columns, labels = LABELS, #template = "simple_white",
    title = "Heart Disease Dataset Correlations",
    font = {'size': 12}, width = 800, height = 800)
fig.update_yaxes(ticksuffix = " ")
fig.update_layout(template = "presentation")
fig.show()
# print(fig.layout)

In [52]:
# Copyright 2022 by Contributors

import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from PDSUtilities.pandas import get_numerical_columns
from PDSUtilities.plotly import apply_default
from PDSUtilities.plotly import get_font
from PDSUtilities.plotly import get_marker
from PDSUtilities.plotly import update_layout
from PDSUtilities.plotly import apply_default
from PDSUtilities.plotly import get_font
from PDSUtilities.plotly import get_marker
from PDSUtilities.plotly import update_layout
from PDSUtilities.plotly import hex_to_rgb
from PDSUtilities.plotly import rgb_to_hex
from PDSUtilities.plotly import get_colors
from PDSUtilities.plotly import update_title
from PDSUtilities.plotly import update_width_and_height

def get_labels(columns, labels):
    if isinstance(labels, list):
        message = "Length of labels list must match length of columns list..."
        assert len(columns) == len(labels), message
        labels = { f"{columns[c]}": labels[c] for c in range(len(columns))}
    return labels

def get_correlation_label(correlations, columns, labels, row, col, precision, align = "middle"):
    BR = "<br />"
    col_label = labels.get(columns[col], columns[col])
    row_label = labels.get(columns[row], columns[row])
    cor_label = f"<b>{np.round(correlations.iloc[row, col], precision)}</b>"
    if align == "top":
        return "<span>" + cor_label + BR + col_label + BR + row_label + "</span>"
    if align == "bottom":
        return "<span>" + col_label + BR + row_label + BR + cor_label + "</span>"
    return "<span>" + col_label + BR + cor_label + BR + row_label + "</span>"

def get_color(value, colors):
    rlo, glo, blo = colors[1]
    rhi, ghi, bhi = colors[0]
    fraction = (value + 1.0)/2.001
    r = rlo + int(np.round(fraction*(rhi - rlo)))
    g = glo + int(np.round(fraction*(ghi - glo)))
    b = blo + int(np.round(fraction*(bhi - blo)))
    return rgb_to_hex((r, g, b))

def plot_correlation_matrix(df, columns = None, labels = {},
    width = None, height = None, title = None, precision = 4,
    template = None, colors = 0, xangle = 45, yangle = 45,
    font = {}, axis_font = {}, hover_font = {}, label_font = {},
    title_font = {}):
    #
    #
    font = apply_default(get_font(), font)
    axis_font = apply_default(font, axis_font)
    label_font = apply_default(font, label_font)
    #
    colors = get_colors(colors)
    colors = [hex_to_rgb(color) for color in colors]
    columns = get_numerical_columns(df, columns)
    labels = get_labels(columns, labels)
    rows = [columns[c] for c in range(len(columns))]
    cols = [columns[c] for c in range(len(columns))]
    correlations = df[columns].corr()
    #
    fig = make_subplots(rows = len(rows), cols = len(cols),
        horizontal_spacing = 0.1/len(cols),
        vertical_spacing = 0.1/len(rows),
        shared_xaxes = True,
        shared_yaxes = True,
        # print_grid = True,
    )
    for r in range(len(rows)):
        for c in range(r):
            value = correlations[cols[c]][rows[r]]
            text = get_correlation_label(
                correlations, columns, labels, r, c, precision, "middle"
            )
            fig.append_trace(
                go.Scatter(
                    x = [0], y = [0],
                    mode = 'markers',
                    marker = dict(
                        symbol = "square",
                        size = 1000,
                        color = get_color(value, colors),
                    ),
                    hoverlabel = dict(font = hover_font),
                    hovertemplate = f"{rows[r]}<br>{cols[c]}",
                    name = str(np.round(value, precision)), #columns[row] + "/" + columns[col],
                    showlegend = False,
                ),
                r + 1, c + 1
            )
            # Used to center correlation text in
            # the plot as plotly annotations...
            fig.append_trace(
                go.Scatter(
                    x = [0.0],
                    y = [0.0],
                    mode = 'text',
                    text = text,
                    textfont = label_font,
                    showlegend = False,
                    hoverinfo = "skip",
                ),
                c + 1, r + 1
            )
    # Point axes in upper plots to the axes
    # in the corresponding lower plots...
    fig.update_xaxes(range = [-1.0, 1.0])
    fig.update_yaxes(range = [-1.0, 1.0])
    fig.update_xaxes(showgrid = False, ticks = "",  mirror = True)
    fig.update_yaxes(showgrid = False, ticks = "",  mirror = True)
    fig.update_xaxes(linecolor = "black", linewidth = 0.5, zeroline = False)
    fig.update_yaxes(linecolor = "black", linewidth = 0.5, zeroline = False)
    fig.update_xaxes(tickfont_size = 1, tickfont_color = "#FFFFFF")
    fig.update_yaxes(tickfont_size = 1, tickfont_color = "#FFFFFF")
    # We use a single tick label as the axis label...
    for c in range(len(cols)):
        fig.update_xaxes(
            tickmode = "array",
            tickvals = [0.0],
            tickfont = axis_font,
            ticktext = [labels.get(cols[c], cols[c])],
            tickangle = -xangle,
            row = len(rows), col = c + 1,
        )
    for r in range(len(rows)):
        fig.update_yaxes(
            tickmode = "array",
            tickvals = [0.2 if yangle == 45 else 0.0],
            tickfont = axis_font,
            ticktext = [
                labels.get(rows[r], rows[r]) + "  " if yangle == 45 else " "
            ],
            tickangle = -yangle,
            row = r + 1, col = 1,
        )
    #
    fig = update_width_and_height(fig, width, height)
    fig = update_title(fig, title, title_font, font)
    fig = update_layout(fig, font = font, template = template)
    return fig

In [53]:
columns = numerical_columns
if "FastingBS" in columns:
    columns.remove("FastingBS")
fig = plot_correlation_matrix(df, columns = columns, labels = LABELS, #template = "simple_white",
    title = "Heart Disease Dataset Correlations",
    font = {'size': 11}, hover_font = {'size': 14}, width = 600, height = 600)
fig.update_yaxes(ticksuffix = " ")
fig.update_layout(template = "presentation")
fig.show()
# print(fig.layout)

In [54]:
# Copyright 2022 by Contributors

import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from PDSUtilities.plotly import apply_default
from PDSUtilities.plotly import get_font
from PDSUtilities.plotly import get_marker
from PDSUtilities.plotly import update_layout
from PDSUtilities.plotly import hex_to_rgb
from PDSUtilities.plotly import rgb_to_hex
from PDSUtilities.plotly import get_colors
from PDSUtilities.plotly import update_title
from PDSUtilities.plotly import update_width_and_height
from PDSUtilities.pandas import get_numerical_columns

def get_labels(columns, labels):
    if isinstance(labels, list):
        message = "Length of labels list must match length of columns list..."
        assert len(columns) == len(labels), message
        labels = { f"{columns[c]}": labels[c] for c in range(len(columns))}
    return labels

def get_color(value, colors):
    rlo, glo, blo = colors[1]
    rhi, ghi, bhi = colors[0]
    fraction = (value + 1.0)/2.001
    r = rlo + int(np.round(fraction*(rhi - rlo)))
    g = glo + int(np.round(fraction*(ghi - glo)))
    b = blo + int(np.round(fraction*(bhi - blo)))
    return rgb_to_hex((r, g, b))

def plot_correlation_triangle(df, columns = None, labels = {},
    width = None, height = None, title = None, precision = 4,
    template = None, colors = 0, xangle = 45, yangle = 45,
    font = {}, axis_font = {}, hover_font = {}, label_font = {},
    title_font = {}):
    #
    font = apply_default(get_font(), font)
    axis_font = apply_default(font, axis_font)
    label_font = apply_default(font, label_font)
    label_font = apply_default(label_font, { 'color': "#FFFFFF" })
    #
    colors = get_colors(colors)
    colors = [hex_to_rgb(color) for color in colors]
    columns = get_numerical_columns(df, columns)
    labels = get_labels(columns, labels)
    rows = [columns[c] for c in range(1, len(columns)    )]
    cols = [columns[c] for c in range(0, len(columns) - 1)]
    correlations = df[columns].corr()
    #
    fig = make_subplots(rows = len(rows), cols = len(cols),
        horizontal_spacing = 0.1/len(cols),
        vertical_spacing = 0.1/len(rows),
        shared_xaxes = True,
        shared_yaxes = True,
        # print_grid = True,
    )
    for r in range(len(rows)):
        for c in range(r + 1):
            value = correlations[cols[c]][rows[r]]
            fig.append_trace(
                go.Scatter(
                    x = [0.0], y = [0.0],
                    mode = 'markers+text',
                    # Make a large square to fill the plot area
                    # since we can't set the background color...
                    marker = dict(
                        symbol = "square",
                        size = 1000,
                        color = get_color(value, colors),
                    ),
                    hoverlabel = dict(font = hover_font),
                    hovertemplate = f"{rows[r]}<br>{cols[c]}",
                    text = str(np.round(value, precision)),
                    name = str(np.round(value, precision)),
                    textfont = label_font,
                    showlegend = False,
                ),
                r + 1, c + 1
            )
    fig.update_xaxes(range = [-1.0, 1.0])
    fig.update_yaxes(range = [-1.0, 1.0])
    fig.update_xaxes(showgrid = False, ticks = "",  mirror = True)
    fig.update_yaxes(showgrid = False, ticks = "",  mirror = True)
    fig.update_xaxes(linecolor = "black", linewidth = 0.5, zeroline = False)
    fig.update_yaxes(linecolor = "black", linewidth = 0.5, zeroline = False)
    # We use a single tick label as the axis label...
    for c in range(len(cols)):
        fig.update_xaxes(
            tickmode = "array",
            tickvals = [0.0],
            tickfont = axis_font,
            ticktext = [labels.get(cols[c], cols[c])],
            tickangle = -xangle,
            row = len(rows), col = c + 1,
        )
    for r in range(len(rows)):
        fig.update_yaxes(
            tickmode = "array",
            tickvals = [0.2 if yangle == 45 else 0.0],
            tickfont = axis_font,
            ticktext = [
                labels.get(rows[r], rows[r]) + "  " if yangle == 45 else " "
            ],
            tickangle = -yangle,
            row = r + 1, col = 1,
        )
    #
    fig = update_width_and_height(fig, width, height)
    fig = update_title(fig, title, title_font, font)
    fig = update_layout(fig, font = font, template = template)
    return fig

In [55]:
columns = numerical_columns
if "FastingBS" in columns:
    columns.remove("FastingBS")
fig = plot_correlation_triangle(df, columns = columns, labels = LABELS, #template = "simple_white",
    title = "Heart Disease Dataset Correlations", precision = 2,
    font = {'size': 12}, hover_font = {'size': 14}, width = 400, height = 400)
fig.update_layout(template = "presentation")
fig.show()
# print(fig.layout)