# Sankey diagram

branchings 
 - Frist branching: mechanisms
- Second branching: on-scene survival 
- Third branching: pre-hospital diagnosis
- Fourth branching: inhospital diagnosis
- Fifth branching: death


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go

In [None]:
# Options
SHOW_ON_SCENE_SURVIVED_NODE = True

In [None]:
data_path = '/Users/jk1/Library/CloudStorage/OneDrive-UniversitédeGenève/icu_research/prehospital/pediatric_trauma/data/Data_PedRegaTrauma_coded_for_analysis_250417.xlsx'

In [None]:
# Load data
_df = pd.read_excel(data_path, sheet_name='All centres cleaned')

In [None]:
_df['on-scene_survival'].value_counts()

In [None]:
_df['on-scene_survival'] = _df['on-sceen survival (y/n)'] 
# Place discharge != 0 means survived hospital stay
_df['hospital_survival'] = _df['Place discharge'] != 0

In [None]:

# Mechanism encoding (from fig_mechanism.ipynb)
mechanism_encoding_groups_map = {
    'Collisions': [1, 5, 6],
    'Falls': [2, 7],
    'Burns': [3],
    'Other': [4, 8],
}

_df['mechanism_group'] = _df['Mechanism of injury'].map(
    lambda x: next((group for group, codes in mechanism_encoding_groups_map.items() if x in codes), pd.NA)
)

# On-scene survival encoding
def _map_on_scene_survival(val):
    if pd.isna(val):
        return pd.NA
    if isinstance(val, str):
        v = val.strip().lower()
        if v in {'yes', 'y', 'survived', 'alive', '1', 'true'}:
            return 'Survived on scene'
        if v in {'no', 'n', 'dead', 'died', '0', 'false'}:
            return 'Died on scene'
    if val in {1, True}:
        return 'Survived on scene'
    if val in {0, False}:
        return 'Died on scene'
    return pd.NA


_df['on_scene_survival_clean'] = _df['on-scene_survival'].apply(_map_on_scene_survival)

# In-hospital diagnosis encoding (based on dx_accuracy.ipynb)
dx_code_to_name = {
    0: 'No diagnosis',
    1: 'Traumatic brain or cervical spine injury',
    2: 'Chest trauma',
    3: 'Abdominal trauma',
    4: 'Pelvic Trauma',
    5: 'Upper extremity trauma',
    6: 'Lower extremity trauma',
    7: 'Spine injury',
    8: 'Face',
    9: 'Polytrauma',
    10: 'Drowning',
    11: 'Burns',
}

_dx = _df['Main diagnosis in-hospital'].replace(999, pd.NA)


def _clean_dx(val):
    if pd.isna(val):
        return pd.NA
    if isinstance(val, (int, float)) and not pd.isna(val):
        try:
            return int(val)
        except Exception:
            return pd.NA
    if isinstance(val, str):
        val = val.replace('C2-Intoxikation,', '').replace('nan,', '')
        val = val.replace('Obstrukt.Atemversagen -REA', '10')
        # keep digits and commas only
        val = ''.join(ch for ch in val if ch.isdigit() or ch == ',')
        val = val.strip(',')
        if val == '':
            return pd.NA
        try:
            return int(val.split(',')[0])
        except Exception:
            return pd.NA
    return pd.NA


_df['in_hosp_dx_code'] = _dx.apply(_clean_dx)
_df['in_hosp_dx'] = _df['in_hosp_dx_code'].map(dx_code_to_name)

# Hospital survival encoding (uses derived hospital_survival flag)
def _map_hospital_survival(val):
    if pd.isna(val):
        return pd.NA
    if isinstance(val, str):
        v = val.strip().lower()
        if v in {'yes', 'y', 'discharged alive', 'alive', '1', 'true'}:
            return 'Survived'
        if v in {'no', 'n', 'in-hospital death', 'died', '0', 'false'}:
            return 'Died'
    if val in {1, True}:
        return 'Discharged alive'
    if val in {0, False}:
        return 'In-hospital death'
    return pd.NA


_df['death'] = _df['hospital_survival'].apply(_map_hospital_survival)

# Build Sankey data
_sankey_df = _df[['mechanism_group', 'on_scene_survival_clean', 'in_hosp_dx', 'death']].dropna()

_mech_nodes = sorted(_sankey_df['mechanism_group'].unique())

# Default diagnosis ordering for 3rd branching
_dx_order = [
    'Polytrauma',
    'Traumatic brain or cervical spine injury',
    'Face',
    'Upper extremity trauma',
    'Lower extremity trauma',
    'Chest trauma',
    'Abdominal trauma',
    'Pelvic Trauma',
    'Spine injury',
    'Burns',
    'Drowning',
    'No diagnosis',
]
_dx_present = set(_sankey_df['in_hosp_dx'].unique())
_dx_nodes = [name for name in _dx_order if name in _dx_present]

_death_nodes = [d for d in ['Discharged alive', 'In-hospital death'] if d in _sankey_df['death'].unique()]


def _linspace_positions(n):
    if n <= 1:
        return [0.5]
    return [i / (n - 1) for i in range(n)]


_sources = []
_targets = []
_values = []

if SHOW_ON_SCENE_SURVIVED_NODE:
    _on_scene_nodes = [
        d
        for d in ['Survived on scene', 'Died on scene']
        if d in _sankey_df['on_scene_survival_clean'].unique()
    ]
    _labels = _mech_nodes + _on_scene_nodes + _dx_nodes + _death_nodes

    _mech_index = {name: i for i, name in enumerate(_mech_nodes)}
    _on_scene_index = {name: i + len(_mech_nodes) for i, name in enumerate(_on_scene_nodes)}
    _dx_index = {
        name: i + len(_mech_nodes) + len(_on_scene_nodes)
        for i, name in enumerate(_dx_nodes)
    }
    _death_index = {
        name: i + len(_mech_nodes) + len(_on_scene_nodes) + len(_dx_nodes)
        for i, name in enumerate(_death_nodes)
    }

    _flow_1 = (
        _sankey_df
        .groupby(['mechanism_group', 'on_scene_survival_clean'])
        .size()
        .reset_index(name='count')
    )

    _flow_2 = (
        _sankey_df
        .query("on_scene_survival_clean == 'Survived on scene'")
        .groupby(['on_scene_survival_clean', 'in_hosp_dx'])
        .size()
        .reset_index(name='count')
    )

    _flow_3 = (
        _sankey_df
        .query("on_scene_survival_clean == 'Survived on scene'")
        .groupby(['in_hosp_dx', 'death'])
        .size()
        .reset_index(name='count')
    )

    for _, row in _flow_1.iterrows():
        _sources.append(_mech_index[row['mechanism_group']])
        _targets.append(_on_scene_index[row['on_scene_survival_clean']])
        _values.append(int(row['count']))

    for _, row in _flow_2.iterrows():
        _sources.append(_on_scene_index[row['on_scene_survival_clean']])
        _targets.append(_dx_index[row['in_hosp_dx']])
        _values.append(int(row['count']))

    for _, row in _flow_3.iterrows():
        _sources.append(_dx_index[row['in_hosp_dx']])
        _targets.append(_death_index[row['death']])
        _values.append(int(row['count']))

    _y_mech = _linspace_positions(len(_mech_nodes))
    _y_on_scene = _linspace_positions(len(_on_scene_nodes))
    _y_dx = _linspace_positions(len(_dx_nodes))
    _y_death = _linspace_positions(len(_death_nodes))

    _x_positions = (
        [0.0] * len(_mech_nodes)
        + [0.33] * len(_on_scene_nodes)
        + [0.66] * len(_dx_nodes)
        + [1.0] * len(_death_nodes)
    )
    _y_positions = _y_mech + _y_on_scene + _y_dx + _y_death
else:
    _on_scene_died_label = (
        'Died on scene'
        if 'Died on scene' in _sankey_df['on_scene_survival_clean'].unique()
        else None
    )
    _labels = _mech_nodes + ([_on_scene_died_label] if _on_scene_died_label else []) + _dx_nodes + _death_nodes

    _mech_index = {name: i for i, name in enumerate(_mech_nodes)}
    _on_scene_died_index = len(_mech_nodes) if _on_scene_died_label else None
    _dx_index = {
        name: i + len(_mech_nodes) + (1 if _on_scene_died_label else 0)
        for i, name in enumerate(_dx_nodes)
    }
    _death_index = {
        name: i + len(_mech_nodes) + (1 if _on_scene_died_label else 0) + len(_dx_nodes)
        for i, name in enumerate(_death_nodes)
    }

    _flow_1 = (
        _sankey_df
        .query("on_scene_survival_clean == 'Died on scene'")
        .groupby(['mechanism_group', 'on_scene_survival_clean'])
        .size()
        .reset_index(name='count')
    )

    _flow_2 = (
        _sankey_df
        .query("on_scene_survival_clean == 'Survived on scene'")
        .groupby(['mechanism_group', 'in_hosp_dx'])
        .size()
        .reset_index(name='count')
    )

    _flow_3 = (
        _sankey_df
        .query("on_scene_survival_clean == 'Survived on scene'")
        .groupby(['in_hosp_dx', 'death'])
        .size()
        .reset_index(name='count')
    )

    if _on_scene_died_label:
        for _, row in _flow_1.iterrows():
            _sources.append(_mech_index[row['mechanism_group']])
            _targets.append(_on_scene_died_index)
            _values.append(int(row['count']))

    for _, row in _flow_2.iterrows():
        _sources.append(_mech_index[row['mechanism_group']])
        _targets.append(_dx_index[row['in_hosp_dx']])
        _values.append(int(row['count']))

    for _, row in _flow_3.iterrows():
        _sources.append(_dx_index[row['in_hosp_dx']])
        _targets.append(_death_index[row['death']])
        _values.append(int(row['count']))

    _y_mech = _linspace_positions(len(_mech_nodes))
    _y_on_scene = [0.99] if _on_scene_died_label else []
    _y_dx = _linspace_positions(len(_dx_nodes))
    _y_death = _linspace_positions(len(_death_nodes))

    _x_positions = (
        [0.0] * len(_mech_nodes)
        + ([0.33] if _on_scene_died_label else [])
        + [0.66] * len(_dx_nodes)
        + [1.0] * len(_death_nodes)
    )
    _y_positions = _y_mech + _y_on_scene + _y_dx + _y_death
    

fig = go.Figure(
    data=[
        go.Sankey(
            arrangement='snap',
            node=dict(
                label=_labels,
                pad=10,
                thickness=10,
                x=_x_positions,
                y=_y_positions,
            ),
            link=dict(
                source=_sources,
                target=_targets,
                value=_values,
            ),
        )
    ]
)

fig.update_layout(
    title_text='',
    font_size=12,
)

fig.show()

In [None]:
output_dir = '/Users/jk1/Library/CloudStorage/OneDrive-UniversitédeGenève/icu_research/prehospital/pediatric_trauma/figures'
fig.write_image(
    f'{output_dir}/sankey_diagram.png',
    scale=4,
    width=1200,
    height=900,
)
# fig.write_html(f'{output_dir}/sankey_diagram.html', include_plotlyjs='cdn')