# Project 5: ICU Summary Generator - Automated Clinical Documentation

**Objective**: Automatically generate comprehensive ICU patient summaries from EHR data

**Tech Stack**: PySpark, spaCy, Medical NLP

## Cell 1: Environment Setup

In [0]:
%pip install spacy
%pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_core_sci_md-0.5.1.tar.gz

dbutils.library.restartPython()

## Cell 2: Import Libraries

In [0]:
from pyspark.sql import Window
from pyspark.sql.functions import (
    col, lit, concat_ws, when, collect_list, struct,
    max as spark_max, min as spark_min, avg, count as spark_count,
    row_number, explode, udf
)
from pyspark.sql.types import StringType, ArrayType, StructType, StructField

import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import random
import spacy

print("✅ Libraries imported")

## Cell 3: Load Medical NLP Model

In [0]:
# Load spaCy medical model
nlp = spacy.load("en_core_sci_md")

print(f"✅ Loaded spaCy medical NLP model")
print(f"   Pipeline: {nlp.pipe_names}")

## Cell 4: Generate ICU Patient Data

In [0]:
random.seed(42)
np.random.seed(42)

# Generate 30 ICU patients
patient_data = []
for i in range(1, 31):
    admission_date = datetime(2024, 1, 1) + timedelta(days=random.randint(0, 200))
    patient_data.append({
        'patient_id': f'ICU_{str(i).zfill(3)}',
        'age': random.randint(40, 85),
        'gender': random.choice(['M', 'F']),
        'admission_date': admission_date.isoformat(),
        'admission_diagnosis': random.choice([
            'Septic Shock',
            'Acute Respiratory Failure',
            'Post-Cardiac Surgery',
            'Acute Myocardial Infarction',
            'Traumatic Brain Injury'
        ]),
        'icu_los_hours': random.randint(48, 240)
    })

patients_df = spark.createDataFrame(patient_data)

print(f"✅ Generated {patients_df.count()} ICU patients")
patients_df.show(5)

## Cell 5: Generate Vital Signs Data

In [0]:
vital_signs_data = []

for patient in patient_data:
    admission_dt = datetime.fromisoformat(patient['admission_date'])
    num_readings = patient['icu_los_hours'] // 4  # Every 4 hours
    
    for reading_num in range(num_readings):
        reading_time = admission_dt + timedelta(hours=reading_num * 4)
        vital_signs_data.append({
            'patient_id': patient['patient_id'],
            'timestamp': reading_time.isoformat(),
            'heart_rate': random.randint(60, 140),
            'sbp': random.randint(90, 180),
            'dbp': random.randint(50, 110),
            'respiratory_rate': random.randint(12, 30),
            'temperature': round(random.uniform(36.0, 39.5), 1),
            'spo2': random.randint(88, 100)
        })

vitals_df = spark.createDataFrame(vital_signs_data)

print(f"✅ Generated {vitals_df.count()} vital sign readings")

## Cell 6: Generate Lab Results

In [0]:
lab_data = []

for patient in patient_data:
    admission_dt = datetime.fromisoformat(patient['admission_date'])
    num_labs = max(1, patient['icu_los_hours'] // 24)  # Daily labs
    
    for lab_num in range(num_labs):
        lab_time = admission_dt + timedelta(days=lab_num)
        lab_data.append({
            'patient_id': patient['patient_id'],
            'lab_date': lab_time.date().isoformat(),
            'wbc': round(random.uniform(4.0, 20.0), 1),
            'hemoglobin': round(random.uniform(8.0, 16.0), 1),
            'platelet': random.randint(100, 400),
            'sodium': random.randint(130, 145),
            'potassium': round(random.uniform(3.0, 5.5), 1),
            'creatinine': round(random.uniform(0.5, 3.0), 2),
            'lactate': round(random.uniform(0.5, 4.0), 1)
        })

labs_df = spark.createDataFrame(lab_data)

print(f"✅ Generated {labs_df.count()} lab result sets")

## Cell 7: Generate Medications

In [0]:
medication_data = []
med_options = [
    'Norepinephrine', 'Propofol', 'Fentanyl', 'Vancomycin',
    'Piperacillin-Tazobactam', 'Insulin', 'Heparin',
    'Furosemide', 'Metoprolol'
]

for patient in patient_data:
    num_meds = random.randint(3, 7)
    selected_meds = random.sample(med_options, num_meds)
    
    for med in selected_meds:
        medication_data.append({
            'patient_id': patient['patient_id'],
            'medication': med,
            'route': random.choice(['IV', 'PO', 'IV infusion']),
            'frequency': random.choice(['q4h', 'q6h', 'continuous', 'daily'])
        })

meds_df = spark.createDataFrame(medication_data)

print(f"✅ Generated {meds_df.count()} medication orders")

## Cell 8: Generate Clinical Notes

In [0]:
notes_data = []

note_templates = {
    'Septic Shock': [
        'Patient admitted with septic shock. Started on broad-spectrum antibiotics and vasopressors.',
        'Hemodynamically improving. Lactate trending down. Weaning vasopressors.',
        'Stable off vasopressors. Antibiotics narrowed based on cultures.'
    ],
    'Acute Respiratory Failure': [
        'Intubated for hypoxemic respiratory failure. Lung-protective ventilation initiated.',
        'PEEP trial performed. Oxygenation improving.',
        'Successful spontaneous breathing trial. Extubated to nasal cannula.'
    ],
    'Post-Cardiac Surgery': [
        'Post-op day 1 s/p CABG. Hemodynamically stable. Chest tubes draining.',
        'Chest tubes removed. Progressing with physical therapy.',
        'Ready for transfer to step-down unit.'
    ]
}

for patient in patient_data:
    diagnosis = patient['admission_diagnosis']
    templates = note_templates.get(diagnosis, ['Patient in ICU.', 'Stable.', 'Improving.'])
    
    admission_dt = datetime.fromisoformat(patient['admission_date'])
    for i, note in enumerate(templates):
        note_time = admission_dt + timedelta(days=i)
        notes_data.append({
            'patient_id': patient['patient_id'],
            'note_date': note_time.date().isoformat(),
            'note_text': note
        })

notes_df = spark.createDataFrame(notes_data)

print(f"✅ Generated {notes_df.count()} clinical notes")

## Cell 9: Analyze Vital Signs

In [0]:
# Calculate vital sign statistics
vitals_stats = vitals_df.groupBy("patient_id").agg(
    avg("heart_rate").alias("avg_hr"),
    spark_max("heart_rate").alias("max_hr"),
    spark_min("heart_rate").alias("min_hr"),
    avg("sbp").alias("avg_sbp"),
    spark_max("sbp").alias("max_sbp"),
    spark_min("sbp").alias("min_sbp"),
    avg("respiratory_rate").alias("avg_rr"),
    spark_max("temperature").alias("max_temp"),
    spark_min("spo2").alias("min_spo2")
)

# Flag abnormalities
vitals_stats = vitals_stats.withColumn(
    "abnormalities",
    concat_ws("; ",
        when(col("max_hr") > 120, "Tachycardia"),
        when(col("min_sbp") < 90, "Hypotension"),
        when(col("max_temp") > 38.3, "Fever"),
        when(col("min_spo2") < 92, "Hypoxemia")
    )
)

print("📊 Vital Signs Summary:")
vitals_stats.show(5, truncate=False)

## Cell 10: Analyze Lab Results

In [0]:
# Get latest labs per patient
window = Window.partitionBy("patient_id").orderBy(col("lab_date").desc())

latest_labs = labs_df.withColumn("row_num", row_number().over(window)) \
    .filter(col("row_num") == 1) \
    .drop("row_num")

# Flag abnormal labs
latest_labs = latest_labs.withColumn(
    "lab_abnormalities",
    concat_ws("; ",
        when(col("wbc") > 12.0, "Leukocytosis"),
        when(col("hemoglobin") < 10.0, "Anemia"),
        when(col("creatinine") > 1.3, "Elevated creatinine"),
        when(col("lactate") > 2.0, "Elevated lactate")
    )
)

print("🔬 Latest Lab Results:")
latest_labs.select("patient_id", "wbc", "creatinine", "lactate", "lab_abnormalities").show(5, truncate=False)

## Cell 11: Summarize Medications

In [0]:
# Aggregate medications
meds_summary = meds_df.groupBy("patient_id").agg(
    collect_list(
        struct(col("medication"), col("route"), col("frequency"))
    ).alias("med_list")
)

# Format medications
def format_meds(med_list):
    if not med_list:
        return "None"
    return "; ".join([f"{m['medication']} ({m['route']}, {m['frequency']})" for m in med_list])

format_meds_udf = udf(format_meds, StringType())

meds_summary = meds_summary.withColumn(
    "medications",
    format_meds_udf(col("med_list"))
).withColumn(
    "on_vasopressors",
    when(col("medications").contains("Norepinephrine"), True).otherwise(False)
).withColumn(
    "on_antibiotics",
    when(col("medications").contains("Vancomycin") | 
         col("medications").contains("Piperacillin"), True).otherwise(False)
)

print("💊 Medication Summary:")
meds_summary.select("patient_id", "on_vasopressors", "on_antibiotics").show(5)

## Cell 12: Aggregate Clinical Notes

In [0]:
# Aggregate notes per patient
notes_summary = notes_df.groupBy("patient_id").agg(
    collect_list(
        struct(col("note_date"), col("note_text"))
    ).alias("notes_list")
)

# Create clinical course summary
def summarize_course(notes_list):
    if not notes_list:
        return "No notes"
    sorted_notes = sorted(notes_list, key=lambda x: x['note_date'])
    return " | ".join([f"Day {i+1}: {note['note_text']}" for i, note in enumerate(sorted_notes)])

summarize_udf = udf(summarize_course, StringType())

notes_summary = notes_summary.withColumn(
    "clinical_course",
    summarize_udf(col("notes_list"))
)

print("📋 Clinical Course:")
notes_summary.select("patient_id", "clinical_course").show(3, truncate=100)

## Cell 13: Generate Complete Summaries

In [0]:
# Join all components
icu_summary = patients_df \
    .join(vitals_stats, "patient_id", "left") \
    .join(latest_labs.select("patient_id", "wbc", "hemoglobin", "creatinine", 
                             "lactate", "lab_abnormalities"), "patient_id", "left") \
    .join(meds_summary.select("patient_id", "medications", "on_vasopressors", 
                              "on_antibiotics"), "patient_id", "left") \
    .join(notes_summary.select("patient_id", "clinical_course"), "patient_id", "left")

# Generate summary text
def generate_summary(row):
    summary = f"""
=== ICU PATIENT SUMMARY ===
Patient: {row['patient_id']} | Age: {row['age']}y {row['gender']}
Admitted: {row['admission_date'][:10]}
ICU Length of Stay: {row['icu_los_hours']} hours ({row['icu_los_hours']//24} days)

ADMISSION DIAGNOSIS:
{row['admission_diagnosis']}

VITAL SIGNS:
- Heart Rate: {row['avg_hr']:.0f} bpm (range: {row['min_hr']}-{row['max_hr']})
- Blood Pressure: {row['avg_sbp']:.0f} mmHg (range: {row['min_sbp']}-{row['max_sbp']})
- SpO2: Minimum {row['min_spo2']}%
"""
    
    if row['abnormalities']:
        summary += f"- Abnormalities: {row['abnormalities']}\n"
    
    summary += f"""
LABORATORY RESULTS (Latest):
- WBC: {row['wbc']:.1f} K/µL
- Hemoglobin: {row['hemoglobin']:.1f} g/dL
- Creatinine: {row['creatinine']:.2f} mg/dL
- Lactate: {row['lactate']:.1f} mmol/L
"""
    
    if row['lab_abnormalities']:
        summary += f"- Abnormalities: {row['lab_abnormalities']}\n"
    
    summary += f"""
MEDICATIONS:
{row['medications']}

CRITICAL INTERVENTIONS:
- Vasopressor Support: {'Yes' if row['on_vasopressors'] else 'No'}
- Antibiotics: {'Yes' if row['on_antibiotics'] else 'No'}

CLINICAL COURSE:
{row['clinical_course']}

=== END OF SUMMARY ===
"""
    return summary.strip()

generate_summary_udf = udf(generate_summary, StringType())

# Generate all summaries
icu_summary = icu_summary.withColumn(
    "summary",
    generate_summary_udf(struct([col(c) for c in icu_summary.columns]))
)

print(f"✅ Generated {icu_summary.count()} ICU summaries")

## Cell 14: Display Sample Summaries

In [0]:
# Show 3 sample summaries
samples = icu_summary.select("patient_id", "summary").limit(3).collect()

for i, row in enumerate(samples, 1):
    print(f"\n{'='*80}")
    print(f"SAMPLE {i}")
    print('='*80)
    print(row['summary'])

## Cell 15: Summary Statistics

In [0]:
# Calculate statistics
stats = icu_summary.select(
    spark_count("patient_id").alias("total_patients"),
    avg("icu_los_hours").alias("avg_los_hours"),
    spark_count(when(col("on_vasopressors") == True, 1)).alias("patients_on_vasopressors"),
    spark_count(when(col("on_antibiotics") == True, 1)).alias("patients_on_antibiotics"),
    spark_count(when(col("abnormalities").isNotNull(), 1)).alias("patients_with_vital_abnormalities"),
    spark_count(when(col("lab_abnormalities").isNotNull(), 1)).alias("patients_with_lab_abnormalities")
).collect()[0]

print("\n" + "="*80)
print("ICU SUMMARY GENERATOR - PROJECT SUMMARY")
print("="*80)

print(f"\n📊 Dataset Statistics:")
print(f"   Total Patients: {stats['total_patients']}")
print(f"   Vital Signs Readings: {vitals_df.count()}")
print(f"   Lab Results: {labs_df.count()}")
print(f"   Medications: {meds_df.count()}")
print(f"   Clinical Notes: {notes_df.count()}")

print(f"\n📋 Summary Generation:")
print(f"   Average ICU LOS: {stats['avg_los_hours']:.1f} hours ({stats['avg_los_hours']/24:.1f} days)")
print(f"   Patients on Vasopressors: {stats['patients_on_vasopressors']}")
print(f"   Patients on Antibiotics: {stats['patients_on_antibiotics']}")
print(f"   Patients with Vital Abnormalities: {stats['patients_with_vital_abnormalities']}")
print(f"   Patients with Lab Abnormalities: {stats['patients_with_lab_abnormalities']}")

print("\n" + "="*80)
print("✅ Project 5 Complete - ICU Summaries Generated")
print("="*80)