In [29]:
import numpy as np
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from itertools import combinations

df = pd.read_csv('data/heart.csv')
df

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
0,52,1,0,125,212,0,1,168,0,1.0,2,2,3,0
1,53,1,0,140,203,1,0,155,1,3.1,0,0,3,0
2,70,1,0,145,174,0,1,125,1,2.6,0,0,3,0
3,61,1,0,148,203,0,1,161,0,0.0,2,1,3,0
4,62,0,0,138,294,1,1,106,0,1.9,1,3,2,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1020,59,1,1,140,221,0,1,164,1,0.0,2,0,2,1
1021,60,1,0,125,258,0,0,141,1,2.8,1,1,3,0
1022,47,1,0,110,275,0,0,118,1,1.0,1,1,2,0
1023,50,0,0,110,254,0,0,159,0,0.0,2,0,2,1


In [30]:
# min and max values for each feature
min_max_values = {}
for column in df.columns:
    min_val = df[column].min()
    max_val = df[column].max()
    min_max_values[column] = (min_val, max_val) 
min_max_values

{'age': (29, 77),
 'sex': (0, 1),
 'cp': (0, 3),
 'trestbps': (94, 200),
 'chol': (126, 564),
 'fbs': (0, 1),
 'restecg': (0, 2),
 'thalach': (71, 202),
 'exang': (0, 1),
 'oldpeak': (0.0, 6.2),
 'slope': (0, 2),
 'ca': (0, 4),
 'thal': (0, 3),
 'target': (0, 1)}

In [31]:
# preparation of data for sankey diagram

# put age into bins 25-35, 35-45, 45-55, 55-65, 65-75, 75-85
age_bins = [25, 35, 45, 55, 65, 75, 85]
age_labels = ['25-34', '35-44', '45-54', '55-64', '65-74', '75-84']
df['age_group'] = pd.cut(df['age'], bins=age_bins, labels=age_labels, right=False)  
age_counts = df['age_group'].value_counts().sort_index()

# put trestbps into bins 90-110, 110-130, 130-150, 150-170, 170-190, 190-210
trestbps_bins = [90, 110, 130, 150, 170, 190, 210]
trestbps_labels = ['90-109', '110-129', '130-149', '150-169', '170-189', '190-209']
df['trestbps_group'] = pd.cut(df['trestbps'], bins=trestbps_bins, labels=trestbps_labels, right=False)  
trestbps_counts = df['trestbps_group'].value_counts().sort_index()  

# put chol into bins 100-200, 200-300, 300-400, 400-500, 500-600
chol_bins = [100, 200, 300, 400, 500, 600]
chol_labels = ['100-199', '200-299', '300-399', '400-499', '500-599']
df['chol_group'] = pd.cut(df['chol'], bins=chol_bins, labels=chol_labels, right=False)  
chol_counts = df['chol_group'].value_counts().sort_index()

# put thalach into bins 70-100, 100-130, 130-160, 160-190, 190-220
thalach_bins = [70, 100, 130, 160, 190, 220]
thalach_labels = ['70-99', '100-129', '130-159', '160-189', '190-219']
df['thalach_group'] = pd.cut(df['thalach'], bins=thalach_bins, labels=thalach_labels, right=False)  
thalach_counts = df['thalach_group'].value_counts().sort_index()    

df.head(5)

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target,age_group,trestbps_group,chol_group,thalach_group
0,52,1,0,125,212,0,1,168,0,1.0,2,2,3,0,45-54,110-129,200-299,160-189
1,53,1,0,140,203,1,0,155,1,3.1,0,0,3,0,45-54,130-149,200-299,130-159
2,70,1,0,145,174,0,1,125,1,2.6,0,0,3,0,65-74,130-149,100-199,100-129
3,61,1,0,148,203,0,1,161,0,0.0,2,1,3,0,55-64,130-149,200-299,160-189
4,62,0,0,138,294,1,1,106,0,1.9,1,3,2,0,55-64,130-149,200-299,100-129


In [32]:
# plotly graph objects to create sankey diagram
labels = []
labels.extend([f'Sex: {label}' for label in ['0', '1']])
labels.extend([f'Age: {label}' for label in age_labels])
labels.extend([f'Trestbps: {label}' for label in trestbps_labels])
labels.extend([f'Chol: {label}' for label in chol_labels])
labels.extend([f'Thalach: {label}' for label in thalach_labels])
labels.extend([f'Target: {label}' for label in ['0', '1']])
label_indices = {label: i for i, label in enumerate(labels)}
sources = []
targets = []
values = []
link_colors = []

# sex to Age
for sex_label in ['0', '1']:
    for age_label in age_labels:
        for target_val in [0, 1]:
            count = len(df[(df['sex'] == int(sex_label)) & (df['age_group'] == age_label) & (df['target'] == target_val)])
            if count > 0:
                sources.append(label_indices[f'Sex: {sex_label}'])
                targets.append(label_indices[f'Age: {age_label}'])
                values.append(count)
                link_colors.append('rgba(0, 255, 0, 0.4)' if target_val == 0 else 'rgba(255, 0, 0, 0.4)')

# Age to Trestbps
for age_label in age_labels:
    for trestbps_label in trestbps_labels:
        for target_val in [0, 1]:
            count = len(df[(df['age_group'] == age_label) & (df['trestbps_group'] == trestbps_label) & (df['target'] == target_val)])
            if count > 0:
                sources.append(label_indices[f'Age: {age_label}'])
                targets.append(label_indices[f'Trestbps: {trestbps_label}'])
                values.append(count)
                link_colors.append('rgba(0, 255, 0, 0.4)' if target_val == 0 else 'rgba(255, 0, 0, 0.4)')

# Trestbps to Chol
for trestbps_label in trestbps_labels:
    for chol_label in chol_labels:
        for target_val in [0, 1]:
            count = len(df[(df['trestbps_group'] == trestbps_label) & (df['chol_group'] == chol_label) & (df['target'] == target_val)])
            if count > 0:
                sources.append(label_indices[f'Trestbps: {trestbps_label}'])
                targets.append(label_indices[f'Chol: {chol_label}'])
                values.append(count)
                link_colors.append('rgba(0, 255, 0, 0.4)' if target_val == 0 else 'rgba(255, 0, 0, 0.4)')

# Chol to Thalach
for chol_label in chol_labels:
    for thalach_label in thalach_labels:
        for target_val in [0, 1]:
            count = len(df[(df['chol_group'] == chol_label) & (df['thalach_group'] == thalach_label) & (df['target'] == target_val)])
            if count > 0:
                sources.append(label_indices[f'Chol: {chol_label}'])
                targets.append(label_indices[f'Thalach: {thalach_label}'])
                values.append(count)
                link_colors.append('rgba(0, 255, 0, 0.4)' if target_val == 0 else 'rgba(255, 0, 0, 0.4)')

# Thalach to Target
for thalach_label in thalach_labels:
    for target_val in [0, 1]:
        count = len(df[(df['thalach_group'] == thalach_label) & (df['target'] == target_val)])
        if count > 0:
            sources.append(label_indices[f'Thalach: {thalach_label}'])
            targets.append(label_indices[f'Target: {target_val}'])
            values.append(count)
            link_colors.append('rgba(0, 255, 0, 0.4)' if target_val == 0 else 'rgba(255, 0, 0, 0.4)')

# Create color mapping for each variable group
sex_color = '#636EFA'
age_color = '#EF553B'
trestbps_color = '#00CC96'
chol_color = '#AB63FA'
thalach_color = '#FFA15A'
target_color = '#19D3F3'

# Assign colors to each node based on its variable group
node_colors = []
node_colors.extend([sex_color] * 2)  # Sex: 0, 1
node_colors.extend([age_color] * len(age_labels))  # Age groups
node_colors.extend([trestbps_color] * len(trestbps_labels))  # Trestbps groups
node_colors.extend([chol_color] * len(chol_labels))  # Chol groups
node_colors.extend([thalach_color] * len(thalach_labels))  # Thalach groups
node_colors.extend([target_color] * 2)  # Target: 0, 1

sankey_data = go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color="black", width=0.5),
        label=labels,
        color=node_colors
    ),
    link=dict(
        source=sources,
        target=targets,
        value=values,
        color=link_colors
    )
)   
fig = go.Figure(data=[sankey_data])
fig.update_layout(title_text="Sankey Diagram of Heart Disease Features", font_size=10)
fig.show()

fig.write_image("sankey_heart_disease.png", scale=5, width=1200, height=600)