In [None]:
import os
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from scipy.stats import chi2_contingency

#### Part I: Create the Table of Descriptive Statistics

In [16]:
def load_data_for_statistics(base_path, site_mapping):
    data_dict = {}
    site_labels = ['Site1', 'Site2', 'Site3', 'Site4']
    for site_label in site_labels:
        site_name = site_mapping.get(site_label, site_label)
        outcome = pd.read_pickle(base_path + site_name + '/processed_data/outcome.pkl')
        demo = pd.read_pickle(base_path + site_name + '/AKI_DEMO.pkl')
        demo_deduplicated = demo[['PATID', 'ENCOUNTERID', 'DEATH_DATE']].drop_duplicates()
        demo_cleaned = (demo_deduplicated
                        .groupby(['PATID'], as_index=False)
                        .agg({'DEATH_DATE': lambda x: x.max() if x.notna().any() else pd.NaT}))
        demo_final = demo_cleaned.merge(
            demo[['PATID', 'ENCOUNTERID', 'AGE', 'SEX', 'RACE', 'HISPANIC']].drop_duplicates(),
            on='PATID',
            how='right')

        site_data = outcome.merge(demo_final,
                                  on=['PATID', 'ENCOUNTERID'],
                                  how='left')
        dx = pd.read_pickle(base_path + site_name + '/processed_data/dx.pkl')
        site_data = site_data.merge(dx[['PATID', 'ENCOUNTERID', 'PREADM_CKD_FLAG']],
                                    on=['PATID', 'ENCOUNTERID'],
                                    how='left')

        site_data['REVERSAL'] = ((site_data['RVRT_SINCE_ONSET'] + 1) <= 7)
        filter_aki1up = ((site_data.AKI2_SINCE_ADMIT - site_data.ONSET_SINCE_ADMIT) <= 7) & (
                    site_data['AKI_INIT_STG'] == 1)
        filter_aki2up = ((site_data.AKI3_SINCE_ADMIT - site_data.ONSET_SINCE_ADMIT) <= 7) & (
                    site_data['AKI_INIT_STG'] == 2)
        site_data['PROGRESSION'] = (filter_aki1up | filter_aki2up)
        site_data['DECEASE'] = ((site_data['DEATH_DATE'] - site_data['ONSET_DATE']).dt.days <= 7) & (
                    (site_data['DEATH_DATE'] - site_data['ONSET_DATE']).dt.days >= 0)
        site_data['PREADM_CKD'] = site_data['PREADM_CKD_FLAG'].fillna(False)

        data_dict[site_label] = site_data
    return data_dict

In [17]:
# Load data for the summary table
base_path = './'
result_path = os.path.join(base_path, 'result') + '/'

site_mapping = {
    "... MASKED_FOR_ANONYMITY": "... MASKED_FOR_ANONYMITY",
    # ...
}

data_dict = load_data_for_statistics(base_path, site_mapping)

In [6]:
race_mapping = {
    '01': 'Native American',
    '02': 'Asian',
    '03': 'Black',
    '05': 'White'
}

def categorize_race(race):
    if race in race_mapping:
        return race_mapping[race]
    elif pd.isna(race) or race in ['NI', 'UN']:
        return 'Unknown_race'
    else:
        return 'Other'


def categorize_hispanic(hispanic):
    if pd.isna(hispanic):
        return 'Unknown_eth'
    else:
        return 'Yes_eth' if hispanic == 'Y' else 'No_eth'


def categorize_sex(sex):
    if pd.isnull(sex):
        return 'Unknown'
    else:
        return 'Female' if sex == 'F' else 'Male'


# Function to calculate counts, proportions, and significance with Chi-Square Test of Independence
def calculate_stats_and_chisq_test(df_dict):
    results = []
    p_values = []

    categories = {
        'Age': {
            'bins': [18, 26, 36, 46, 56, 66, np.inf],
            'labels': ['18-25', '26-35', '36-45', '46-55', '56-65', '≥66'],
            'column': 'AGE'
        },
        'Sex': {
            'column': 'SEX_CATEGORIZED'
        },
        'Race': {
            'column': 'RACE_CATEGORIZED'
        },
        'Hispanic': {
            'column': 'HISPANIC_CATEGORIZED'
        },
        'Pre-admission CKD': {
            'bins': [-0.5, 0.5, 1.5],
            'labels': ['No_ckd', 'Yes_ckd'],
            'column': 'PREADM_CKD'
        },
        'AKI Stage at Onset': {
            'bins': [0.5, 1.5, 2.5, 3.5],
            'labels': ['AKI1', 'AKI2', 'AKI3'],
            'column': 'AKI_INIT_STG'
        },
        'AKI Reversal': {
            'bins': [-0.5, 0.5, 1.5],
            'labels': ['No_rvsl', 'Yes_rvsl'],
            'column': 'REVERSAL'
        },
        'AKI Progression': {
            'bins': [-0.5, 0.5, 1.5],
            'labels': ['No_stgup', 'Yes_stgup'],
            'column': 'PROGRESSION'
        },
        'Mortality (7 day)': {
            'bins': [-0.5, 0.5, 1.5],
            'labels': ['No_death', 'Yes_death'],
            'column': 'DECEASE'
        }
    }

    for category, details in categories.items():
        contingency_table = []
        for site, df in df_dict.items():
            if category == 'Race':
                df['RACE_CATEGORIZED'] = df['RACE'].apply(categorize_race)
            if category == 'Hispanic':
                df['HISPANIC_CATEGORIZED'] = df['HISPANIC'].apply(categorize_hispanic)
            if category == 'Sex':
                df['SEX_CATEGORIZED'] = df['SEX'].apply(categorize_sex)

            if 'bins' in details:
                df[category] = pd.cut(df[details['column']], bins=details['bins'], labels=details['labels'],
                                      right=False)
            else:
                df[category] = df[details['column']]

            counts = df[category].value_counts(sort=False)
            contingency_table.append(counts)

            total = counts.sum()
            proportions = counts / total

            for label, count in counts.items():
                proportion = proportions[label]
                formatted = f"{count:,}({proportion:.1%})"
                results.append([category, label, site, formatted])
        contingency_table_df = pd.concat(contingency_table, axis=1).fillna(0)
        chi2, p, _, _ = chi2_contingency(contingency_table_df.T)
        p_values.append([category, f"{p:.1e}"])

    results_df = pd.DataFrame(results, columns=['Category', 'Label', 'Site', 'Value'])
    p_values_df = pd.DataFrame(p_values, columns=['Category', 'p-value'])
    pivot_df = results_df.pivot(index=['Category', 'Label'], columns='Site', values='Value').reset_index()
    sorted_df = pivot_df.merge(p_values_df, on='Category', how='left')
    category_order = [
        'Age',
        'Sex',
        'Race',
        'Hispanic',
        'Pre-admission CKD',
        'AKI Stage at Onset',
        'AKI Reversal',
        'AKI Progression',
        'Mortality (7 day)'
    ]

    label_order = ['18-25', '26-35', '36-45', '46-55', '56-65', '≥66',
                   'Female', 'Male',
                   'White', 'Black', 'Asian', 'Native American', 'Other', 'Unknown_race',
                   'Yes_eth', 'No_eth', 'Unknown_eth',
                   'Yes_ckd', 'No_ckd',
                   'AKI1', 'AKI2', 'AKI3',
                   'Yes_rvsl', 'No_rvsl',
                   'Yes_stgup', 'No_stgup',
                   'Yes_death', 'No_death']

    sorted_df['Category'] = pd.Categorical(sorted_df['Category'], categories=category_order, ordered=True)
    sorted_df['Label'] = pd.Categorical(sorted_df['Label'], categories=label_order, ordered=True)

    sorted_df = sorted_df.sort_values(by=['Category', 'Label']).reset_index(drop=True)

    encounter_counts = ['Encounter Count', 'N']
    patient_counts = ['Patient Count', 'N']
    for site, df in df_dict.items():
        encounter_count = f"{df[['PATID', 'ENCOUNTERID']].drop_duplicates().shape[0]:,}"
        patient_count = f"{df.PATID.nunique():,}"
        encounter_counts.append(encounter_count)
        patient_counts.append(patient_count)

    enc_df = pd.DataFrame(encounter_counts, index=['Category', 'Label'] + list(data_dict.keys())).T
    pat_df = pd.DataFrame(patient_counts, index=['Category', 'Label'] + list(data_dict.keys())).T
    df_final = pd.concat([enc_df, pat_df, sorted_df], axis=0)

    df_final.to_csv(result_path + 'descriptive_tbl.csv', index=False, sep=',', encoding='utf-8')
    return df_final

In [7]:
descriptive_stats = calculate_stats_and_chisq_test(data_dict)

#### Part II: Create the Sankey Plot

In [21]:
def plot_sankey(base_path):
    sankey_df = pd.read_pickle(base_path + 'data_sankey' + '.pkl')
    sankey_df = sankey_df[
        ['State_d0', 'State_d1', 'State_d2', 'State_d3', 'State_d4', 'State_d5', 'State_d6', 'State_d7']]
    num_days = 8
    sankey_data = []
    for i in range(sankey_df.shape[1] - 1):
        temp_df = sankey_df.groupby([sankey_df.columns[i], sankey_df.columns[i + 1]]).size().reset_index(name='count')
        temp_df.columns = ['source', 'target', 'value']
        temp_df['source'] = temp_df['source'] + i * 6  
        temp_df['target'] = temp_df['target'] + (i + 1) * 6
        temp_df['value'] = temp_df['value'] 
        sankey_data.append(temp_df)

    sankey_nodes = pd.concat(sankey_data)
    source = sankey_nodes['source'].tolist()
    target = sankey_nodes['target'].tolist()
    value = sankey_nodes['value'].tolist()

    label_single = ['No AKI', 'AKI1', 'AKI2', 'AKI3', 'Discharged', 'Dead']
    labels = label_single * num_days

    color = ['#1f77b4',  # muted blue
             '#ff7f0e',  # safety orange
             '#2ca02c',  # green
             '#e377c2',  # pinkish
             '#d4af37',  # golden yellow
             '#7f7f7f'  # grey
             ]

    colors = color * num_days

    node_percentages = []
    for i, label in enumerate(labels):
        day = i // 6
        status = i % 6
        mask = sankey_df['State_d' + str(day)] == status
        pct = mask.sum() / len(sankey_df) * 100
        node_percentages.append(f'{pct:.1f}%')

    labels_with_percentages = [f"{label} ({node_percentages[i]})" for i, label in enumerate(labels)]
    custom_x_positions = [0.05, 0.15, 0.30, 0.45, 0.60, 0.620, 0.72, 1.0]

    fig = go.Figure(go.Sankey(
        node=dict(
            pad=30,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=labels_with_percentages,
            color=colors,
            x=custom_x_positions
        ),
        link=dict(
            source=source,
            target=target,
            value=value,
            color='rgba(150,150,150,0.2)',
            hoverinfo='skip'
        )
    ))

    annotations = [
        dict(
            x=1.007 * (i / (num_days - 1)),
            y=1.065,
            xref="paper",
            yref="paper",
            text=f"Day {i}" if i != 0 else f'Onset Day',
            showarrow=False,
            font=dict(size=18)
        ) for i in range(num_days)
    ]

    fig.update_layout(
        font_size=12.9,
        width=1560,
        height=890,
        annotations=annotations
    )

    fig.write_html(result_path + 'sankey_plot.html')

In [None]:
# Generate the Sankey plot
plot_sankey(base_path)