# Heart Disease Detection using Bayesial (Belief) Network

In [1]:
# imports required packages

import pandas as pd
from pgmpy.models import BayesianNetwork
from pgmpy.estimators import MaximumLikelihoodEstimator
from pgmpy.inference import VariableElimination

## The Data Set

**The Data Set:** Cleveland (open source) heart disease data set (available at https://archive.ics.uci.edu/dataset/45/heart+disease)

**Attribute Type:** Categorical, Integer, Real

**Instances:** 303

**Attributes:** 13

The **_heartdisease_** field refers to the presence of heart disease in the patient.  It is integer valued from 0 (no presence) to 4.

Attribute documentation:

- age: age in years
- sex: sex (1 = male; 0 = female)
- cp: chest pain type
    - Value 1: typical angina
    - Value 2: atypical angina
    - Value 3: non-anginal pain
    - Value 4: asymptomatic
- trestbps: resting blood pressure (in mm Hg on admission to the hospital)
- chol: serum cholestoral in mg/dl
- fbs: (fasting blood sugar > 120 mg/dl)  (1 = true; 0 = false)
- restecg: resting electrocardiographic results
    - Value 0: normal
    - Value 1: having ST-T wave abnormality (T wave inversions and/or ST elevation or depression of > 0.05 mV)
    - Value 2: showing probable or definite left ventricular hypertrophy by Estes' criteria
- thalach: maximum heart rate achieved
- exang: exercise induced angina (1 = yes; 0 = no)
- oldpeak = ST depression induced by exercise relative to rest
- slope: the slope of the peak exercise ST segment
    - Value 1: upsloping
    - Value 2: flat
    - Value 3: downsloping
- ca: number of major vessels (0-3) colored by flourosopy
- thal: 3 = normal; 6 = fixed defect; 7 = reversable defect
- heartdisease: diagnosis of heart disease (angiographic disease status)
    - Value 0: < 50% diameter narrowing
    - Value 1 through 4: > 50% diameter narrowing

In [2]:
# Gets and prints the attributes of the data set
with open("./../Data/heart_disease/processed.cleveland.attributes", "r", newline=None) as f:
    attributes = f.readline().replace('\n', '').split(',')

print(attributes)

['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal', 'heartdisease']


In [3]:
# Loads the data set and displays it
# NOTE: As the data set doesn't contain headers in it, it is set explicitely through parameter 'name'
data = pd.read_csv("./../Data/heart_disease/processed.cleveland.data", header=None, names=attributes)

display(data)

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,heartdisease
0,63.0,1.0,1.0,145.0,233.0,1.0,2.0,150.0,0.0,2.3,3.0,0.0,6.0,0
1,67.0,1.0,4.0,160.0,286.0,0.0,2.0,108.0,1.0,1.5,2.0,3.0,3.0,2
2,67.0,1.0,4.0,120.0,229.0,0.0,2.0,129.0,1.0,2.6,2.0,2.0,7.0,1
3,37.0,1.0,3.0,130.0,250.0,0.0,0.0,187.0,0.0,3.5,3.0,0.0,3.0,0
4,41.0,0.0,2.0,130.0,204.0,0.0,2.0,172.0,0.0,1.4,1.0,0.0,3.0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
298,45.0,1.0,1.0,110.0,264.0,0.0,0.0,132.0,0.0,1.2,2.0,0.0,7.0,1
299,68.0,1.0,4.0,144.0,193.0,1.0,0.0,141.0,0.0,3.4,2.0,2.0,7.0,2
300,57.0,1.0,4.0,130.0,131.0,0.0,0.0,115.0,1.0,1.2,2.0,1.0,7.0,3
301,57.0,0.0,2.0,130.0,236.0,0.0,2.0,174.0,0.0,0.0,2.0,1.0,3.0,1


## Exploratory Data Analysis (EDA)

In [4]:
# Displays the summary of the data set
display(data.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 303 entries, 0 to 302
Data columns (total 14 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   age           303 non-null    float64
 1   sex           303 non-null    float64
 2   cp            303 non-null    float64
 3   trestbps      303 non-null    float64
 4   chol          303 non-null    float64
 5   fbs           303 non-null    float64
 6   restecg       303 non-null    float64
 7   thalach       303 non-null    float64
 8   exang         303 non-null    float64
 9   oldpeak       303 non-null    float64
 10  slope         303 non-null    float64
 11  ca            303 non-null    object 
 12  thal          303 non-null    object 
 13  heartdisease  303 non-null    int64  
dtypes: float64(11), int64(1), object(2)
memory usage: 33.3+ KB


None

In [5]:
# Let's check the reason for 'ca' and 'thal' columns to be of 'object' type instead of numeric.

# Prints value counts for both these columns
print(data.ca.value_counts())
print(data.thal.value_counts())

0.0    176
1.0     65
2.0     38
3.0     20
?        4
Name: ca, dtype: int64
3.0    166
7.0    117
6.0     18
?        2
Name: thal, dtype: int64


In [6]:
# Removes the rows having '?' data it them

data = data[data.ca != '?']
data = data[data.thal != '?']

In [7]:
# Now, converts data types for all columns except 'oldpeak' to 'int'
data = data.convert_dtypes()

In [8]:
# Checks the result
print(data.dtypes)

age               Int64
sex               Int64
cp                Int64
trestbps          Int64
chol              Int64
fbs               Int64
restecg           Int64
thalach           Int64
exang             Int64
oldpeak         Float64
slope             Int64
ca               string
thal             string
heartdisease      Int64
dtype: object


In [9]:
# Converts type 'string' to 'int' for column 'ca' and 'thal'

data.ca = pd.to_numeric(data['ca']).astype(int)
data.thal = pd.to_numeric(data['thal']).astype(int)

In [10]:
# Finally, checks the data type for all column
print(data.dtypes)

age               Int64
sex               Int64
cp                Int64
trestbps          Int64
chol              Int64
fbs               Int64
restecg           Int64
thalach           Int64
exang             Int64
oldpeak         Float64
slope             Int64
ca                int64
thal              int64
heartdisease      Int64
dtype: object


In [11]:
# Finally, checks for the modified data
display(data)

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,heartdisease
0,63,1,1,145,233,1,2,150,0,2.3,3,0,6,0
1,67,1,4,160,286,0,2,108,1,1.5,2,3,3,2
2,67,1,4,120,229,0,2,129,1,2.6,2,2,7,1
3,37,1,3,130,250,0,0,187,0,3.5,3,0,3,0
4,41,0,2,130,204,0,2,172,0,1.4,1,0,3,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
297,57,0,4,140,241,0,0,123,1,0.2,2,0,7,1
298,45,1,1,110,264,0,0,132,0,1.2,2,0,7,1
299,68,1,4,144,193,1,0,141,0,3.4,2,2,7,2
300,57,1,4,130,131,0,0,115,1,1.2,2,1,7,3


## Modeling

### Fitting Model

In [12]:
# Initializes the Bayesian network model with nodes and edges
model = BayesianNetwork([
    ('age', 'heartdisease'), 
    ('sex', 'heartdisease'),
    ('cp','heartdisease'), 
    ('exang','heartdisease'),
    ('heartdisease','restecg'),
    ('heartdisease','chol')
])

In [13]:
# Fits the model to calculate Conditional Probability Distribution (CPD) 
# at each node using estimator MaximumLikelihoodEstimator

model.fit(data, estimator=MaximumLikelihoodEstimator)

### Inferencing from Model

In [14]:
# Instantiate general purpose inference algorithm having support for Bayesian Network and others
inference = VariableElimination(model)

In [15]:
# Query to infer probability for heart disease given restecg
print(inference.query(variables=['heartdisease'], evidence={'restecg': 1}))

+-----------------+---------------------+
| heartdisease    |   phi(heartdisease) |
| heartdisease(0) |              0.1012 |
+-----------------+---------------------+
| heartdisease(1) |              0.0000 |
+-----------------+---------------------+
| heartdisease(2) |              0.2392 |
+-----------------+---------------------+
| heartdisease(3) |              0.2015 |
+-----------------+---------------------+
| heartdisease(4) |              0.4581 |
+-----------------+---------------------+


In [16]:
# Query to infer probability for heart disease given 'cp'
print(inference.query(variables=['heartdisease'],evidence={'cp':2}))

+-----------------+---------------------+
| heartdisease    |   phi(heartdisease) |
| heartdisease(0) |              0.3610 |
+-----------------+---------------------+
| heartdisease(1) |              0.2159 |
+-----------------+---------------------+
| heartdisease(2) |              0.1373 |
+-----------------+---------------------+
| heartdisease(3) |              0.1537 |
+-----------------+---------------------+
| heartdisease(4) |              0.1321 |
+-----------------+---------------------+
