# Predict the outcome of an encounter, given historic data
##<img src="https://databricks.com/wp-content/themes/databricks/assets/images/header_logo_2x.png" alt="logo" width="150"/> 
Distributed ML with MLFlow and Hyperopt
In this notebook, we train a model to predict whether a patient is at risk of a given codition, using the patient's encounter history and demographic information. 

<ol>
  <li> **Data**: We use the dataset created in the previouse step that we created using simulated patient records.</li>
  <li> **Parameteres**: Users can specify the target condition (to be predicted), the number of comorbid conditions to include and the number of days of record before the most recent encounter
  <li> **Model Training**: We use [*hyperopt*](https://docs.databricks.com/applications/machine-learning/automl/hyperopt/index.html#hyperopt) with SparkTrials for distributed hyperparameter tuning and trying different algorithms</li>
  <li> **Model tracking and management**: Using [*MLFlow*](https://docs.databricks.com/applications/mlflow/index.html#mlflow), we track our training experiments and log the models for the best model </li>
</ol>

In [2]:
%run ./00-rwe-etl-delta

## 1. Specify paths and parameters

In [4]:
dbutils.widgets.text('condition', 'drug overdose', 'Condition to model')
dbutils.widgets.text('num_conditions', '10', '# of comorbidities to include')
dbutils.widgets.text('num_days', '90', '# of days to use')

In [5]:
condition=dbutils.widgets.get('condition')
num_conditions=int(dbutils.widgets.get('num_conditions'))
num_days=int(dbutils.widgets.get('num_days'))

In [6]:
## Specify the path to delta tables on dbfs
delta_path = "dbfs:/tmp/rwe-ehr/delta"

## 2. Data preparation
To create the training data, we need to extract a dataset with both positive (affected ) and negative (not affected) labels.

In [8]:
from pyspark.sql import Window
from pyspark.sql.functions import *
from pyspark.sql.types import *

### 2.1 Load Tables

In [10]:
patient_encounters = spark.read.load(delta_path+'/patient_encounters')
# patient_encounters = spark.sql('Select * from rwd.patient_encounters')

### 2.2 Create a list of patients to include

In [12]:
all_patients=patient_encounters.select('PATIENT').dropDuplicates()

# get the list of patients with the target condition (cases)
positive_patients = patient_encounters.select('PATIENT').where(lower("REASONDESCRIPTION").like("%{}%".format(condition))).dropDuplicates().withColumn('label',lit(1))

negative_patients = (
  all_patients
  .join(positive_patients,on=['PATIENT'],how='left_anti')
  .limit(positive_patients.count())
  .withColumnRenamed('Id','PATIENT')
  .withColumn('label',lit(0))
  .dropDuplicates()
)

# create a list of patients to include in training 
patients_to_study = positive_patients.union(negative_patients).cache()
patients_to_study.groupBy('label').count().show()

In [13]:
patients_data_df = patient_encounters.join(patients_to_study,on=['PATIENT'])

### 2.3 Limit encounters to those within the given window of time
and add age at the time of diagnosis

In [15]:
w = (
    Window.partitionBy('PATIENT').orderBy(desc('START_TIME'))
)

patients_recent_diag_date_df = (
  patients_data_df
  .select('PATIENT','START_TIME')
  .withColumn('rank',row_number().over(w))
  .filter('rank == 1')
  .withColumnRenamed('START_TIME','most_recent_diag_date')
  .drop('rank')
  .dropDuplicates()
)

In [16]:
cols=['ORGANIZATION','ENCOUNTERCLASS', 'BIRTHDATE', 'ETHNICITY', 'GENDER','REASONDESCRIPTION','age_at_diag_date', 'label']
patients_data_limit_days =(
    patients_data_df
    .join(patients_recent_diag_date_df,on=['PATIENT'])
    .withColumn('days_diff',datediff(col('most_recent_diag_date'),col('START_TIME')))
    .withColumn('age_at_diag_date',datediff(col('most_recent_diag_date'),col('BIRTHDATE'))/365)
    .filter('days_diff < {}'.format(num_days))
    .select(cols)
  )

In [17]:
display(patients_data_limit_days)

ORGANIZATION,ENCOUNTERCLASS,BIRTHDATE,ETHNICITY,GENDER,REASONDESCRIPTION,age_at_diag_date,label
ac8356a5-78f8-3a63-8a1e-59e832fd54e7,outpatient,1969-11-04,swedish,F,,49.32602739726028,0
ac8356a5-78f8-3a63-8a1e-59e832fd54e7,outpatient,1969-11-04,swedish,F,,49.32602739726028,0
4861d01f-019c-3dac-a153-8334e50919f9,outpatient,1981-08-16,french_canadian,M,,37.72876712328768,1
f4e7709c-02f6-37ca-aeea-8247d74e88e7,wellness,1954-08-11,italian,F,,64.34246575342466,1
c44f361c-2efb-3050-8f97-0354a12e2920,wellness,1999-06-27,polish,M,,19.17808219178082,0
3d10019f-c88e-3de5-9916-6107b9c0263d,wellness,1975-06-24,chinese,M,,42.057534246575344,0
f1fbcbfb-fcfa-3bd2-b7f4-df20f1b3c3a4,wellness,1974-07-08,puerto_rican,M,,44.53150684931507,0
6f122869-a856-3d65-8db9-099bf4f5bbb8,wellness,1962-01-21,scottish,M,,57.15068493150685,0
f1fbcbfb-fcfa-3bd2-b7f4-df20f1b3c3a4,outpatient,2000-01-17,south_american,F,,19.235616438356164,0
f1fbcbfb-fcfa-3bd2-b7f4-df20f1b3c3a4,wellness,2000-01-17,south_american,F,,19.235616438356164,0


### 2.4 Add comomorbidity features

In [19]:
#create a dataframe of comorbid conditions
comorbid_conditions = (
  positive_patients.join(patient_encounters, ['PATIENT'])
  .where(col('REASONDESCRIPTION').isNotNull())
  .dropDuplicates(['PATIENT', 'REASONDESCRIPTION'])
  .groupBy('REASONDESCRIPTION').count()
  .orderBy('count', ascending=False)
  .limit(num_conditions)
)

display(comorbid_conditions)

REASONDESCRIPTION,count
Drug overdose,488
Chronic pain,357
Chronic intractable migraine without aura,356
Impacted molars,348
Viral sinusitis (disorder),330
Acute viral pharyngitis (disorder),237
Acute bronchitis (disorder),205
Normal pregnancy,196
Anemia (disorder),101
Hyperlipidemia,97


In [20]:
comorbidity_list = comorbid_conditions.withColumn('weight',col('count')/max('count').over(Window.partitionBy())).collect()
encounter_features=patients_data_limit_days

idx = 0
for comorbidity in comorbidity_list[1:]:
    encounter_features = (
      encounter_features
      .withColumn("comorbidity_%d" % idx, comorbidity['weight']*(encounter_features['REASONDESCRIPTION'].like('%' + comorbidity['REASONDESCRIPTION'] + '%')).cast('int'))
      .withColumn("comorbidity_%d"  % idx,coalesce(col("comorbidity_%d" % idx),lit(0))) # replacing null values with 0
      .cache()
    )
    idx += 1

In [21]:
display(encounter_features)

ORGANIZATION,ENCOUNTERCLASS,BIRTHDATE,ETHNICITY,GENDER,REASONDESCRIPTION,age_at_diag_date,label,comorbidity_0,comorbidity_1,comorbidity_2,comorbidity_3,comorbidity_4,comorbidity_5,comorbidity_6,comorbidity_7,comorbidity_8
ac8356a5-78f8-3a63-8a1e-59e832fd54e7,outpatient,1969-11-04,swedish,F,,49.32602739726028,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
ac8356a5-78f8-3a63-8a1e-59e832fd54e7,outpatient,1969-11-04,swedish,F,,49.32602739726028,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4861d01f-019c-3dac-a153-8334e50919f9,outpatient,1981-08-16,french_canadian,M,,37.72876712328768,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
f4e7709c-02f6-37ca-aeea-8247d74e88e7,wellness,1954-08-11,italian,F,,64.34246575342466,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
c44f361c-2efb-3050-8f97-0354a12e2920,wellness,1999-06-27,polish,M,,19.17808219178082,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3d10019f-c88e-3de5-9916-6107b9c0263d,wellness,1975-06-24,chinese,M,,42.057534246575344,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
f1fbcbfb-fcfa-3bd2-b7f4-df20f1b3c3a4,wellness,1974-07-08,puerto_rican,M,,44.53150684931507,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6f122869-a856-3d65-8db9-099bf4f5bbb8,wellness,1962-01-21,scottish,M,,57.15068493150685,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
f1fbcbfb-fcfa-3bd2-b7f4-df20f1b3c3a4,outpatient,2000-01-17,south_american,F,,19.235616438356164,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
f1fbcbfb-fcfa-3bd2-b7f4-df20f1b3c3a4,wellness,2000-01-17,south_american,F,,19.235616438356164,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


## 3. Model selection and training

In [23]:
import pandas as pd
import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier

from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score

import mlflow
import mlflow.sklearn

from hyperopt import fmin, tpe, hp, SparkTrials, STATUS_OK, Trials, space_eval
np.random.seed(42)

### 3.1 Create dataset for training using scikit-learn

In [25]:
encounter_features_pdf=encounter_features.select(cols).toPandas()
data_pdf=pd.get_dummies(encounter_features_pdf) # transform categorical variables to vector of indicator functions
X=data_pdf.drop(['label'],axis=1).values
y=data_pdf['label'].values

### 3.2 Define a training function for model selection using hyperopt

In [27]:
def train(params):
    np.random.seed(42)
    classifier_type = params['type']
    tune=params['tune']
    del params['type']
    del params['tune']
    
    if classifier_type == 'svm':
        clf = SVC(**params)
    elif classifier_type == 'lgr':
        clf = LogisticRegression(**params)
    elif classifier_type == 'rndf':
        clf=RandomForestClassifier(**params)
    else:
        return 0
    accuracy = cross_val_score(clf, X, y).mean()
    
    if tune:
      return {'loss': -accuracy, 'status': STATUS_OK}
    else:
      clf.fit(X,y)
      mlflow.sklearn.log_model(clf,'model_clf')
      uri=mlflow.get_artifact_uri(artifact_path='model_clf')
      return(uri)

### 3.3 Define search space

In [29]:
search_space = hp.choice('classifier_type', [
    
    {
        'type': 'lgr',
        'tune': True,
    },
    {
        'type': 'svm',
        'tune': True,
        'C':hp.choice("C", [1.0, 10]),
    },
    {
        'type': 'rndf',
        'tune': True,
#         'max_depth': hp.choice('max_depth', [10, 20]),
    },
    
])

Note that all runs are automatically tracked by `mlflow`

In [31]:
spark_trials = SparkTrials(parallelism=2)

with mlflow.start_run():
  best_result = fmin(
    fn=train, 
    space=search_space,
    algo=tpe.suggest,
    max_evals=32,
    trials=spark_trials
  )


### 3.4 Train the model with the best parameters

In [33]:
params=space_eval(search_space, best_result)
params.update({'tune':False})
model_uri=train(params)

## 4. Load the model for scoring

In [35]:
model=mlflow.sklearn.load_model(model_uri=model_uri)
sample_data=X[1:10]

In [36]:
model.predict(sample_data)