<a href="https://colab.research.google.com/github/Medical-Event-Data-Standard/MEDS_ML4H_2025_Tutorial/blob/main/ACES_visualizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Source: https://gist.githubusercontent.com/Oufattole/d78fe58b0bc67ba6286e64a42c1161fc/raw/0df6184159ef33ab0b32efd703b571867e20f569/gistfile1.txt

import plotly.figure_factory as ff
import pandas as pd
import yaml
import plotly.express as px
from pathlib import Path


def parse_duration(expr):
    """Parse duration strings like '180d', '24h', etc."""
    if not isinstance(expr, str):
        return pd.Timedelta('24h')  # default duration

    if expr == 'trigger':
        return pd.Timedelta(0)

    if '+' in expr:
        # Handle "start + 180d" style expressions
        try:
            duration_part = expr.split('+')[1].strip()
            return pd.Timedelta(duration_part)
        except:
            return pd.Timedelta('24h')

    try:
        return pd.Timedelta(expr)
    except:
        return pd.Timedelta('24h')

def get_window_times(window_config, trigger_time, input_default_lookback):
    """Calculate start and end times for a window"""
    start = window_config.get('start')
    end = window_config.get('end')

    # Handle start time
    if start == 'null' or start is None:
        start_time = trigger_time - pd.Timedelta(input_default_lookback)
    elif start == 'trigger':
        start_time = trigger_time
    elif 'gap.end' in str(start):
        start_time = trigger_time + pd.Timedelta('48h')
    else:
        start_time = trigger_time

    # Handle end time
    if end == 'trigger':
        end_time = trigger_time
    elif isinstance(end, str):
        if 'trigger + ' in end:
            duration = end.split('+ ')[1]
            end_time = trigger_time + pd.Timedelta(duration)
        elif 'start + ' in end:
            duration = end.split('+ ')[1]
            end_time = start_time + pd.Timedelta(duration)
        elif '->' in end:
            end_time = start_time + pd.Timedelta('72h')
        else:
            end_time = trigger_time + pd.Timedelta('24h')
    else:
        end_time = trigger_time + pd.Timedelta('24h')

    return start_time, end_time

def format_requirements(window_config):
    """Format window requirements into readable text"""
    requirements = []
    if 'has' in window_config:
        for predicate, count in window_config['has'].items():
            count_str = f"{count}".replace("(", "").replace(")", "")
            if count_str == "None, 0":
                requirements.append(f"No {predicate}")
            else:
                requirements.append(f"{predicate}: {count_str}")
    if 'label' in window_config:
        requirements.append(f"Label: {window_config['label']}")
    return "<br>".join(requirements) if requirements else "No requirements"

def create_cohort_timeline(config, input_default_lookback):
    """Create timeline visualization from cohort config"""
    windows = config['windows']

    df = []
    all_times = set([0])  # Always include trigger

    for window_name, window_config in windows.items():
        start_time, end_time = get_window_times(window_config, pd.Timestamp('2024-01-01'), input_default_lookback)

        # Convert to hours relative to trigger
        start_hours = (start_time - pd.Timestamp('2024-01-01')).total_seconds() / 3600
        end_hours = (end_time - pd.Timestamp('2024-01-01')).total_seconds() / 3600

        # Store hours directly instead of datetime objects
        df.append({
            'Task': f"Window: {window_name}",
            'Start': start_hours,
            'Finish': end_hours,
            'Description': format_requirements(window_config),
        })

        all_times.add(start_hours)
        all_times.add(end_hours)

    # Convert hours to milliseconds for plotly
    for entry in df:
        entry['Start'] = pd.Timestamp('2024-01-01').timestamp() * 1000 + (entry['Start'] * 3600 * 1000)
        entry['Finish'] = pd.Timestamp('2024-01-01').timestamp() * 1000 + (entry['Finish'] * 3600 * 1000)

    fig = ff.create_gantt(df,
                         colors=px.colors.qualitative.Plotly,
                         title='Cohort Timeline',
                         index_col='Task',
                         showgrid_x=True,
                         showgrid_y=True)

    # Add trigger line at 0 hours
    trigger_ts = pd.Timestamp('2024-01-01').timestamp() * 1000
    fig.add_vline(x=trigger_ts,
                 line_dash="dash",
                 line_color="red",
                 annotation_text=f"Trigger ({config.get('trigger', 'event')})")

    # Create tick marks
    tick_vals = sorted(list(all_times))
    tick_text = []
    tick_timestamps = []

    for hours in tick_vals:
        if abs(hours) >= 24*365:
            tick_text.append(f"{hours/(24*365):.1f}y")
        elif abs(hours) >= 24*30:
            tick_text.append(f"{hours/(24*30):.1f}m")
        elif abs(hours) >= 24:
            tick_text.append(f"{hours/24:.1f}d")
        else:
            tick_text.append(f"{hours:.0f}h")

        # Convert hours to milliseconds timestamp
        tick_timestamps.append(pd.Timestamp('2024-01-01').timestamp() * 1000 + (hours * 3600 * 1000))

    fig.update_xaxes(
        title_text="Time relative to trigger",
        ticktext=tick_text,
        tickvals=tick_timestamps,
        tickmode='array'
    )

    annotations = []
    y_positions = {task['Task']: i for i, task in enumerate(df)}

    for task in df:
        if task['Description'] != "No requirements":
            annotations.append(dict(
                x=task['Start'],
                y=y_positions[task['Task']],
                text="⚡",
                showarrow=False,
                font=dict(size=16)
            ))

    fig.update_layout(
        annotations=annotations,
        height=300,
        margin=dict(l=100, r=100)
    )

    return fig

YAML_STR = """
predicates:
  icu_admission:
    code: { regex: "^ICU_ADMISSION//.*" }
  icu_discharge:
    code: { regex: "^ICU_DISCHARGE//.*" }
  death:
    code: { regex: "MEDS_DEATH.*" }

  # CMO predicates
  cmo_1:
    code: { any: ["LAB//220001//UNK", "LAB//223758//UNK"] }
    text_value: "Comfort measures only"
  cmo_2:
    code: { any: ["LAB//220001//UNK", "LAB//223758//UNK"] }
    text_value: "Comfort care (CMO, Comfort Measures)"

  # DNR predicates
  dnr_1:
    code: { any: ["LAB//220001//UNK", "LAB//223758//UNK"] }
    text_value: "DNR / DNI"
  dnr_2:
    code: { any: ["LAB//220001//UNK", "LAB//223758//UNK"] }
    text_value: "DNAR (Do Not Attempt Resuscitation)  [DNR]"
  dnr_3:
    code: { any: ["LAB//220001//UNK", "LAB//223758//UNK"] }
    text_value: "DNAR (Do Not Attempt Resuscitation) [DNR] / DNI"
  dnr_4:
    code: { any: ["LAB//220001//UNK", "LAB//223758//UNK"] }
    text_value: "DNR (do not resuscitate)"

  # derived predicates
  cmo:
    expr: or(cmo_1, cmo_2)
  dnr:
    expr: or(dnr_1, dnr_2, dnr_3, dnr_4)

trigger: icu_admission

windows:
  input:
    start: null
    end: trigger + 24h
    start_inclusive: True
    end_inclusive: True
    index_timestamp: end
    has:
      cmo: (None, 0) # Exclude patients on comfort measures only
      dnr: (None, 0) # Exclude patients with DNR orders
  gap:
    start: trigger
    end: start + 30h
    start_inclusive: False
    end_inclusive: True
    has:
      cmo: (None, 0)
      dnr: (None, 0)
      icu_discharge: (None, 0)
  target:
    start: trigger
    end: start + 3d
    start_inclusive: True
    end_inclusive: True
    label: icu_discharge
    has:
      death: (None, 0)
"""

def visualize_cohort_yaml(yaml_fp: Path, input_default_lookback: str):
    """Create visualization from YAML string"""
    yaml_str = YAML_STR#yaml_fp.read_text()
    config = yaml.safe_load(yaml_str)
    return create_cohort_timeline(config, input_default_lookback)

In [None]:
visualize_cohort_yaml(None, "10d")