# Sankey diagram

branchings 
 - Frist branching: mechanisms
- Second branching: on-scene survival 
- Third branching: pre-hospital diagnosis
- Fourth branching: inhospital diagnosis
- Fifth branching: death


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go

In [None]:
data_path = '/Users/jk1/Library/CloudStorage/OneDrive-UniversitédeGenève/icu_research/prehospital/pediatric_trauma/data/Data_PedRegaTrauma_coded_for_analysis_250417.xlsx'

In [None]:
# Load data
_df = pd.read_excel(data_path, sheet_name='All centres cleaned')

In [None]:
_df.head()

In [None]:
_df['on-scene_survival'] = _df['on-sceen survival (y/n)'] 
# Place discharge != 0 means survived hospital stay
_df['hospital_survival'] = _df['Place discharge'] != 0

In [None]:


# Mechanism encoding (from fig_mechanism.ipynb)
mechanism_encoding_groups_map = {
    'collisions': [1, 5, 6],
    'falls': [2, 7],
    'burns': [3],
    'other': [4, 8],
}

_df['mechanism_group'] = _df['Mechanism of injury'].map(
    lambda x: next((group for group, codes in mechanism_encoding_groups_map.items() if x in codes), pd.NA)
)

# In-hospital diagnosis encoding (based on dx_accuracy.ipynb)
dx_code_to_name = {
    0: 'No diagnosis',
    1: 'Traumatic brain or cervical spine injury',
    2: 'Chest trauma',
    3: 'Abdominal trauma',
    4: 'Pelvic Trauma',
    5: 'Upper extremity trauma',
    6: 'Lower extremity trauma',
    7: 'Spine injury',
    8: 'Face',
    9: 'Polytrauma',
    10: 'Drowning',
    11: 'Burns',
}

_dx = _df['Main diagnosis in-hospital'].replace(999, pd.NA)


def _clean_dx(val):
    if pd.isna(val):
        return pd.NA
    if isinstance(val, (int, float)) and not pd.isna(val):
        try:
            return int(val)
        except Exception:
            return pd.NA
    if isinstance(val, str):
        val = val.replace('C2-Intoxikation,', '').replace('nan,', '')
        val = val.replace('Obstrukt.Atemversagen -REA', '10')
        # keep digits and commas only
        val = ''.join(ch for ch in val if ch.isdigit() or ch == ',')
        val = val.strip(',')
        if val == '':
            return pd.NA
        try:
            return int(val.split(',')[0])
        except Exception:
            return pd.NA
    return pd.NA


_df['in_hosp_dx_code'] = _dx.apply(_clean_dx)
_df['in_hosp_dx'] = _df['in_hosp_dx_code'].map(dx_code_to_name)

# Death encoding (uses first available survival column)
_survival_col_candidates = [
    'Survival of hospital stay',
    '1y survival',
    'on-sceen survival (y/n)',
]
_survival_col = next((c for c in _survival_col_candidates if c in _df.columns), None)
if _survival_col is None:
    raise ValueError(
        'No survival column found. Expected one of: '
        + ', '.join(_survival_col_candidates)
    )

_survival = _df[_survival_col].replace(999, pd.NA)


def _map_survival(val):
    if pd.isna(val):
        return pd.NA
    if isinstance(val, str):
        v = val.strip().lower()
        if v in {'yes', 'y', 'survived', 'alive', '1'}:
            return 'Survived'
        if v in {'no', 'n', 'dead', 'died', '0'}:
            return 'Died'
    if val in {1, True}:
        return 'Survived'
    if val in {0, False}:
        return 'Died'
    return pd.NA


_df['death'] = _survival.apply(_map_survival)

# Build Sankey data
_sankey_df = _df[['mechanism_group', 'in_hosp_dx', 'death']].dropna()

_mech_nodes = sorted(_sankey_df['mechanism_group'].unique())

# Default diagnosis ordering for 2nd branching
_dx_order = [
    'Polytrauma',
    'Traumatic brain or cervical spine injury',
    'Face',
    'Upper extremity trauma',
    'Lower extremity trauma',
    'Chest trauma',
    'Abdominal trauma',
    'Pelvic Trauma',
    'Spine injury',
    'Burns',
    'Drowning',
    'No diagnosis',
]
_dx_present = set(_sankey_df['in_hosp_dx'].unique())
_dx_nodes = [name for name in _dx_order if name in _dx_present]

_death_nodes = [d for d in ['Survived', 'Died'] if d in _sankey_df['death'].unique()]

_labels = _mech_nodes + _dx_nodes + _death_nodes

_mech_index = {name: i for i, name in enumerate(_mech_nodes)}
_dx_index = {name: i + len(_mech_nodes) for i, name in enumerate(_dx_nodes)}
_death_index = {
    name: i + len(_mech_nodes) + len(_dx_nodes)
    for i, name in enumerate(_death_nodes)
}

_flow_1 = (
    _sankey_df
    .groupby(['mechanism_group', 'in_hosp_dx'])
    .size()
    .reset_index(name='count')
)

_flow_2 = (
    _sankey_df
    .groupby(['in_hosp_dx', 'death'])
    .size()
    .reset_index(name='count')
)

_sources = []
_targets = []
_values = []

for _, row in _flow_1.iterrows():
    _sources.append(_mech_index[row['mechanism_group']])
    _targets.append(_dx_index[row['in_hosp_dx']])
    _values.append(int(row['count']))

for _, row in _flow_2.iterrows():
    _sources.append(_dx_index[row['in_hosp_dx']])
    _targets.append(_death_index[row['death']])
    _values.append(int(row['count']))


def _linspace_positions(n):
    if n <= 1:
        return [0.5]
    return [i / (n - 1) for i in range(n)]


_y_mech = _linspace_positions(len(_mech_nodes))
_y_dx = _linspace_positions(len(_dx_nodes))
_y_death = _linspace_positions(len(_death_nodes))

fig = go.Figure(
    data=[
        go.Sankey(
            arrangement='snap',
            node=dict(
                label=_labels,
                pad=20,
                thickness=15,
                x=[0.0] * len(_mech_nodes) + [0.5] * len(_dx_nodes) + [1.0] * len(_death_nodes),
                y=_y_mech + _y_dx + _y_death,
            ),
            link=dict(
                source=_sources,
                target=_targets,
                value=_values,
            ),
        )
    ]
)

fig.update_layout(
    title_text='',
    font_size=12,
)

fig.show()

In [None]:
output_dir = '/Users/jk1/Library/CloudStorage/OneDrive-UniversitédeGenève/icu_research/prehospital/pediatric_trauma/figures'
fig.write_image(
    f'{output_dir}/sankey_diagram.png',
    scale=4,
    width=1200,
    height=900,
)
# fig.write_html(f'{output_dir}/sankey_diagram.html', include_plotlyjs='cdn')