In [2]:
# notebooks/03_feature_selection.py
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import RFE, chi2, SelectKBest
import pandas as pd
import numpy as np

df = pd.read_csv(r'C:\Users\pc\Documents\Basel BME\Programming\Python\SPRINTS Heart Disease Project\python\data\heart_disease_clean.csv')
X = df.drop('target',axis=1)
y = df['target']

# RandomForest importance
rf = RandomForestClassifier(random_state=42)
rf.fit(X, y)
importances = pd.Series(rf.feature_importances_, index=X.columns).sort_values(ascending=False)
print(importances.head(15))

# RFE with LogisticRegression as example
from sklearn.linear_model import LogisticRegression
lr = LogisticRegression(max_iter=1000)
rfe = RFE(lr, n_features_to_select=8)
rfe.fit(X, y)
rfe_support = pd.Series(rfe.support_, index=X.columns)
print("RFE selected:", rfe_support[rfe_support].index.tolist())

# Chi-Square on non-negative features (if needed convert to positive)
from sklearn.preprocessing import MinMaxScaler
X_pos = MinMaxScaler().fit_transform(X)
chi2selector = SelectKBest(chi2, k=8)
chi2selector.fit(X_pos, y)
chi2_support = pd.Series(chi2selector.get_support(), index=X.columns)
print("Chi2 selected:", chi2_support[chi2_support].index.tolist())

# Decide final features (example combine results)
selected = list(importances.head(12).index)
print("Selected features:", selected)


thalach      0.102121
thal_3.0     0.102089
oldpeak      0.091865
ca_0.0       0.088654
age          0.080317
cp_4.0       0.080056
chol         0.069603
trestbps     0.066766
thal_7.0     0.055302
exang        0.043825
slope_1.0    0.038268
sex          0.026189
cp_3.0       0.024937
slope_2.0    0.021806
ca_1.0       0.015859
dtype: float64
RFE selected: ['oldpeak', 'cp_3.0', 'cp_4.0', 'restecg_0.0', 'slope_2.0', 'thal_3.0', 'thal_7.0', 'ca_0.0']
Chi2 selected: ['exang', 'cp_3.0', 'cp_4.0', 'slope_1.0', 'slope_2.0', 'thal_3.0', 'thal_7.0', 'ca_0.0']
Selected features: ['thalach', 'thal_3.0', 'oldpeak', 'ca_0.0', 'age', 'cp_4.0', 'chol', 'trestbps', 'thal_7.0', 'exang', 'slope_1.0', 'sex']
