In [2]:
# Import libraries
from datetime import timedelta
import os

import numpy as np
import pandas as pd
import re
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

from IPython.display import display, HTML, Image
%matplotlib inline

plt.style.use('ggplot')
plt.rcParams.update({'font.size': 20})

# Access data using Google BigQuery.
from google.colab import auth
from google.cloud import bigquery

In [3]:
# authenticate
auth.authenticate_user()

In [4]:
# Set up environment variables
project_id = 'handy-tiger-432207-n6'
if project_id == 'CHANGE-ME':
  raise ValueError('You must change project_id to your GCP project.')
os.environ["GOOGLE_CLOUD_PROJECT"] = project_id

# Read data from BigQuery into pandas dataframes.
def run_query(query, project_id=project_id):
  return pd.io.gbq.read_gbq(
      query,
      project_id=project_id,
      dialect='standard')

# set the dataset
# if you want to use the demo, change this to mimic_demo
dataset = 'mimiciv'

In [5]:
## Identify Hypertensive Patients (Use icd code)
query = f"""
SELECT DISTINCT
    d.subject_id,
    d.hadm_id,
    d.icd_code,
    a.admittime,
    a.dischtime,
    a.deathtime
FROM
    physionet-data.mimiciv_hosp.diagnoses_icd d
JOIN
    physionet-data.mimiciv_hosp.admissions a ON d.hadm_id = a.hadm_id
WHERE
    d.icd_code IN ('I10', 'I11', 'I12', 'I13', 'I15')

"""

# Execute the query and get the results
pt = run_query(query)
print(f'Total Records: {len(pt)}')

# Display the first few rows of the result
print(pt.head(10))

Total Records: 51704
   subject_id   hadm_id icd_code           admittime           dischtime  \
0    10106244  26713233      I10 2147-05-09 10:34:00 2147-05-12 13:43:00   
1    15443666  27961368      I10 2168-12-30 23:30:00 2169-01-05 16:02:00   
2    16073738  28380412      I10 2149-11-18 14:06:00 2149-11-19 13:30:00   
3    16908360  27443837      I10 2115-07-22 17:04:00 2115-07-24 10:00:00   
4    16111436  26234204      I10 2138-04-01 02:48:00 2138-04-03 14:00:00   
5    13848026  21165376      I10 2116-12-13 18:08:00 2116-12-22 14:31:00   
6    14746577  20661936      I10 2149-11-28 21:32:00 2149-11-30 16:28:00   
7    10586065  26668795      I10 2190-12-27 14:18:00 2190-12-30 15:27:00   
8    10626477  20688698      I10 2174-07-19 04:34:00 2174-07-23 19:57:00   
9    10874066  20626767      I10 2157-08-12 17:28:00 2157-09-09 11:35:00   

  deathtime  
0       NaT  
1       NaT  
2       NaT  
3       NaT  
4       NaT  
5       NaT  
6       NaT  
7       NaT  
8       NaT  
9 

In [6]:
## Identify Hypertensive Patients (Use blood pressure condition or icd code)
query = f"""
WITH icd_based_hypertension AS (
    -- Select patients based on hypertension-related ICD codes
    SELECT DISTINCT
        d.subject_id,
        d.hadm_id,
        a.admittime,
        a.dischtime,
        a.deathtime,
        'ICD_based' AS hypertension_criteria
    FROM
        `physionet-data.mimiciv_hosp.diagnoses_icd` d
    JOIN
        `physionet-data.mimiciv_hosp.admissions` a ON d.hadm_id = a.hadm_id
    WHERE
        d.icd_code IN ('I10', 'I11', 'I12', 'I13', 'I15')
),
bp_based_hypertension AS (
    -- Select patients based on blood pressure measurements indicating hypertension
    SELECT DISTINCT
        v.subject_id,
        v.hadm_id,
        a.admittime,
        a.dischtime,
        a.deathtime,
        'BP_based' AS hypertension_criteria
    FROM
        `physionet-data.mimiciv_icu.chartevents` v
    JOIN
        `physionet-data.mimiciv_hosp.admissions` a ON v.hadm_id = a.hadm_id
    WHERE
        (v.itemid IN (220179, 223751) AND v.valuenum >= 130)  -- Systolic BP (SBP) >= 130 mmHg
        OR
        (v.itemid IN (220180, 223752) AND v.valuenum >= 80)   -- Diastolic BP (DBP) >= 80 mmHg
)
-- Combine both criteria to identify hypertensive patients
SELECT
    subject_id,
    hadm_id,
    admittime,
    dischtime,
    deathtime,
    hypertension_criteria
FROM
    icd_based_hypertension

UNION ALL

SELECT
    subject_id,
    hadm_id,
    admittime,
    dischtime,
    deathtime,
    hypertension_criteria
FROM
    bp_based_hypertension
ORDER BY
    subject_id, admittime;
"""

# Execute the query and get the results
pt = run_query(query)
print(f'Total Records: {len(pt)}')

# Display the first few rows of the result
print(pt.head(10))

Total Records: 116069
   subject_id   hadm_id           admittime           dischtime deathtime  \
0    10000032  29079034 2180-07-23 12:35:00 2180-07-25 17:55:00       NaT   
1    10000980  26913865 2189-06-27 07:38:00 2189-07-03 03:00:00       NaT   
2    10001217  24597018 2157-11-18 22:56:00 2157-11-25 18:00:00       NaT   
3    10001217  27703517 2157-12-18 16:58:00 2157-12-24 14:55:00       NaT   
4    10001401  21544441 2131-06-04 00:00:00 2131-06-15 16:10:00       NaT   
5    10001401  26840593 2131-06-19 21:32:00 2131-07-02 18:18:00       NaT   
6    10001401  24818636 2131-07-30 21:40:00 2131-08-04 14:10:00       NaT   
7    10001401  27060146 2131-10-01 01:33:00 2131-10-05 15:45:00       NaT   
8    10001401  28058085 2131-11-13 23:15:00 2131-11-15 15:16:00       NaT   
9    10001401  27012892 2133-07-09 22:22:00 2133-07-13 18:43:00       NaT   

  hypertension_criteria  
0              BP_based  
1              BP_based  
2              BP_based  
3              BP_based  


In [7]:
## Track Subsequent Stroke Diagnosis (Use just icd code to define hypertensive patients)
query = f"""
WITH hypertensive_patients AS (
    SELECT DISTINCT
        d.subject_id,
        d.hadm_id,
        a.admittime AS hypertensive_admittime,
        a.dischtime AS hypertensive_dischtime
    FROM
        physionet-data.mimiciv_hosp.diagnoses_icd d
    JOIN
        physionet-data.mimiciv_hosp.admissions a ON d.hadm_id = a.hadm_id
    WHERE
        d.icd_code IN ('I10', 'I11', 'I12', 'I13', 'I15')
),
stroke_events AS (
    SELECT DISTINCT
        d.subject_id,
        d.hadm_id,
        d.icd_code AS stroke_icd_code,
        a.admittime AS stroke_admittime
    FROM
        physionet-data.mimiciv_hosp.diagnoses_icd d
    JOIN
        physionet-data.mimiciv_hosp.admissions a ON d.hadm_id = a.hadm_id
    WHERE
        d.icd_code IN ('I639', 'I64', 'I619', 'I679', 'I694')
)
SELECT
    hp.subject_id,
    hp.hadm_id AS hypertensive_hadm_id,
    hp.hypertensive_admittime,
    se.hadm_id AS stroke_hadm_id,
    se.stroke_admittime,
    se.stroke_icd_code
FROM
    hypertensive_patients hp
LEFT JOIN
    stroke_events se ON hp.subject_id = se.subject_id
WHERE
    se.stroke_admittime > hp.hypertensive_admittime
    AND se.stroke_admittime IS NOT NULL
"""

# Execute the query and get the results
pt = run_query(query)
print(f'Total Records: {len(pt)}')

# Display the first few rows of the result
print(pt.head(10))

Total Records: 327
   subject_id  hypertensive_hadm_id hypertensive_admittime  stroke_hadm_id  \
0    10056612              26462956    2189-08-28 18:52:00        24412612   
1    17729814              28283153    2123-03-08 08:14:00        28954621   
2    13802667              25313546    2124-04-16 08:00:00        22571100   
3    14456616              20516052    2153-06-02 16:49:00        22484462   
4    19585869              22816145    2151-05-01 16:31:00        22293542   
5    11582633              20586686    2142-11-16 23:59:00        24017176   
6    14911129              27965129    2140-06-10 21:34:00        28899456   
7    17918100              22670417    2173-03-24 03:58:00        25625919   
8    19357366              28153964    2185-07-12 20:41:00        27799668   
9    12590289              25100469    2131-03-16 01:50:00        24483581   

     stroke_admittime stroke_icd_code  
0 2191-01-01 00:35:00            I639  
1 2124-03-18 17:05:00            I639  
2 

In [8]:
## Track Subsequent Stroke Diagnosis (Use blood pressure condition or icd code to define hypertensive patients)
query = f"""
WITH icd_based_hypertension AS (
    -- Select patients based on ICD codes
    SELECT DISTINCT
        d.subject_id,
        d.hadm_id,
        a.admittime AS hypertensive_admittime,
        a.dischtime AS hypertensive_dischtime,
        'ICD_based' AS hypertension_criteria
    FROM
        `physionet-data.mimiciv_hosp.diagnoses_icd` d
    JOIN
        `physionet-data.mimiciv_hosp.admissions` a ON d.hadm_id = a.hadm_id
    WHERE
        d.icd_code IN ('I10', 'I11', 'I12', 'I13', 'I15')
),
bp_based_hypertension AS (
    -- Select patients based on elevated BP measurements
    SELECT DISTINCT
        v.subject_id,
        v.hadm_id,
        a.admittime AS hypertensive_admittime,
        a.dischtime AS hypertensive_dischtime,
        'BP_based' AS hypertension_criteria
    FROM
        `physionet-data.mimiciv_icu.chartevents` v
    JOIN
        `physionet-data.mimiciv_hosp.admissions` a ON v.hadm_id = a.hadm_id
    WHERE
        (v.itemid IN (220179, 223751) AND v.valuenum >= 130)  -- Systolic BP (SBP) >= 130 mmHg
        OR
        (v.itemid IN (220180, 223752) AND v.valuenum >= 80)   -- Diastolic BP (DBP) >= 80 mmHg
),
hypertensive_patients AS (
    -- Combine both ICD and BP-based hypertension definitions
    SELECT * FROM icd_based_hypertension
    UNION ALL
    SELECT * FROM bp_based_hypertension
),
stroke_events AS (
    -- Select patients who have had a stroke based on ICD codes
    SELECT DISTINCT
        d.subject_id,
        d.hadm_id,
        d.icd_code AS stroke_icd_code,
        a.admittime AS stroke_admittime
    FROM
        `physionet-data.mimiciv_hosp.diagnoses_icd` d
    JOIN
        `physionet-data.mimiciv_hosp.admissions` a ON d.hadm_id = a.hadm_id
    WHERE
        d.icd_code IN ('I639', 'I64', 'I619', 'I679', 'I694')
)
SELECT
    hp.subject_id,
    hp.hadm_id AS hypertensive_hadm_id,
    hp.hypertensive_admittime,
    se.hadm_id AS stroke_hadm_id,
    se.stroke_admittime,
    se.stroke_icd_code
FROM
    hypertensive_patients hp
LEFT JOIN
    stroke_events se ON hp.subject_id = se.subject_id
WHERE
    se.stroke_admittime > hp.hypertensive_admittime
    AND se.stroke_admittime IS NOT NULL
ORDER BY
    hp.subject_id, se.stroke_admittime;
"""

# Execute the query and get the results
pt = run_query(query)
print(f'Total Records: {len(pt)}')

# Display the first few rows of the result
print(pt.head(10))

Total Records: 653
   subject_id  hypertensive_hadm_id hypertensive_admittime  stroke_hadm_id  \
0    10014610              23258342    2173-12-19 11:00:00        23859571   
1    10014610              23258342    2173-12-19 11:00:00        23859571   
2    10030753              29738545    2198-08-25 21:59:00        23017050   
3    10030753              27218915    2200-10-27 00:21:00        23017050   
4    10030753              27165162    2200-11-13 22:44:00        23017050   
5    10030753              21257920    2199-11-19 06:45:00        23017050   
6    10030753              24506973    2199-07-18 20:02:00        23017050   
7    10030753              20090856    2200-05-23 23:24:00        23017050   
8    10030753              21151005    2199-05-04 17:19:00        23017050   
9    10030753              25110668    2198-07-21 21:58:00        23017050   

     stroke_admittime stroke_icd_code  
0 2174-01-05 02:09:00            I639  
1 2174-01-05 02:09:00            I639  
2 

In [9]:
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix
from imblearn.over_sampling import SMOTE
from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier
import seaborn as sns
import matplotlib.pyplot as plt

In [17]:
## Extract related and useful features.
query = f"""
WITH icd_based_hypertension AS (
    -- Select patients based on hypertension-related ICD codes
    SELECT DISTINCT
        d.subject_id,
        d.hadm_id,
        a.admittime,
        a.dischtime,
        a.deathtime,
        'ICD_based' AS hypertension_criteria
    FROM
        `physionet-data.mimiciv_hosp.diagnoses_icd` d
    JOIN
        `physionet-data.mimiciv_hosp.admissions` a ON d.hadm_id = a.hadm_id
    WHERE
        d.icd_code IN ('I10', 'I11', 'I12', 'I13', 'I15')
),
bp_based_hypertension AS (
    -- Select patients based on blood pressure measurements indicating hypertension
    SELECT DISTINCT
        v.subject_id,
        v.hadm_id,
        a.admittime,
        a.dischtime,
        a.deathtime,
        'BP_based' AS hypertension_criteria
    FROM
        `physionet-data.mimiciv_icu.chartevents` v
    JOIN
        `physionet-data.mimiciv_hosp.admissions` a ON v.hadm_id = a.hadm_id
    WHERE
        (v.itemid IN (220179, 223751) AND v.valuenum >= 130)  -- Systolic BP (SBP) >= 130 mmHg
        OR
        (v.itemid IN (220180, 223752) AND v.valuenum >= 80)   -- Diastolic BP (DBP) >= 80 mmHg
),
hypertensive_patients AS (
    -- Combine both criteria to identify hypertensive patients
    SELECT subject_id, hadm_id, admittime, dischtime, deathtime, hypertension_criteria FROM icd_based_hypertension
    UNION ALL
    SELECT subject_id, hadm_id, admittime, dischtime, deathtime, hypertension_criteria FROM bp_based_hypertension
),
patient_icd_codes AS (
    -- Collect all ICD codes for each patient and determine if they had a stroke
    SELECT
        d.subject_id,
        STRING_AGG(DISTINCT d.icd_code, ', ') AS all_icd_codes,
        MAX(CASE WHEN d.icd_code IN ('I639', 'I64', 'I619', 'I679', 'I694') THEN 1 ELSE 0 END) AS has_stroke
    FROM
        `physionet-data.mimiciv_hosp.diagnoses_icd` d
    GROUP BY
        d.subject_id
),
patient_info AS (
    -- Get patient demographic information
    SELECT DISTINCT
        p.subject_id,
        p.gender,
        p.anchor_age AS age,
        a.race,
        p.anchor_year_group
    FROM
        `physionet-data.mimiciv_hosp.patients` p
    JOIN
        `physionet-data.mimiciv_hosp.admissions` a ON p.subject_id = a.subject_id
),
vital_signs AS (
    -- Aggregate vital signs data for each patient
    SELECT
        v.subject_id,
        MAX(CASE WHEN v.itemid IN (220045, 618) THEN v.valuenum END) AS heart_rate,  -- Heart rate
        MAX(CASE WHEN v.itemid IN (220179, 223751) THEN v.valuenum END) AS systolic_bp,  -- Systolic BP
        MAX(CASE WHEN v.itemid IN (220180, 223752) THEN v.valuenum END) AS diastolic_bp, -- Diastolic BP
        MAX(CASE WHEN v.itemid IN (223762) THEN v.valuenum END) AS temperature,  -- Temperature
        MAX(CASE WHEN v.itemid IN (220210) THEN v.valuenum END) AS resp_rate, -- Respiratory rate
        MAX(CASE WHEN v.itemid IN (220277) THEN v.valuenum END) AS oxygen_saturation -- SpO2
    FROM
        `physionet-data.mimiciv_icu.chartevents` v
    GROUP BY
        v.subject_id
),
lab_results AS (
    -- Aggregate lab results for each patient
    SELECT
        l.subject_id,
        MAX(CASE WHEN l.itemid = 50882 THEN l.valuenum END) AS glucose, -- Glucose
        MAX(CASE WHEN l.itemid = 50902 THEN l.valuenum END) AS creatinine,  -- Creatinine
        MAX(CASE WHEN l.itemid = 50931 THEN l.valuenum END) AS sodium, -- Sodium
        MAX(CASE WHEN l.itemid = 50971 THEN l.valuenum END) AS potassium -- Potassium
    FROM
        `physionet-data.mimiciv_hosp.labevents` l
    GROUP BY
        l.subject_id
),
comorbidities AS (
    -- Extract key comorbidities from ICD codes
    SELECT
        d.subject_id,
        MAX(CASE WHEN d.icd_code IN ('E119', 'E109') THEN 1 ELSE 0 END) AS diabetes,
        MAX(CASE WHEN d.icd_code IN ('N189', 'N179') THEN 1 ELSE 0 END) AS chronic_kidney_disease,
        MAX(CASE WHEN d.icd_code IN ('I509') THEN 1 ELSE 0 END) AS heart_failure
    FROM
        `physionet-data.mimiciv_hosp.diagnoses_icd` d
    GROUP BY
        d.subject_id
),
medications AS (
    -- Get medication usage details
    SELECT
        p.subject_id,
        STRING_AGG(DISTINCT m.drug, ', ') AS medications_administered
    FROM
        `physionet-data.mimiciv_hosp.prescriptions` m
    JOIN
        `physionet-data.mimiciv_hosp.patients` p ON m.subject_id = p.subject_id
    GROUP BY
        p.subject_id
)
-- Final query to join all the information together
SELECT
    hp.subject_id,
    MAX(hp.hadm_id) AS hadm_id,
    MAX(hp.admittime) AS admittime,
    MAX(hp.dischtime) AS dischtime,
    MAX(hp.deathtime) AS deathtime,
    hp.hypertension_criteria,
    MAX(pi.gender) AS gender,
    MAX(pi.age) AS age,
    MAX(pi.race) AS race,
    MAX(pi.anchor_year_group) AS anchor_year_group,
    MAX(vs.heart_rate) AS heart_rate,
    MAX(vs.systolic_bp) AS systolic_bp,
    MAX(vs.diastolic_bp) AS diastolic_bp,
    MAX(vs.temperature) AS temperature,
    MAX(vs.resp_rate) AS resp_rate,
    MAX(vs.oxygen_saturation) AS oxygen_saturation,
    MAX(lr.glucose) AS glucose,
    MAX(lr.creatinine) AS creatinine,
    MAX(lr.sodium) AS sodium,
    MAX(lr.potassium) AS potassium,
    MAX(c.diabetes) AS diabetes,
    MAX(c.chronic_kidney_disease) AS chronic_kidney_disease,
    MAX(c.heart_failure) AS heart_failure,
    MAX(med.medications_administered) AS medications_administered,
    MAX(pic.all_icd_codes) AS all_icd_codes,
    MAX(pic.has_stroke) AS has_stroke
FROM
    hypertensive_patients hp
LEFT JOIN
    patient_info pi ON hp.subject_id = pi.subject_id
LEFT JOIN
    vital_signs vs ON hp.subject_id = vs.subject_id
LEFT JOIN
    lab_results lr ON hp.subject_id = lr.subject_id
LEFT JOIN
    comorbidities c ON hp.subject_id = c.subject_id
LEFT JOIN
    medications med ON hp.subject_id = med.subject_id
LEFT JOIN
    patient_icd_codes pic ON hp.subject_id = pic.subject_id
GROUP BY
    hp.subject_id, hp.hypertension_criteria
ORDER BY hp.subject_id, MAX(hp.admittime);
"""

# Execute the query and get the results
pt = run_query(query)
print(f'Total Records: {len(pt)}')

# Display the first few rows of the result
print(pt.head(10))

Total Records: 80990
   subject_id   hadm_id           admittime           dischtime  \
0    10000032  29079034 2180-07-23 12:35:00 2180-07-25 17:55:00   
1    10000980  26913865 2189-06-27 07:38:00 2189-07-03 03:00:00   
2    10001217  27703517 2157-12-18 16:58:00 2157-12-24 14:55:00   
3    10001401  28058085 2133-07-09 22:22:00 2133-07-13 18:43:00   
4    10001725  25563031 2110-04-11 15:08:00 2110-04-14 15:00:00   
5    10001884  26184834 2131-01-07 20:39:00 2131-01-20 05:15:00   
6    10001884  29678536 2131-01-07 20:39:00 2131-01-20 05:15:00   
7    10002013  23581541 2160-05-18 07:45:00 2160-05-23 13:30:00   
8    10002013  28629319 2167-07-05 06:10:00 2167-07-05 11:44:00   
9    10002131  24065018 2128-03-17 14:53:00 2128-03-19 16:25:00   

            deathtime hypertension_criteria gender  age  \
0                 NaT              BP_based      F   52   
1                 NaT              BP_based      F   73   
2                 NaT              BP_based      F   55   
3    

In [18]:
## Count Number of patients with stroke/without stroke.
stroke_count = pt[pt['has_stroke'] == 1]['subject_id'].nunique()
non_stroke_count = pt[pt['has_stroke'] == 0]['subject_id'].nunique()

print(f"Number of patients with stroke: {stroke_count}")
print(f"Number of patients without stroke: {non_stroke_count}")

Number of patients with stroke: 776
Number of patients without stroke: 68488


In [19]:
print(pt.info())
print(pt.describe())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 80990 entries, 0 to 80989
Data columns (total 26 columns):
 #   Column                    Non-Null Count  Dtype         
---  ------                    --------------  -----         
 0   subject_id                80990 non-null  Int64         
 1   hadm_id                   80990 non-null  Int64         
 2   admittime                 80990 non-null  datetime64[us]
 3   dischtime                 80990 non-null  datetime64[us]
 4   deathtime                 7637 non-null   datetime64[us]
 5   hypertension_criteria     80990 non-null  object        
 6   gender                    80990 non-null  object        
 7   age                       80990 non-null  Int64         
 8   race                      80990 non-null  object        
 9   anchor_year_group         80990 non-null  object        
 10  heart_rate                61512 non-null  float64       
 11  systolic_bp               61463 non-null  float64       
 12  diastolic_bp      

In [29]:
## Handling Null Values.
# Fill null values in numerical columns with the median and mode
numerical_cols = pt.select_dtypes(include=['int64', 'Int64', 'float64']).columns
for col in numerical_cols:
    if pt[col].dtype == 'float64':
        pt[col] = pt[col].fillna(pt[col].median())
    elif pt[col].dtype in ['int64', 'Int64']:
        pt[col] = pt[col].fillna(pt[col].mode()[0])

# Fill null values in categorical columns with the mode
categorical_cols = pt.select_dtypes(include=['object']).columns
for col in categorical_cols:
    pt[col] = pt[col].fillna(pt[col].mode()[0])

# Verify that there are no null values left
print(pt.isnull().sum())

subject_id                      0
hadm_id                         0
admittime                       0
dischtime                       0
deathtime                   73353
hypertension_criteria           0
gender                          0
age                             0
race                            0
anchor_year_group               0
heart_rate                      0
systolic_bp                     0
diastolic_bp                    0
temperature                     0
resp_rate                       0
oxygen_saturation               0
glucose                         0
creatinine                      0
sodium                          0
potassium                       0
diabetes                        0
chronic_kidney_disease          0
heart_failure                   0
medications_administered        0
all_icd_codes                   0
has_stroke                      0
dtype: int64


In [32]:
## Using SMOTE for Imbalanced Dataset
