# Step 2: AF Cohort Construction

This notebook builds the complete AF cohort by:
1. Constructing AF episodes from validated rhythm itemids
2. Linking medication administrations
3. Adding demographics and outcomes
4. Exporting the cohort for analysis

**Prerequisites**: Run `01_rhythm_itemid_discovery.ipynb` first to identify rhythm itemids.

In [None]:
# Import libraries
import pandas as pd
import numpy as np
from google.cloud import bigquery
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import sys
sys.path.append('..')
from config import *

# Set display options
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)

# Initialize BigQuery client
client = bigquery.Client(project=OUTPUT_PROJECT_ID)

print("Environment setup complete")
print(f"Using project: {OUTPUT_PROJECT_ID}")
print(f"MIMIC-IV version: 3.1")

## Load Validated Rhythm ItemIDs

In [None]:
# Load the validated itemids from Step 1
try:
    validated_itemids_df = pd.read_csv('../data/validated_rhythm_itemids.csv')
    RHYTHM_ITEMIDS = validated_itemids_df['itemid'].tolist()
    print(f"Loaded {len(RHYTHM_ITEMIDS)} validated rhythm itemids:")
    print(RHYTHM_ITEMIDS)
    print("\nItemID details:")
    print(validated_itemids_df[['itemid', 'item_label', 'af_hits', 'af_percentage']])
except FileNotFoundError:
    print("ERROR: validated_rhythm_itemids.csv not found.")
    print("Please run notebook 01_rhythm_itemid_discovery.ipynb first.")
    RHYTHM_ITEMIDS = []  # Will cause query to fail - this is intentional

## Build Complete AF Cohort

This query combines:
- AF episode detection from chartevents
- Medication administrations (eMAR + drips)
- Demographics and outcomes

In [None]:
# Build the comprehensive cohort query
cohort_query = f"""
-- Complete AF Cohort Construction
DECLARE RHYTHM_ITEMIDS ARRAY<INT64> DEFAULT {RHYTHM_ITEMIDS};
DECLARE AA_RX ARRAY<STRING> DEFAULT {ANTIARRHYTHMIC_MEDS};
DECLARE RATE_RX ARRAY<STRING> DEFAULT {RATE_CONTROL_MEDS};

-- Build AF episodes
WITH rhythm_obs AS (
  SELECT
    ce.stay_id,
    ce.charttime,
    LOWER(ce.value) AS value_lc
  FROM `{CHARTEVENTS_TABLE}` ce
  WHERE ce.itemid IN UNNEST(RHYTHM_ITEMIDS)
    AND ce.charttime IS NOT NULL
    AND ce.value IS NOT NULL
),

labeled AS (
  SELECT
    stay_id,
    charttime,
    REGEXP_CONTAINS(value_lc, r'{AF_REGEX}') AS is_af
  FROM rhythm_obs
),

dedup AS (
  SELECT AS VALUE x
  FROM (
    SELECT ARRAY_AGG(l ORDER BY l.charttime LIMIT 1)[OFFSET(0)] AS x
    FROM labeled l
    GROUP BY l.stay_id, l.charttime
  )
),

segmented AS (
  SELECT
    stay_id,
    charttime,
    is_af,
    SUM(CASE WHEN is_af != LAG(is_af) OVER w OR LAG(is_af) OVER w IS NULL THEN 1 ELSE 0 END)
      OVER (PARTITION BY stay_id ORDER BY charttime) AS seg_id
  FROM dedup
  WINDOW w AS (PARTITION BY stay_id ORDER BY charttime)
),

segments AS (
  SELECT
    stay_id,
    seg_id,
    ANY_VALUE(is_af) AS is_af,
    MIN(charttime) AS seg_start,
    LEAD(MIN(charttime)) OVER (PARTITION BY stay_id ORDER BY MIN(charttime)) AS seg_end
  FROM segmented
  GROUP BY stay_id, seg_id
),

af_episodes AS (
  SELECT
    s.stay_id,
    s.seg_start AS af_start,
    COALESCE(s.seg_end, i.outtime) AS af_end,
    TIMESTAMP_DIFF(COALESCE(s.seg_end, i.outtime), s.seg_start, MINUTE)/60.0 AS af_hours,
    ROW_NUMBER() OVER (PARTITION BY s.stay_id ORDER BY s.seg_start) AS episode_number
  FROM segments s
  JOIN `{ICUSTAYS_TABLE}` i USING (stay_id)
  WHERE s.is_af = TRUE 
    AND TIMESTAMP_DIFF(COALESCE(s.seg_end, i.outtime), s.seg_start, MINUTE)/60.0 > {MIN_AF_DURATION_HOURS}
),

-- Get medications
emar AS (
  SELECT
    e.subject_id,
    e.hadm_id,
    e.charttime AS admin_time,
    CASE
      WHEN EXISTS (SELECT 1 FROM UNNEST(AA_RX) g WHERE LOWER(ed.medication) LIKE CONCAT('%', g, '%'))
      THEN 'AA'
      WHEN EXISTS (SELECT 1 FROM UNNEST(RATE_RX) g WHERE LOWER(ed.medication) LIKE CONCAT('%', g, '%'))
      THEN 'RATE'
    END AS drug_class
  FROM `{EMAR_TABLE}` e
  JOIN `{EMAR_DETAIL_TABLE}` ed ON e.emar_id = ed.emar_id
  WHERE e.charttime IS NOT NULL
    AND (
      EXISTS (SELECT 1 FROM UNNEST(AA_RX) g WHERE LOWER(ed.medication) LIKE CONCAT('%', g, '%'))
      OR EXISTS (SELECT 1 FROM UNNEST(RATE_RX) g WHERE LOWER(ed.medication) LIKE CONCAT('%', g, '%'))
    )
),

icu_drips AS (
  SELECT
    ie.stay_id,
    ie.starttime AS t_start,
    ie.endtime AS t_end,
    CASE
      WHEN EXISTS (SELECT 1 FROM UNNEST(AA_RX) g WHERE di.label LIKE CONCAT('%', g, '%'))
      THEN 'AA'
      ELSE 'RATE'
    END AS drug_class
  FROM `{INPUTEVENTS_TABLE}` ie
  JOIN `{D_ITEMS_TABLE}` di ON ie.itemid = di.itemid
  WHERE (
    EXISTS (SELECT 1 FROM UNNEST(AA_RX) g WHERE di.label LIKE CONCAT('%', g, '%'))
    OR EXISTS (SELECT 1 FROM UNNEST(RATE_RX) g WHERE di.label LIKE CONCAT('%', g, '%'))
  )
),

-- Match medications to AF episodes
af_with_meds AS (
  SELECT
    af.stay_id,
    af.af_start,
    af.af_end,
    af.af_hours,
    af.episode_number,
    COUNTIF(em.drug_class = 'AA' AND em.admin_time BETWEEN af.af_start AND af.af_end) > 0 AS received_aa_emar,
    COUNTIF(em.drug_class = 'RATE' AND em.admin_time BETWEEN af.af_start AND af.af_end) > 0 AS received_rate_emar,
    COUNTIF(dr.drug_class = 'AA' AND dr.t_start <= af.af_end AND COALESCE(dr.t_end, dr.t_start) >= af.af_start) > 0 AS received_aa_drip,
    COUNTIF(dr.drug_class = 'RATE' AND dr.t_start <= af.af_end AND COALESCE(dr.t_end, dr.t_start) >= af.af_start) > 0 AS received_rate_drip
  FROM af_episodes af
  JOIN `{ICUSTAYS_TABLE}` icu ON af.stay_id = icu.stay_id
  LEFT JOIN emar em ON icu.subject_id = em.subject_id AND icu.hadm_id = em.hadm_id
  LEFT JOIN icu_drips dr ON af.stay_id = dr.stay_id
  GROUP BY af.stay_id, af.af_start, af.af_end, af.af_hours, af.episode_number
)

-- Final cohort with demographics and outcomes
SELECT
  p.subject_id,
  ad.hadm_id,
  icu.stay_id,
  afm.af_start,
  afm.af_end,
  afm.af_hours,
  afm.episode_number,
  CASE WHEN afm.episode_number = 1 THEN TRUE ELSE FALSE END AS is_first_episode,
  afm.received_aa_emar OR afm.received_aa_drip AS received_antiarrhythmic,
  afm.received_rate_emar OR afm.received_rate_drip AS received_rate_control,
  age.age,
  p.gender,
  ad.race,
  ad.admission_type,
  icu.first_careunit,
  icu.intime AS icu_intime,
  icu.outtime AS icu_outtime,
  TIMESTAMP_DIFF(icu.outtime, icu.intime, HOUR)/24.0 AS icu_los_days,
  ad.hospital_expire_flag,
  TIMESTAMP_DIFF(ad.dischtime, ad.admittime, HOUR)/24.0 AS hosp_los_days,
  CASE
    WHEN p.dod IS NOT NULL AND ad.dischtime IS NOT NULL
      AND TIMESTAMP_DIFF(p.dod, ad.dischtime, DAY) <= 30
    THEN TRUE
    ELSE FALSE
  END AS mortality_30day,
  sofa.sofa_24hours
FROM af_with_meds afm
JOIN `{ICUSTAYS_TABLE}` icu ON afm.stay_id = icu.stay_id
JOIN `{ADMISSIONS_TABLE}` ad ON icu.hadm_id = ad.hadm_id
JOIN `{PATIENTS_TABLE}` p ON ad.subject_id = p.subject_id
LEFT JOIN `{AGE_TABLE}` age ON p.subject_id = age.subject_id AND ad.hadm_id = age.hadm_id
LEFT JOIN `{SOFA_TABLE}` sofa ON afm.stay_id = sofa.stay_id
ORDER BY subject_id, hadm_id, stay_id, episode_number;
"""

print("Executing cohort construction query...")
print("This will scan chartevents and may take several minutes...\n")

In [None]:
# Execute the query
af_cohort = client.query(cohort_query).to_dataframe()

print(f"\nCohort construction complete!")
print(f"Total AF episodes: {len(af_cohort):,}")
print(f"Unique ICU stays with AF: {af_cohort['stay_id'].nunique():,}")
print(f"Unique hospital admissions: {af_cohort['hadm_id'].nunique():,}")
print(f"Unique patients: {af_cohort['subject_id'].nunique():,}")

## Preview the Cohort

In [None]:
# Display first few rows
af_cohort.head(10)

In [None]:
# Data types and missing values
print("\nData summary:")
print(af_cohort.info())
print("\nMissing values:")
print(af_cohort.isnull().sum())

## Quick Summary Statistics

In [None]:
print("AF EPISODE CHARACTERISTICS")
print("=" * 50)
print(f"Episodes per patient (median): {af_cohort.groupby('subject_id')['episode_number'].max().median():.0f}")
print(f"AF duration (median hours): {af_cohort['af_hours'].median():.1f}")
print(f"AF duration (mean hours): {af_cohort['af_hours'].mean():.1f}")
print(f"\nFirst episodes only: {af_cohort['is_first_episode'].sum():,}")

print("\n\nMEDICATION USAGE")
print("=" * 50)
print(f"Episodes with antiarrhythmic: {af_cohort['received_antiarrhythmic'].sum():,} ({af_cohort['received_antiarrhythmic'].mean()*100:.1f}%)")
print(f"Episodes with rate control: {af_cohort['received_rate_control'].sum():,} ({af_cohort['received_rate_control'].mean()*100:.1f}%)")

print("\n\nDEMOGRAPHICS (unique patients)")
print("=" * 50)
patient_level = af_cohort.drop_duplicates('subject_id')
print(f"Median age: {patient_level['age'].median():.0f} years")
print(f"\nGender distribution:")
print(patient_level['gender'].value_counts())

print("\n\nOUTCOMES (unique patients)")
print("=" * 50)
print(f"Hospital mortality: {patient_level['hospital_expire_flag'].sum():,} ({patient_level['hospital_expire_flag'].mean()*100:.1f}%)")
print(f"30-day mortality: {patient_level['mortality_30day'].sum():,} ({patient_level['mortality_30day'].mean()*100:.1f}%)")
print(f"Median ICU LOS: {patient_level['icu_los_days'].median():.1f} days")
print(f"Median hospital LOS: {patient_level['hosp_los_days'].median():.1f} days")

## Save the Cohort

In [None]:
# Save complete cohort
output_file = f"../{COHORT_OUTPUT_PATH}/af_cohort_complete.csv"
af_cohort.to_csv(output_file, index=False)
print(f"Complete cohort saved to: {output_file}")

# Save first episodes only
first_episodes = af_cohort[af_cohort['is_first_episode'] == True]
output_file_first = f"../{COHORT_OUTPUT_PATH}/af_cohort_first_episodes.csv"
first_episodes.to_csv(output_file_first, index=False)
print(f"First episodes only saved to: {output_file_first}")

# Save patient-level summary (one row per patient)
output_file_patient = f"../{COHORT_OUTPUT_PATH}/af_cohort_patient_level.csv"
patient_level.to_csv(output_file_patient, index=False)
print(f"Patient-level data saved to: {output_file_patient}")

## Next Steps

The cohort has been successfully constructed and saved. Proceed to:
- `03_statistical_analysis.ipynb` for detailed statistical summaries
- `04_visualization_dashboard.ipynb` for charts and visualizations