In [1]:
import pandas as pd
import numpy as np
from datasets import load_dataset
import openai
from openai import OpenAI
import time
from tqdm import tqdm

import importlib
import sys; sys.path.append("../src")
import sepsis
importlib.reload(sepsis)
from sepsis import SepsisExample, get_llm_generated_answer, isolate_individual_features, distill_relevant_features, calculate_expert_alignment_score, parse_measurement_string, format_time_series_for_prompt, query_openai

### Load Sepsis Data

In [2]:
sepsis_data = load_dataset("BrachioLab/mcmed-sepsis")
sepsis_data = sepsis_data['test'].to_pandas()
sepsis_data = sepsis_data.sample(3, random_state=11).reset_index(drop=True)

### Stage 0: Get LLM Explanations

In [3]:
sepsis_examples = []
for idx, row in tqdm(sepsis_data.iterrows(), total=sepsis_data.shape[0], desc="Processing Rows"):
    time_series_string = row['data']
    time_series_dict = parse_measurement_string(time_series_string)
    llm_label, explanation = get_llm_generated_answer(time_series_dict)
    if llm_label is None:
        continue
    sepsis_examples.append(SepsisExample(
        time_series_text=time_series_string,
        time_series_data=time_series_dict,
        ground_truth=row['label'],
        llm_label=llm_label,
        llm_explanation=explanation
    ))

Processing Rows: 100% 3/3 [00:00<00:00, 168.55it/s]


In [4]:
sepsis_examples[0].llm_explanation

'Elevated HR, signs of systemic inflammation, and oxygen saturation decline suggest high sepsis risk.'

### Stage 1: Atomic claim extraction

In [5]:
for example in sepsis_examples:
    claims = isolate_individual_features(example.llm_explanation)
    if claims is None:
        continue
    example.claims = [claim.strip() for claim in claims]

In [6]:
sepsis_examples[0].claims

['Elevated heart rate suggests high sepsis risk.',
 'Signs of systemic inflammation suggest high sepsis risk.',
 'Oxygen saturation decline suggests high sepsis risk.']

### Stage 2: Distill relevant claims

In [7]:
for example in sepsis_examples:
    relevant_claims = distill_relevant_features(example)
    example.relevant_claims = relevant_claims

100% 3/3 [00:11<00:00,  3.82s/it]
100% 4/4 [00:11<00:00,  2.81s/it]
100% 5/5 [00:14<00:00,  2.90s/it]


In [8]:
sepsis_examples[0].relevant_claims

['Elevated heart rate suggests high sepsis risk.',
 'Signs of systemic inflammation suggest high sepsis risk.',
 'Oxygen saturation decline suggests high sepsis risk.']

### Stage 3: Calculate alignment scores

In [9]:
for example in sepsis_examples:
    alignment_scores = []
    alignment_categories = []
    reasonings = []
    for claim in tqdm(example.relevant_claims):
        category, alignment_score, reasoning = calculate_expert_alignment_score(claim)
        if category is None:
            continue
        alignment_scores.append(alignment_score)
        alignment_categories.append(category)
    example.alignment_scores = alignment_scores
    example.alignment_categories = alignment_categories
    example.final_alignment = np.mean(alignment_scores)

100% 3/3 [00:00<00:00,  3.19it/s]
100% 1/1 [00:04<00:00,  4.57s/it]
100% 2/2 [00:00<00:00, 19.85it/s]


In [10]:
sepsis_examples[0].alignment_scores

[0.8, 0.8, 0.7]

In [11]:
sepsis_examples[0].alignment_categories

['Presence of\u202f≥\u202f2 SIRS criteria—temperature\u202f>\u202f38\u202f°C or\u202f<\u202f36\u202f°C, heart\u202frate\u202f>\u202f90\u202fbpm, respiratory\u202frate\u202f>\u202f20\u202f/min or PaCO₂\u202f<\u202f32\u202fmm\u202fHg, or WBC\u202f>\u202f12\u202f000/µL or\u202f<\u202f4\u202f000/µL—identifies systemic inflammation consistent with early sepsis.',
 'Presence of\u202f≥\u202f2 SIRS criteria—temperature\u202f>\u202f38\u202f°C or\u202f<\u202f36\u202f°C, heart\u202frate\u202f>\u202f90\u202fbpm, respiratory\u202frate\u202f>\u202f20\u202f/min or PaCO₂\u202f<\u202f32\u202fmm\u202fHg, or WBC\u202f>\u202f12\u202f000/µL or\u202f<\u202f4\u202f000/µL—identifies systemic inflammation consistent with early sepsis.',
 'An increase of\u202f≥\u202f2\u202fpoints in any SOFA component—e.g., PaO₂/FiO₂\u202f<\u202f300, platelets\u202f<\u202f100\u202f×\u202f10⁹/L, bilirubin\u202f>\u202f2\u202fmg/dL, creatinine\u202f>\u202f2\u202fmg/dL, or GCS\u202f<\u202f12—confirms new organ dysfunction and high 