In [34]:
#!/usr/bin/env python3
"""
Bridge Edge Analysis: High Betweenness + Large Mortality Difference
Identifies critical edges connecting diseases with very different mortality rates
"""

import pandas as pd
import numpy as np
import networkx as nx
from pathlib import Path


# Directories
DATA_DIR = Path('Data/')
OUTPUT_DIR = Path('outputs/')

def load_network_edges_with_mortality(gender, age_group):
    """Load network edges with betweenness and mortality differences"""
    
    # Load adjacency matrix
    adj_path = DATA_DIR / f'Adj_Matrix_{gender}_ICD_age_{age_group}.csv'
    A = pd.read_csv(adj_path, sep=' ', header=None).values
    
    # Create graph
    G = nx.from_numpy_array(A)
    
    # Calculate edge betweenness
    edge_betweenness = nx.edge_betweenness_centrality(G, weight=None)
    
    # Load ICD codes
    icd_df = pd.read_csv(DATA_DIR / 'ICD10_Diagnoses_All.csv')
    icd_dict = dict(zip(icd_df['diagnose_id'] - 1, icd_df['icd_code']))  # 0-indexed
    descr_dict = dict(zip(icd_df['diagnose_id'] - 1, icd_df['descr']))
    
    # Load mortality data
    if gender == 'Female':
        mortality_df = pd.read_csv(DATA_DIR / 'mortality_diag_Female.csv')
    else:
        mortality_df = pd.read_csv(DATA_DIR / 'mortality_diag_Male.csv')
    
    # Filter mortality for this age group
    mortality_age = mortality_df[mortality_df['age_10'] == age_group]
    mortality_dict = dict(zip(mortality_age['icd_code'], mortality_age['mortality']))
    
    # Build edge dataframe
    results = []
    for edge in G.edges():
        node1, node2 = edge
        
        # Get ICD codes
        icd1 = icd_dict.get(node1)
        icd2 = icd_dict.get(node2)
        
        if icd1 is None or icd2 is None:
            continue
        
        # Get descriptions
        desc1 = descr_dict.get(node1, '')
        desc2 = descr_dict.get(node2, '')
        
        # Get mortality
        mort1 = mortality_dict.get(icd1, 0)
        mort2 = mortality_dict.get(icd2, 0)
        
        # Get edge betweenness
        bet = edge_betweenness.get(edge, 0)
        
        # Calculate mortality difference
        mort_diff = abs(mort1 - mort2)
        
        # Get edge weight from adjacency matrix
        weight = A[node1, node2]
        
        results.append({
            'Sex': gender,
            'Age_Group': age_group,
            'ICD_Code_1': icd1,
            'ICD_Code_2': icd2,
            'Description_1': desc1,
            'Description_2': desc2,
            'Weight': weight,
            'Edge_Betweenness': bet,
            'Mortality_1': mort1,
            'Mortality_2': mort2,
            'Mortality_Diff': mort_diff
        })
    
    return pd.DataFrame(results)

def identify_critical_bridge_edges_zscore(df_all, top_percent=40, min_mort_diff=0.10):
    """
    Identify bridge edges using Z-score product method (as in manuscript)
    
    Parameters:
    - top_percent: Top percentage of z-score products to select (default 40)
    - min_mort_diff: Minimum absolute mortality difference (default 0.10 = 10%)
    """
    
    print(f"\nZ-SCORE METHOD (Manuscript approach):")
    print(f"  - Computing z(betweenness) × z(mortality_diff)")
    print(f"  - Selecting top {top_percent}% of z-score products")
    print(f"  - Minimum absolute mortality difference > {min_mort_diff*100:.0f}%")
    
    all_bridges = []
    
    # Process each sex-age group separately
    for sex in ['Female', 'Male']:
        for age_group in sorted(df_all['Age_Group'].unique()):
            subset = df_all[
                (df_all['Sex'] == sex) & 
                (df_all['Age_Group'] == age_group)
            ].copy()
            
            if len(subset) == 0:
                continue
            
            # Calculate z-scores for betweenness
            bet_mean = subset['Edge_Betweenness'].mean()
            bet_std = subset['Edge_Betweenness'].std()
            if bet_std > 0:
                subset['z_betweenness'] = (subset['Edge_Betweenness'] - bet_mean) / bet_std
            else:
                subset['z_betweenness'] = 0
            
            # Calculate z-scores for mortality difference
            mort_mean = subset['Mortality_Diff'].mean()
            mort_std = subset['Mortality_Diff'].std()
            if mort_std > 0:
                subset['z_mort_diff'] = (subset['Mortality_Diff'] - mort_mean) / mort_std
            else:
                subset['z_mort_diff'] = 0
            
            # Calculate z-score product (only for positive z-scores)
            subset['z_product'] = subset['z_betweenness'] * subset['z_mort_diff']
            
            # Filter: positive z-scores, above threshold percentile, and minimum mortality diff
            threshold_percentile = 100 - top_percent
            z_threshold = subset['z_product'].quantile(threshold_percentile / 100)
            
            bridge_edges = subset[
                (subset['z_betweenness'] > 0) &
                (subset['z_mort_diff'] > 0) &
                (subset['z_product'] >= z_threshold) &
                (subset['Mortality_Diff'] >= min_mort_diff)
            ].copy()
            
            if len(bridge_edges) > 0:
                bridge_edges['Selection_Method'] = 'Z-Score Product'
                all_bridges.append(bridge_edges)
    
    if len(all_bridges) == 0:
        return pd.DataFrame()
    
    return pd.concat(all_bridges, ignore_index=True)

def identify_critical_bridge_edges(df_all, bet_percentile=95, mort_diff_percentile=95, min_mort_diff=0.10):
    """
    Identify edges with high betweenness AND large mortality difference
    
    Parameters:
    - bet_percentile: Percentile threshold for betweenness (default 95)
    - mort_diff_percentile: Percentile threshold for mortality difference (default 95)
    - min_mort_diff: Minimum absolute mortality difference (default 0.10 = 10%)
    """
    
    print(f"\nPERCENTILE METHOD:")
    print(f"  - Betweenness > {bet_percentile}th percentile")
    print(f"  - Mortality difference > {mort_diff_percentile}th percentile")
    print(f"  - Minimum absolute mortality difference > {min_mort_diff*100:.0f}%")
    
    all_bridges = []
    
    # Process each sex-age group separately
    for sex in ['Female', 'Male']:
        for age_group in sorted(df_all['Age_Group'].unique()):
            subset = df_all[
                (df_all['Sex'] == sex) & 
                (df_all['Age_Group'] == age_group)
            ].copy()
            
            if len(subset) == 0:
                continue
            
            # Calculate thresholds
            bet_threshold = subset['Edge_Betweenness'].quantile(bet_percentile / 100)
            mort_diff_threshold = subset['Mortality_Diff'].quantile(mort_diff_percentile / 100)
            
            # Identify bridge edges with:
            # 1. High betweenness (percentile threshold)
            # 2. Large mortality difference (percentile threshold)
            # 3. Minimum absolute mortality difference (clinical significance)
            bridge_edges = subset[
                (subset['Edge_Betweenness'] >= bet_threshold) & 
                (subset['Mortality_Diff'] >= mort_diff_threshold) &
                (subset['Mortality_Diff'] >= min_mort_diff)
            ].copy()
            
            if len(bridge_edges) > 0:
                bridge_edges['Betweenness_Percentile'] = bet_percentile
                bridge_edges['Mort_Diff_Percentile'] = mort_diff_percentile
                all_bridges.append(bridge_edges)
    
    if len(all_bridges) == 0:
        return pd.DataFrame()
    
    return pd.concat(all_bridges, ignore_index=True)

def add_english_descriptions(df):
    """Add English descriptions"""
    
    # Load English descriptions
    eng_df = pd.read_csv(DATA_DIR / 'DiagAll_Eng__2_.csv')
    icd_to_eng = dict(zip(eng_df['Code'], eng_df['ShortDescription']))
    
    df['Description_Eng_1'] = df['ICD_Code_1'].map(icd_to_eng)
    df['Description_Eng_2'] = df['ICD_Code_2'].map(icd_to_eng)
    
    # Fill missing with German description
    df['Description_Eng_1'] = df['Description_Eng_1'].fillna(df['Description_1'])
    df['Description_Eng_2'] = df['Description_Eng_2'].fillna(df['Description_2'])
    
    return df

def generate_latex_table(df):
    """Generate LaTeX table with lower mortality → higher mortality ordering"""
    
    # Sort by sex, age, and betweenness
    age_order = {1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8}
    df['age_num'] = df['Age_Group'].map(age_order)
    df = df.sort_values(['Sex', 'age_num', 'Edge_Betweenness'], ascending=[True, True, False])
    
    # Map age group to range
    age_map = {1: '0-9', 2: '10-19', 3: '20-29', 4: '30-39',
               5: '40-49', 6: '50-59', 7: '60-69', 8: '70-79'}
    df['Age_Range'] = df['Age_Group'].map(age_map)
    
    # Reorder columns: lower mortality first
    df['ICD_Low'] = df.apply(lambda x: x['ICD_Code_1'] if x['Mortality_1'] < x['Mortality_2'] else x['ICD_Code_2'], axis=1)
    df['ICD_High'] = df.apply(lambda x: x['ICD_Code_2'] if x['Mortality_1'] < x['Mortality_2'] else x['ICD_Code_1'], axis=1)
    df['Mort_Low'] = df.apply(lambda x: x['Mortality_1'] if x['Mortality_1'] < x['Mortality_2'] else x['Mortality_2'], axis=1)
    df['Mort_High'] = df.apply(lambda x: x['Mortality_2'] if x['Mortality_1'] < x['Mortality_2'] else x['Mortality_1'], axis=1)
    
    latex = """\\begin{longtable}{llllrrr}
\\caption{Critical Bridge Edges: Low to High Mortality Connections} \\label{tab:bridge_edges} \\\\
\\toprule
Sex & Age & ICD (Low) & ICD (High) & Betweenness & Mort. Low & Mort. High \\\\
\\midrule
\\endfirsthead

\\multicolumn{7}{c}{\\tablename\\ \\thetable\\ -- Continued from previous page} \\\\
\\toprule
Sex & Age & ICD (Low) & ICD (High) & Betweenness & Mort. Low & Mort. High \\\\
\\midrule
\\endhead

\\midrule
\\multicolumn{7}{r}{Continued on next page} \\\\
\\endfoot

\\bottomrule
\\endlastfoot

"""
    
    for idx, row in df.iterrows():
        latex += f"{row['Sex']} & {row['Age_Range']} & {row['ICD_Low']} & {row['ICD_High']} & "
        latex += f"{row['Edge_Betweenness']:.5f} & {row['Mort_Low']:.4f} & {row['Mort_High']:.4f} \\\\\n"
        
        # Add midrule after age group change
        if idx < len(df) - 1:
            next_row = df.iloc[idx + 1]
            if (row['Sex'] != next_row['Sex']) or (row['Age_Group'] != next_row['Age_Group']):
                latex += "\\midrule\n"
    
    latex += """\\end{longtable}
"""
    
    return latex

def print_summary(df):
    """Print summary statistics"""
    
    print("\n" + "="*80)
    print("SUMMARY: CRITICAL BRIDGE EDGES")
    print("="*80)
    
    print(f"\nTotal bridge edges identified: {len(df)}")
    
    for sex in ['Female', 'Male']:
        sex_data = df[df['Sex'] == sex]
        print(f"\n{sex}: {len(sex_data)} edges")
        
        age_map = {1: '0-9', 2: '10-19', 3: '20-29', 4: '30-39',
                   5: '40-49', 6: '50-59', 7: '60-69', 8: '70-79'}
        
        for age in sorted(sex_data['Age_Group'].unique()):
            age_data = sex_data[sex_data['Age_Group'] == age]
            age_str = age_map[age]
            print(f"  {age_str}: {len(age_data)} edges")
            
            # Show top 3
            if len(age_data) > 0:
                top3 = age_data.nlargest(3, 'Mortality_Diff')
                for _, edge in top3.iterrows():
                    # Order: lower mortality → higher mortality
                    if edge['Mortality_1'] < edge['Mortality_2']:
                        low_icd = edge['ICD_Code_1']
                        high_icd = edge['ICD_Code_2']
                        low_mort = edge['Mortality_1']
                        high_mort = edge['Mortality_2']
                    else:
                        low_icd = edge['ICD_Code_2']
                        high_icd = edge['ICD_Code_1']
                        low_mort = edge['Mortality_2']
                        high_mort = edge['Mortality_1']
                    
                    print(f"    - {low_icd} → {high_icd}: "
                          f"Bet={edge['Edge_Betweenness']:.5f}, "
                          f"Mort: {low_mort:.4f} → {high_mort:.4f} "
                          f"(Δ{edge['Mortality_Diff']:.4f})")

def main():
    """Main execution"""
    
    print("="*80)
    print("CRITICAL BRIDGE EDGE ANALYSIS")
    print("="*80)
    
    # Load data for all sex-age groups
    print("\nLoading edge data with mortality...")
    all_data = []
    
    for gender in ['Female', 'Male']:
        for age_group in range(1, 9):
            print(f"  Processing {gender} age {age_group}...")
            df = load_network_edges_with_mortality(gender, age_group)
            all_data.append(df)
    
    df_all = pd.concat(all_data, ignore_index=True)
    print(f"\nTotal edges analyzed: {len(df_all)}")
    
    # Method 1: Percentile method (95th percentile + minimum 10% mortality difference)
    print("\n" + "="*80)
    print("METHOD 1: PERCENTILE THRESHOLDS")
    print("="*80)
    df_bridges_percentile = identify_critical_bridge_edges(
        df_all, 
        bet_percentile=95, 
        mort_diff_percentile=95,
        min_mort_diff=0.10
    )
    print(f"Critical bridge edges identified: {len(df_bridges_percentile)}")
    
    # Method 2: Z-score product method (as in manuscript, but using top 5%)
    print("\n" + "="*80)
    print("METHOD 2: Z-SCORE PRODUCT (TOP 5%)")
    print("="*80)
    df_bridges_zscore = identify_critical_bridge_edges_zscore(
        df_all,
        top_percent=5,  # Changed from 40% to 5%
        min_mort_diff=0.30
    )
    print(f"Critical bridge edges identified: {len(df_bridges_zscore)}")
    
    # Add English descriptions to both
    print("\nAdding English descriptions...")
    df_bridges_percentile = add_english_descriptions(df_bridges_percentile)
    df_bridges_zscore = add_english_descriptions(df_bridges_zscore)
    
    # Generate LaTeX tables for both methods
    print("\nGenerating LaTeX tables...")
    
    # Method 1 table
    latex_percentile = generate_latex_table(df_bridges_percentile)
    
    # Method 2 table (modify caption)
    latex_zscore = generate_latex_table(df_bridges_zscore)
    latex_zscore = latex_zscore.replace(
        "Critical Bridge Edges: Low to High Mortality Connections",
        "Critical Bridge Edges (Z-Score Method): Low to High Mortality Connections"
    )
    latex_zscore = latex_zscore.replace(
        "\\label{tab:bridge_edges}",
        "\\label{tab:bridge_edges_zscore}"
    )
    
    # Save outputs
    print("\nSaving outputs...")
    
    # Save Method 1 (Percentile)
    tex_file_1 = OUTPUT_DIR / 'bridge_edges_mortality_PERCENTILE.tex'
    with open(tex_file_1, 'w') as f:
        f.write(latex_percentile)
    print(f"✓ Method 1 (Percentile) LaTeX saved to: {tex_file_1}")
    
    csv_file_1 = OUTPUT_DIR / 'bridge_edges_mortality_PERCENTILE.csv'
    df_bridges_percentile.to_csv(csv_file_1, index=False)
    print(f"✓ Method 1 (Percentile) CSV saved to: {csv_file_1}")
    
    # Save Method 2 (Z-Score)
    tex_file_2 = OUTPUT_DIR / 'bridge_edges_mortality_ZSCORE.tex'
    with open(tex_file_2, 'w') as f:
        f.write(latex_zscore)
    print(f"✓ Method 2 (Z-Score) LaTeX saved to: {tex_file_2}")
    
    csv_file_2 = OUTPUT_DIR / 'bridge_edges_mortality_ZSCORE.csv'
    df_bridges_zscore.to_csv(csv_file_2, index=False)
    print(f"✓ Method 2 (Z-Score) CSV saved to: {csv_file_2}")
    
    # Print summaries for both
    print("\n" + "="*80)
    print("METHOD 1 SUMMARY (PERCENTILE)")
    print("="*80)
    print_summary(df_bridges_percentile)
    
    print("\n" + "="*80)
    print("METHOD 2 SUMMARY (Z-SCORE - MANUSCRIPT METHOD)")
    print("="*80)
    print_summary(df_bridges_zscore)
    
    # Comparison
    print("\n" + "="*80)
    print("COMPARISON OF METHODS")
    print("="*80)
    print(f"Method 1 (Percentile 95th): {len(df_bridges_percentile)} edges")
    print(f"Method 2 (Z-Score top 40%): {len(df_bridges_zscore)} edges")
    
    # Find overlap
    if len(df_bridges_percentile) > 0 and len(df_bridges_zscore) > 0:
        # Create edge identifiers
        df_bridges_percentile['edge_id'] = df_bridges_percentile.apply(
            lambda x: f"{x['Sex']}_{x['Age_Group']}_{x['ICD_Code_1']}_{x['ICD_Code_2']}", axis=1
        )
        df_bridges_zscore['edge_id'] = df_bridges_zscore.apply(
            lambda x: f"{x['Sex']}_{x['Age_Group']}_{x['ICD_Code_1']}_{x['ICD_Code_2']}", axis=1
        )
        
        overlap = set(df_bridges_percentile['edge_id']) & set(df_bridges_zscore['edge_id'])
        print(f"Overlapping edges: {len(overlap)}")
        print(f"Overlap percentage: {len(overlap)/len(df_bridges_percentile)*100:.1f}% of Method 1")
        print(f"Overlap percentage: {len(overlap)/len(df_bridges_zscore)*100:.1f}% of Method 2")
    
    print("\n" + "="*80)
    print("✓ ANALYSIS COMPLETE")
    print("="*80)

if __name__ == '__main__':
    main()

CRITICAL BRIDGE EDGE ANALYSIS

Loading edge data with mortality...
  Processing Female age 1...
  Processing Female age 2...
  Processing Female age 3...
  Processing Female age 4...
  Processing Female age 5...
  Processing Female age 6...
  Processing Female age 7...
  Processing Female age 8...
  Processing Male age 1...
  Processing Male age 2...
  Processing Male age 3...
  Processing Male age 4...
  Processing Male age 5...
  Processing Male age 6...
  Processing Male age 7...
  Processing Male age 8...

Total edges analyzed: 23886

METHOD 1: PERCENTILE THRESHOLDS

PERCENTILE METHOD:
  - Betweenness > 95th percentile
  - Mortality difference > 95th percentile
  - Minimum absolute mortality difference > 10%
Critical bridge edges identified: 61

METHOD 2: Z-SCORE PRODUCT (TOP 5%)

Z-SCORE METHOD (Manuscript approach):
  - Computing z(betweenness) × z(mortality_diff)
  - Selecting top 5% of z-score products
  - Minimum absolute mortality difference > 30%
Critical bridge edges identi