In [51]:
import pandas as pd

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, ConfusionMatrixDisplay, confusion_matrix
from sklearn.preprocessing import LabelBinarizer

from matplotlib import pyplot as plt

Фиксируем random_state для повторяемости эксперимента

In [28]:
random_state = 42

About Dataset

The Diabetes prediction dataset is a collection of medical and demographic data from patients, along with their diabetes status (positive or negative). The data includes features such as age, gender, body mass index (BMI), hypertension, heart disease, smoking history, HbA1c level, and blood glucose level. This dataset can be used to build machine learning models to predict diabetes in patients based on their medical history and demographic information. This can be useful for healthcare professionals in identifying patients who may be at risk of developing diabetes and in developing personalized treatment plans. Additionally, the dataset can be used by researchers to explore the relationships between various medical and demographic factors and the likelihood of developing diabetes.

In [29]:
data = pd.read_csv('diabetes_prediction_dataset.csv')

In [30]:
data.sample(10, random_state=random_state)

Unnamed: 0,gender,age,hypertension,heart_disease,smoking_history,bmi,HbA1c_level,blood_glucose_level,diabetes
75721,Female,13.0,0,0,No Info,20.82,5.8,126,0
80184,Female,3.0,0,0,No Info,21.0,5.0,145,0
19864,Male,63.0,0,0,former,25.32,3.5,200,0
76699,Female,2.0,0,0,never,17.43,6.1,126,0
92991,Female,33.0,0,0,not current,40.08,6.2,200,1
76434,Female,70.0,0,0,never,23.89,6.5,200,0
84004,Female,51.0,0,0,current,27.32,5.0,158,0
80917,Female,12.0,0,0,No Info,27.32,4.8,158,0
60767,Female,45.0,0,0,No Info,27.32,6.2,145,0
50074,Female,19.0,0,0,former,27.32,6.2,90,0


In [31]:
data.describe(include='all')

Unnamed: 0,gender,age,hypertension,heart_disease,smoking_history,bmi,HbA1c_level,blood_glucose_level,diabetes
count,100000,100000.0,100000.0,100000.0,100000,100000.0,100000.0,100000.0,100000.0
unique,3,,,,6,,,,
top,Female,,,,No Info,,,,
freq,58552,,,,35816,,,,
mean,,41.885856,0.07485,0.03942,,27.320767,5.527507,138.05806,0.085
std,,22.51684,0.26315,0.194593,,6.636783,1.070672,40.708136,0.278883
min,,0.08,0.0,0.0,,10.01,3.5,80.0,0.0
25%,,24.0,0.0,0.0,,23.63,4.8,100.0,0.0
50%,,43.0,0.0,0.0,,27.32,5.8,140.0,0.0
75%,,60.0,0.0,0.0,,29.58,6.2,159.0,0.0


Пропусков не видно, но они есть (smoking_history: No Info)

In [32]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 100000 entries, 0 to 99999
Data columns (total 9 columns):
 #   Column               Non-Null Count   Dtype  
---  ------               --------------   -----  
 0   gender               100000 non-null  object 
 1   age                  100000 non-null  float64
 2   hypertension         100000 non-null  int64  
 3   heart_disease        100000 non-null  int64  
 4   smoking_history      100000 non-null  object 
 5   bmi                  100000 non-null  float64
 6   HbA1c_level          100000 non-null  float64
 7   blood_glucose_level  100000 non-null  int64  
 8   diabetes             100000 non-null  int64  
dtypes: float64(3), int64(4), object(2)
memory usage: 6.9+ MB


Preprocessing

Смело удаляю дубликаты, т.к. возможно совпадение возроста и пола у разных людей, но не индекса массы тела, уровня гликированного гемоглобина и уровня глюкозы в крови одновременно.

In [33]:
data.drop_duplicates(inplace=True)

Иных заменяем на самый часто встречающийся пол - Female, т.к. количество замен = 18.

In [36]:
def encoder(tmp):
    if tmp == 'Male':
        return 0
    else:
        return 1

In [37]:
data['gender'] = data['gender'].transform(encoder)

In [38]:
# ==========
data.drop(columns='smoking_history', inplace=True)

Разделим данные на тренировочную и тестовую выборки (для валидации модель сама выделит кусок из тренировочного)

In [116]:
x_train, x_test, y_train, y_test = train_test_split(
    data.drop(columns='diabetes'),
    data['diabetes'],
    test_size=0.25,
    random_state=random_state
)

In [92]:
clf = RandomForestClassifier(class_weight='balanced')

In [93]:
clf.fit(x_train, y_train)

In [94]:
print(classification_report(y_train, clf.predict(x_train)))

              precision    recall  f1-score   support

           0       1.00      0.99      1.00     65794
           1       0.95      1.00      0.97      6315

    accuracy                           0.99     72109
   macro avg       0.97      1.00      0.98     72109
weighted avg       1.00      0.99      0.99     72109



In [95]:
print(classification_report(y_test, clf.predict(x_test)))

              precision    recall  f1-score   support

           0       0.97      0.99      0.98     21870
           1       0.88      0.69      0.77      2167

    accuracy                           0.96     24037
   macro avg       0.92      0.84      0.88     24037
weighted avg       0.96      0.96      0.96     24037



In [71]:
confusion_matrix = confusion_matrix(list(y_test), clf.predict(x_test))
ConfusionMatrixDisplay(confusion_matrix, display_labels=[0, 1]).plot()

plt.show()

TypeError: 'numpy.ndarray' object is not callable

In [54]:
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

In [59]:
clf1 = DecisionTreeClassifier(max_depth=10, min_samples_leaf=4)
clf1.fit(
    data.drop(columns='diabetes'),
    data['diabetes']
)

In [60]:
cnt = 0

In [None]:
cnt += 1
fig = plt.figure(figsize=(25,20))
_ = tree.plot_tree(clf1,
                   feature_names=data.columns[:-1],  
                   class_names=data.columns[-1],
                   filled=True)
fig.savefig(f'decisionTree{cnt}.png', dpi=600)

In [64]:
data.drop(columns=['gender', 'heart_disease'], inplace=True)

In [109]:
data = tmp.copy()

In [72]:
from sklearn.model_selection import RandomizedSearchCV

In [75]:
params = {
        'criterion': ["gini", "entropy", "log_loss"],
        'max_depth': list(range(5, 10)),
        'min_samples_leaf': list(range(4, 8)),
        'bootstrap': [True, False],
}
clf = RandomizedSearchCV(estimator=RandomForestClassifier(), param_distributions=params, random_state=random_state)

In [76]:
clf.fit(x_train, y_train)

In [78]:
print(classification_report(y_test, clf.predict(x_test)))


              precision    recall  f1-score   support

           0       0.97      1.00      0.98     21870
           1       1.00      0.67      0.80      2167

    accuracy                           0.97     24037
   macro avg       0.98      0.84      0.89     24037
weighted avg       0.97      0.97      0.97     24037



In [105]:
from sklearn.linear_model import LogisticRegressionCV

In [117]:
clf = LogisticRegressionCV(random_state=random_state)

In [118]:
clf.fit(x_train, y_train)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

In [119]:
print(classification_report(y_test, clf.predict(x_test)))

              precision    recall  f1-score   support

           0       0.96      0.99      0.98     21870
           1       0.86      0.63      0.73      2167

    accuracy                           0.96     24037
   macro avg       0.91      0.81      0.85     24037
weighted avg       0.95      0.96      0.95     24037

