# Multimodal Fusion for Pulmonary Embolism Classification

In this demonstration, we will recreate the results from our manuscript *Multimodal fusion with deep neural networks for leveraging CT imaging and electronic health record: a case-study in pulmonary embolism detection*. Specifically, we will build a multimodal fusion model (late fusion) that combine information from both CT scans and Electronic Medical Record (EMR) to automatically diagnose the presence of PE. 

![Workflow](./figs/workflow.png)

### Motivation

Pulmonary Embolism (PE) is a serious medical condition that hospitalizes 300,000 people in the United States every year. The gold standard diagnostic modality for PE is Computed Tomography Pulmonary Angiography (CTPA) which is interpreted by radiologists. Studies have shown that prompt diagnosis and treatment can greatly reduce morbidity and mortality. Strategies to automate accurate interpretation and timely reporting of CTPA examinations may successfully triage urgent cases of PE to the immediate attention of physicians, improving time to diagnosis and treatment.

Recent advancements in deep learning have led to a resurgence of medical imaging and Electronic Medical Record (EMR) models for a variety of applications, including clinical decision support, automated workflow triage, clinical prediction and more. However, very few models have been developed to integrate both clinical and imaging data, despite that in routine practice clinicians rely on EMR to provide context in medical imaging interpretation.

### Fusion Strategies
![Fusion Strategies](./figs/fusion_strategies.png)

### Data
We will use a subset of RadFusion, a large-scale multimodal pulmonary embolism detection dataset consisting of 1837 CT imaging studies (comprising 600,000+ 2D slices) for 1794 patients and their corresponding EHR summary data. The full dataset with CT scans can be access via the following link: 
- https://stanfordaimi.azurewebsites.net/datasets/3a7548a4-8f65-4ab7-85fa-3d68c9efc1bd

### References
- Huang, Shih-Cheng, et al. "PENet—a scalable deep-learning model for automated diagnosis of pulmonary embolism using volumetric CT imaging." NPJ digital medicine 3.1 (2020): 1-9.
- Huang, Shih-Cheng, et al. "Multimodal fusion with deep neural networks for leveraging CT imaging and electronic health record: a case-study in pulmonary embolism detection." Scientific reports 10.1 (2020): 1-9.
- Zhou, Yuyin, et al. "RadFusion: Benchmarking Performance and Fairness for Multimodal Pulmonary Embolism Detection from CT and EHR." arXiv preprint arXiv:2111.11665 (2021).

## Research Use Agreement

Before we can proceed to download the data, please agree to this **Research Use Agreement** by registering to download from our website:
- https://stanfordaimi.azurewebsites.net/datasets/3a7548a4-8f65-4ab7-85fa-3d68c9efc1bd


![User Agreement](./figs/UserAgreement.png)


## System Setup & Downloading the Data

In [None]:
!pip install numpy pandas scikit-learn matplotlib
!gdown --id 1w0ocK3br8oqVwn6zK5qgtRaj9Ql37dtd  # /content/Demographics.csv
!gdown --id 1MEhVZ87J2IwFmkgxOi8WjdVKTdwOpDDY  # /content/INP_MED.csv
!gdown --id 1PRgFvQjqEUudeJ0FLR3DbtvqmI7t7sCT  # /content/OUT_MED.csv
!gdown --id 1EDZOYmWrvv6D3XaZrjVous95c9HdiBEx  # /content/Vitals.csv
!gdown --id 1Nlm1ZgibRv6kJBIJkQHkRh8oPqUpELnK  # /content/ICD.csv
!gdown --id 17Y9DJsolaRPyMkk_Xm3w-iCgSOxkQOyf  # /content/LABS.csv
!gdown --id 1JDb5f18uNo2hXXQqcHlRbcjswph1y98h  # /content/Vision.csv

## Data Exploration
After downloading the data, you should be able to find the following files in your directory: 
    
- Demographics.csv 
- INP_MED.csv
- OUT_MED.csv
- Vitals.csv
- ICD.csv
- LABS.csv
- Vision.csv

Let's explore the contents in each file.

In [None]:
# import libraries
import pandas as pd
import numpy as np

### Patient Demographics

The demographic features consist of one-hot encoded gender, race and smoking habits and the age as a numeric variable.

In [None]:
demo_df = pd.read_csv('/content/Demographics.csv')
print(demo_df.shape)
demo_df.head(5)

### Inpatient & Outpatient Medications

641 unique classes of drugs were identified for outpatient medication. Each medication was represented as both the frequency within the 12-month window and a binary label of whether the drug was prescribed to the patient.

In [None]:
out_med_df = pd.read_csv('/content/OUT_MED.csv')
print(out_med_df.shape)
out_med_df.head(5)

In [None]:
in_med_df = pd.read_csv('/content/INP_MED.csv')
print(in_med_df.shape)
in_med_df.head(5)

### ICD Codes

We excluded all ICD codes with less than 1% occurrences in the training dataset and collapsed into top diagnosis categories, which resulted in a total of 141 diagnosis groups. We used a binary presence/absence as well as a frequency to represent diagnosis code as features. All ICD codes recorded with the same encounter number as the patient’s CT exam, or within a 24 hour window prior to their CT examination, were dropped to avoid data leakage.

In [None]:
icd_df = pd.read_csv('/content/ICD.csv')
print(icd_df.shape)
icd_df.head(5)

### Lab Tests

We identified 22 lab tests and represented each test as binary presence/absence as well as the latest value of the test.

In [None]:
lab_df = pd.read_csv('/content/LABS.csv')
print(lab_df.shape)
lab_df.head(5)

### Vitals

For vitals, we included systolic and diastolic blood pressure, height, weight, body mass index (BMI), temperature, respiration rate, pulse oximetry (spO2) and heart rate. The vitals were represented with respect to their sensitivity to change, which was computed by taking the derivative of the vital values along the temporal axis.

In [None]:
vitals_df = pd.read_csv('/content/Vitals.csv')
vitals_df.head(5)

### PE CTs

The RadFusion dataset includes CTPA scans for each study. Due to time and computational constraint, we have ran inference on these CT scans using PENet, and stored the prediction probabilities in **Vision.csv**. Additional, this csv file incldues the labels (PE positive / PE negative), the type of PE (central, segmental and sub-segmental) and the train/val/test split used to develope PENet. For more information about PENet, please refer to: 
- Manuscript: [https://www.nature.com/articles/s41746-020-0266-y](https://www.nature.com/articles/s41746-020-0266-y)
- GitHub: [https://github.com/marshuang80/penet](https://github.com/marshuang80/penet)

In [None]:
# TODO, remove pe_type if label = 0
vision_df = pd.read_csv('/content/Vision.csv')
vision_df.head(5)

## Process Data

We are going to pre-process the EMR data by: 
- Remove any features with zero variance 
- Normalize all features to be within the same range

Next, we are going to combine all the EMR features into one dataframe

In [None]:
processed_emr_dfs = []
for df in [demo_df, out_med_df, in_med_df, icd_df, lab_df, vitals_df]:
    # remove zero variance featurs
    df = df.loc[:,df.apply(pd.Series.nunique) != 1]
    
    # set index 
    df = df.set_index('idx')

    # normalize features
    df = df.apply(lambda x: (x - x.mean())/(x.std()))
    
    processed_emr_dfs.append(df)

emr_df = pd.concat(processed_emr_dfs, axis=1)
emr_df.head(5)

In [None]:
# define columns
EMR_FEATURE_COLS = emr_df.columns.tolist()
PE_TYPE_COL = 'pe_type'
SPLIT_COL = 'split'
VISION_PRED_COL = 'pred'
EMR_PRED_COL = 'emr_pred'
FUSION_PRED_COL = 'late_fusion_pred'
LABEL_COL = 'label'

In [None]:
# join vision information with emr dataframe
vision_df = vision_df.set_index('idx')
df = pd.concat([vision_df, emr_df], axis=1)

In [None]:
# using train and val split for cross validation
# Create data splits
df_dev = df[(df.split == 'train') | (df.split == 'val')]  # for gridsearch CV
df_train = df[df.split == 'train']
df_val = df[df.split == 'val']
df_test = df[df.split == 'test']

## Train EMR Model

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV

# Uncomment and run grid search if time permits
"""
# define model
clf = LogisticRegression(
    penalty='elasticnet', solver='saga', random_state=0
)

# define grid search
param_grid = {
    "C": [0.01, 0.1, 1.0, 100], 
    "class_weight": ['balanced'],
    "max_iter": [1000],
    "l1_ratio": [0.01, 0.25, 0.5, 0.75, 0.99]
}
gsc = GridSearchCV(
    estimator=clf,
    param_grid=param_grid,
    scoring='roc_auc',
    n_jobs=-1,
    verbose=10
)

# run grid search
gsc.fit(df_dev[EMR_FEATURE_COLS], df_dev[LABEL_COL])
print(f"Best parameters: {gsc.best_params_}")
clf = gsc.best_estimator_
"""

clf = LogisticRegression(
    penalty='elasticnet', 
    solver='saga', 
    random_state=0,
    C= 0.1, 
    class_weight='balanced', 
    l1_ratio= 0.99, 
    max_iter= 1000
)
clf.fit(df_train[EMR_FEATURE_COLS], df_train[LABEL_COL])

## Test EMR Model

In [None]:
# test with best model
emr_prob = clf.predict_proba(df_test[EMR_FEATURE_COLS])

# take probability of positive class 
emr_prob = [p[1] for p in emr_prob]

df_test = df_test.assign(emr_pred = emr_prob)

## Late Fusion (Mean Aggregation)

<img src="./figs/late_fusion_mean_agg.png" width="200">

Now that we are have prediction probabilities from both the EMR and Vision model, we will apply a simple late fusion strategy with mean aggregation. 

In [None]:
# Late fusion by taking the average prediction probability from vision model and emr model
late_fusion_pred = np.mean(
    [df_test[EMR_PRED_COL], df_test[VISION_PRED_COL]], 
    axis=0
)
df_test = df_test.assign(late_fusion_pred = late_fusion_pred)

## Evaluate Performance

In [None]:
from sklearn import metrics
import matplotlib.pyplot as plt

plt.style.use('ggplot')
plt.figure(figsize=(7, 7))
lw = 2

def plot_auc(df, label):
    # PENet performance
    fpr_v, tpr_v, _ = metrics.roc_curve(
        df[LABEL_COL], 
        df[VISION_PRED_COL])
    roc_auc_v = metrics.auc(fpr_v, tpr_v)
    plt.plot(
        fpr_v, 
        tpr_v, 
        color='darkorange',
        lw=lw, 
        label='PENet ROC curve (area = %0.2f)' % roc_auc_v)

    # EMR model performance
    fpr_emr, tpr_emr, _ = metrics.roc_curve(
        df[LABEL_COL], 
        df[EMR_PRED_COL])
    roc_auc_emr = metrics.auc(fpr_emr, tpr_emr)
    plt.plot(
        fpr_emr, 
        tpr_emr,
        lw=lw, 
        label='EMR Model ROC curve (area = %0.2f)' % roc_auc_emr)

    # Fusion model performance
    fpr_fusion, tpr_fusion, _ = metrics.roc_curve(
        df[LABEL_COL], 
        df[FUSION_PRED_COL])
    roc_auc_fusion = metrics.auc(fpr_fusion, tpr_fusion)
    plt.plot(
        fpr_fusion, 
        tpr_fusion,
        lw=lw, 
        label='Fusion Model ROC curve (area = %0.2f)' % roc_auc_fusion)

    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([0.0, 0.95])
    plt.ylim([0.0, 1.05])
    plt.axes().set_aspect('equal', 'datalim')

    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'Receiver operating characteristic ({label})')
    plt.legend(loc="lower right")

    plt.show()

In [None]:
# Performance for all cases
plot_auc(df_test, 'All Cases')

In [None]:
# Performance for non-subsegmental cases
df_test_no_subseg = df_test[
    df_test[PE_TYPE_COL] != 'subsegmental']
plot_auc(df_test_no_subseg, 'No Subsegmental')

In [None]:
# Visualize histogram of Predicted Probs
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure

# style
plt.clf()
plt.style.use('ggplot')
matplotlib.rc('xtick', labelsize=10) 
matplotlib.rc('ytick', labelsize=10) 
f, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True, figsize=(21,6), dpi=300)
bins = np.linspace(0, 1, 30)

# seperate cases into positive and negative
positive_cases = df_test_no_subseg[
    df_test_no_subseg[LABEL_COL] == 1]
negative_cases = df_test_no_subseg[
    df_test_no_subseg[LABEL_COL] == 0]

# PENet
ax1.hist(
    [positive_cases[VISION_PRED_COL], negative_cases[VISION_PRED_COL]], 
    bins, 
    label=['positive','negative'], 
    width=0.01)

# EMR
ax2.hist(
    [positive_cases[EMR_PRED_COL], negative_cases[EMR_PRED_COL]], 
    bins, 
    label=['positive', 'negative'], 
    width=0.01)

# Fusion
ax3.hist(
    [positive_cases[FUSION_PRED_COL], negative_cases[FUSION_PRED_COL]], 
    bins, 
    label=['positive','negative'], 
    width=0.01)

f.tight_layout(pad=0.5)
plt.legend(loc='upper right')
ax2.set_xlabel("Predicted Probabilities", fontsize = 25)
ax1.set_ylabel("Count", fontsize = 25)
ax1.set_title('Vision Only', fontsize = 25)
ax2.set_title('EMR Only', fontsize = 25)
ax3.set_title('Fusion', fontsize = 25)
plt.show()

## Bonus: Other Fusion Strategies

![OtherFusionStrategies](./figs/other_fusion_strategies.png)

In [None]:
# Try out other fusion strategies here