In [1]:
import sys
import os
import json

exlib_path = "/shared_data0/chaenyk/exlib"
src_path = os.path.join(exlib_path, "src")
sys.path.insert(0, src_path)
import pyarrow as pa
import pyarrow_hotfix
import torch
import yaml
import argparse
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from collections import namedtuple
import matplotlib.pyplot as plt

import exlib
from tqdm.auto import tqdm
pa.PyExtensionType.set_auto_load(True)
pyarrow_hotfix.uninstall()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

from sepsis import Timeseries, get_llm_generated_answer, isolate_individual_features, is_claim_relevant, distill_relevant_features, calculate_expert_alignment_score

In [2]:
data_path = '/shared_data0/chaenyk/arpa_h/tfix/arpah_textified.json'

In [52]:
import json

print("Loading data...")
data = []
with open(data_path, "r") as f:
    for line in f: 
        data.append(json.loads(line.strip()))

filtered_data = [p for p in data if p.get('event_desc', '').count(':') > 10]

Loading data...


In [53]:
filtered_data

[{'label': [False, False, False, True, True],
  'event_desc': 'Gender/Male; AGE/60-80; 0.02: CREATININE, 0.65; 0.02: HEMOGLOBIN, 7.3; 0.02: MEAN CELLULAR HEMOGLOBIN, 32.0; 0.02: MEAN CELLULAR HEMOGLOBIN CONCENTRATION, 34.0; 0.02: WBC, 0.7; 0.02: SODIUM, 133.0; 0.02: # BAND NEUTROPHILS, 0.0; 0.02: POTASSIUM, 3.7; 0.02: # NEUTROPHILS MANUAL, 0.24; 0.02: PLATELETS, 17.0; 0.02: % NEUTROPHILS MANUAL, 34.8; 0.02: UREA NITROGEN, 11.0; 0.02: % BAND NEUTROPHILS, 0.0'},
 {'label': [False, False, False, False, True],
  'event_desc': 'Gender/Male; AGE/20-40; 0.07: propofol 1000 mg in 100 mL infusion premix, 20.0; 0.07: propofol 1000 mg in 100 mL infusion premix, 20.0; 0.17: RESPIRATIONS, 16.0; 0.17: PULSE OXIMETRY, 100.0; 0.17: PULSE, 70.0; 0.17: R OR POSTOP FIO2, 100.0; 0.18: R OR GLASGOW COMA SCALE SCORE, 15.0; 0.18: R OR GLASGOW COMA SCALE BEST MOTOR RESPONSE, 6.0; 0.18: R OR GLASGOW COMA SCALE BEST VERBAL RESPONSE, 5.0; 0.18: R OR GLASGOW COMA SCALE EYE OPENING, 4.0; 0.25: R IP FIO2, 100.0'},


In [54]:
data = filtered_data[30]
data

{'label': [False, False, True, True, True],
 'event_desc': 'Gender/Male; AGE/40-60; 0.0: GLUCOSE POINT OF CARE, 336.0; 0.02: WEIGHT/SCALE, 5440.0; 0.02: R IP IDEAL BODY WEIGHT, 68.4; 0.02: PULSE OXIMETRY, 92.0; 0.02: HEIGHT, 68.0; 0.02: PULSE, 84.0; 0.02: R AN ADJUSTED BODY WEIGHT, 102.73; 0.02: TEMPERATURE, 99.6; 0.02: RESPIRATIONS, 22.0; 0.05: R OR GLASGOW COMA SCALE SCORE, 15.0; 0.05: R OR GLASGOW COMA SCALE BEST MOTOR RESPONSE, 6.0; 0.05: R OR GLASGOW COMA SCALE EYE OPENING, 4.0; 0.05: R OR GLASGOW COMA SCALE BEST VERBAL RESPONSE, 5.0; 0.27: HEMOGLOBIN, 10.4; 0.27: MEAN CELLULAR HEMOGLOBIN CONCENTRATION, 32.0; 0.27: MEAN CELLULAR HEMOGLOBIN, 27.0; 0.27: # NEUTROPHILS, 16.1; 0.27: WBC, 18.7; 0.27: % NEUTROPHILS, 86.3; 0.27: PLATELETS, 276.0'}

In [55]:
parts = data['event_desc'].split('; ')

for i in range(len(parts)):
    part = parts[i]
    if not part[0].isdigit():
        part = part.replace('/', ', ')
        part = '0.0: ' + part
    parts[i] = part

data['event_desc'] = '; '.join(parts)
print(data['event_desc'])

0.0: Gender, Male; 0.0: AGE, 40-60; 0.0: GLUCOSE POINT OF CARE, 336.0; 0.02: WEIGHT/SCALE, 5440.0; 0.02: R IP IDEAL BODY WEIGHT, 68.4; 0.02: PULSE OXIMETRY, 92.0; 0.02: HEIGHT, 68.0; 0.02: PULSE, 84.0; 0.02: R AN ADJUSTED BODY WEIGHT, 102.73; 0.02: TEMPERATURE, 99.6; 0.02: RESPIRATIONS, 22.0; 0.05: R OR GLASGOW COMA SCALE SCORE, 15.0; 0.05: R OR GLASGOW COMA SCALE BEST MOTOR RESPONSE, 6.0; 0.05: R OR GLASGOW COMA SCALE EYE OPENING, 4.0; 0.05: R OR GLASGOW COMA SCALE BEST VERBAL RESPONSE, 5.0; 0.27: HEMOGLOBIN, 10.4; 0.27: MEAN CELLULAR HEMOGLOBIN CONCENTRATION, 32.0; 0.27: MEAN CELLULAR HEMOGLOBIN, 27.0; 0.27: # NEUTROPHILS, 16.1; 0.27: WBC, 18.7; 0.27: % NEUTROPHILS, 86.3; 0.27: PLATELETS, 276.0


In [56]:
input = data['event_desc']
entries = input.split(';')

time = []
measure = []
value = []

for entry in entries:
    entry = entry.strip()
    if not entry:
        continue
    t_part, rest = entry.split(':', 1)
    m_part, v_part = rest.split(',', 1)
    time.append(t_part.strip())
    measure.append(m_part.strip())
    value.append(v_part.strip())

In [57]:
# ### Extract Explanation
# example = Timeseries(time=time, measurement=measure, value=value)
# analysis_result = get_llm_generated_answer(example=example)
# print(analysis_result)

In [31]:
analysis_result = "Prediction: No\n\nExplanation:The patient is a male aged 20-40 years old. The initial data shows that the patient was given acetaminophen, diphenhydramine, and metoclopramide. The Glasgow Coma Scale (GCS) score was 15, indicating a normal level of consciousness. The patient's respiratory rate was 23 breaths per minute, pulse oximetry was 96%, pulse was 114 beats per minute, and heart rate was 112 beats per minute. Later, the pulse oximetry dropped to 94%, heart rate decreased to 108 beats per minute, pulse decreased to 110 beats per minute, and respiratory rate increased to 36 breaths per minute.Based on the provided data, there are no significant changes in neurological status or disproportionately severe symptoms compared to initial vital signs. Therefore, the answer is no, the patient is not at high risk of developing sepsis within the next 12 hours."

In [58]:
analysis_result = "Prediction: Yes\n\nExplanation: The patient exhibits several risk factors and early warning signs that suggest a high risk of developing sepsis within the next 12 hours. Firstly, the patient's age is 50, which is a moderate risk factor for sepsis. The glucose level is significantly elevated at 336 mg/dL, which can indicate stress or infection. The Glasgow Coma Scale (GCS) score is 15, which is normal, but the individual components show some concern: the best motor response is 6, eye opening is 4, and verbal response is 5, suggesting potential neurological changes. The respiratory rate is 22, which is on the higher side, and the pulse oximetry is 92%, indicating possible hypoxia. The platelet count is 276,000/µL, which is within normal range, but the white blood cell count is elevated at 18.7, suggesting an inflammatory or infectious process. The temperature is 99.6°F, close to the fever threshold. These factors, combined with the elevated glucose and potential hypoxia, suggest a heightened risk of sepsis, warranting close monitoring and further investigation for infection."

In [27]:
### Make it to atomic claims
raw_atomic_claims = isolate_individual_features(analysis_result)

print("Extracted atomic claims:")
for i, claim in enumerate(raw_atomic_claims, 1):
    print(f"{i}. {claim}")

Extracted atomic claims:
2. The patient's age is 50, which is a moderate risk factor for sepsis.
3. The glucose level is significantly elevated at 336 mg/dL, which can indicate stress or infection.
4. The Glasgow Coma Scale (GCS) score is 15, which is normal.
5. The individual components of the GCS show some concern: the best motor response is 6, eye opening is 4, and verbal response is 5, suggesting potential neurological changes.
6. The respiratory rate is 22, which is on the higher side.
7. The pulse oximetry is 92%, indicating possible hypoxia.
8. The platelet count is 276,000/µL, which is within normal range.
9. The white blood cell count is elevated at 18.7, suggesting an inflammatory or infectious process.
10. The temperature is 99.6°F, close to the fever threshold.
11. These factors, combined with the elevated glucose and potential hypoxia, suggest a heightened risk of sepsis.
12. Close monitoring and further investigation for infection are warranted.


In [28]:
### Relevant claims
example = Timeseries(time=time, measurement=measure, value=value)
answer = 'Yes'
claim_relevances = [is_claim_relevant(example, answer, raw_atomic_claim, 0.8) \
    for raw_atomic_claim in raw_atomic_claims]
claim_relevances

[False, False, True, False, False, True, True, False, True, True, True, False]

In [29]:
### Distilled relevant claims
example = Timeseries(time=time, measurement=measure, value=value)
answer = 'Yes'

distilled_claims = distill_relevant_features(example, answer, raw_atomic_claims, 0.8)

print("Distilled relevant claims:")
for i, claim in enumerate(distilled_claims, 1):
    print(f"{i}. {claim}")

Distilled relevant claims:
1. The glucose level is significantly elevated at 336 mg/dL, which can indicate stress or infection.
2. The respiratory rate is 22, which is on the higher side.
3. The pulse oximetry is 92%, indicating possible hypoxia.
4. The white blood cell count is elevated at 18.7, suggesting an inflammatory or infectious process.
5. The temperature is 99.6°F, close to the fever threshold.
6. These factors, combined with the elevated glucose and potential hypoxia, suggest a heightened risk of sepsis.


In [30]:
### Compute alignment scores
scores = calculate_expert_alignment_score(distilled_claims)
scores

{'alignment_scores': [{'claim': 'The glucose level is significantly elevated at 336 mg/dL, which can indicate stress or infection.',
   'score': 2},
  {'claim': 'The respiratory rate is 22, which is on the higher side.',
   'score': 3},
  {'claim': 'The pulse oximetry is 92%, indicating possible hypoxia.',
   'score': 5},
  {'claim': 'The white blood cell count is elevated at 18.7, suggesting an inflammatory or infectious process.',
   'score': 3},
  {'claim': 'The temperature is 99.6°F, close to the fever threshold.',
   'score': 5},
  {'claim': 'These factors, combined with the elevated glucose and potential hypoxia, suggest a heightened risk of sepsis.',
   'score': 3}],
 'total_score': 3.5}