# Sankey diagram

branchings 
 - Frist branching: mechanisms
- Second branching: on-scene survival 
- Third branching: pre-hospital diagnosis
- Fourth branching: inhospital diagnosis
- Fifth branching: death


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go

In [None]:
# Options
SHOW_ON_SCENE_SURVIVED_NODE = True
COLOR_FLOWS_BY_INHOSPITAL_DX = True

In [None]:
data_path = '/Users/jk1/Library/CloudStorage/OneDrive-UniversitédeGenève/icu_research/prehospital/pediatric_trauma/data/Data_PedRegaTrauma_coded_for_analysis_250417.xlsx'

In [None]:
# Load data
_df = pd.read_excel(data_path, sheet_name='All centres cleaned')

In [None]:
_df['on-scene_survival'].value_counts()

In [None]:
_df['on-scene_survival'] = _df['on-sceen survival (y/n)'] 
# Place discharge != 0 means survived hospital stay
_df['hospital_survival'] = _df['Place discharge'] != 0

In [None]:

# Mechanism encoding (from fig_mechanism.ipynb)
mechanism_encoding_groups_map = {
    'Collisions': [1, 5, 6],
    'Falls': [2, 7],
    'Burns': [3],
    'Other': [4, 8],
}

_df['mechanism_group'] = _df['Mechanism of injury'].map(
    lambda x: next((group for group, codes in mechanism_encoding_groups_map.items() if x in codes), pd.NA)
)

# On-scene survival encoding
def _map_on_scene_survival(val):
    if pd.isna(val):
        return pd.NA
    if isinstance(val, str):
        v = val.strip().lower()
        if v in {'yes', 'y', 'survived', 'alive', '1', 'true'}:
            return 'Survived on scene'
        if v in {'no', 'n', 'dead', 'died', '0', 'false'}:
            return 'Died on scene'
    if val in {1, True}:
        return 'Survived on scene'
    if val in {0, False}:
        return 'Died on scene'
    return pd.NA


_df['on_scene_survival_clean'] = _df['on-scene_survival'].apply(_map_on_scene_survival)

# Pre-hospital diagnosis preprocessing (from dx_accuracy.ipynb)
_pre = _df['Main diagnosis pre-hospital']
_pre = _pre.replace('<NA>', pd.NA)
_pre = _pre.replace('Vd. a. Asphiktische REA', 10)
_pre = _pre.replace(
    '1. CO Intoxikation durch Rauchgasvergiftung (Kachelofen)\n   - CO 20%\n   - Schwindel, Unwohlsein, fragliche krampfartigen Äquivalente',
    0,
 )
_pre = _pre.replace(
    '1. CO INtoxikation durch Rauchgasvergiftung (Kachelofen) mit\n   - Krampfäquivalent, Schwindel, Übelkeit\n   - CO 22%',
    0,
 )
_pre = _pre.replace(999, pd.NA)

# In-hospital diagnosis encoding (based on dx_accuracy.ipynb)
dx_code_to_name = {
    0: 'No diagnosis',
    1: 'Traumatic brain or cervical spine injury',
    2: 'Chest trauma',
    3: 'Abdominal trauma',
    4: 'Pelvic Trauma',
    5: 'Upper extremity trauma',
    6: 'Lower extremity trauma',
    7: 'Spine injury',
    8: 'Face',
    9: 'Polytrauma',
    10: 'Drowning',
    11: 'Burns',
}


def _clean_dx(val):
    if pd.isna(val):
        return pd.NA
    if isinstance(val, (int, float)) and not pd.isna(val):
        try:
            return int(val)
        except Exception:
            return pd.NA
    if isinstance(val, str):
        val = val.replace('C2-Intoxikation,', '').replace('nan,', '')
        val = val.replace('Obstrukt.Atemversagen -REA', '10')
        # keep digits and commas only
        val = ''.join(ch for ch in val if ch.isdigit() or ch == ',')
        val = val.strip(',')
        if val == '':
            return pd.NA
        try:
            return int(val.split(',')[0])
        except Exception:
            return pd.NA
    return pd.NA


def _clean_pre_dx(val):
    if pd.isna(val):
        return pd.NA
    if isinstance(val, (int, float)) and not pd.isna(val):
        try:
            return int(val)
        except Exception:
            return pd.NA
    if isinstance(val, str):
        val = ''.join(ch for ch in val if ch.isdigit() or ch == ',')
        val = val.strip(',')
        if val == '':
            return pd.NA
        try:
            return int(val.split(',')[0])
        except Exception:
            return pd.NA
    return pd.NA


_dx = _df['Main diagnosis in-hospital'].replace(999, pd.NA)

_df['in_hosp_dx_code'] = _dx.apply(_clean_dx)
_df['in_hosp_dx'] = _df['in_hosp_dx_code'].map(dx_code_to_name)

_df['pre_hosp_dx_code'] = _pre.apply(_clean_pre_dx)
_df['pre_hosp_dx'] = _df['pre_hosp_dx_code'].map(dx_code_to_name)

# Hospital survival encoding (uses derived hospital_survival flag)
def _map_hospital_survival(val):
    if pd.isna(val):
        return pd.NA
    if isinstance(val, str):
        v = val.strip().lower()
        if v in {'yes', 'y', 'discharged alive', 'alive', '1', 'true'}:
            return 'Survived'
        if v in {'no', 'n', 'in-hospital death', 'died', '0', 'false'}:
            return 'Died'
    if val in {1, True}:
        return 'Discharged alive'
    if val in {0, False}:
        return 'In-hospital death'
    return pd.NA


_df['death'] = _df['hospital_survival'].apply(_map_hospital_survival)

# Build Sankey data
_sankey_df = _df[['mechanism_group', 'on_scene_survival_clean', 'pre_hosp_dx', 'in_hosp_dx', 'death']].dropna()

_mech_nodes = sorted(_sankey_df['mechanism_group'].unique())

# Default diagnosis ordering for diagnosis branchings
_dx_order = [
    'Polytrauma',
    'Traumatic brain or cervical spine injury',
    'Face',
    'Upper extremity trauma',
    'Lower extremity trauma',
    'Chest trauma',
    'Abdominal trauma',
    'Pelvic Trauma',
    'Spine injury',
    'Burns',
    'Drowning',
    'No diagnosis',
]
_pre_dx_present = set(_sankey_df['pre_hosp_dx'].unique())
_pre_dx_nodes = [name for name in _dx_order if name in _pre_dx_present]

_dx_present = set(_sankey_df['in_hosp_dx'].unique())
_dx_nodes = [name for name in _dx_order if name in _dx_present]

_death_nodes = [d for d in ['Discharged alive', 'In-hospital death'] if d in _sankey_df['death'].unique()]


def _linspace_positions(n):
    if n <= 1:
        return [0.5]
    return [i / (n - 1) for i in range(n)]


def _wrap_label(label, max_len=18):
    if not isinstance(label, str):
        return label
    if len(label) <= max_len:
        return label
    words = label.split(' ')
    lines = []
    current = ''
    for word in words:
        if len(current) + len(word) + (1 if current else 0) <= max_len:
            current = f"{current} {word}".strip()
        else:
            if current:
                lines.append(current)
            current = word
    if current:
        lines.append(current)
    return '<br>'.join(lines)


def _hex_to_rgba(hex_color, alpha=0.35):
    if not isinstance(hex_color, str):
        return f'rgba(150,150,150,{alpha})'
    hex_color = hex_color.lstrip('#')
    if len(hex_color) == 3:
        hex_color = ''.join([c * 2 for c in hex_color])
    try:
        r, g, b = (int(hex_color[i:i + 2], 16) for i in (0, 2, 4))
        return f'rgba({r},{g},{b},{alpha})'
    except Exception:
        return f'rgba(150,150,150,{alpha})'


_sources = []
_targets = []
_values = []
_link_inhosp = []


def _append_link(source, target, value, in_hosp_dx=None):
    _sources.append(source)
    _targets.append(target)
    _values.append(int(value))
    if COLOR_FLOWS_BY_INHOSPITAL_DX:
        _link_inhosp.append(in_hosp_dx)


if SHOW_ON_SCENE_SURVIVED_NODE:
    _on_scene_nodes = [
        d
        for d in ['Survived on scene', 'Died on scene']
        if d in _sankey_df['on_scene_survival_clean'].unique()
    ]
    _labels = _mech_nodes + _on_scene_nodes + _pre_dx_nodes + _dx_nodes + _death_nodes

    _mech_index = {name: i for i, name in enumerate(_mech_nodes)}
    _on_scene_index = {name: i + len(_mech_nodes) for i, name in enumerate(_on_scene_nodes)}
    _pre_dx_index = {
        name: i + len(_mech_nodes) + len(_on_scene_nodes)
        for i, name in enumerate(_pre_dx_nodes)
    }
    _dx_index = {
        name: i + len(_mech_nodes) + len(_on_scene_nodes) + len(_pre_dx_nodes)
        for i, name in enumerate(_dx_nodes)
    }
    _death_index = {
        name: i + len(_mech_nodes) + len(_on_scene_nodes) + len(_pre_dx_nodes) + len(_dx_nodes)
        for i, name in enumerate(_death_nodes)
    }

    if COLOR_FLOWS_BY_INHOSPITAL_DX:
        _flow_1 = (
            _sankey_df
            .groupby(['mechanism_group', 'on_scene_survival_clean', 'in_hosp_dx'])
            .size()
            .reset_index(name='count')
        )
        _flow_2 = (
            _sankey_df
            .query("on_scene_survival_clean == 'Survived on scene'")
            .groupby(['on_scene_survival_clean', 'pre_hosp_dx', 'in_hosp_dx'])
            .size()
            .reset_index(name='count')
        )
        _flow_3 = (
            _sankey_df
            .query("on_scene_survival_clean == 'Survived on scene'")
            .groupby(['pre_hosp_dx', 'in_hosp_dx'])
            .size()
            .reset_index(name='count')
        )
        _flow_4 = (
            _sankey_df
            .query("on_scene_survival_clean == 'Survived on scene'")
            .groupby(['in_hosp_dx', 'death'])
            .size()
            .reset_index(name='count')
        )
    else:
        _flow_1 = (
            _sankey_df
            .groupby(['mechanism_group', 'on_scene_survival_clean'])
            .size()
            .reset_index(name='count')
        )
        _flow_2 = (
            _sankey_df
            .query("on_scene_survival_clean == 'Survived on scene'")
            .groupby(['on_scene_survival_clean', 'pre_hosp_dx'])
            .size()
            .reset_index(name='count')
        )
        _flow_3 = (
            _sankey_df
            .query("on_scene_survival_clean == 'Survived on scene'")
            .groupby(['pre_hosp_dx', 'in_hosp_dx'])
            .size()
            .reset_index(name='count')
        )
        _flow_4 = (
            _sankey_df
            .query("on_scene_survival_clean == 'Survived on scene'")
            .groupby(['in_hosp_dx', 'death'])
            .size()
            .reset_index(name='count')
        )

    for _, row in _flow_1.iterrows():
        _append_link(
            _mech_index[row['mechanism_group']],
            _on_scene_index[row['on_scene_survival_clean']],
            row['count'],
            row.get('in_hosp_dx')
        )

    for _, row in _flow_2.iterrows():
        _append_link(
            _on_scene_index[row['on_scene_survival_clean']],
            _pre_dx_index[row['pre_hosp_dx']],
            row['count'],
            row.get('in_hosp_dx')
        )

    for _, row in _flow_3.iterrows():
        _append_link(
            _pre_dx_index[row['pre_hosp_dx']],
            _dx_index[row['in_hosp_dx']],
            row['count'],
            row.get('in_hosp_dx')
        )

    for _, row in _flow_4.iterrows():
        _append_link(
            _dx_index[row['in_hosp_dx']],
            _death_index[row['death']],
            row['count'],
            row.get('in_hosp_dx')
        )

    _y_mech = _linspace_positions(len(_mech_nodes))
    _y_on_scene = _linspace_positions(len(_on_scene_nodes))
    _y_pre_dx = _linspace_positions(len(_pre_dx_nodes))
    _y_dx = _linspace_positions(len(_dx_nodes))
    _y_death = _linspace_positions(len(_death_nodes))

    _x_positions = (
        [0.0] * len(_mech_nodes)
        + [0.25] * len(_on_scene_nodes)
        + [0.5] * len(_pre_dx_nodes)
        + [0.75] * len(_dx_nodes)
        + [1.0] * len(_death_nodes)
    )
    _y_positions = _y_mech + _y_on_scene + _y_pre_dx + _y_dx + _y_death
else:
    _on_scene_died_label = (
        'Died on scene'
        if 'Died on scene' in _sankey_df['on_scene_survival_clean'].unique()
        else None
    )
    _labels = (
        _mech_nodes
        + ([_on_scene_died_label] if _on_scene_died_label else [])
        + _pre_dx_nodes
        + _dx_nodes
        + _death_nodes
    )

    _mech_index = {name: i for i, name in enumerate(_mech_nodes)}
    _on_scene_died_index = len(_mech_nodes) if _on_scene_died_label else None
    _pre_dx_index = {
        name: i + len(_mech_nodes) + (1 if _on_scene_died_label else 0)
        for i, name in enumerate(_pre_dx_nodes)
    }
    _dx_index = {
        name: i + len(_mech_nodes) + (1 if _on_scene_died_label else 0) + len(_pre_dx_nodes)
        for i, name in enumerate(_dx_nodes)
    }
    _death_index = {
        name: i + len(_mech_nodes) + (1 if _on_scene_died_label else 0) + len(_pre_dx_nodes) + len(_dx_nodes)
        for i, name in enumerate(_death_nodes)
    }

    if COLOR_FLOWS_BY_INHOSPITAL_DX:
        _flow_1 = (
            _sankey_df
            .query("on_scene_survival_clean == 'Died on scene'")
            .groupby(['mechanism_group', 'on_scene_survival_clean', 'in_hosp_dx'])
            .size()
            .reset_index(name='count')
        )
        _flow_2 = (
            _sankey_df
            .query("on_scene_survival_clean == 'Survived on scene'")
            .groupby(['mechanism_group', 'pre_hosp_dx', 'in_hosp_dx'])
            .size()
            .reset_index(name='count')
        )
        _flow_3 = (
            _sankey_df
            .query("on_scene_survival_clean == 'Survived on scene'")
            .groupby(['pre_hosp_dx', 'in_hosp_dx'])
            .size()
            .reset_index(name='count')
        )
        _flow_4 = (
            _sankey_df
            .query("on_scene_survival_clean == 'Survived on scene'")
            .groupby(['in_hosp_dx', 'death'])
            .size()
            .reset_index(name='count')
        )
    else:
        _flow_1 = (
            _sankey_df
            .query("on_scene_survival_clean == 'Died on scene'")
            .groupby(['mechanism_group', 'on_scene_survival_clean'])
            .size()
            .reset_index(name='count')
        )
        _flow_2 = (
            _sankey_df
            .query("on_scene_survival_clean == 'Survived on scene'")
            .groupby(['mechanism_group', 'pre_hosp_dx'])
            .size()
            .reset_index(name='count')
        )
        _flow_3 = (
            _sankey_df
            .query("on_scene_survival_clean == 'Survived on scene'")
            .groupby(['pre_hosp_dx', 'in_hosp_dx'])
            .size()
            .reset_index(name='count')
        )
        _flow_4 = (
            _sankey_df
            .query("on_scene_survival_clean == 'Survived on scene'")
            .groupby(['in_hosp_dx', 'death'])
            .size()
            .reset_index(name='count')
        )

    if _on_scene_died_label:
        for _, row in _flow_1.iterrows():
            _append_link(
                _mech_index[row['mechanism_group']],
                _on_scene_died_index,
                row['count'],
                row.get('in_hosp_dx')
            )

    for _, row in _flow_2.iterrows():
        _append_link(
            _mech_index[row['mechanism_group']],
            _pre_dx_index[row['pre_hosp_dx']],
            row['count'],
            row.get('in_hosp_dx')
        )

    for _, row in _flow_3.iterrows():
        _append_link(
            _pre_dx_index[row['pre_hosp_dx']],
            _dx_index[row['in_hosp_dx']],
            row['count'],
            row.get('in_hosp_dx')
        )

    for _, row in _flow_4.iterrows():
        _append_link(
            _dx_index[row['in_hosp_dx']],
            _death_index[row['death']],
            row['count'],
            row.get('in_hosp_dx')
        )

    _y_mech = _linspace_positions(len(_mech_nodes))
    _y_on_scene = [0.99] if _on_scene_died_label else []
    _y_pre_dx = _linspace_positions(len(_pre_dx_nodes))
    _y_dx = _linspace_positions(len(_dx_nodes))
    _y_death = _linspace_positions(len(_death_nodes))

    if _on_scene_died_label:
        _x_positions = (
            [0.0] * len(_mech_nodes)
            + [0.25]
            + [0.5] * len(_pre_dx_nodes)
            + [0.75] * len(_dx_nodes)
            + [1.0] * len(_death_nodes)
        )
    else:
        _x_positions = (
            [0.0] * len(_mech_nodes)
            + [0.33] * len(_pre_dx_nodes)
            + [0.66] * len(_dx_nodes)
            + [1.0] * len(_death_nodes)
        )
    _y_positions = _y_mech + _y_on_scene + _y_pre_dx + _y_dx + _y_death
    

# Colors (same palette for pre- and in-hospital diagnosis)
_dx_palette = [
    '#1f77b4',
    '#ff7f0e',
    '#2ca02c',
    '#d62728',
    '#9467bd',
    '#8c564b',
    '#e377c2',
    '#7f7f7f',
    '#bcbd22',
    '#17becf',
    '#9edae5',
    '#c7c7c7',
]
_dx_color_map = {name: _dx_palette[i % len(_dx_palette)] for i, name in enumerate(_dx_nodes)}

_mech_palette = ['#4e79a7', '#f28e2b', '#e15759', '#76b7b2']
_mech_color_map = {name: _mech_palette[i % len(_mech_palette)] for i, name in enumerate(_mech_nodes)}

_on_scene_color_map = {
    'Survived on scene': '#59a14f',
    'Died on scene': '#e15759',
}
_death_color_map = {
    'Discharged alive': '#59a14f',
    'In-hospital death': '#e15759',
}

if SHOW_ON_SCENE_SURVIVED_NODE:
    _on_scene_label_set = set(_on_scene_nodes)
else:
    _on_scene_label_set = set([_on_scene_died_label]) if _on_scene_died_label else set()

_node_colors = []
for _label in _labels:
    if _label in _mech_color_map:
        _node_colors.append(_mech_color_map[_label])
    elif _label in _on_scene_label_set:
        _node_colors.append(_on_scene_color_map.get(_label, '#bab0ac'))
    elif _label in _dx_color_map:
        _node_colors.append(_dx_color_map[_label])
    elif _label in _death_color_map:
        _node_colors.append(_death_color_map[_label])
    else:
        _node_colors.append('#bab0ac')

if COLOR_FLOWS_BY_INHOSPITAL_DX:
    _link_colors = [
        _hex_to_rgba(_dx_color_map.get(_dx), alpha=0.25)
        for _dx in _link_inhosp
    ]
else:
    _link_colors = None


_rightmost_labels = set(_death_nodes)
_labels_display = [
    '' if label in _rightmost_labels else _wrap_label(label)
    for label in _labels
 ]

_rightmost_annotations = []
for _label, _x, _y in zip(_labels, _x_positions, _y_positions):
    if _label in _rightmost_labels:
        if _label == 'Discharged alive':
            _y = 0.4
        _rightmost_annotations.append({
            'xref': 'paper',
            'yref': 'paper',
            'x': 1,
            'y': 1 - _y,
            'text': _wrap_label(_label),
            'showarrow': False,
            'xanchor': 'left',
            'yanchor': 'middle',
            'align': 'left',
            'font': {'size': 12},
        })

fig = go.Figure(
    data=[
        go.Sankey(
            arrangement='snap',
            node=dict(
                label=_labels_display,
                pad=15,
                thickness=15,
                x=_x_positions,
                y=_y_positions,
                color=_node_colors,
            ),
            link=dict(
                source=_sources,
                target=_targets,
                value=_values,
                color=_link_colors if COLOR_FLOWS_BY_INHOSPITAL_DX else None,
            ),
        )
    ]
)

fig.update_layout(
    title_text='',
    font_size=12,
    annotations=_rightmost_annotations,
    margin=dict(r=120),
)

fig.show()

In [None]:
output_dir = '/Users/jk1/Library/CloudStorage/OneDrive-UniversitédeGenève/icu_research/prehospital/pediatric_trauma/figures'
fig.write_image(
    f'{output_dir}/sankey_diagram2.png',
    scale=4,
    width=1200,
    height=700,
)
# fig.write_html(f'{output_dir}/sankey_diagram.html', include_plotlyjs='cdn')