
<img src="https://github.com/abchapman93/DELPHI_Intro_to_NLP_Spring_2024/blob/main/media/DELPHI-long.png?raw=true" size="20%">
</br>

<h1 valign="center" align="center"><font size="+150">Introduction to NLP in Python</br>Spring 2024</font></h1>

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.insert(0, "..")

from delphi_nlp_2024 import *
from delphi_nlp_2024.quizzes.quizzes import *
from delphi_nlp_2024.helpers import *

import medspacy_pna

In [3]:
import medspacy
from IPython.display import Image

In [4]:
from medspacy.visualization import visualize_dep, visualize_ent, MedspaCyVisualizerWidget

In [5]:
import pandas as pd

In [6]:
df = pd.read_csv("../data/mimic2_pneumonia_corpus.csv")

In [7]:
df.head()

Unnamed: 0.1,Unnamed: 0,subject_id,hadm_id,admit_dt,Pneumonia,text
0,5,37,18052,3264-08-14 00:00:00,1,\n\n\n DATE: [**3264-8-14**] 10:57 AM\n ...
1,14,94,8743,2656-08-18 00:00:00,1,\n\n\n DATE: [**2656-8-19**] 4:17 PM\n ...
2,10,117,14296,3131-11-27 00:00:00,1,\n\n\n DATE: [**3131-11-28**] 1:30 PM\n ...
3,19,184,203,3251-04-30 00:00:00,1,\n\n\n DATE: [**3251-5-1**] 3:18 PM\n ...
4,18,184,17249,3251-03-19 00:00:00,1,\n\n\n DATE: [**3251-3-19**] 3:18 PM\n ...


In [8]:
len(df)

200

In [9]:
nlp = medspacy_pna.build_nlp("radiology")



In [10]:
from medspacy_pna.document_classification.document_classifier import BaseDocumentClassifier

DEFAULT_SCHEMA = "linked"

TARGET_CLASSES = {"PNEUMONIA", "CONSOLIDATION", "INFILTRATE", "OPACITY"}

ENTITY_ATTRIBUTES = {
        "is_negated": False,
        "is_hypothetical": False,
        "is_historical": False,
        "is_family": False,
        "is_uncertain": True, # The only change I'm making here is that is_uncertain will be excluded
        "is_ignored": False
}

TIER_1_CLASSES = {
    "PNEUMONIA",
    "CONSOLIDATION",
}

TIER_2_CLASSES = {
    "INFILTRATE",
    "OPACITY",
}

ALTERNATE_DIAGNOSES = {
    "ATELECTASIS",
    "PULMONARY_EDEMA",
    # "SOFT_TISSUE_ATTENUATION",
    # "PLEURAL_EFFUSION",
    # "EMPHYSEMA",
    "INTERSTITIAL_LUNG_DISEASE",
    "FIBROSIS",
}

LINK_PHRASES = [
    "may represent",
    "may be",
    "may be related to",
    "related to",
    "r/t",
    "likely",
    "likely representing",
    "likely represents",
    "consistent with",
    "compatible with",
    "c/w",
    "suggest",
    "may represent",
    "associated",
    "comptaible",
    "due to",
    "worrisome for",
    "suspicious for",
    "secondary to",
    "suggesting",
    "suggests",
]

class MimicRadiologyDocumentClassifier(BaseDocumentClassifier):
    domain = "radiology"
    schemas = (
        "full",
        "attributes", "linked", "keywords")

    def __init__(self, nlp, name="pneumonia_radiologydocumentclassifier", classification_schema=None):
        self.nlp = nlp
        self.name = name
        if classification_schema is None:
            classification_schema = DEFAULT_SCHEMA
        super().__init__(classification_schema=classification_schema)

    @property
    def relevant_classes(self):
        return TARGET_CLASSES.union(ALTERNATE_DIAGNOSES)

    @property
    def target_classes(self):
        return TARGET_CLASSES

    def is_relevant_class(self, label):
        return label in self.relevant_classes

    def link_evidence(self, doc):
        for ent in doc.ents:
            ent._.linked_ents = tuple()
        for (ent, modifier) in doc._.context_graph.edges:
            if ent.label_ in ALTERNATE_DIAGNOSES and modifier.span.text.lower() in LINK_PHRASES:
                # print(ent, modifier)
                sent = ent.sent
                span = doc[sent.start:ent.start]
                other_ents = span.ents
                for other in other_ents:
                    if other.label_ in TIER_2_CLASSES:
                        ent._.linked_ents += (other,)
                        other._.linked_ents += (ent,)

    def gather_ent_data(self, doc, link_ents=False):
        asserted_ent_labels = set()
        uncertain_ent_labels = set()
        negated_ent_labels = set()
        if link_ents:
            self.link_evidence(doc)
        for ent in doc.ents:
            if ent.label_ not in self.relevant_classes:

                continue
            is_excluded = False

            # Check if any of the attributes don't match required values (ie., is_negated == True)
            for (attr, req_value) in ENTITY_ATTRIBUTES.items():
                # This entity won't count as positive evidence, move onto the next one
                if getattr(ent._, attr) != req_value:
                    # print(ent, attr)
                    is_excluded = True
                    # print("Excluding", ent)
                    # print(attr, getattr(ent._, attr))
                    break
            # TODO: this is an additional piece of logic around alternate dx, should maybe go somewhere else
            if not is_excluded:
                if link_ents and ent.label_ in TIER_2_CLASSES and len(ent._.linked_ents):

                    is_excluded = True
            if not is_excluded:
                # print(ent)
                if ent._.is_uncertain:
                    uncertain_ent_labels.add(ent.label_)
                else:
                    asserted_ent_labels.add(ent.label_)
                    # print(ent, ent.sent, ent._.modifiers)

            elif ent._.is_negated:
                negated_ent_labels.add(ent.label_)
        return {
            "asserted": asserted_ent_labels,
            "uncertain": uncertain_ent_labels,
            "negated": negated_ent_labels
        }

    def classify_document_keywords(self, doc):
        """Classify based *only* on the presence of target entity labels."""
        ent_data = self.gather_ent_data(doc, link_ents=False)
        ent_labels = set()
        for (_, sub_ent_labels) in ent_data.items():
            ent_labels.update(sub_ent_labels)
        if ent_labels.intersection(TARGET_CLASSES):
            return "POS"
        return "NEG"

    def classify_document_attributes(self, doc, link_ents=False):
        ent_data = self.gather_ent_data(doc, link_ents=link_ents)
        if self.debug:
            print(ent_data)
        # print(ent_data)
        asserted_ent_labels = ent_data["asserted"]
        uncertain_ent_labels = ent_data["uncertain"]
        negated_ent_labels = ent_data["negated"]

        # print(negated_ent_labels)
        # print(asserted_ent_labels)

        if 0 == 1:
            pass
        # NOTE 9/27: If there is an uncertain Tier 2, bump up to Positive
        elif uncertain_ent_labels.intersection(TIER_1_CLASSES) and asserted_ent_labels.intersection(TIER_2_CLASSES):
            document_classification = "POS"
        # 9/27: prioritize possible over positive
        elif uncertain_ent_labels.intersection(TIER_1_CLASSES):
            document_classification = "POSSIBLE"
        elif asserted_ent_labels.intersection(TIER_1_CLASSES):
            document_classification = "POS"


        elif negated_ent_labels.intersection(TIER_1_CLASSES):
            document_classification = "NEG"
        elif asserted_ent_labels.intersection(TIER_2_CLASSES):
            document_classification = "POS"
        elif uncertain_ent_labels.intersection(TIER_2_CLASSES):
            document_classification = "POSSIBLE"
        else:
            document_classification = "NEG"
        return document_classification

    def classify_document_radiology_full(self, doc):
        """
        Radiology logic:
            1. Look for asserted (+ or ?) Tier 1 Evidence --> POS/POSSIBLE
            2. Look for negated (-) Tier 1 Evidence --> NEG
            3. Look for asserted (+ or ?) alternate diagnosis --> NEG
            4. Look for asserted Tier 2 evidence --> POS/POSSIBLE
            5. If nothing, return Neg --> NEG
        """
        # raise NotImplementedError("Need to sync with attribute classification")
        ent_data = self.gather_ent_data(doc, link_ents=False)
        asserted_ent_labels = ent_data["asserted"]
        uncertain_ent_labels = ent_data["uncertain"]
        negated_ent_labels = ent_data["negated"]

        if asserted_ent_labels.intersection(TIER_1_CLASSES):
            document_classification = "POS"
        # NOTE 9/27: If there is an uncertain Tier 2, bump up to Positive
        elif uncertain_ent_labels.intersection(TIER_1_CLASSES) and asserted_ent_labels.intersection(TIER_2_CLASSES):
            document_classification = "POS"
        elif uncertain_ent_labels.intersection(TIER_1_CLASSES):
            document_classification = "POSSIBLE"
        elif negated_ent_labels.intersection(TIER_1_CLASSES):
            document_classification = "NEG"
        elif asserted_ent_labels.union(uncertain_ent_labels).intersection(ALTERNATE_DIAGNOSES):
            document_classification = "NEG"
        elif asserted_ent_labels.intersection(TIER_2_CLASSES):
            document_classification = "POS"
        elif uncertain_ent_labels.intersection(TIER_2_CLASSES):
            document_classification = "POSSIBLE"
        else:
            document_classification = "NEG"
        return document_classification

    def classify_document_radiology_linked(self, doc):
        """
        """
        return self.classify_document_attributes(doc, link_ents=True)


    def _classify_document(self, doc, classification_schema=None, **kwargs):
        if classification_schema is None:
            classification_schema = self.classification_schema
        if classification_schema == "full":
            return self.classify_document_radiology_full(doc,)
        elif classification_schema == "keywords":
            return self.classify_document_keywords(doc)
        elif classification_schema == "attributes":
            return self.classify_document_attributes(doc)
        elif classification_schema == "linked":
            return self.classify_document_radiology_linked(doc)
        else:
            raise ValueError("Invalid classification_schema", classification_schema)

In [37]:
clf = MimicRadiologyDocumentClassifier(nlp, classification_schema="attributes")

In [38]:
df["doc"] = list(nlp.pipe(df["text"], disable=["pneumonia_radiologydocumentclassifier"]))

  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches 

  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches 

  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches 

  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches 

  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches 

  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)
  matches = self.matcher(doc)


In [42]:
for doc in df["doc"]:
    clf(doc)

In [44]:
df["baseline_document_classification"] = [1*(doc._.document_classification != "NEG") for doc in df["doc"]]


In [45]:
df["baseline_document_classification"].value_counts()

baseline_document_classification
0    155
1     45
Name: count, dtype: int64

In [46]:
def create_html(row):
    html = f"<h1>Subject ID: {row['subject_id']}</h1>"
    html += f"<h1>Hospital Admission ID: {row['hadm_id']}</h1>"
    html += f"<h2>True classification: {row['document_classification']}</h2>"
    html += f"<h2>Baseline NLP classification: {row['baseline_document_classification']}</h2>"
    html += visualize_ent(row["doc"], jupyter=False, sections=True, context=True)
    return html

In [47]:
df = df.rename({"Pneumonia": "document_classification"}, axis=1)

In [48]:
df = df.sample(frac=1.0).reset_index(drop=True)

In [49]:
n_train = int(0.8*200)

In [50]:
n_train

160

In [51]:
df["html"] = df.apply(create_html, axis=1)

In [52]:
train = df.iloc[:n_train].reset_index(drop=True)
test = df.iloc[n_train:].reset_index(drop=True)

In [53]:
ppv = test.query("baseline_document_classification == 1")["document_classification"].mean()
sens = test.query("document_classification == 1")["baseline_document_classification"].mean()
f1 = 2*(ppv*sens)/(ppv+sens)

In [54]:
(ppv, sens, f1)

(0.75, 0.21428571428571427, 0.3333333333333333)

Create HTML's

In [58]:
nlp("Findings: opacities")._.document_classification

  matches = self.matcher(doc)


'POS'

In [25]:
from medspacy.visualization import visualize_ent

In [26]:
from IPython.display import HTML

In [65]:
errors = train[train["document_classification"] == train["baseline_document_classification"]]


In [70]:
row = errors.iloc[6]
HTML(row["html"])

In [68]:
doc = row["doc"]
doc._.document_classification

'NEG'

In [163]:
train["split"] = "train"
test["split"] = "test"
df_out = pd.concat([train, test])

In [164]:
df_out

Unnamed: 0.1,Unnamed: 0,subject_id,hadm_id,admit_dt,document_classification,text,doc,nlp_document_classification,html,baseline_document_classification,split
0,27,286,18885,3081-12-18 00:00:00,1,\n\n\n DATE: [**3081-12-18**] 10:43 PM\n ...,"(\n\n\n , DATE, :, [, *, *, 3081, -, 12, -...",1,<h1>Subject ID: 286</h1><h1>Hospital Admission...,0,train
1,627,5587,10741,3469-09-13 00:00:00,1,\n\n\n DATE: [**3469-9-13**] 6:54 PM\n ...,"(\n\n\n , DATE, :, [, *, *, 3469, -, 9, -,...",0,<h1>Subject ID: 5587</h1><h1>Hospital Admissio...,0,train
2,1012,9090,11421,2854-10-11 00:00:00,1,\n\n\n DATE: [**2854-10-11**] 2:11 PM\n ...,"(\n\n\n , DATE, :, [, *, *, 2854, -, 10, -...",0,<h1>Subject ID: 9090</h1><h1>Hospital Admissio...,0,train
3,38,353,543,3095-07-27 00:00:00,0,\n\n\n DATE: [**3095-8-13**] 4:40 PM\n ...,"(\n\n\n , DATE, :, [, *, *, 3095, -, 8, -,...",0,<h1>Subject ID: 353</h1><h1>Hospital Admission...,0,train
4,382,3386,15597,2577-12-22 00:00:00,1,\n\n\n DATE: [**2577-12-22**] 8:02 PM\n ...,"(\n\n\n , DATE, :, [, *, *, 2577, -, 12, -...",1,<h1>Subject ID: 3386</h1><h1>Hospital Admissio...,1,train
...,...,...,...,...,...,...,...,...,...,...,...
35,42,356,20183,3233-10-02 00:00:00,1,\n\n\n DATE: [**3233-10-10**] 11:41 AM\n ...,"(\n\n\n , DATE, :, [, *, *, 3233, -, 10, -...",1,<h1>Subject ID: 356</h1><h1>Hospital Admission...,0,test
36,1,3,2075,2682-09-07 00:00:00,0,\n\n\n DATE: [**2682-9-10**] 5:22 AM\n ...,"(\n\n\n , DATE, :, [, *, *, 2682, -, 9, -,...",0,<h1>Subject ID: 3</h1><h1>Hospital Admission I...,0,test
37,329,2953,26494,3262-09-28 00:00:00,1,\n\n\n DATE: [**3262-10-12**] 2:03 PM\n ...,"(\n\n\n , DATE, :, [, *, *, 3262, -, 10, -...",0,<h1>Subject ID: 2953</h1><h1>Hospital Admissio...,0,test
38,364,3248,3881,2785-11-10 00:00:00,1,\n\n\n DATE: [**2785-11-10**] 10:16 AM\n ...,"(\n\n\n , DATE, :, [, *, *, 2785, -, 11, -...",0,<h1>Subject ID: 3248</h1><h1>Hospital Admissio...,0,test


In [165]:
df_out = df_out[
    ["subject_id", "hadm_id", "text", "document_classification", "split",
     "baseline_document_classification", "html"]
].copy()


In [166]:
df_out.to_pickle("../data/pneumonia_data/pneumonia.pkl")

In [167]:
df_out.query("split == 'train'").to_pickle("../data/pneumonia_data/train.pkl")

In [168]:
df_out.query("split == 'test'").to_pickle("../data/pneumonia_data/test.pkl")


In [169]:
from sklearn.metrics import classification_report

In [170]:
print(classification_report(train["document_classification"], train["baseline_document_classification"]))

              precision    recall  f1-score   support

           0       0.31      0.79      0.44        48
           1       0.72      0.23      0.35       112

    accuracy                           0.40       160
   macro avg       0.51      0.51      0.40       160
weighted avg       0.60      0.40      0.38       160



In [171]:
print(classification_report(test["document_classification"], test["baseline_document_classification"]))

              precision    recall  f1-score   support

           0       0.31      0.83      0.45        12
           1       0.75      0.21      0.33        28

    accuracy                           0.40        40
   macro avg       0.53      0.52      0.39        40
weighted avg       0.62      0.40      0.37        40

