# 1. Import libraries and variables

In [1]:
import pandas as pd
import numpy as np
import os 
import pickle
import xgboost as xgb
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from os import environ
import psycopg2
from contextlib import contextmanager
import time 
import boto3
import os
from model_utils.model_config import ModelConfig
from telep.config.enums import ModelName
from model_utils.file_utils import S3Storage
from normalized_protocol_names.api import NormalizedProtocolNames
from model_prepare import ModelPrepare
from model_train import ModelTrain

HOME_DIR = os.getcwd()

DATA_DIR = os.path.join(HOME_DIR, 'derived_data/')
MODEL_DIR = os.path.join(HOME_DIR, 'models/')
if os.path.exists(DATA_DIR)==False: 
    os.mkdir(DATA_DIR)
if os.path.exists(MODEL_DIR)==False: 
    os.mkdir(MODEL_DIR)

REGISTRY_BUCKET = "s3://prod.ml-model-registry.*company-data-covered*" 
REGISTER_MODEL = False  # set this to True if we want to register


ID_COL = 'care_request_id'
LABEL_COLS = ['label_iv', 'label_catheter', 'label_rx_admin']

# start_date = environ['TRAIN_START_DATE']
# end_date = environ['TRAIN_END_DATE']
start_date = "2020-07-01"
end_date = "2022-06-27"

# author_email = environ['AUTHOR_EMAIL']
author_email = "tarun.narasimhan@*company-data-covered*.com"

In [2]:
# get an s3 client; the profile name will depend on your own AWS credential setup
session = boto3.Session()
s3 = session.client("s3")

# get the latest risk protocol mapping version
normalized_protocol_names = NormalizedProtocolNames(s3)
latest_rp_version = normalized_protocol_names.get_latest_version()
print(f'latest risk protocol mapping version = {latest_rp_version}')

latest risk protocol mapping version = DFXqOQRbDgZAE1K4XOHN7N8jefJ0cvoH


In [3]:
CONNECTION_ARGS = {
        'host': environ['REDSHIFT_HOST'],
        'database': environ['REDSHIFT_DATABASE'],
        'user': environ['REDSHIFT_USER'],
        'password': environ['REDSHIFT_PASSWORD'],
        'port': environ['REDSHIFT_PORT']
    }

@contextmanager
def get_connection():
    conn = psycopg2.connect(**CONNECTION_ARGS)
    try:
        yield conn
    finally:
        conn.close()

def get_df_from_query(query : str, params=None):
    '''Return pandas dataframe from SQL query

    Args:
        query (string): String containing SQL query
        params (any, optional): Optional keyword parmaeters for `pd.read_sql`. Defaults to None.

    Returns:
        pd.DataFrame
    '''
    with get_connection() as conn:
        df = pd.read_sql(query, conn, params=params)
        return df

SUBSTRINGS = [
    'headache',
    'gastric',
    'self harm',
    'neuro',
    'gyne',
    'wound',
    'heart',
    'calling',
    'head injury',
    'abdominal',
    'laceration',
    'urin',
    'bite',
    'rect',
    'partner',
    'lab',
    'diarrhea',
    'hallucinat',
    'post',
    'burn',
    'nausea',
    'weakness',
    'home',
    'boarding',
    'assessment',
    'suicid',
    'cough',
    'wellness',
    'pacemaker',
    'rash',
    'rn',
    'cellulitis',
    'seizure',
    'hpotension',
    'covid',
    'overdose',
    'fever',
    'hypertension',
    'hypotension',
    'syncope',
    'dizz',
    'dental',
    'allerg',
    'advanced care',
    'chest',
    'stool',
    'flu',
    'numb',
    'leth',
    'follow',
    'extremity',
    'anxiety',
    'catheter',
    'testic',
    'leg',
    'ear',
    'ems',
    'hospice',
    'vision',
    'domestic',
    'breath',
    'general',
    'spine',
    'sugar',
    'dehy',
    'constip',
    'preg',
    'throat',
    'back',
    'fall',
    'sinus',
    'syncompe',
    'education',
    'nose',
    'confus',
    'palpitat'
]

SUBSTRING_MAP = {
    'cough': 'flu',
    'headache': 'flu',
    'leth': 'weakness',
    'post': 'follow',
    'cellulitis': 'rash',
    'leg': 'extremity',
    'partner': 'calling',
    'suicid': 'self harm',
    'laceration': 'wound'
 }

class ProtocolMap:
    def get(self, protocol):
        if type(protocol) is not str:
            return None
        matches = [s for s in SUBSTRINGS if s in protocol.lower()]
        if len(matches) > 0:
            best_match = max(matches, key=len)
            return SUBSTRING_MAP.get(best_match, best_match)


# 2. Labels
**Note**: nulls in `is_medication_administered` come from the left join of care request's onto patient prescription data, and are therefore handled as negatives. 

In [4]:
q = f'''
WITH x AS (
	SELECT
		dm.care_request_id,
		dhp.clinical_encounter_id,
		dhp.medication_id,
		dhp.vaccine_route,
		CASE WHEN vaccine_route = 'PO' THEN
			TRUE
		WHEN dhp.clinical_encounter_id IS NULL THEN
			FALSE
		ELSE
			FALSE
		END AS was_oral,
		CASE WHEN was_administered IS NULL THEN FALSE ELSE was_administered END as was_administered,
		CASE WHEN was_oral=FALSE
			AND was_administered THEN
			TRUE
		ELSE
			FALSE
		END AS was_non_oral_admin -- this is TRUE iff a medication was administered AND it was non-oral
	FROM
		core.care_request_care_delivery_mart dm
		LEFT JOIN core.dh_prescriptions dhp ON dm.clinical_encounter_id = dhp.clinical_encounter_id
	WHERE
		dm.complete_datetime_local BETWEEN '{start_date}' AND '{end_date}'
		AND is_acute_care_visit = 'Y'
	)
SELECT
	care_request_id,
	bool_or(was_non_oral_admin) AS is_non_oral_admin
FROM
	x
GROUP BY
	care_request_id
ORDER BY
	care_request_id
;
'''
rx = get_df_from_query(q)

# handle nulls as negatives
rx['label_rx_admin'] = np.where(rx['is_non_oral_admin']==True, True, False)
# rx = rx.drop(columns='is_medication_administered')

  df = pd.read_sql(query, conn, params=params)


In [5]:
q = f'''
	SELECT
		dm.care_request_id,
		CASE
			WHEN POSITION('96360' IN procedure_codes_aggregated) THEN 1
			WHEN POSITION('96361' IN procedure_codes_aggregated) THEN 1
			WHEN POSITION('96365' IN procedure_codes_aggregated) THEN 1
			WHEN POSITION('96374' IN procedure_codes_aggregated) THEN 1
			WHEN POSITION('J7030' IN procedure_codes_aggregated) THEN 1
			ELSE 0 END
		AS label_iv_proc,
		CASE WHEN vaccine_route='IV' THEN 1 ELSE 0 END
		AS label_iv_med,
		CASE WHEN label_iv_proc=1 OR label_iv_med=1 THEN 1 ELSE 0 END AS label_iv,
		CASE
			WHEN POSITION('51701' IN procedure_codes_aggregated) THEN 1
			WHEN POSITION('51702' IN procedure_codes_aggregated) THEN 1
			WHEN POSITION('51703' IN procedure_codes_aggregated) THEN 1
			WHEN POSITION('51705' IN procedure_codes_aggregated) THEN 1
			WHEN POSITION('99507' IN procedure_codes_aggregated) THEN 1
			WHEN POSITION('A4338' IN procedure_codes_aggregated) THEN 1
			WHEN POSITION('P9612' IN procedure_codes_aggregated) THEN 1
			ELSE 0 END
		AS label_catheter
	FROM
		core.care_request_care_delivery_mart dm
		LEFT JOIN core.dh_prescriptions pp ON dm.clinical_encounter_id = pp.clinical_encounter_id
	WHERE
		is_acute_care_visit='Y'
	AND
		complete_datetime_local BETWEEN '{start_date}' AND '{end_date}'
	;
'''
proc = get_df_from_query(q)
proc = proc.drop(columns=['label_iv_proc', 'label_iv_med'])

# drop duplicates
proc = proc.drop_duplicates()
# for care_request_id's that have multiple medications
proc = proc.groupby('care_request_id', as_index=False).max()

  df = pd.read_sql(query, conn, params=params)


In [6]:
resp = rx.merge(proc, on='care_request_id')

for col in ['label_iv', 'label_catheter']:
    resp[col] = resp[col].astype(bool)
    
print("Number of label rows:", resp.shape[0])

Number of label rows: 425301


# 3. Features

In [7]:
query_path = 'features.sql'
with open(query_path) as f:
    query = f.read()
args = {"start_date": start_date, "end_date": end_date}

ft = get_df_from_query(query, args)

  df = pd.read_sql(query, conn, params=params)


In [8]:
# additional duplicates coming from medical and social hx
ft = ft.drop_duplicates('care_request_id')

In [9]:
df = pd.merge(resp, ft, on='care_request_id', how='inner')

print("Number of feature-response rows: ", df.shape[0])

Number of feature-response rows:  424495


# 4. Preprocessing and Training

In [10]:
df["month"] = df["created_date"].apply(lambda dt: dt.month)
df["protocol_keyword"] = df["risk_protocol_standardized"].apply(ProtocolMap().get)

In [11]:
df.head()

Unnamed: 0,care_request_id,is_non_oral_admin,label_rx_admin,label_iv,label_catheter,created_date,market_short_name,place_of_service,risk_protocol_standardized,risk_protocol,...,x_med,y_med,z_med,responses,prev_iv,prev_catheter,prev_rx_admin,patient_visit_count,month,protocol_keyword
0,293936,False,False,False,False,2020-06-24 22:30:45.835530+00:00,RIC,Home,*company-data-covered* Acute Care - follow up visit,*company-data-covered* Acute Care - follow up visit,...,0.0,0.0,0.0,"{""questions"":[{""weight_yes"":10,""weight_no"":0,""...",0.0,0.0,0.0,2.0,6,follow
1,294208,False,False,False,False,2020-06-25 16:11:28.237281+00:00,HOU,Home,Skin rash(cellulitis)/skin abscesses - extremi...,Skin rash(cellulitis)/skin abscesses - extremi...,...,,,,"{""questions"":[{""weight_yes"":5.5,""weight_no"":0,...",,,,1.0,6,rash
2,295670,False,False,False,False,2020-06-27 17:39:19.431410+00:00,POR,Home,Weakness / Lethargy / Dehydration,Dehydration,...,,,,"{""questions"":[{""weight_yes"":5.5,""weight_no"":0,...",,,,1.0,6,weakness
3,295936,False,False,False,False,2020-06-28 00:58:29.729062+00:00,RIC,Home,*company-data-covered* Acute Care - follow up visit,*company-data-covered* Acute Care - follow up visit,...,0.0,0.0,2.0,"{""questions"":[{""weight_yes"":5.5,""weight_no"":0,...",0.0,0.0,0.0,3.0,6,follow
4,295951,False,False,False,False,2020-06-28 01:48:26.801550+00:00,DEN,Home,*company-data-covered* Acute Care - follow up visit,*company-data-covered* Acute Care - follow up visit,...,,,,"{""questions"":[{""weight_yes"":10,""weight_no"":0,""...",0.0,0.0,0.0,2.0,6,follow


In [12]:
trained_models = {}

for LABEL_COL in LABEL_COLS:
    df_label = df.copy()
    print(f"*** Starting {LABEL_COL} preprocessing ***")
    # drop other label columns
    for col in df.columns:
        if col.startswith("label_") and col != LABEL_COL:
            df_label.drop(columns=col, inplace=True)

    # do random splitting 
    Preparer = ModelPrepare(df_label, ID_COL, LABEL_COL)
    Preparer.random_splitter()
    trainX, validX, testX = Preparer.return_features(return_pandas=True)
    trainY, validY, testY = Preparer.return_responses()
    trainY, validY, testY = np.int64(trainY), np.int64(validY), np.int64(testY)

    # create transformer
    trans = [
    ("protocol", OneHotEncoder(handle_unknown="ignore"), ["protocol_keyword"]),
    ("age",
        Pipeline(
            [
                ("impute", SimpleImputer()),
                ("scaler", StandardScaler())
            ]
        ),
        ["patient_age"],
    ),
    ("risk_score",
        Pipeline(
            [
                ("impute", SimpleImputer()), 
                ("scaler", StandardScaler())
            ]),
        ["risk_score"],
    ),
    ("pos", OneHotEncoder(handle_unknown="ignore"), ["place_of_service"]),
    ("market", OneHotEncoder(handle_unknown="ignore"), ["market_short_name"]),
    ("month", StandardScaler(), ["month"]),
    ("gender", OneHotEncoder(handle_unknown="ignore"), ["patient_gender"])
    ]

    column_trans = ColumnTransformer(
        trans, 
        verbose_feature_names_out=False
    )

    trainX = column_trans.fit_transform(trainX)
    validX = column_trans.transform(validX)
    testX = column_trans.transform(testX)

    # train the model
    _, c = np.unique(trainY, return_counts=True)
    print("Class balance:", list(np.round(c/np.sum(c), 2)))

    print(f"*** Starting {LABEL_COL} training ***")
    Trainer = ModelTrain(
        xgb.XGBClassifier, 
        objective = 'binary:logistic', 
        use_label_encoder=False
    )

    Trainer.tune_hyperparams(
        trainX,
        trainY,
        validX,
        validY, 
        hyper_eval_metric = 'auroc',
        hyper_eval_metric_max = True, 
        eval_set=[(trainX, trainY), (validX,validY)],
        verbose=False
    )
    model = Trainer.get_model()
    trained_models[LABEL_COL] = model

    if REGISTER_MODEL:    
        # Save model config object to model registry s3 bucket
        print(f"*** Saving {LABEL_COL} model config object to s3 model registry")
    
        model_name = ModelName[LABEL_COL]
        curr_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

        model_config = ModelConfig(
            model_name=model_name, # must use one of the enums from ModelName
            model=model,
            training_set=(trainX, trainY),
            test_set=(validX, validY),
            column_transformer=column_trans,
            risk_protocol_mapping_version=latest_rp_version,
            author_email=author_email,
            description=f'{LABEL_COL} model - {curr_time}'
        )

        latest_model_dir = model_config.save_to_model_registry(
            model_registry_home=REGISTRY_BUCKET, 
            s3=s3
        )
        print(f'{LABEL_COL} model was saved to the following path:\n{latest_model_dir}')

*** Starting label_iv preprocessing ***
Random seed: 123
Feature rows: 297146 63674 63675
Response rows: 297146 63674 63675
Class balance: [0.94, 0.06]
*** Starting label_iv training ***

Training hyperparameter set 1/1




Finished hyperparameter set 1/1; valid auroc=0.869
*** Starting label_catheter preprocessing ***
Random seed: 123
Feature rows: 297146 63674 63675
Response rows: 297146 63674 63675
Class balance: [0.98, 0.02]
*** Starting label_catheter training ***

Training hyperparameter set 1/1




Finished hyperparameter set 1/1; valid auroc=0.932
*** Starting label_rx_admin preprocessing ***
Random seed: 123
Feature rows: 297146 63674 63675
Response rows: 297146 63674 63675
Class balance: [0.9, 0.1]
*** Starting label_rx_admin training ***

Training hyperparameter set 1/1




Finished hyperparameter set 1/1; valid auroc=0.782


In [13]:
print(f''' 
*******************************
Finished training Tele-P models
*******************************
''')

 
*******************************
Finished training Tele-P models
*******************************

