# Assignment 02: Polars on EHR Event Logs

Build a Polars pipeline that summarizes diagnosis prevalence from synthetic EHR events. Use lazy scans, filtering, joins, and group-bys to compute site-level diabetes prevalence.

## Setup


In [19]:
import polars as pl
import yaml
from pathlib import Path
from datetime import datetime
from generate_test_data import generate_test_data

print(f"Polars version: {pl.__version__}")
print("Environment ready!")

Polars version: 1.37.1
Environment ready!


## Configuration


In [20]:
with open("config.yaml") as f:
    config = yaml.safe_load(f)

print("Config loaded:")
print(f"  Patients: {config['data']['patients_path']}")
print(f"  Sites: {config['data']['sites_path']}")
print(f"  Events: {config['data']['events_path']}")
print(f"  ICD-10 lookup: {config['data']['icd10_path']}")

Config loaded:
  Patients: data/patients.parquet
  Sites: data/sites.parquet
  Events: data/events.parquet
  ICD-10 lookup: data/icd10_codes.parquet


## Generate data


In [21]:
SIZE = config["data"]["size"]
DATA_DIR = Path(config["data"]["dir"])

# Create output directory if it doesn't exist
DATA_DIR.mkdir(parents=True, exist_ok=True)

# Generate data - "medium" takes ~10 seconds on my laptop
generate_test_data(size=SIZE, output_dir=DATA_DIR)

INFO Loading codebooks
INFO Generating sites
INFO Generating patients
INFO Generating events
INFO Wrote 100000 patients
INFO Wrote 8 sites
INFO Wrote 7487826 events
INFO Output directory: data


## Hints (optional)

- Distinct patient counts: call `.unique()` before `group_by()`.
  - Example: `events.select(["site_id", "patient_id"]).unique()`
- Prefix filter for ICD-10: `pl.col("code").str.starts_with(prefix)`
- Optional polish: `.fill_null(0)` after a left join, and `.round(3)` on prevalence

## Part 1: Lazy Data Loading

Use `pl.scan_parquet()` to create LazyFrames without loading data into memory.


In [22]:
# TODO: Scan patients, sites, events, and ICD-10 lookup
patients = None
sites = None
events = None
icd10 = None

# Check schemas (fast, still lazy)
if patients is not None:
    print("Patients schema:")
    print(patients.collect_schema())

if events is not None:
    print("Events schema:")
    print(events.collect_schema())

## Part 2: Filter and Prep Events

Filter to the assignment date window and extract ICD-10 diagnosis events.


In [23]:
start_date = datetime.fromisoformat(config["data"]["start_date"])

# TODO: parse event_ts to Datetime
# TODO: filter events to event_ts >= start_date
# TODO: filter to record_type == "ICD-10-CM" for diagnosis events


patients = pl.scan_parquet(config["data"]["patients_path"])
sites    = pl.scan_parquet(config["data"]["sites_path"])
events   = pl.scan_parquet(config["data"]["events_path"])
icd10    = pl.scan_parquet(config["data"]["icd10_path"])

events_filtered = (
    events
    .with_columns(
        pl.col("event_ts").str.strptime(pl.Datetime, strict=False)
    )
    .filter(
        (pl.col("event_ts") >= start_date) &
        (pl.col("record_type") == "ICD-10-CM")
    )
)

dx_events = events_filtered

## Part 3: Diagnosis Prevalence by Site

Compute the percent of patients per site with a type 2 diabetes diagnosis.


In [24]:
prefix = config["data"]["diabetes_prefix"]

# TODO: Filter dx_events to ICD-10 codes starting with prefix
# TODO: total patients per site (unique patient_id from events_filtered)
# TODO: diabetes patients per site (unique patient_id from filtered dx)
# TODO: join counts + site names, calculate prevalence

total_by_site = (
    events_filtered
    .select(["site_id", "patient_id"])
    .unique()
    .group_by("site_id")
    .agg(pl.len().alias("patients_seen"))
)

diabetes_by_site = (
    dx_events
    .filter(pl.col("code").str.starts_with(prefix))
    .select(["site_id", "patient_id"])
    .unique()
    .group_by("site_id")
    .agg(pl.len().alias("diabetes_patients"))
)

dx_summary = (
    total_by_site
    .join(diabetes_by_site, on="site_id", how="left")
    .with_columns(pl.col("diabetes_patients").fill_null(0))
    .with_columns(
        (pl.col("diabetes_patients") / pl.col("patients_seen")).alias("diabetes_prevalence")
    )
    .join(
        sites.select(["site_id", "site_name", "site_type"]),
        on="site_id",
        how="left",
    )
    .select([
        "site_id",
        "site_name",
        "site_type",
        "patients_seen",
        "diabetes_patients",
        "diabetes_prevalence",
    ])
    .sort("site_id")
)

## Part 4: Collect and Export


In [25]:
# TODO: collect dx_summary using streaming engine
# TODO: create outputs directory
# TODO: write Parquet + CSV outputs using config paths

dx_summary_df = dx_summary.collect(streaming=True)

parquet_path = Path(config["outputs"]["dx_summary_parquet"])
csv_path = Path(config["outputs"]["dx_summary_csv"])
parquet_path.parent.mkdir(parents=True, exist_ok=True)

dx_summary_df.write_parquet(parquet_path)
dx_summary_df.write_csv(csv_path)

if dx_summary is not None:
    print("Outputs ready")

  dx_summary_df = dx_summary.collect(streaming=True)


Outputs ready


## Validation


In [26]:
outputs = [
    config["outputs"]["dx_summary_parquet"],
    config["outputs"]["dx_summary_csv"],
]

missing = [path for path in outputs if not Path(path).exists()]
if missing:
    print("Missing outputs:", missing)
else:
    print("All outputs created")

All outputs created


## Next Steps (Optional)

1. Run `python -m pytest .github/tests/test_assignment.py -v` in your terminal.
2. Use exploratory data analysis (EDA) or visualization techniques to get a feel for the dataset
