In [57]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from keras import Sequential
from keras.layers import Dense, Input, Dropout
from keras.optimizers import Adam
from sklearn.metrics import classification_report, accuracy_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.multioutput import MultiOutputClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
import joblib

In [2]:
df = pd.read_csv("extracted-data/master_data_all_leads.csv", header=[0], index_col=[0])

In [3]:
df

Unnamed: 0,age,sex,smoker,systolic_blood_pressure,diastolic_blood_pressure,bundle branch block,cardiomyopathy,dysrhythmia,healthy control,heart failure (nyha 2),...,median_qrs_durations,mean_t_durations,std_t_durations,median_t_durations,mean_qt_intervals,std_qt_intervals,median_qt_intervals,vlf_power,lf_power,hf_power
0,81.0,1.0,0.0,140.0,80.0,0.0,0.0,0.0,0.0,0.0,...,87.0,277.865385,6.348921,279.0,306.423077,6.298187,307.0,0.0,0.0,0.0
1,81.0,1.0,0.0,140.0,80.0,0.0,0.0,0.0,0.0,0.0,...,38.0,103.884615,14.653240,100.5,121.750000,18.770758,123.0,0.0,0.0,0.0
2,81.0,1.0,0.0,140.0,80.0,0.0,0.0,0.0,0.0,0.0,...,14.0,27.269231,0.879584,27.0,41.288462,1.214283,42.0,0.0,0.0,0.0
3,81.0,1.0,0.0,140.0,80.0,0.0,0.0,0.0,0.0,0.0,...,1.0,57.596154,1.348213,57.5,58.615385,1.346703,59.0,0.0,0.0,0.0
4,81.0,1.0,0.0,140.0,80.0,0.0,0.0,0.0,0.0,0.0,...,90.0,243.923077,18.162563,237.0,273.538462,36.812655,297.5,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8230,61.0,1.0,,,,0.0,0.0,0.0,0.0,0.0,...,100.0,249.807339,8.996918,254.0,281.513761,9.046886,285.0,0.0,0.0,0.0
8231,61.0,1.0,,,,0.0,0.0,0.0,0.0,0.0,...,101.0,267.342282,3.644932,268.0,299.201342,3.656345,299.0,0.0,0.0,0.0
8232,61.0,1.0,,,,0.0,0.0,0.0,0.0,0.0,...,99.0,263.449664,2.361424,263.0,295.919463,2.341676,296.0,0.0,0.0,0.0
8233,61.0,1.0,,,,0.0,0.0,0.0,0.0,0.0,...,94.0,155.228188,89.717313,116.0,183.577181,89.811524,144.0,0.0,0.0,0.0


In [4]:
diseases = ['myocardial infarction', 'healthy control',
            'valvular heart disease', 'dysrhythmia', 'heart failure (nyha 2)',
            'heart failure (nyha 3)', 'heart failure (nyha 4)',
            'palpitation', 'cardiomyopathy', 'stable angina', 'hypertrophy',
            'bundle branch block', 'unstable angina', 'myocarditis']

In [5]:
# Features
X = df.drop(columns=[*diseases, 'age', 'sex',
            'smoker', 'systolic_blood_pressure', 'diastolic_blood_pressure'])

# Targets
y = df[diseases]

In [6]:
X

Unnamed: 0,mean_r_peaks_amplitude,std_r_peaks_amplitude,median_r_peaks_amplitude,mean_rr_intervals,std_rr_intervals,median_rr_intervals,mean_heart_rate,std_heart_rate,median_heart_rate,rmssd,...,median_qrs_durations,mean_t_durations,std_t_durations,median_t_durations,mean_qt_intervals,std_qt_intervals,median_qt_intervals,vlf_power,lf_power,hf_power
0,0.443254,0.013023,0.443747,733.764706,9.436243,733.0,81.783607,1.051872,81.855389,11.131936,...,87.0,277.865385,6.348921,279.0,306.423077,6.298187,307.0,0.0,0.0,0.0
1,-0.218650,0.026060,-0.220316,733.784314,9.421115,733.0,81.781367,1.049283,81.855389,10.976338,...,38.0,103.884615,14.653240,100.5,121.750000,18.770758,123.0,0.0,0.0,0.0
2,-0.239967,0.020062,-0.237549,733.764706,9.398766,733.0,81.783496,1.047410,81.855389,11.001818,...,14.0,27.269231,0.879584,27.0,41.288462,1.214283,42.0,0.0,0.0,0.0
3,-0.197383,0.013107,-0.197509,733.764706,9.434165,733.0,81.783608,1.052207,81.855389,11.112156,...,1.0,57.596154,1.348213,57.5,58.615385,1.346703,59.0,0.0,0.0,0.0
4,-0.052284,0.209993,-0.139005,733.725490,31.606700,732.0,81.926021,3.524795,81.967213,54.297330,...,90.0,243.923077,18.162563,237.0,273.538462,36.812655,297.5,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8230,1.186274,0.051521,1.183949,1088.185185,505.843797,811.5,64.285357,21.508580,73.937406,827.893487,...,100.0,249.807339,8.996918,254.0,281.513761,9.046886,285.0,0.0,0.0,0.0
8231,0.867600,0.018980,0.866156,804.932432,59.311097,804.5,75.141017,8.471008,74.580514,97.706558,...,101.0,267.342282,3.644932,268.0,299.201342,3.656345,299.0,0.0,0.0,0.0
8232,1.039915,0.021605,1.038964,804.932432,59.458441,804.0,75.144326,8.499183,74.626866,97.959245,...,99.0,263.449664,2.361424,263.0,295.919463,2.341676,296.0,0.0,0.0,0.0
8233,0.597420,0.022094,0.596290,804.925676,58.814998,804.0,75.129345,8.361182,74.626866,96.920298,...,94.0,155.228188,89.717313,116.0,183.577181,89.811524,144.0,0.0,0.0,0.0


In [7]:
y

Unnamed: 0,myocardial infarction,healthy control,valvular heart disease,dysrhythmia,heart failure (nyha 2),heart failure (nyha 3),heart failure (nyha 4),palpitation,cardiomyopathy,stable angina,hypertrophy,bundle branch block,unstable angina,myocarditis
0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8230,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8231,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8232,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8233,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


### Splitting the dataset

In [8]:
# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

### Scaling the data

In [9]:
# Create a pipeline that first imputes missing values and then scales the data
pipeline = Pipeline([
    # You can choose 'median' or 'most_frequent' as well
    ('imputer', SimpleImputer(strategy='mean')),
    ('scaler', StandardScaler())
])

X_train_scaled = pipeline.fit_transform(X_train)
X_test_scaled = pipeline.transform(X_test)

## Model Training

In [10]:
def get_classification_report(y_pred):
    for i, disease in enumerate(y.columns):
        print(f"Evaluating model for disease: {disease}")
        print(f"Accuracy: {accuracy_score(y_test.iloc[:, i], y_pred[:, i])}")
        print("Classification Report:")
        print(classification_report(
            y_test.iloc[:, i], y_pred[:, i], zero_division=0)
        )
        print("\n")

### Sequential Model

In [11]:
# Define the neural network model
sequential_model = Sequential([
    Input(shape=(X_train_scaled.shape[1],)),
    Dense(256, activation='relu'),  # Increased number of neurons
    Dropout(0.5),                   # Added dropout for regularization
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(64, activation='relu'),
    Dense(y_train.shape[1], activation='sigmoid')
])


# Compile the model
sequential_model.compile(optimizer=Adam(
    learning_rate=0.001), loss='binary_crossentropy', metrics=['accuracy'])

# Train the model
history = sequential_model.fit(
    X_train_scaled, y_train, epochs=100, batch_size=64, validation_split=0.2)

# Evaluate the model
evaluation = sequential_model.evaluate(X_test_scaled, y_test)
print(f"Test loss: {evaluation[0]}, Test accuracy: {evaluation[1]}")

# Optionally, plot training history

# plt.plot(history.history['accuracy'])
# plt.plot(history.history['val_accuracy'])
# plt.title('Model accuracy')
# plt.ylabel('Accuracy')
# plt.xlabel('Epoch')
# plt.legend(['Train', 'Validation'], loc='upper left')
# plt.show()

# plt.plot(history.history['loss'])
# plt.plot(history.history['val_loss'])
# plt.title('Model loss')
# plt.ylabel('Loss')
# plt.xlabel('Epoch')
# plt.legend(['Train', 'Validation'], loc='upper left')
# plt.show()

Epoch 1/100
[1m83/83[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - accuracy: 0.3975 - loss: 0.3929 - val_accuracy: 0.7041 - val_loss: 0.1309
Epoch 2/100
[1m83/83[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.7306 - loss: 0.1305 - val_accuracy: 0.7041 - val_loss: 0.1230
Epoch 3/100
[1m83/83[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.7262 - loss: 0.1249 - val_accuracy: 0.7132 - val_loss: 0.1177
Epoch 4/100
[1m83/83[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.7265 - loss: 0.1225 - val_accuracy: 0.7215 - val_loss: 0.1146
Epoch 5/100
[1m83/83[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.7097 - loss: 0.1230 - val_accuracy: 0.7291 - val_loss: 0.1126
Epoch 6/100
[1m83/83[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.7288 - loss: 0.1179 - val_accuracy: 0.7276 - val_loss: 0.1123
Epoch 7/100
[1m83/83[0m [32m━━━

In [12]:
df

Unnamed: 0,age,sex,smoker,systolic_blood_pressure,diastolic_blood_pressure,bundle branch block,cardiomyopathy,dysrhythmia,healthy control,heart failure (nyha 2),...,median_qrs_durations,mean_t_durations,std_t_durations,median_t_durations,mean_qt_intervals,std_qt_intervals,median_qt_intervals,vlf_power,lf_power,hf_power
0,81.0,1.0,0.0,140.0,80.0,0.0,0.0,0.0,0.0,0.0,...,87.0,277.865385,6.348921,279.0,306.423077,6.298187,307.0,0.0,0.0,0.0
1,81.0,1.0,0.0,140.0,80.0,0.0,0.0,0.0,0.0,0.0,...,38.0,103.884615,14.653240,100.5,121.750000,18.770758,123.0,0.0,0.0,0.0
2,81.0,1.0,0.0,140.0,80.0,0.0,0.0,0.0,0.0,0.0,...,14.0,27.269231,0.879584,27.0,41.288462,1.214283,42.0,0.0,0.0,0.0
3,81.0,1.0,0.0,140.0,80.0,0.0,0.0,0.0,0.0,0.0,...,1.0,57.596154,1.348213,57.5,58.615385,1.346703,59.0,0.0,0.0,0.0
4,81.0,1.0,0.0,140.0,80.0,0.0,0.0,0.0,0.0,0.0,...,90.0,243.923077,18.162563,237.0,273.538462,36.812655,297.5,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8230,61.0,1.0,,,,0.0,0.0,0.0,0.0,0.0,...,100.0,249.807339,8.996918,254.0,281.513761,9.046886,285.0,0.0,0.0,0.0
8231,61.0,1.0,,,,0.0,0.0,0.0,0.0,0.0,...,101.0,267.342282,3.644932,268.0,299.201342,3.656345,299.0,0.0,0.0,0.0
8232,61.0,1.0,,,,0.0,0.0,0.0,0.0,0.0,...,99.0,263.449664,2.361424,263.0,295.919463,2.341676,296.0,0.0,0.0,0.0
8233,61.0,1.0,,,,0.0,0.0,0.0,0.0,0.0,...,94.0,155.228188,89.717313,116.0,183.577181,89.811524,144.0,0.0,0.0,0.0


In [13]:
# Predict the probabilities on the test set
y_pred_prob = sequential_model.predict(X_test_scaled)

# Binarize the predictions with a threshold of 0.5
y_pred = (y_pred_prob > 0.5).astype(int)

# Print the classification report
# print(classification_report(y_test, y_pred, target_names=[i for i in diseases], zero_division=0))
get_classification_report(y_pred)

[1m52/52[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
Evaluating model for disease: myocardial infarction
Accuracy: 0.8585306618093503
Classification Report:
              precision    recall  f1-score   support

         0.0       0.81      0.75      0.78       544
         1.0       0.88      0.91      0.90      1103

    accuracy                           0.86      1647
   macro avg       0.84      0.83      0.84      1647
weighted avg       0.86      0.86      0.86      1647



Evaluating model for disease: healthy control
Accuracy: 0.9101396478445659
Classification Report:
              precision    recall  f1-score   support

         0.0       0.94      0.95      0.95      1420
         1.0       0.68      0.65      0.67       227

    accuracy                           0.91      1647
   macro avg       0.81      0.80      0.81      1647
weighted avg       0.91      0.91      0.91      1647



Evaluating model for disease: valvular heart disease
Accuracy: 0.98

### Random Forest Model

In [14]:
# Initialize the base model
base_model_random_forest = RandomForestClassifier(random_state=42)

# Initialize the multi-output classifier
random_forest_model = MultiOutputClassifier(
    base_model_random_forest, n_jobs=-1
)

In [15]:
# Train the model
random_forest_model.fit(X_train_scaled, y_train)

# Predict on the test set
y_pred = random_forest_model.predict(X_test_scaled)

get_classification_report(y_pred)

Evaluating model for disease: myocardial infarction
Accuracy: 0.9483910139647844
Classification Report:
              precision    recall  f1-score   support

         0.0       0.96      0.88      0.92       544
         1.0       0.94      0.98      0.96      1103

    accuracy                           0.95      1647
   macro avg       0.95      0.93      0.94      1647
weighted avg       0.95      0.95      0.95      1647



Evaluating model for disease: healthy control
Accuracy: 0.9672131147540983
Classification Report:
              precision    recall  f1-score   support

         0.0       0.96      1.00      0.98      1420
         1.0       0.99      0.77      0.87       227

    accuracy                           0.97      1647
   macro avg       0.98      0.88      0.92      1647
weighted avg       0.97      0.97      0.97      1647



Evaluating model for disease: valvular heart disease
Accuracy: 0.9957498482088646
Classification Report:
              precision    recall  

### SVM Model

In [16]:
# Initialize the base model
base_model_svm = SVC(random_state=42)

# Initialize the multi-output classifier
svm_model = MultiOutputClassifier(base_model_svm, n_jobs=-1)

In [17]:
# Train the model
svm_model.fit(X_train_scaled, y_train)

# Predict on the test set
y_pred = svm_model.predict(X_test_scaled)

get_classification_report(y_pred)

Evaluating model for disease: myocardial infarction
Accuracy: 0.7941712204007286
Classification Report:
              precision    recall  f1-score   support

         0.0       0.78      0.52      0.63       544
         1.0       0.80      0.93      0.86      1103

    accuracy                           0.79      1647
   macro avg       0.79      0.72      0.74      1647
weighted avg       0.79      0.79      0.78      1647



Evaluating model for disease: healthy control
Accuracy: 0.8894960534304797
Classification Report:
              precision    recall  f1-score   support

         0.0       0.89      0.99      0.94      1420
         1.0       0.83      0.25      0.39       227

    accuracy                           0.89      1647
   macro avg       0.86      0.62      0.66      1647
weighted avg       0.88      0.89      0.86      1647



Evaluating model for disease: valvular heart disease
Accuracy: 0.9860352155434122
Classification Report:
              precision    recall  

### kNN Model

In [18]:
# Initialize the base model
base_model_knn = KNeighborsClassifier()

# Initialize the multi-output classifier
knn_model = MultiOutputClassifier(base_model_knn, n_jobs=-1)

In [19]:
# Train the model
knn_model.fit(X_train_scaled, y_train)

# Predict on the test set
y_pred = knn_model.predict(X_test_scaled)

get_classification_report(y_pred)

Evaluating model for disease: myocardial infarction
Accuracy: 0.8099574984820886
Classification Report:
              precision    recall  f1-score   support

         0.0       0.76      0.62      0.68       544
         1.0       0.83      0.91      0.86      1103

    accuracy                           0.81      1647
   macro avg       0.80      0.76      0.77      1647
weighted avg       0.81      0.81      0.80      1647



Evaluating model for disease: healthy control
Accuracy: 0.9028536733454766
Classification Report:
              precision    recall  f1-score   support

         0.0       0.93      0.96      0.94      1420
         1.0       0.69      0.54      0.60       227

    accuracy                           0.90      1647
   macro avg       0.81      0.75      0.77      1647
weighted avg       0.90      0.90      0.90      1647



Evaluating model for disease: valvular heart disease
Accuracy: 0.9860352155434122
Classification Report:
              precision    recall  

### Decision Tree Model

In [20]:
# Initialize the base model
base_model_decision_tree = DecisionTreeClassifier()

# Initialize the multi-output classifier
decision_tree_model = MultiOutputClassifier(base_model_decision_tree, n_jobs=-1)

In [21]:
# Train the model
decision_tree_model.fit(X_train_scaled, y_train)

# Predict on the test set
y_pred = decision_tree_model.predict(X_test_scaled)

get_classification_report(y_pred)

Evaluating model for disease: myocardial infarction
Accuracy: 0.9143897996357013
Classification Report:
              precision    recall  f1-score   support

         0.0       0.86      0.88      0.87       544
         1.0       0.94      0.93      0.94      1103

    accuracy                           0.91      1647
   macro avg       0.90      0.91      0.90      1647
weighted avg       0.91      0.91      0.91      1647



Evaluating model for disease: healthy control
Accuracy: 0.949605343047966
Classification Report:
              precision    recall  f1-score   support

         0.0       0.97      0.97      0.97      1420
         1.0       0.82      0.82      0.82       227

    accuracy                           0.95      1647
   macro avg       0.89      0.89      0.89      1647
weighted avg       0.95      0.95      0.95      1647



Evaluating model for disease: valvular heart disease
Accuracy: 0.994535519125683
Classification Report:
              precision    recall  f1

## Let's test the models

In [22]:
X_test

Unnamed: 0,mean_r_peaks_amplitude,std_r_peaks_amplitude,median_r_peaks_amplitude,mean_rr_intervals,std_rr_intervals,median_rr_intervals,mean_heart_rate,std_heart_rate,median_heart_rate,rmssd,...,median_qrs_durations,mean_t_durations,std_t_durations,median_t_durations,mean_qt_intervals,std_qt_intervals,median_qt_intervals,vlf_power,lf_power,hf_power
706,0.199892,0.034022,0.198356,952.700000,22.366865,952.0,63.013546,1.476507,63.025280,15.374321,...,114.0,298.140496,94.015148,346.0,334.545455,94.485645,379.0,0.0,0.0,0.0
6025,1.190115,0.036537,1.190729,873.335766,31.072150,872.0,68.788394,2.428198,68.807339,32.696465,...,60.0,232.528986,1.357705,233.0,266.137681,1.425401,266.0,0.0,0.0,0.0
7666,0.451657,0.028745,0.450948,1091.761468,29.799298,1092.0,54.998276,1.510978,54.945055,17.144808,...,86.0,321.818182,1.440845,322.0,385.736364,5.169275,385.0,0.0,0.0,0.0
6954,1.740954,0.045024,1.741112,992.116667,69.040228,996.0,60.782753,4.421877,60.241207,35.678537,...,96.0,309.115702,17.752285,311.0,371.950413,19.526566,377.0,0.0,0.0,0.0
5602,-0.385131,0.026491,-0.390854,685.089820,51.629939,680.0,87.856592,3.768368,88.235294,72.393545,...,4.0,182.857143,3.873422,183.0,183.857143,3.873422,184.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8058,-0.599309,0.029243,-0.591962,798.744966,248.164514,783.0,81.703591,22.899216,76.628352,377.445718,...,3.0,34.986667,1.492142,35.0,37.993333,1.547026,38.0,0.0,0.0,0.0
296,0.564520,0.023634,0.564430,1012.247788,85.442475,1038.0,59.791789,6.149660,57.803468,132.716837,...,106.0,283.166667,7.059791,284.0,310.807018,6.834994,311.0,0.0,0.0,0.0
7716,-1.035530,0.028232,-1.039479,868.028986,15.694521,868.0,69.144655,1.246930,69.124424,11.804366,...,2.0,240.956835,25.811400,246.0,242.661871,25.772540,248.0,0.0,0.0,0.0
57,-0.076566,0.340947,-0.371246,760.853333,30.227225,764.0,78.985041,3.181118,78.534031,34.396426,...,2.0,131.384106,17.065073,129.0,146.059603,30.413972,131.0,0.0,0.0,0.0


In [62]:
# Step 1: Select a random row from the DataFrame
random_row = df.sample()

# Step 2: Extract features (X) from the selected row
# Assuming disease_columns contains the disease labels
X_sample = random_row.drop(columns=[*diseases, 'age', 'sex',
                                    'smoker', 'systolic_blood_pressure', 'diastolic_blood_pressure'])
# X_sample = random_row

# Fit and transform the preprocessing pipeline on the selected row
X_sample_preprocessed = pipeline.transform(X_sample)

# Step 3: Make predictions using each trained model
predictions = {}

# Sequential Model
sequential_prediction = sequential_model.predict(X_sample)
predictions['Sequential'] = sequential_prediction

# Random Forest Model
random_forest_prediction = random_forest_model.predict(X_sample_preprocessed)
predictions['Random Forest'] = random_forest_prediction

# SVM Model
svm_prediction = svm_model.predict(X_sample_preprocessed)
predictions['SVM'] = svm_prediction

# KNN Model
knn_prediction = knn_model.predict(X_sample_preprocessed)
predictions['KNN'] = knn_prediction

# Decision Tree Model
decision_tree_prediction = decision_tree_model.predict(X_sample_preprocessed)
predictions['Decision Tree'] = decision_tree_prediction

# Step 4: Interpret the predictions
print(f"Random Row:\n{random_row[diseases]}")

print(f"\nCorrect output: {random_row[diseases].to_numpy()}")

print("\nPredictions:")
for model, prediction in predictions.items():
    print(f"{model} Prediction:", prediction)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step


Random Row:
      myocardial infarction  healthy control  valvular heart disease  \
5497                    0.0              0.0                     0.0   

      dysrhythmia  heart failure (nyha 2)  heart failure (nyha 3)  \
5497          1.0                     0.0                     0.0   

      heart failure (nyha 4)  palpitation  cardiomyopathy  stable angina  \
5497                     0.0          0.0             0.0            0.0   

      hypertrophy  bundle branch block  unstable angina  myocarditis  
5497          0.0                  0.0              0.0          0.0  

Correct output: [[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

Predictions:
Sequential Prediction: [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]]
Random Forest Prediction: [[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
SVM Prediction: [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
KNN Prediction: [[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
Decision Tree Prediction: [[0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 

In [52]:
# Out of all these models, the random forest model is working the best, so we will select the random forest model as the final model
model = random_forest_model

In [53]:
joblib.dump(model, './models/random_forest_model.joblib')

['./models/random_forest_model.joblib']

In [54]:
loaded_model = joblib.load('./models/random_forest_model.joblib')

In [55]:
loaded_model

In [56]:
loaded_model.predict(X_sample_preprocessed)

array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])