# CDC Diabetes Health Indicators

**Task:** This midterm project aims to build a service that predicts whether a patient has diabetes, is pre-diabetic, or healthy using "Diabetes Health Indicators Dataset" provided by the CDC. 


🔗 Dataset page: [CDC Diabetes Health Indicators](https://archive.ics.uci.edu/dataset/891/cdc+diabetes+health+indicators)

## Summary of dataset information

Information provided on the dataset page (see link previous section):

**Dataset information**
- The "Diabetes Health Indicators Dataset" is available on the UCI Machine Learning Repository.
- Created to to better understand the relationship between lifestyle and diabetes in the US.
- The CDC funded the creation of the dataset. 
- Cross validation or a fixed train-test split could be used for data splits. 
- The dataset contains sensitive data such as gender, income, and education level. 
- The dataset contains **21 feature variables (categorical and integer)** and **1 target variable (binary)**.
- Each row represents a person participating in the study. 
- Data preprocessing was performed by bucketing of age. The dataset has no missing values. 


**Quoted from the dataset page**<br>

> *"The Diabetes Health Indicators Dataset contains healthcare statistics and lifestyle survey information about people in general along with their diagnosis of diabetes. The 35 features consist of some demographics, lab test results, and answers to survey questions for each patient. The target variable for classification is whether a patient has diabetes, is pre-diabetic, or healthy."*

**Remark on the quote above**<br>
The quote states that the dataset contains 35 features. However, the dataset page further states

>|  | Information |
>| :--- | :--- |
>| Dataset Characteristics | Tabular, Multivariate |
>| Subject Area | Life Science |
>| Associated Tasks | Classification |
>| Feature Type | Categorical, Integer |
>| \# Instances | 253680 |
>| \# Features | 21 |

💡We will check this discrepancy in when digging into the dataset.

## Downloading the dataset

**Download of the dataset is provided via**

1. Python API using the `ucimlrepo` package.
    - Code provided by the dataset page:

        ```python
        from ucimlrepo import fetch_ucirepo 
        # fetch dataset 
        cdc_diabetes_health_indicators = fetch_ucirepo(id=891) 
        # data (as pandas dataframes) 
        X = cdc_diabetes_health_indicators.data.features 
        y = cdc_diabetes_health_indicators.data.targets 
        # metadata 
        print(cdc_diabetes_health_indicators.metadata) 
        # variable information 
        print(cdc_diabetes_health_indicators.variables) 
        ```
    - The metadata of the downloaded dataset, using the code above, includes a download link for the dataset in CSV format. However, no additional information is provided along with the CSV file:<br>
    https://archive.ics.uci.edu/static/public/891/data.csv
    
1. On the project page [CDC Diabetes Health Indicators](https://archive.ics.uci.edu/dataset/891/cdc+diabetes+health+indicators) there is a reference to the dataset source which redirects to a [Kaggle dataset](https://www.kaggle.com/datasets/alexteboul/diabetes-health-indicators-dataset)

<br>


**Data download used in this project**

- ✅ In this project the `ucimlrepo` will be used to download the dataset and all relevant metadata. 

- 💾 The dataset's dataframe as well as all other relevant data will be stored in the [./data_cdc_diabetes_health_indicators](data_cdc_diabetes_health_indicators) folder locally. 

- 💡 Downloading and then reusing the downloaded data instead of redownloading it using the `ucimlrepo` package done to ensure reproducibility of the project in case the dataset is not available anymore or changes over time.


**Further information on the dataset**

ℹ️ More information about the features will be revealed after downloading the dataset and revealing the dataset's metadata.

---

In [18]:
## Importing libraries
import os
import json
import pandas as pd
import numpy as np
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from ucimlrepo import fetch_ucirepo 
from sklearn.model_selection import train_test_split
from IPython.utils.capture import capture_output
from IPython.display import display

from lib.eda_functions import plot_boxplots, plot_histograms, plot_histograms, plot_boxplots_normalize_df

In [19]:
path_dir = '/workspace/data_cdc_diabetes_health_indicators'
# data set id of the cdc diabetes health indicators
dataset_id = 891
file_name = f'data_{dataset_id}.csv'
path_data_csv = os.path.join(path_dir, file_name)

💡 The variables below have only been set to `True` for the first run to download all relevent data in the subsequent cell. Afterwards they can be set to `False` to avoid downloading the data again. For reproducing the here presented results please use the default values using the already downloaded files.

In [None]:
# Data has already been downloaded and saved as json and csv.
# Therefore both variable can be set to False.
# Value True was only used for initial download.

# modify only in case of downloading again
download = False
# only used for writing additional information to local files 
write_files = False

In [None]:
if download:
    cdc_diabetes_health_indicators = fetch_ucirepo(id=dataset_id)
    
    # create dataframe from data and targets
    X = cdc_diabetes_health_indicators.data.features
    y = cdc_diabetes_health_indicators.data.targets
    ids = cdc_diabetes_health_indicators.data.ids
    df = pd.concat([ids, y, X], axis=1)
    df.set_index('ID', inplace=True)
    df = pd.concat([ids, y, X], axis=1)

    # read variables
    variables = cdc_diabetes_health_indicators.variables

    print("keys:\n", cdc_diabetes_health_indicators.keys())
    print("metadata keys:\n", cdc_diabetes_health_indicators.metadata)
    display(cdc_diabetes_health_indicators.variables)
    if write_files:
        df.to_csv(path_data_csv, index=False)
        cdc_diabetes_health_indicators.variables.to_csv(
            os.path.join(path_dir, 'variables.csv'), index=False)
        # write the following to a json file
        json_data = {}
        json_data['keys'] = list(cdc_diabetes_health_indicators.keys())
        json_data['metadata_keys'] = list(cdc_diabetes_health_indicators.metadata.keys())
        json_data['data'] = {'headers': cdc_diabetes_health_indicators.data.headers.tolist(),}
        json_data['variables'] = 'see ./variables.csv for more information' 
        json_data['metadata'] = dict(cdc_diabetes_health_indicators.metadata)

        with open(os.path.join(path_dir, 'cdc_diabetes_health_indicators.json'), 'w') as f:
            f.write(json.dumps(json_data, indent=2))
else:
    df = pd.read_csv(path_data_csv, index_col='ID')
    variables = pd.read_csv(os.path.join(path_dir, 'variables.csv'))


In [21]:
df.head()

Unnamed: 0_level_0,Diabetes_binary,HighBP,HighChol,CholCheck,BMI,Smoker,Stroke,HeartDiseaseorAttack,PhysActivity,Fruits,...,AnyHealthcare,NoDocbcCost,GenHlth,MentHlth,PhysHlth,DiffWalk,Sex,Age,Education,Income
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0,1,1,1,40,1,0,0,0,0,...,1,0,5,18,15,1,0,9,4,3
1,0,0,0,0,25,1,0,0,1,0,...,0,1,3,0,0,0,0,7,6,1
2,0,1,1,1,28,0,0,0,0,1,...,1,1,5,30,30,1,0,9,4,8
3,0,1,0,1,27,0,0,0,1,1,...,1,0,2,0,0,0,0,11,3,6
4,0,1,1,1,24,0,0,0,1,1,...,1,0,2,3,0,0,0,11,5,4


---
## Information about the dataset

### Information about dataset columns

From the information it can be seen that the target column is `Diabetes_binary`.

In [None]:
display(variables)

### Variables (Target and Features)

| ID | Type | Description | 
| --- | --- | --- |
| ID | Integer | Patient ID |

<br>

| Target | Type | Description | 
| --- | --- | --- |
| Diabetes_binary | Binary | 0 = no diabetes<br>1 = prediabetes or diabetes |

Features in are sorted in the table below using their data type:
- Integer
- Categrical
- Binary

> Information for binary features (except for feature `Sex`):
> - `0` = `no` 
> - `1` = `yes`

| Features | Type | Description | 
| --- | --- | --- |
| BMI | Integer | Body Mass Index |
| MentHlth | Integer | Now thinking about your mental health, which includes stress, depression, and problems with emotions, for how many days during the past 30 days was your mental health not good? scale 1-30 days |
| PhysHlth | Integer | Now thinking about your physical health, which includes physical illness and injury, for how many days during the past 30 days was your physical health not good? scale 1-30 days |
|  |  |  |
| GenHlth | Integer (Categorical) | Would you say that in general your health is: scale 1-5<br>1 = excellent<br>2 = very good<br> 3 = good<br> 4 = fair<br> 5 = poor |
| Age | Integer (Categorical) | Age,13-level age category (_AGEG5YR see codebook)<br>1 = 18-24<br>9 = 60-64<br> 13 = 80 or older |
| Education | Integer (Categorical) | Education level (EDUCA see codebook) scale 1-6<br>1 = Never attended school or only kindergarten<br>2 = Grades 1 through 8 (Elementary)<br>3 = Grades 9 through 11 (Some high school)<br>4 = Grade 12 or GED (High school graduate)<br>5 = College 1 year to 3 years (Some college or technical school)<br>6 = College 4 years or more (College graduate) |
| Income | Integer (Categorical) | Income scale (INCOME2 see codebook) scale 1-8<br> 1 = less than $10,000<br> 5 = less than $35,000<br> 8 = $75,000 or more" |
|  |  |  |
| Sex | Binary | Sex, 0 = female 1 = male |
| HighBP | Binary | High blood preasure |
| HighChol | Binary | High cholesterol |
| CholCheck | Binary | Cholesterol check in 5 years |
| Smoker | Binary | Have you smoked at least 100 cigarettes in your entire life? [Note: 5 packs = 100 cigarettes] |
| Stroke | Binary | (Ever told) you had a stroke. |
| HeartDiseaseorAttack | Binary | Coronary heart disease (CHD) or myocardial infarction (MI) |
| PhysActivity | Binary | Physical activity in past 30 days - not including job< |
| Fruits | Binary | Consume Fruit 1 or more times per day |
| Veggies | Binary | Consume Vegetables 1 or more times per day |
| HvyAlcoholConsump | Binary | Heavy drinkers (adult men having more than 14 drinks per week and adult women having more than 7 drinks per week)|
| AnyHealthcare | Binary | "Have any kind of health care coverage, including health insurance, prepaid plans such as HMO, etc. |
| NoDocbcCost | Binary | Was there a time in the past 12 months when you needed to see a doctor but could not because of cost? |
| ffWalk | Binary | Do you have serious difficulty walking or climbing stairs? |




In [20]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 253680 entries, 0 to 253679
Data columns (total 22 columns):
 #   Column                Non-Null Count   Dtype
---  ------                --------------   -----
 0   Diabetes_binary       253680 non-null  int64
 1   HighBP                253680 non-null  int64
 2   HighChol              253680 non-null  int64
 3   CholCheck             253680 non-null  int64
 4   BMI                   253680 non-null  int64
 5   Smoker                253680 non-null  int64
 6   Stroke                253680 non-null  int64
 7   HeartDiseaseorAttack  253680 non-null  int64
 8   PhysActivity          253680 non-null  int64
 9   Fruits                253680 non-null  int64
 10  Veggies               253680 non-null  int64
 11  HvyAlcoholConsump     253680 non-null  int64
 12  AnyHealthcare         253680 non-null  int64
 13  NoDocbcCost           253680 non-null  int64
 14  GenHlth               253680 non-null  int64
 15  MentHlth              253680 non-null  

Datatypes for all columns is `int64`.

Creating varariables for the target name and columns containing different data types:
- feature_names_binary
- feature_names_categorical
- feature_names_integer

💡 Data type can stay `int64` for all columns. Going to apply OneHotEncoding to all categorical columns later. 

In [None]:
target_name = 'Diabetes_binary'
feature_names_binary = [
    'Sex', 
    'HighBP', 
    'HighChol', 
    'CholCheck', 
    'Smoker', 
    'Stroke', 
    'HeartDiseaseorAttack',
    'PhysActivity',
    'Fruits',
    'Veggies',
    'HvyAlcoholConsump',
    'AnyHealthcare',
    'NoDocbcCost',
    'DiffWalk',
]
feature_names_integer = ['BMI']
feature_names_categorical = ['GenHlth', 'MentHlth', 'PhysHlth', 'Age', 'Education', 'Income']

feature_names = feature_names_binary + feature_names_integer + feature_names_categorical

Using assertions to check if the data columns have been split correctly into `Binary`, `Categorical` and `Integer`.

- Checking that the sum of values in the `feature_<type>` lists is equal to the number of columns in the data frame
- Checking that the `feature_<type>` lists do not overlap
- Checking that the `feature_<type>` lists contain all columns of the data frame

In [None]:
# create assert in case column numbers do not match
len_all_features = len(feature_names_binary) + len(feature_names_integer) + len(feature_names_categorical)
assert df.columns.size == len_all_features + 1, \
    f'Number of columns in dataframe ({df.columns.size}) does not match the number of features ({len_all_features + 1})'

In [None]:
# create asssert in case names overlap
set_binary_intersect_integer = set(feature_names_binary).intersection(set(feature_names_integer))
set_binary_intersect_categorical = set(feature_names_binary).intersection(set(feature_names_categorical))
set_integer_intersect_categorical = set(feature_names_integer).intersection(set(feature_names_categorical))


assert target_name not in set_binary_intersect_integer, \
    f'Features overlap between binary and integer features: {set_binary_intersect_integer}'
assert target_name not in set_binary_intersect_categorical, \
    f'Features overlap between binary and categorical features: {set_binary_intersect_categorical}'
assert target_name not in set_integer_intersect_categorical, \
    f'Features overlap between integer and categorical features: {set_integer_intersect_categorical}'

assert set_binary_intersect_integer == set(), \
    f'Features overlap between binary and integer features: {set_binary_intersect_integer}'
assert set_binary_intersect_categorical == set(), \
    f'Features overlap between binary and categorical features: {set_binary_intersect_categorical}'
assert set_integer_intersect_categorical == set(), \
    f'Features overlap between integer and categorical features: {set_integer_intersect_categorical}'

In [None]:
# create assert in case target column is not in dataframe
set_diff = set(df.columns).difference(set([target_name]+feature_names))
assert set_diff == set(), \
    f'Dataframe contains columns not covered by the other variables: {set_diff}'

---

## Data Set Information

Checking dataset for
- unique values
- missing values
- duplicates

In [None]:
df.nunique()

It is said that the dataset has no missing values, but checking to be sure.

In [None]:
indices_issnull = df.isnull().any()
sum_of_issnull_cols = df.loc[:,indices_issnull].isnull().sum()
frac_of_isnull_cols = sum_of_issnull_cols/len(df)
print(f"Missing data - absolute values: ")
print("---")
print(sum_of_issnull_cols)
print("\n===\n")
print("Missing data - percent")
print("---")
print(frac_of_isnull_cols.round(decimals=3)*100)


✅ There is no missing data.

The patient ID, is used as dataframe index. We are now checkin if there are duplicate rows (not including the ID) and explain how we are going to handle them.

In [22]:
df_original = df.copy(deep=True)

In [23]:
indices_duplicates = df_original.duplicated(keep='first')
num_total = len(df_original)
num_duplicates = indices_duplicates.sum()
print(f"Total number of rows: {num_total:7d}")
print(f"Number of duplicates: {num_duplicates:7d}")
print(f"Percentage of duplicates: {num_duplicates/num_total:.2%}")

Total number of rows:  253680
Number of duplicates:   24206
Percentage of duplicates: 9.54%


**Oberservation:**  
There are almost 10% of duplicate rows in the dataset.

**Explanation:**
💡 Except for the `BMI` feature, which is only stored as `Integer`, the dataset consists of only `Binary` and `Categorical` features. This means that the duplicate rows are not duplicate patients, but different patients displaying the same values for all features due to the nature of the data types.

**Handling:**
We are now going to check, if the patient IDs are unique, comparing the index length with the number of unique patient IDs.

In [33]:
print(f"Index length equal df_original.index.size == df_original.index.nunique() 

True

In [None]:
df.drop_duplicates(keep='first', inplace=True)
print(f"Number of rows after dropping duplicates: {len(df):7d}")

---

## Splitting into Train and Test
- ℹ️ If there is no separate test datasset available, it is important to split the data into `test` and `train` (including `val`). Before analyzing the data any further.
    - 60% `train`
    - 20% `val`
    - 20% `test`

- 💡Using a seed point (`random_state`) to make sure the split is reproducible.

In [None]:
seed = 42
frac_train = 0.6
frac_val = 0.2
frac_test = 0.2
df_train, df_test = train_test_split(df, test_size=frac_test, random_state=seed, shuffle=True)
df_train, df_val = train_test_split(df_train, test_size=frac_val/frac_train, random_state=seed, shuffle=True)

overall_len = len(df_train) + len(df_val) + len(df_test)    
overall_len == len(df)

In [None]:
indices_issnull = df_train.isnull().any()
sum_of_issnull_cols = df_train.loc[:,indices_issnull].isnull().sum()
frac_of_isnull_cols = sum_of_issnull_cols/len(df_train)

---

## Exploratory Data Analysis (EDA) of the train split

As we must never investigate our `test` split, we are going to do the EDA on the `train` split.

In [None]:
ax = df_train.hist(figsize=(10, 7), bins=20)
plt.tight_layout()
plt.show()

⚠️ Regarding the target column `Diabetes_binary` it can be seen that the data is imbalanced. There are more healthy people than people with diabetes. This is important to keep in mind when evaluating the model.

Furthermore it can be seen that binary columns tend to have the vast amount of entries for one value. Furhtermore are the other variabels  not nomally distributed. This is important to keep in mind when choosing the model.

In [None]:
# correlation matrix, abbreviate the column names for better readability
corr_matrix = df_train.corr()
# set upper triangle to nan incluing the diagonal
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
corr_matrix = corr_matrix.mask(mask)
corr_matrix_short_names = corr_matrix.rename(columns=lambda x: f'{x[:5]}...' if len(x) > 8 else x)
display(corr_matrix_short_names.style.background_gradient(cmap='seismic', axis=None, vmin=-1, vmax=1))


In [None]:
# plot correlation matrix using seaborn
fig, ax = plt.subplots(figsize=(15,10))
sns.heatmap(corr_matrix, annot=True, ax=ax, cmap='seismic', fmt='.2f', vmin=-1, vmax=1)
plt.show()

Display features with highest pairwise correlation.

In [None]:
# list all pairs of columns by their abslute correlation value, add the absolute correlation value
corr_matrix_abs = corr_matrix.abs()

corr_matrix_abs = corr_matrix_abs.unstack()
corr_matrix_abs = corr_matrix_abs.sort_values(ascending=False)
# corr_matrix_abs = corr_matrix_abs[corr_matrix_abs != 1]
corr_matrix_abs = corr_matrix_abs.reset_index()
corr_matrix_abs.columns = ['feature_1', 'feature_2', 'abs_corr']
corr_matrix_abs.dropna(how='any', axis=0, inplace=True)
corr_matrix_abs.head(10)


In [None]:
# from scipy.stats import boxcox
# normalized_data = boxcox(df.Solids)
# sns.distplot(normalized_data[0], label='boxcox')
# plt.legend()

In [None]:
# solids_norm = (df.Solids-df.Solids.mean())/df.Solids.std()
# sns.distplot(df.Solids, label='orig')
# plt.show()

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, ConfusionMatrixDisplay, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split


from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import PrecisionRecallDisplay
from sklearn.inspection import permutation_importance
from sklearn.inspection import permutation_importance


In [None]:
def print_metrics(y_val, y_pred):
    # overall accuracy
    accuracy = accuracy_score(y_val, y_pred)
    print("Accuracy:", accuracy)
    # calculate accuracy for each of the both classes separately
    accuracy_class_0 = accuracy_score(y_val[y_val==0], y_pred[y_val==0])
    accuracy_class_1 = accuracy_score(y_val[y_val==1], y_pred[y_val==1])
    print("Accuracy class 0:", accuracy_class_0)
    print("Accuracy class 1:", accuracy_class_1)
    # print the classification report
    print("Classification report:\n", classification_report(y_val, y_pred))
    # print the confusion matrix
    print("Confusion matrix:\n", confusion_matrix(y_val, y_pred))
    print("Confusion matrix normalized:\n", confusion_matrix(y_val, y_pred, normalize='true'))

def plot_confusion_matrix(y_val, y_pred):
    # plot confusion matrix and normalized confusion matrix side by side
    fig, ax = plt.subplots(1,2, figsize=(10,5))
    cm = confusion_matrix(y_val, y_pred)
    cm_norm = confusion_matrix(y_val, y_pred, normalize='true')
    ax[0].set_title('Confusion Matrix')
    ax[1].set_title('Normalized Confusion Matrix')
    ConfusionMatrixDisplay(cm).plot(ax=ax[0])
    ConfusionMatrixDisplay(cm_norm).plot(ax=ax[1])
    plt.show()

def plot_roc_and_precision_recall_curve(y_val, y_pred_proba):
    # plot roc curve and precision recall curve side by side
    fig, ax = plt.subplots(1,2, figsize=(10,4))
    fpr, tpr, thresholds = roc_curve(y_val, y_pred_proba)
    roc_auc = auc(fpr, tpr)
    RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc).plot(ax=ax[0])
    ax[0].plot([0, 1], [0, 1], 'k--')
    ax[0].set_title('ROC Curve')
    ax[0].legend(loc="lower right")

    precision, recall, thresholds = precision_recall_curve(y_val, y_pred_proba)
    average_precision = average_precision_score(y_val, y_pred_proba)
    PrecisionRecallDisplay(precision=precision, recall=recall, average_precision=average_precision).plot(ax=ax[1])
    ax[1].set_title('Precision Recall Curve')
    ax[1].legend(loc="lower left")
    plt.show()

def plot_feature_importance(model, X_train, y_train):
    # plot the feature importance
    result = permutation_importance(model, X_train, y_train, n_repeats=10, random_state=seed)
    sorted_idx = result.importances_mean.argsort()
    fig, ax = plt.subplots(figsize=(12,5))
    ax.boxplot(result.importances[sorted_idx].T, vert=False, labels=X_train.columns[sorted_idx])
    ax.set_title("Permutation Importances (train set)")
    fig.tight_layout()
    plt.show()

In [None]:
# training a logistic regression model

# define the target column
target_column = 'Diabetes_binary'
# define the feature columns
feature_columns = [col for col in df.columns if col != target_column]

# split the data into train and val data
X_train = df_train[feature_columns]
y_train = df_train[target_column]
X_val = df_val[feature_columns]
y_val = df_val[target_column]

class_weight = len(y_train) / np.bincount(y_train)
class_weight = {0: class_weight[0], 1: class_weight[1]}

In [None]:
# create the model
model = LogisticRegression(random_state=seed, max_iter=1000)
# train the model
model.fit(X_train, y_train)
# predict the target values
y_pred = model.predict(X_val)
y_pred_proba = model.predict_proba(X_val)[:,1]

print_metrics(y_val, y_pred)
plot_confusion_matrix(y_val, y_pred)
plot_roc_and_precision_recall_curve(y_val, y_pred_proba)
plot_feature_importance(model, X_train, y_train)

---

In [None]:
# create the model
model = LogisticRegression(random_state=seed, max_iter=1000, class_weight='balanced')
# train the model
model.fit(X_train, y_train)
# predict the target values
y_pred = model.predict(X_val)
y_pred_proba = model.predict_proba(X_val)[:,1]

print_metrics(y_val, y_pred)
plot_confusion_matrix(y_val, y_pred)
plot_roc_and_precision_recall_curve(y_val, y_pred_proba)
plot_feature_importance(model, X_train, y_train)

----

In [None]:

# create the model
model = LogisticRegression(random_state=seed, max_iter=1000, class_weight=class_weight)
# train the model
model.fit(X_train, y_train)
# predict the target values
y_pred = model.predict(X_val)
y_pred_proba = model.predict_proba(X_val)[:,1]

print_metrics(y_val, y_pred)
plot_confusion_matrix(y_val, y_pred)
plot_roc_and_precision_recall_curve(y_val, y_pred_proba)
plot_feature_importance(model, X_train, y_train)

In [None]:
# plot roc curve and precision recall curve side by side    
fig, ax = plt.subplots(1,2, figsize=(10,4))
#determine sample weights for roc curve
sample_weight = np.ones(len(y_val))
sample_weight[0] = 1/sum(y_val==0)
sample_weight[1] = 1/sum(y_val==1)

fpr, tpr, thresholds = roc_curve(y_val, y_pred_proba, sample_weight=sample_weight)
roc_auc = auc(fpr, tpr)
RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc).plot(ax=ax[0])
ax[0].plot([0, 1], [0, 1], 'k--')
ax[0].set_title('ROC Curve')
ax[0].legend(loc="lower right")

precision, recall, thresholds = precision_recall_curve(y_val, y_pred_proba, sample_weight=sample_weight)
average_precision = average_precision_score(y_val, y_pred_proba)
PrecisionRecallDisplay(precision=precision, recall=recall, average_precision=average_precision).plot(ax=ax[1])
ax[1].set_title('Precision Recall Curve')
ax[1].legend(loc="lower left")
plt.show()


sample_weight