# Classifying Chest X-rays with Cross-Modal Data Programming

This tutorial demonstrates how to use the *cross-modal data programming* technique described in Dunnmon and Ratner, et al. (2019) to build a Convolutional Neural Network (CNN) model with no hand-labeled data that performs similarly to a CNN supervised using several thousand data points labeled by radiologists.  We begin by setting up our environment, importing 

In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys, os
# Setting path to Snorkel MeTaL
sys.path.append('../../metal')
# Making sure CUDA devices are visible!
os.environ['CUDA_VISIBLE_DEVICES']='0'

# Importing pandas for data processing
import pandas as pd

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


First, we set up the data dictionary and load data that we've already split for you into an (approximately) 80% train split, 10% development split, and 10% test split.

In [20]:
# Setting up data dictionary and defining data splits
data = {}
splits = ['train','dev','test']

for split in splits:
    data[split] = pd.read_csv(f'data/{split}_entries.csv')[['label','xray_paths','text']]
    perc_pos = sum(data[split]['label'])/len(data[split])
    print(f'{len(data[split])} {split} examples: {100*perc_pos:0.1f}% Abnormal')

2630 train examples: 63.8% Abnormal
376 dev examples: 63.0% Abnormal
378 test examples: 61.6% Abnormal


You can see an example of a single data point below -- note that the raw label convention is 1 for abnormal, 0 for abnormal.

In [34]:
sample = data['train'].iloc[0]
print('RAW TEXT:\n \n',sample['text'],'\n')
print('IMAGE PATHS: \n \n', sample['xray_paths'],'\n')
print('LABEL:', sample['label'])

RAW TEXT:
 
 COMPARISON: Chest x-XXXX XXXX INDICATION: XXXX in bathtub FINDINGS: The lungs and pleural spaces show no acute abnormality. Hyperexpanded lungs. Calcified right upper lobe granuloma, unchanged. Heart size and pulmonary vascularity within normal limits. No displaced rib fractures. IMPRESSION: 1. Hyperexpansion without acute pulmonary abnormality. 

IMAGE PATHS: 
 
 ./data/openi/xrays/CXR2824_IM-1245-13001.png 

LABEL: 1


We now define our *labeling functions*: simple, heuristic functions written by a domain expert (e.g., a radiologist) that correctly label a report as normal or abnormal with probability better than random chance.  Note that we use our labeled *development set* to ensure that we do not include any poorly performing heuristics!

In [35]:
import re
import numpy

ABSTAIN = 0
FALSE = 2 # normal
TRUE = 1 # abnormal


def get_text(c, name=None):
    if name is None:
        text = " ".join([_.text for _ in c[0].context.document.sentences])
    else:
        text = " ".join(
            [_.text for _ in c[0].context.document.sentences if _.name == name]
        )
    return text


def LF_report_is_short(c):
    text = get_text(c)
    return FALSE if len(text) < 280 else ABSTAIN


negative_inflection_words = ["but", "however", "otherwise"]


def LF_negative_inflection_words_in_report(c):
    text = get_text(c)
    return (
        TRUE
        if any(word in text.lower() for word in negative_inflection_words)
        else ABSTAIN
    )


def LF_is_seen_or_noted_in_report(c):
    text = get_text(c)
    return (
        TRUE if any(word in text.lower() for word in ["is seen", "noted"]) else ABSTAIN
    )

def LF_disease_in_report(c):
    text = get_text(c)
    return TRUE if "disease" in text.lower() else ABSTAIN


def LF_recommend_in_report(c):
    text = get_text(c)
    return TRUE if "recommend" in text.lower() else ABSTAIN


def LF_mm_in_report(c):
    text = get_text(c)
    return TRUE if any(word in text.lower() for word in ["mm", "cm"]) else ABSTAIN


abnormal_mesh_terms = [
    "opacity",
    "cardiomegaly",
    "calcinosis",
    "hypoinflation",
    "calcified granuloma",
    "thoracic vertebrae",
    "degenerative",
    "hyperdistention",
    "catheters",
    "granulomatous",
    "nodule",
    "fracture" "surgical",
    "instruments",
    "emphysema",
]


def LF_abnormal_mesh_terms_in_report(c):
    text = get_text(c)
    if any(mesh in text.lower() for mesh in abnormal_mesh_terms):
        return TRUE
    else:
        return ABSTAIN


words_indicating_normalcy = ["clear", "no", "normal", "unremarkable", "free", "midline"]

def LF_consistency_in_report(c):
    """
    The words 'clear', 'no', 'normal', 'free', 'midline' in
    findings section of the report
    """

    #     report = c.report_text.text
    #     findings = report[report.find('FINDINGS:'):]
    #     findings = findings[:findings.find('IMPRESSION:')]
    findings = get_text(c, name="FINDINGS")
    sents = findings.split(".")

    num_sents_without_normal = 0
    for sent in sents:
        sent = sent.lower()
        if not any(word in sent for word in words_indicating_normalcy):
            num_sents_without_normal += 1
        elif "not" in sent:
            num_sents_without_normal += 1
    return FALSE if num_sents_without_normal < 2 else TRUE


categories = [
    "normal",
    "opacity",
    "cardiomegaly",
    "calcinosis",
    "lung/hypoinflation",
    "calcified granuloma",
    "thoracic vertebrae/degenerative",
    "lung/hyperdistention",
    "spine/degenerative",
    "catheters, indwelling",
    "granulomatous disease",
    "nodule",
    "surgical instruments",
    "scoliosis",
    "osteophyte",
    "spondylosis",
    "fractures, bone",
]


def LF_normal(report):
    r = re.compile("No acute cardiopulmonary abnormality", re.IGNORECASE)
    text = get_text(report)
    for s in text.split("."):
        if r.search(s):
            return FALSE
    return ABSTAIN

reg_equivocation = re.compile(
    (
        r"unlikely|likely|suggests|questionable|concerning|possibly|potentially|"
        r"could represent|may represent|may relate|cannot exclude|can't exclude|may be"
    ),
    re.IGNORECASE,
)


def LF_positive_MeshTerm(report):
    text = get_text(report)
    for idx in range(1, len(categories)):
        reg_pos = re.compile(categories[idx], re.IGNORECASE)
        reg_neg = re.compile(
            r"(No|without|resolution)\\s([a-zA-Z0-9\-,_]*\\s){0,10}" + categories[idx],
            re.IGNORECASE,
        )
        for s in text.split("."):
            if (
                reg_pos.search(s)
                and (not reg_neg.search(s))
                and (not reg_equivocation.search(s))
            ):
                return TRUE
    return ABSTAIN


def LF_fracture(report):
    text = get_text(report)
    reg_pos = re.compile("fracture", re.IGNORECASE)
    reg_neg = re.compile(
        r"(No|without|resolution)\\s([a-zA-Z0-9\-,_]*\\s){0,10}fracture", re.IGNORECASE
    )
    for s in text.split("."):
        if (
            reg_pos.search(s)
            and (not reg_neg.search(s))
            and (not reg_equivocation.search(s))
        ):
            return TRUE
    return ABSTAIN

def LF_calcinosis(report):
    text = get_text(report)
    reg_01 = re.compile("calc", re.IGNORECASE)
    reg_02 = re.compile("arter|aorta|muscle|tissue", re.IGNORECASE)
    for s in text.split("."):
        if reg_01.search(s) and reg_02.search(s):
            return TRUE
    return ABSTAIN


def LF_degen_spine(report):
    text = get_text(report)
    reg_01 = re.compile("degen", re.IGNORECASE)
    reg_02 = re.compile("spine", re.IGNORECASE)
    for s in text.split("."):
        if reg_01.search(s) and reg_02.search(s):
            return TRUE
    return ABSTAIN


def LF_lung_hypoinflation(report):
    text = get_text(report)
    # reg_01 = re.compile('lung|pulmonary',re.IGNORECASE)
    reg_01 = re.compile(
        (
            r"hypoinflation|collapse|(low|decrease|diminish)\\s"
            r"([a-zA-Z0-9\-,_]*\\s){0,4}volume"
        ),
        re.IGNORECASE,
    )
    for s in text.split("."):
        if reg_01.search(s):
            return TRUE
    return ABSTAIN

def LF_lung_hyperdistention(report):
    text = get_text(report)
    # reg_01 = re.compile('lung|pulmonary',re.IGNORECASE)
    reg_01 = re.compile("increased volume|hyperexpan|inflated", re.IGNORECASE)
    for s in text.split("."):
        if reg_01.search(s):
            return TRUE
    return ABSTAIN


def LF_catheters(report):
    text = get_text(report)
    reg_01 = re.compile(" line|catheter|PICC", re.IGNORECASE)
    for s in text.split("."):
        if reg_01.search(s):
            return TRUE
    return ABSTAIN


def LF_surgical(report):
    text = get_text(report)
    reg_01 = re.compile("clip", re.IGNORECASE)
    for s in text.split("."):
        if reg_01.search(s):
            return TRUE
    return ABSTAIN


def LF_granuloma(report):
    text = get_text(report)
    reg_01 = re.compile("granuloma", re.IGNORECASE)
    for s in text.split("."):
        if reg_01.search(s):
            return TRUE
    return ABSTAIN

In [None]:

from metal.analysis import single_lf_summary, confusion_matrix

# Testing single LF
lf_test = lf_impression_section_positive

# Computing labels
Y_lf = np.array([lf_test(doc) for doc in dev_docs])
single_lf_summary(Y_lf, Y=Y_dev)

In [None]:
# Print confusion matrix
conf = confusion_matrix(Y_dev, Y_lf)

In [None]:
LFs = [
    LF_report_is_short,
    LF_consistency_in_report,
    LF_negative_inflection_words_in_report,
    LF_is_seen_or_noted_in_report,
    LF_disease_in_report,
    LF_abnormal_mesh_terms_in_report,
    LF_recommend_in_report,
    LF_mm_in_report,
    LF_normal,
    LF_positive_MeshTerm,
    LF_fracture,
    LF_calcinosis,
    LF_degen_spine,
    LF_lung_hypoinflation,
    LF_lung_hyperdistention,
    LF_catheters,
    LF_surgical,
    LF_granuloma,
]

In [None]:
from scipy.sparse import csr_matrix
import dask
from dask.diagnostics import ProgressBar
from eeg_utils import evaluate_lf_on_docs, create_label_matrix
import pickle

# Resetting LFs
clobber_lfs = True
Ls_file = 'Ls_0p3.pkl'
Ys_file = 'Ys_0p3.pkl'

# Get lf names
lf_names = [lf.__name__ for lf in lfs]

# Loading Ls if they exist

Ls = []
Ys = []
if clobber_lfs or (not os.path.exists(Ls_file)):
    print('Computing label matrices...')
    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        Ls.append(create_label_matrix(lfs,docs))  
    with open(Ls_file,'wb') as af:
        pickle.dump(Ls, af)
    
    print('Creating label vectors...')
    Ys = [[],Y_dev, Y_test]
    with open(Ys_file,'wb') as af:
        pickle.dump(Ls, af)
else:
    print('Loading pre-computed label matrices...')
    with open(Ls_file,'rb') as af:
        Ls=pickle.load(af) 
        

# Create label matrices
#Ls = []
#for i, docs in enumerate([train_docs, dev_docs, test_docs]):
#    Ls.append(create_label_matrix(lfs,docs)) 
    
# Create Ys
Ys = [[], Y_dev, Y_test]

In [None]:
from metal.analysis import lf_summary

# Analyzing LF stats
df_lf = lf_summary(Ls[1], Y=Y_dev, lf_names=lf_names)
df_lf

In [None]:
from  metal.contrib.visualization.analysis import view_conflicts

# Viewing conflicts
view_conflicts(Ls[1], normalize=True)

In [None]:
from metal.label_model import LabelModel
from metal.utils import LogWriter
from metal.tuners import RandomSearchTuner

# Creating metal label model
#label_model = LabelModel(k=2, seed=123)

# Creating search space
search_space = {
        'l2': {'range': [0.0001, 0.1], 'scale':'log'},           # linear range
        'lr': {'range': [0.0001, 0.01], 'scale': 'log'},  # log range
        }

searcher = RandomSearchTuner(LabelModel, log_dir='./run_logs',
               log_writer_class=None)

In [None]:
%%time
# Training label model
label_model = searcher.search(search_space, (Ls[1],Ys[1]), \
        train_args=[Ls[0]], init_args=[],
        init_kwargs={'k':2, 'seed':123}, train_kwargs={'n_epochs':100},
        max_search=20)

In [None]:
# Saving best model
searcher._save_best_model(label_model)

In [None]:
# Getting scores
scores = label_model.score((Ls[1], Ys[1]), metric=['accuracy','precision', 'recall', 'f1'])

In [None]:
from metal.label_model.baselines import MajorityLabelVoter

# Checking if we beat majority vote
mv = MajorityLabelVoter(seed=123)
scores = mv.score((Ls[1], Ys[1]), metric=['accuracy', 'precision', 'recall', 'f1'])

In [None]:
# Getting probabilistic training labels
# Y_train_ps stands for "Y[labels]_train[split]_p[redicted]s[oft]"
Y_train_ps = label_model.predict_proba(Ls[0])
Y_dev_ps = label_model.predict_proba(Ls[1])
Y_test_ps = label_model.predict_proba(Ls[2])
Y_ps = [Y_train_ps, Y_dev_ps, Y_test_ps]

In [None]:
# Running some analysis 
from metal.contrib.visualization.analysis import plot_predictions_histogram
Y_dev_p = label_model.predict(Ls[1])
plot_predictions_histogram(Y_dev_p, Ys[1], title="Label Distribution")

In [None]:
from  metal.contrib.visualization.analysis  import plot_probabilities_histogram

# Looking at probability histogram for training labels
plot_probabilities_histogram(Y_dev_ps[:,0], title="Probablistic Label Distribution")