## Heart Disease Prediction: Multi-class vs Binary Classification Comparison

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from keras.models import Sequential
from keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
from sklearn.metrics import classification_report, accuracy_score
import sys
import random
import tensorflow as tf

print(f'python version: {sys.version}')
print(f'numpy version: {np.__version__}')
print(f'pandas version: {pd.__version__}')
print(f'tensorflow version: {tf.__version__}')

# To get more consistent results, try to set the random seed:
random.seed(19)
np.random.seed(19)
tf.random.set_seed(19)

## Data Loading and Preprocessing

In [None]:
# Import the heart disease dataset
column_names = ['age','sex','cp','trestbps','chol','fbs','restecg','thalach','exang','oldpeak','slope','ca','thal','class']
df = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data', names=column_names)

print("Original dataset info:")
print(df.dtypes)
print("\nDataset shape:", df.shape)

In [None]:
# Drop rows with NaN values from DataFrame
df = df.dropna(axis=0)

# Transform data to numeric because ca and thal are object datatypes
data = df.apply(pd.to_numeric)
print("After preprocessing:")
print(data.dtypes)
print("\nDataset shape:", data.shape)

In [None]:
# Analyze target classes
print("Original Multi-class Distribution:")
print(data['class'].value_counts().sort_index())
print("\nUnique classes:", sorted(data['class'].unique()))
print("Number of classes:", len(data['class'].unique()))

In [None]:
# Create X and Y datasets for training
X = data.iloc[:,0:13]
y_multi = data.iloc[:,-1]  # Multi-class target

# Create binary classification target
# Class 0: No heart disease (0)
# Class 1: Any level of heart disease (1, 2, 3, 4 -> 1)
y_binary = (y_multi > 0).astype(int)

print("Binary Classification Distribution:")
print(y_binary.value_counts().sort_index())
print("\nBinary classes mapping:")
print("0: No heart disease")
print("1: Heart disease (any level)")

## Multi-class Classification Model

In [None]:
# Split data for multi-class classification
X_train_multi, X_test_multi, y_train_multi, y_test_multi = train_test_split(X, y_multi, test_size=0.2, shuffle=False, random_state=90)

# One-hot encode for multi-class (5 classes: 0, 1, 2, 3, 4)
y_train_multi_cat = to_categorical(y_train_multi, num_classes=5)
y_test_multi_cat = to_categorical(y_test_multi, num_classes=5)

print("Multi-class training data shapes:")
print(f"X_train: {X_train_multi.shape}")
print(f"y_train: {y_train_multi_cat.shape}")
print(f"X_test: {X_test_multi.shape}")
print(f"y_test: {y_test_multi_cat.shape}")

In [None]:
# Build multi-class model
model_multi = Sequential()
model_multi.add(Input(shape=(13,)))
model_multi.add(Dense(10, kernel_initializer='normal', activation='relu'))
model_multi.add(Dense(8, kernel_initializer='normal', activation='relu'))
model_multi.add(Dense(4, kernel_initializer='normal', activation='relu'))
model_multi.add(Dense(5, activation='softmax'))  # 5 classes with softmax

print("Multi-class Model Architecture:")
model_multi.summary()

In [None]:
# Compile multi-class model
model_multi.compile(loss='categorical_crossentropy', optimizer=Adam(learning_rate=0.001), metrics=['accuracy'])

# Setup callbacks
callbacks_list_multi = [ModelCheckpoint(filepath='best_model_multi.keras', monitor='val_loss', save_best_only=True)]

In [None]:
# Train multi-class model
print("Training Multi-class Model...")
history_multi = model_multi.fit(X_train_multi, y_train_multi_cat, 
                                epochs=60, batch_size=8, verbose=1, 
                                validation_data=(X_test_multi, y_test_multi_cat),
                                callbacks=[callbacks_list_multi])

## Binary Classification Model

In [None]:
# Split data for binary classification
X_train_binary, X_test_binary, y_train_binary, y_test_binary = train_test_split(X, y_binary, test_size=0.2, shuffle=False, random_state=90)

print("Binary classification training data shapes:")
print(f"X_train: {X_train_binary.shape}")
print(f"y_train: {y_train_binary.shape}")
print(f"X_test: {X_test_binary.shape}")
print(f"y_test: {y_test_binary.shape}")

print("\nBinary target distribution in training set:")
print(pd.Series(y_train_binary).value_counts().sort_index())
print("\nBinary target distribution in test set:")
print(pd.Series(y_test_binary).value_counts().sort_index())

In [None]:
# Build binary classification model with same architecture but sigmoid output
model_binary = Sequential()
model_binary.add(Input(shape=(13,)))
model_binary.add(Dense(10, kernel_initializer='normal', activation='relu'))
model_binary.add(Dense(8, kernel_initializer='normal', activation='relu'))
model_binary.add(Dense(4, kernel_initializer='normal', activation='relu'))
model_binary.add(Dense(1, activation='sigmoid'))  # 1 output with sigmoid for binary classification

print("Binary Classification Model Architecture:")
model_binary.summary()

In [None]:
# Compile binary model
model_binary.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.001), metrics=['accuracy'])

# Setup callbacks
callbacks_list_binary = [ModelCheckpoint(filepath='best_model_binary.keras', monitor='val_loss', save_best_only=True)]

In [None]:
# Train binary model
print("Training Binary Classification Model...")
history_binary = model_binary.fit(X_train_binary, y_train_binary, 
                                 epochs=60, batch_size=8, verbose=1, 
                                 validation_data=(X_test_binary, y_test_binary),
                                 callbacks=[callbacks_list_binary])

## Model Evaluation and Comparison

In [None]:
# Evaluate Multi-class Model
print("=== MULTI-CLASS MODEL EVALUATION ===")
pred_multi = model_multi.predict(X_test_multi)
y_pred_multi_argmax = np.argmax(pred_multi, axis=1)
y_test_multi_argmax = np.argmax(y_test_multi_cat, axis=1)

multi_accuracy = accuracy_score(y_test_multi_argmax, y_pred_multi_argmax) * 100
print(f'Multi-class Classification Accuracy: {multi_accuracy:.2f}%')
print('\nMulti-class Classification Report:')
print(classification_report(y_test_multi_argmax, y_pred_multi_argmax))

In [None]:
# Evaluate Binary Model
print("=== BINARY CLASSIFICATION MODEL EVALUATION ===")
pred_binary = model_binary.predict(X_test_binary)
y_pred_binary = (pred_binary > 0.5).astype(int).flatten()

binary_accuracy = accuracy_score(y_test_binary, y_pred_binary) * 100
print(f'Binary Classification Accuracy: {binary_accuracy:.2f}%')
print('\nBinary Classification Report:')
print(classification_report(y_test_binary, y_pred_binary))

In [None]:
# Compare Accuracies
print("=== ACCURACY COMPARISON ===")
print(f"Multi-class Classification Accuracy: {multi_accuracy:.2f}%")
print(f"Binary Classification Accuracy: {binary_accuracy:.2f}%")
print(f"Difference: {binary_accuracy - multi_accuracy:.2f} percentage points")

if binary_accuracy > multi_accuracy:
    print("\nüèÜ WINNER: Binary Classification performs better!")
    print(f"Binary classification is {binary_accuracy - multi_accuracy:.2f} percentage points more accurate.")
elif multi_accuracy > binary_accuracy:
    print("\nüèÜ WINNER: Multi-class Classification performs better!")
    print(f"Multi-class classification is {multi_accuracy - binary_accuracy:.2f} percentage points more accurate.")
else:
    print("\nü§ù TIE: Both models perform equally well!")

In [None]:
# Plot training history comparison
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

# Multi-class accuracy
ax1.plot(history_multi.history['accuracy'], label='Training Accuracy')
ax1.plot(history_multi.history['val_accuracy'], label='Validation Accuracy')
ax1.set_title('Multi-class Model Accuracy')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Accuracy')
ax1.legend()
ax1.grid(True)

# Multi-class loss
ax2.plot(history_multi.history['loss'], label='Training Loss')
ax2.plot(history_multi.history['val_loss'], label='Validation Loss')
ax2.set_title('Multi-class Model Loss')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.legend()
ax2.grid(True)

# Binary accuracy
ax3.plot(history_binary.history['accuracy'], label='Training Accuracy')
ax3.plot(history_binary.history['val_accuracy'], label='Validation Accuracy')
ax3.set_title('Binary Model Accuracy')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Accuracy')
ax3.legend()
ax3.grid(True)

# Binary loss
ax4.plot(history_binary.history['loss'], label='Training Loss')
ax4.plot(history_binary.history['val_loss'], label='Validation Loss')
ax4.set_title('Binary Model Loss')
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Loss')
ax4.legend()
ax4.grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Final accuracy comparison bar chart
plt.figure(figsize=(10, 6))
models = ['Multi-class\n(5 classes)', 'Binary\n(2 classes)']
accuracies = [multi_accuracy, binary_accuracy]
colors = ['skyblue', 'lightcoral']

bars = plt.bar(models, accuracies, color=colors, alpha=0.8, edgecolor='black')
plt.title('Classification Accuracy Comparison', fontsize=16, fontweight='bold')
plt.ylabel('Accuracy (%)', fontsize=12)
plt.ylim(0, 100)

# Add accuracy values on top of bars
for bar, acc in zip(bars, accuracies):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
             f'{acc:.2f}%', ha='center', va='bottom', fontsize=12, fontweight='bold')

plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("SUMMARY OF RESULTS")
print("="*60)
print(f"Multi-class Classification (5 classes): {multi_accuracy:.2f}%")
print(f"Binary Classification (2 classes): {binary_accuracy:.2f}%")
print(f"Performance difference: {abs(binary_accuracy - multi_accuracy):.2f} percentage points")

if binary_accuracy > multi_accuracy:
    print(f"\n‚úÖ Binary classification is MORE ACCURATE by {binary_accuracy - multi_accuracy:.2f} percentage points")
    print("\nPossible reasons for better binary performance:")
    print("‚Ä¢ Simpler decision boundary (disease vs no disease)")
    print("‚Ä¢ More balanced classes after grouping")
    print("‚Ä¢ Reduced complexity eliminates confusion between disease severity levels")
elif multi_accuracy > binary_accuracy:
    print(f"\n‚úÖ Multi-class classification is MORE ACCURATE by {multi_accuracy - binary_accuracy:.2f} percentage points")
    print("\nPossible reasons for better multi-class performance:")
    print("‚Ä¢ Preserves important information about disease severity")
    print("‚Ä¢ Model can learn more nuanced patterns")
    print("‚Ä¢ Different disease levels have distinct characteristics")
else:
    print("\nü§ù Both approaches perform equally well")
print("="*60)