In [27]:
#!/usr/bin/env python3
"""
COMPLETE THRESHOLD ROBUSTNESS ANALYSIS
======================================

This script performs a complete threshold robustness test starting from
adjacency matrices and produces all results and visualizations.

INPUT:
  - Adjacency matrices (Adj_Matrix_*.csv) containing odds ratios
  - ICD code mappings (ICD10_Diagnoses_All.csv)
  - Prevalence data (Prevalence_Sex_Age_Year_ICD.csv)
  - Mortality data (mortality_diag_Female.csv, mortality_diag_Male.csv)

OUTPUT:
  - CSV files with results for each threshold
  - Summary statistics and overlap analysis
  - Publication-ready visualizations

USAGE:
  python threshold_robustness_complete.py
"""

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from typing import Dict, List, Set
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# CONFIGURATION
# ============================================================================

DATA_DIR = Path('Data')
OUTPUT_DIR = Path('outputs')
FIG_DIR = OUTPUT_DIR / 'threshold_robustness_figures'
FIG_DIR.mkdir(exist_ok=True)

# Age group mapping
AGE_MAP = {1: '0-9', 2: '10-19', 3: '20-29', 4: '30-39',
           5: '40-49', 6: '50-59', 7: '60-69', 8: '70-79'}

# OR thresholds to test
THRESHOLDS = {
    'or_1.5': {'name': 'OR > 1.5 (Baseline)', 'value': 1.5},
    'or_2.0': {'name': 'OR > 2.0 (Strict)', 'value': 2.0}
}

print("=" * 80)
print("COMPLETE THRESHOLD ROBUSTNESS ANALYSIS")
print("=" * 80)
print()
print("Starting from adjacency matrices...")
print(f"Thresholds: {', '.join([t['name'] for t in THRESHOLDS.values()])}")
print()

# ============================================================================
# STEP 1: LOAD DATA
# ============================================================================

print("Step 1: Loading data...")

def load_prevalence_data():
    """Load ICD codes and prevalence data"""
    icd_df = pd.read_csv(DATA_DIR / 'ICD10_Diagnoses_All.csv')
    prev_df = pd.read_csv(DATA_DIR / 'Prevalence_Sex_Age_Year_ICD.csv')
    return icd_df, prev_df

def load_mortality_data():
    """Load mortality data from separate male/female files"""
    dfs = []
    age_mapping = {1: '0-9', 2: '10-19', 3: '20-29', 4: '30-39',
                   5: '40-49', 6: '50-59', 7: '60-69', 8: '70-79'}
    
    for gender, filepath in [('Female', 'mortality_diag_Female.csv'), 
                             ('Male', 'mortality_diag_Male.csv')]:
        file = DATA_DIR / filepath
        if file.exists():
            df = pd.read_csv(file)
            df['sex'] = gender
            if 'age_10' in df.columns:
                df['Age_Group'] = df['age_10'].map(age_mapping)
            dfs.append(df)
    
    if dfs:
        mort_df = pd.concat(dfs, ignore_index=True)
        if 'mortality' in mort_df.columns:
            mort_df['mortality'] = pd.to_numeric(mort_df['mortality'], errors='coerce').fillna(0)
        return mort_df
    return None

icd_df, prev_df = load_prevalence_data()
mort_df = load_mortality_data()

print(f"  ✓ Loaded {len(icd_df)} ICD codes")
print(f"  ✓ Loaded prevalence data")
print(f"  ✓ Loaded mortality data")
print()

# Check available age groups
available_age_groups = {}
for gender in ['Female', 'Male']:
    available_age_groups[gender] = []
    for age_group in range(1, 9):
        adj_path = DATA_DIR / f'Adj_Matrix_{gender}_ICD_age_{age_group}.csv'
        if adj_path.exists():
            available_age_groups[gender].append(age_group)

print("Available data:")
for gender in ['Female', 'Male']:
    print(f"  {gender}: Age groups {available_age_groups[gender]}")
print()

# ============================================================================
# STEP 2: DEFINE ANALYSIS FUNCTIONS
# ============================================================================

def load_adjacency_matrix(gender: str, age_group: int, or_threshold: float) -> np.ndarray:
    """Load adjacency matrix and apply OR threshold"""
    adj_path = DATA_DIR / f'Adj_Matrix_{gender}_ICD_age_{age_group}.csv'
    A = pd.read_csv(adj_path, sep=' ', header=None).values
    # Apply threshold: keep only edges with OR >= threshold
    A_thresholded = (A >= or_threshold).astype(float)
    return A_thresholded

def compute_degree_outliers(gender: str, age_group: int, icd_df: pd.DataFrame, 
                            prev_dict: dict, or_threshold: float) -> pd.DataFrame:
    """Compute high-degree outliers (80th percentile of log(degree/prevalence))"""
    A = load_adjacency_matrix(gender, age_group, or_threshold)
    G = nx.from_numpy_array(A)
    
    nodes = []
    for node in range(len(A)):
        degree = G.degree(node)
        if degree > 0:
            icd_row = icd_df[icd_df['diagnose_id'] == node + 1]
            if len(icd_row) > 0:
                icd_code = icd_row.iloc[0]['icd_code']
                prev = prev_dict.get(icd_code, 0)
                if prev > 0:
                    ratio = degree / prev
                    log_ratio = np.log10(ratio)
                    nodes.append({
                        'node': node, 'icd_code': icd_code,
                        'degree': degree, 'prevalence': prev,
                        'log_ratio': log_ratio
                    })
    
    df_nodes = pd.DataFrame(nodes)
    if len(df_nodes) > 0:
        upper_bound = df_nodes['log_ratio'].quantile(0.80)
        outliers = df_nodes[df_nodes['log_ratio'] >= upper_bound].copy()
        outliers['Sex'] = gender
        outliers['Age_Group'] = age_group
        outliers['OR_Threshold'] = or_threshold
        return outliers
    return pd.DataFrame()

def compute_high_mortality_sinks(gender: str, age_group: int, icd_df: pd.DataFrame,
                                mort_dict: dict, or_threshold: float) -> pd.DataFrame:
    """Compute high-mortality sinks (top 20% Z-score product)"""
    A = load_adjacency_matrix(gender, age_group, or_threshold)
    G = nx.from_numpy_array(A)
    betweenness = nx.betweenness_centrality(G)
    
    nodes = []
    for node in range(len(A)):
        bet = betweenness.get(node, 0)
        if bet > 0:
            icd_row = icd_df[icd_df['diagnose_id'] == node + 1]
            if len(icd_row) > 0:
                icd_code = icd_row.iloc[0]['icd_code']
                mort = mort_dict.get(icd_code, 0)
                nodes.append({
                    'node': node, 'icd_code': icd_code,
                    'betweenness': bet, 'mortality': mort
                })
    
    df_nodes = pd.DataFrame(nodes)
    if len(df_nodes) > 0:
        mean_bet, std_bet = df_nodes['betweenness'].mean(), df_nodes['betweenness'].std()
        mean_mort, std_mort = df_nodes['mortality'].mean(), df_nodes['mortality'].std()
        
        if std_bet > 0 and std_mort > 0:
            df_nodes['z_betweenness'] = (df_nodes['betweenness'] - mean_bet) / std_bet
            df_nodes['z_mortality'] = (df_nodes['mortality'] - mean_mort) / std_mort
            df_nodes['z_product'] = df_nodes['z_betweenness'] * df_nodes['z_mortality']
            
            threshold = df_nodes['z_product'].quantile(0.80)
            sinks = df_nodes[df_nodes['z_product'] >= threshold].copy()
            sinks['Sex'] = gender
            sinks['Age_Group'] = age_group
            sinks['OR_Threshold'] = or_threshold
            return sinks
    return pd.DataFrame()

def compute_high_mortality_bridges(gender: str, age_group: int, icd_df: pd.DataFrame,
                                  mort_dict: dict, or_threshold: float) -> pd.DataFrame:
    """Compute high-mortality bridges (top 5% Z-score, 30% mortality diff)"""
    A = load_adjacency_matrix(gender, age_group, or_threshold)
    G = nx.from_numpy_array(A)
    edge_betweenness = nx.edge_betweenness_centrality(G)
    
    edges = []
    for (node1, node2), bet in edge_betweenness.items():
        if bet > 0:
            icd_row1 = icd_df[icd_df['diagnose_id'] == node1 + 1]
            icd_row2 = icd_df[icd_df['diagnose_id'] == node2 + 1]
            
            if len(icd_row1) > 0 and len(icd_row2) > 0:
                icd1 = icd_row1.iloc[0]['icd_code']
                icd2 = icd_row2.iloc[0]['icd_code']
                mort1 = mort_dict.get(icd1, 0)
                mort2 = mort_dict.get(icd2, 0)
                mort_diff = abs(mort1 - mort2)
                
                edges.append({
                    'node1': node1, 'node2': node2,
                    'icd1': icd1, 'icd2': icd2,
                    'betweenness': bet,
                    'mortality1': mort1, 'mortality2': mort2,
                    'mort_diff': mort_diff
                })
    
    df_edges = pd.DataFrame(edges)
    if len(df_edges) > 0:
        # Filter by 10% mortality difference (as in original Z-score method)
        df_edges = df_edges[df_edges['mort_diff'] >= 0.10].copy()
        
        if len(df_edges) > 0:
            mean_bet, std_bet = df_edges['betweenness'].mean(), df_edges['betweenness'].std()
            mean_diff, std_diff = df_edges['mort_diff'].mean(), df_edges['mort_diff'].std()
            
            if std_bet > 0 and std_diff > 0:
                df_edges['z_betweenness'] = (df_edges['betweenness'] - mean_bet) / std_bet
                df_edges['z_mort_diff'] = (df_edges['mort_diff'] - mean_diff) / std_diff
                df_edges['z_product'] = df_edges['z_betweenness'] * df_edges['z_mort_diff']
                
                threshold = df_edges['z_product'].quantile(0.95)
                bridges = df_edges[df_edges['z_product'] >= threshold].copy()
                bridges['Sex'] = gender
                bridges['Age_Group'] = age_group
                bridges['OR_Threshold'] = or_threshold
                return bridges
    return pd.DataFrame()

def create_unique_identifier(row, is_edge=False):
    """Create unique identifier for node or edge"""
    if is_edge:
        return f"{row['Sex']}_{row['Age_Group']}_{row['icd1']}_{row['icd2']}"
    else:
        return f"{row['Sex']}_{row['Age_Group']}_{row['icd_code']}"

# ============================================================================
# STEP 3: RUN ANALYSIS FOR ALL THRESHOLDS
# ============================================================================

print("Step 2: Computing critical nodes/edges for each threshold...")
print()

results = {}

for threshold_key, threshold_info in THRESHOLDS.items():
    threshold_name = threshold_info['name']
    threshold_value = threshold_info['value']
    
    print(f"Processing: {threshold_name}")
    print("-" * 80)
    
    all_outliers = []
    all_sinks = []
    all_bridges = []
    
    for gender in ['Female', 'Male']:
        for age_group in available_age_groups[gender]:
            age_str = AGE_MAP[age_group]
            print(f"  {gender} age {age_group} ({age_str})...", end=' ')
            
            # Get prevalence and mortality
            prev_subset = prev_df[
                (prev_df['sex'] == gender) & 
                (prev_df['Age_Group'] == age_str) &
                (prev_df['year'] == 2014)
            ]
            prev_dict = dict(zip(prev_subset['icd_code'], prev_subset['p']))
            
            mort_subset = mort_df[
                (mort_df['sex'] == gender) &
                (mort_df['Age_Group'] == age_str)
            ]
            mort_dict = dict(zip(mort_subset['icd_code'], mort_subset['mortality']))
            
            try:
                # Compute metrics
                outliers = compute_degree_outliers(gender, age_group, icd_df, prev_dict, threshold_value)
                if len(outliers) > 0:
                    all_outliers.append(outliers)
                
                sinks = compute_high_mortality_sinks(gender, age_group, icd_df, mort_dict, threshold_value)
                if len(sinks) > 0:
                    all_sinks.append(sinks)
                
                bridges = compute_high_mortality_bridges(gender, age_group, icd_df, mort_dict, threshold_value)
                if len(bridges) > 0:
                    all_bridges.append(bridges)
                
                print("✓")
            except Exception as e:
                print(f"Error: {e}")
    
    # Store results
    results[threshold_key] = {
        'name': threshold_name,
        'value': threshold_value,
        'outliers_df': pd.concat(all_outliers, ignore_index=True) if all_outliers else pd.DataFrame(),
        'sinks_df': pd.concat(all_sinks, ignore_index=True) if all_sinks else pd.DataFrame(),
        'bridges_df': pd.concat(all_bridges, ignore_index=True) if all_bridges else pd.DataFrame()
    }
    
    # Create unique identifiers
    if len(results[threshold_key]['outliers_df']) > 0:
        results[threshold_key]['outliers_set'] = set(
            results[threshold_key]['outliers_df'].apply(create_unique_identifier, axis=1)
        )
    else:
        results[threshold_key]['outliers_set'] = set()
    
    if len(results[threshold_key]['sinks_df']) > 0:
        results[threshold_key]['sinks_set'] = set(
            results[threshold_key]['sinks_df'].apply(create_unique_identifier, axis=1)
        )
    else:
        results[threshold_key]['sinks_set'] = set()
    
    if len(results[threshold_key]['bridges_df']) > 0:
        results[threshold_key]['bridges_set'] = set(
            results[threshold_key]['bridges_df'].apply(
                lambda row: create_unique_identifier(row, is_edge=True), axis=1
            )
        )
    else:
        results[threshold_key]['bridges_set'] = set()
    
    print(f"  ✓ Outliers: {len(results[threshold_key]['outliers_set'])}")
    print(f"  ✓ Sinks: {len(results[threshold_key]['sinks_set'])}")
    print(f"  ✓ Bridges: {len(results[threshold_key]['bridges_set'])}")
    print()

# ============================================================================
# STEP 4: COMPUTE OVERLAP STATISTICS
# ============================================================================

print("Step 3: Computing overlap statistics...")
print()

def compute_jaccard(set1, set2):
    """Compute Jaccard index"""
    if len(set1) == 0 and len(set2) == 0:
        return np.nan
    union = set1 | set2
    if len(union) == 0:
        return 0.0
    intersection = set1 & set2
    return len(intersection) / len(union)

comparisons = [
    ('or_1.5', 'or_2.0', 'OR > 1.5 vs OR > 2.0')
]

overlap_results = []

for key1, key2, label in comparisons:
    r1 = results[key1]
    r2 = results[key2]
    
    stats = {
        'comparison': label,
        'condition1': r1['name'],
        'condition2': r2['name'],
        
        'outliers_jaccard': compute_jaccard(r1['outliers_set'], r2['outliers_set']),
        'outliers_common': len(r1['outliers_set'] & r2['outliers_set']),
        'outliers_cond1_only': len(r1['outliers_set'] - r2['outliers_set']),
        'outliers_cond2_only': len(r2['outliers_set'] - r1['outliers_set']),
        
        'sinks_jaccard': compute_jaccard(r1['sinks_set'], r2['sinks_set']),
        'sinks_common': len(r1['sinks_set'] & r2['sinks_set']),
        'sinks_cond1_only': len(r1['sinks_set'] - r2['sinks_set']),
        'sinks_cond2_only': len(r2['sinks_set'] - r1['sinks_set']),
        
        'bridges_jaccard': compute_jaccard(r1['bridges_set'], r2['bridges_set']),
        'bridges_common': len(r1['bridges_set'] & r2['bridges_set']),
        'bridges_cond1_only': len(r1['bridges_set'] - r2['bridges_set']),
        'bridges_cond2_only': len(r2['bridges_set'] - r1['bridges_set'])
    }
    
    overlap_results.append(stats)
    
    print(f"{label}:")
    print(f"  Outliers: J={stats['outliers_jaccard']:.3f}, Common={stats['outliers_common']}")
    print(f"  Sinks:    J={stats['sinks_jaccard']:.3f}, Common={stats['sinks_common']}")
    print(f"  Bridges:  J={stats['bridges_jaccard']:.3f}, Common={stats['bridges_common']}")
    print()

# ============================================================================
# STEP 5: SAVE RESULTS
# ============================================================================

print("Step 4: Saving results...")
print()

# Summary table
summary_data = []
for key, data in results.items():
    summary_data.append({
        'Condition': data['name'],
        'OR_Threshold': data['value'],
        'Outliers': len(data['outliers_set']),
        'Sinks': len(data['sinks_set']),
        'Bridges': len(data['bridges_set'])
    })

summary_df = pd.DataFrame(summary_data)
summary_df.to_csv(OUTPUT_DIR / 'threshold_robustness_summary.csv', index=False)
print(f"  ✓ Saved: threshold_robustness_summary.csv")

# Overlap statistics
overlap_df = pd.DataFrame(overlap_results)
overlap_df.to_csv(OUTPUT_DIR / 'threshold_robustness_overlap_statistics.csv', index=False)
print(f"  ✓ Saved: threshold_robustness_overlap_statistics.csv")

# Detailed results for each condition
for key, data in results.items():
    if len(data['outliers_df']) > 0:
        data['outliers_df'].to_csv(OUTPUT_DIR / f'threshold_{key}_outliers.csv', index=False)
        print(f"  ✓ Saved: threshold_{key}_outliers.csv")
    
    if len(data['sinks_df']) > 0:
        data['sinks_df'].to_csv(OUTPUT_DIR / f'threshold_{key}_sinks.csv', index=False)
        print(f"  ✓ Saved: threshold_{key}_sinks.csv")
    
    if len(data['bridges_df']) > 0:
        data['bridges_df'].to_csv(OUTPUT_DIR / f'threshold_{key}_bridges.csv', index=False)
        print(f"  ✓ Saved: threshold_{key}_bridges.csv")

print()

# ============================================================================
# STEP 6: CREATE VISUALIZATIONS
# ============================================================================

print("Step 5: Creating visualizations...")
print()

# Set style
plt.style.use('seaborn-v0_8-paper')
sns.set_palette("husl")

# ------------------------------------------------------------------------
# FIGURE 1: Comparison Bars
# ------------------------------------------------------------------------

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

metrics = ['Outliers', 'Sinks', 'Bridges']
colors = ['#3498db', '#e74c3c', '#2ecc71']

for idx, metric in enumerate(metrics):
    ax = axes[idx]
    
    thresholds = summary_df['OR_Threshold'].values
    values = summary_df[metric].values
    
    bars = ax.bar(range(len(thresholds)), values, color=colors[idx], 
                  alpha=0.8, edgecolor='black', linewidth=1.5)
    
    ax.set_xticks(range(len(thresholds)))
    ax.set_xticklabels(['OR > 1.5\n(Baseline)', 'OR > 2.0\n(Strict)'], fontsize=11)
    ax.set_ylabel('Count', fontsize=12, fontweight='bold')
    ax.set_title(f'{metric}', fontsize=13, fontweight='bold')
    ax.grid(axis='y', alpha=0.3, linestyle='--')
    
    for bar, val in zip(bars, values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{int(val)}', ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.suptitle('Counts of Critical Nodes/Edges Across OR Thresholds',
             fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(FIG_DIR / 'threshold_comparison_bars.png', dpi=300, bbox_inches='tight')
print("  ✓ Saved: threshold_comparison_bars.png")
plt.close()

# ------------------------------------------------------------------------
# FIGURE 2: Jaccard Indices
# ------------------------------------------------------------------------

fig, ax = plt.subplots(figsize=(8, 6))

comparisons_labels = overlap_df['comparison'].values
metrics_lower = ['outliers', 'sinks', 'bridges']
metric_names = ['Outliers', 'Sinks', 'Bridges']

x = np.arange(len(metric_names))
colors = ['#3498db', '#e74c3c', '#2ecc71']

# Since we only have one comparison, create a simple bar chart
values = [overlap_df.loc[0, f'{m}_jaccard'] for m in metrics_lower]

bars = ax.bar(x, values, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)

for bar, val in zip(bars, values):
    height = bar.get_height()
    if not np.isnan(val):
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
               f'{val:.3f}', ha='center', va='bottom', fontsize=12, fontweight='bold')

ax.set_xlabel('Metric', fontsize=13, fontweight='bold')
ax.set_ylabel('Jaccard Index', fontsize=13, fontweight='bold')
ax.set_title('Robustness: OR > 1.5 vs OR > 2.0',
             fontsize=15, fontweight='bold', pad=20)
ax.set_xticks(x)
ax.set_xticklabels(metric_names, fontsize=12)
ax.grid(axis='y', alpha=0.3, linestyle='--')

ax.axhline(y=0.7, color='green', linestyle='--', alpha=0.6, linewidth=2, label='J > 0.70')
ax.axhline(y=0.4, color='orange', linestyle='--', alpha=0.6, linewidth=2, label='J > 0.40')
ax.set_ylim(0, 1.05)

ax.legend(loc='lower left', framealpha=0.95, fontsize=11)

plt.tight_layout()
plt.savefig(FIG_DIR / 'threshold_jaccard_indices.png', dpi=300, bbox_inches='tight')
print("  ✓ Saved: threshold_jaccard_indices.png")
plt.close()

# ============================================================================
# STEP 7: SUMMARY
# ============================================================================

print()
print("=" * 80)
print("ANALYSIS COMPLETE!")
print("=" * 80)
print()
print("Summary:")
print(summary_df.to_string(index=False))
print()

# Calculate average Jaccard indices from overlap statistics
avg_outliers = overlap_df['outliers_jaccard'].mean()
avg_sinks = overlap_df['sinks_jaccard'].mean()
avg_bridges = overlap_df['bridges_jaccard'].mean()
overall_avg = np.mean([avg_outliers, avg_sinks, avg_bridges])

if overall_avg > 0.7:
    assessment = "EXCELLENT"
elif overall_avg > 0.5:
    assessment = "GOOD"
else:
    assessment = "MODERATE"

print(f"Average Jaccard Indices:")
print(f"  Outliers: {avg_outliers:.3f}")
print(f"  Sinks:    {avg_sinks:.3f}")
print(f"  Bridges:  {avg_bridges:.3f}")
print(f"  Overall:  {overall_avg:.3f} - {assessment}")
print()
print("Output files:")
print(f"  CSV files: {OUTPUT_DIR}/threshold_*.csv")
print(f"  Figures:   {FIG_DIR}/*.png")
print()
print("=" * 80)

COMPLETE THRESHOLD ROBUSTNESS ANALYSIS

Starting from adjacency matrices...
Thresholds: OR > 1.5 (Baseline), OR > 2.0 (Strict)

Step 1: Loading data...
  ✓ Loaded 1080 ICD codes
  ✓ Loaded prevalence data
  ✓ Loaded mortality data

Available data:
  Female: Age groups [1, 2, 3, 4, 5, 6, 7, 8]
  Male: Age groups [1, 2, 3, 4, 5, 6, 7, 8]

Step 2: Computing critical nodes/edges for each threshold...

Processing: OR > 1.5 (Baseline)
--------------------------------------------------------------------------------
  Female age 1 (0-9)... ✓
  Female age 2 (10-19)... ✓
  Female age 3 (20-29)... ✓
  Female age 4 (30-39)... ✓
  Female age 5 (40-49)... ✓
  Female age 6 (50-59)... ✓
  Female age 7 (60-69)... ✓
  Female age 8 (70-79)... ✓
  Male age 1 (0-9)... ✓
  Male age 2 (10-19)... ✓
  Male age 3 (20-29)... ✓
  Male age 4 (30-39)... ✓
  Male age 5 (40-49)... ✓
  Male age 6 (50-59)... ✓
  Male age 7 (60-69)... ✓
  Male age 8 (70-79)... ✓
  ✓ Outliers: 785
  ✓ Sinks: 432
  ✓ Bridges: 115

Process

In [25]:
#!/usr/bin/env python3
"""
Edge Weight Distribution Analysis for Critical Diseases
=======================================================
Generates histograms showing the distribution of edge weights (odds ratios)
for edges connected to outliers, sinks, and bridges in the OR > 1.5 network.
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Configuration
DATA_DIR = Path('Data')
OUTPUT_DIR = Path('outputs')
FIG_DIR = OUTPUT_DIR / 'edge_weight_histograms'
FIG_DIR.mkdir(exist_ok=True)

# Age group mapping
AGE_MAP = {1: '0-9', 2: '10-19', 3: '20-29', 4: '30-39',
           5: '40-49', 6: '50-59', 7: '60-69', 8: '70-79'}

print("=" * 80)
print("EDGE WEIGHT DISTRIBUTION ANALYSIS")
print("=" * 80)
print()

# ============================================================================
# STEP 1: LOAD CRITICAL DISEASE LISTS
# ============================================================================

print("Step 1: Loading critical disease identifications...")

# Load the OR > 1.5 results
outliers_df = pd.read_csv(OUTPUT_DIR / 'threshold_or_1.5_outliers.csv')
sinks_df = pd.read_csv(OUTPUT_DIR / 'threshold_or_1.5_sinks.csv')
bridges_df = pd.read_csv(OUTPUT_DIR / 'threshold_or_1.5_bridges.csv')

print(f"  ✓ Outliers: {len(outliers_df)}")
print(f"  ✓ Sinks: {len(sinks_df)}")
print(f"  ✓ Bridges: {len(bridges_df)}")
print()

# Load ICD mapping
icd_df = pd.read_csv(DATA_DIR / 'ICD10_Diagnoses_All.csv')
icd_to_node = dict(zip(icd_df['icd_code'], icd_df['diagnose_id'] - 1))  # 0-indexed

# ============================================================================
# STEP 2: EXTRACT EDGE WEIGHTS FOR EACH CATEGORY
# ============================================================================

print("Step 2: Extracting edge weights...")

def get_edge_weights_for_nodes(icd_codes, sex, age_group):
    """Get all edge weights for given nodes (outliers or sinks)"""
    # Load adjacency matrix
    adj_path = DATA_DIR / f'Adj_Matrix_{sex}_ICD_age_{age_group}.csv'
    if not adj_path.exists():
        return []
    
    A = pd.read_csv(adj_path, sep=' ', header=None).values
    
    weights = []
    for icd_code in icd_codes:
        node = icd_to_node.get(icd_code)
        if node is None:
            continue
        
        # Get all edges connected to this node
        # Outgoing edges
        for neighbor in range(len(A)):
            if neighbor != node and A[node, neighbor] >= 1.5:
                weights.append(A[node, neighbor])
        
        # Incoming edges (if matrix is not symmetric)
        for neighbor in range(len(A)):
            if neighbor != node and A[neighbor, node] >= 1.5:
                weights.append(A[neighbor, node])
    
    return weights

def get_edge_weights_for_bridges(bridges_subset):
    """Get edge weights for specific bridge edges"""
    weights = []
    
    for _, row in bridges_subset.iterrows():
        sex = row['Sex']
        age_group = row['Age_Group']
        icd1 = row['icd1']
        icd2 = row['icd2']
        
        # Load adjacency matrix
        adj_path = DATA_DIR / f'Adj_Matrix_{sex}_ICD_age_{age_group}.csv'
        if not adj_path.exists():
            continue
        
        A = pd.read_csv(adj_path, sep=' ', header=None).values
        
        node1 = icd_to_node.get(icd1)
        node2 = icd_to_node.get(icd2)
        
        if node1 is not None and node2 is not None:
            weight = A[node1, node2]
            if weight >= 1.5:
                weights.append(weight)
    
    return weights

# Collect edge weights for each category
outlier_weights = []
sink_weights = []
bridge_weights = []

# Process outliers
print("  Processing outliers...")
for sex in ['Female', 'Male']:
    for age_group in range(1, 9):
        subset = outliers_df[(outliers_df['Sex'] == sex) & (outliers_df['Age_Group'] == age_group)]
        if len(subset) > 0:
            icd_codes = subset['icd_code'].unique()
            weights = get_edge_weights_for_nodes(icd_codes, sex, age_group)
            outlier_weights.extend(weights)

print(f"    Collected {len(outlier_weights)} edge weights")

# Process sinks
print("  Processing sinks...")
for sex in ['Female', 'Male']:
    for age_group in range(1, 9):
        subset = sinks_df[(sinks_df['Sex'] == sex) & (sinks_df['Age_Group'] == age_group)]
        if len(subset) > 0:
            icd_codes = subset['icd_code'].unique()
            weights = get_edge_weights_for_nodes(icd_codes, sex, age_group)
            sink_weights.extend(weights)

print(f"    Collected {len(sink_weights)} edge weights")

# Process bridges (these are specific edges, not nodes)
print("  Processing bridges...")
bridge_weights = get_edge_weights_for_bridges(bridges_df)
print(f"    Collected {len(bridge_weights)} edge weights")
print()

# ============================================================================
# STEP 3: CREATE VISUALIZATIONS
# ============================================================================

print("Step 3: Creating visualizations...")

# Set style
plt.style.use('seaborn-v0_8-paper')
sns.set_palette("husl")

# ------------------------------------------------------------------------
# FIGURE 1: Three separate histograms
# ------------------------------------------------------------------------

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

categories = [
    ('Outliers', outlier_weights, '#3498db'),
    ('Sinks', sink_weights, '#e74c3c'),
    ('Bridges', bridge_weights, '#2ecc71')
]

for idx, (name, weights, color) in enumerate(categories):
    ax = axes[idx]
    
    if len(weights) > 0:
        # Create histogram with log-spaced bins
        log_bins = np.logspace(np.log10(1.5), np.log10(max(weights)), 50)
        counts, bins, patches = ax.hist(weights, bins=log_bins, color=color, alpha=0.7, 
                                        edgecolor='black', linewidth=0.5)
        
        # Set log scale and x-axis limits
        ax.set_xscale('log')
        ax.set_xlim(1.5, max(weights) * 1.1)
        
        # Add statistics
        mean_weight = np.mean(weights)
        median_weight = np.median(weights)
        
        ax.axvline(mean_weight, color='red', linestyle='--', linewidth=2, 
                  label=f'Mean: {mean_weight:.2f}')
        ax.axvline(median_weight, color='orange', linestyle='--', linewidth=2,
                  label=f'Median: {median_weight:.2f}')
        
        ax.set_xlabel('Odds Ratio (log scale)', fontsize=12, fontweight='bold')
        ax.set_ylabel('Frequency', fontsize=12, fontweight='bold')
        ax.set_title(f'{name}\n(n={len(weights)} edges)', fontsize=13, fontweight='bold')
        ax.legend(loc='upper left', fontsize=10)
        ax.grid(axis='both', alpha=0.3, linestyle='--')
        
        # Add text with statistics
        stats_text = f'Min: {np.min(weights):.2f}\nMax: {np.max(weights):.2f}\nStd: {np.std(weights):.2f}'
        ax.text(0.98, 0.65, stats_text, transform=ax.transAxes,
               fontsize=9, verticalalignment='top', horizontalalignment='right',
               bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    else:
        ax.text(0.5, 0.5, 'No data', transform=ax.transAxes,
               ha='center', va='center', fontsize=14)
        ax.set_xlabel('Odds Ratio', fontsize=12, fontweight='bold')
        ax.set_ylabel('Frequency', fontsize=12, fontweight='bold')
        ax.set_title(f'{name}', fontsize=13, fontweight='bold')

plt.suptitle('Distribution of Edge Weights (Odds Ratios > 1.5)',
             fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(FIG_DIR / 'edge_weight_histograms_separate.png', dpi=300, bbox_inches='tight')
print("  ✓ Saved: edge_weight_histograms_separate.png")
plt.close()

# ------------------------------------------------------------------------
# FIGURE 2: Overlaid histograms
# ------------------------------------------------------------------------

fig, ax = plt.subplots(figsize=(10, 6))

# Determine overall range for consistent bins
all_weights = []
if len(outlier_weights) > 0:
    all_weights.extend(outlier_weights)
if len(sink_weights) > 0:
    all_weights.extend(sink_weights)
if len(bridge_weights) > 0:
    all_weights.extend(bridge_weights)

if len(all_weights) > 0:
    log_bins = np.logspace(np.log10(1.5), np.log10(max(all_weights)), 50)
else:
    log_bins = 50

if len(outlier_weights) > 0:
    ax.hist(outlier_weights, bins=log_bins, color='#3498db', alpha=0.5, 
           label=f'Outliers (n={len(outlier_weights)})', edgecolor='black', linewidth=0.5)

if len(sink_weights) > 0:
    ax.hist(sink_weights, bins=log_bins, color='#e74c3c', alpha=0.5,
           label=f'Sinks (n={len(sink_weights)})', edgecolor='black', linewidth=0.5)

if len(bridge_weights) > 0:
    ax.hist(bridge_weights, bins=log_bins, color='#2ecc71', alpha=0.5,
           label=f'Bridges (n={len(bridge_weights)})', edgecolor='black', linewidth=0.5)

ax.set_xscale('log')
if len(all_weights) > 0:
    ax.set_xlim(1.5, max(all_weights) * 1.1)
ax.set_xlabel('Odds Ratio (log scale)', fontsize=13, fontweight='bold')
ax.set_ylabel('Frequency', fontsize=13, fontweight='bold')
ax.set_title('Comparison of Edge Weight Distributions (OR > 1.5)',
             fontsize=15, fontweight='bold', pad=20)
ax.legend(loc='upper left', fontsize=11, framealpha=0.9)
ax.grid(axis='both', alpha=0.3, linestyle='--')

plt.tight_layout()
plt.savefig(FIG_DIR / 'edge_weight_histograms_overlaid.png', dpi=300, bbox_inches='tight')
print("  ✓ Saved: edge_weight_histograms_overlaid.png")
plt.close()

# ------------------------------------------------------------------------
# FIGURE 3: Box plots for comparison
# ------------------------------------------------------------------------

fig, ax = plt.subplots(figsize=(8, 6))

data_to_plot = []
labels = []

if len(outlier_weights) > 0:
    data_to_plot.append(outlier_weights)
    labels.append(f'Outliers\n(n={len(outlier_weights)})')

if len(sink_weights) > 0:
    data_to_plot.append(sink_weights)
    labels.append(f'Sinks\n(n={len(sink_weights)})')

if len(bridge_weights) > 0:
    data_to_plot.append(bridge_weights)
    labels.append(f'Bridges\n(n={len(bridge_weights)})')

bp = ax.boxplot(data_to_plot, labels=labels, patch_artist=True,
                medianprops=dict(color='red', linewidth=2),
                boxprops=dict(facecolor='lightblue', alpha=0.7),
                whiskerprops=dict(linewidth=1.5),
                capprops=dict(linewidth=1.5))

# Color boxes differently
colors = ['#3498db', '#e74c3c', '#2ecc71']
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

ax.set_yscale('log')
# Set y-axis limits to start at 1.5
if len(data_to_plot) > 0:
    max_val = max([max(d) for d in data_to_plot])
    ax.set_ylim(1.5, max_val * 1.2)
ax.set_ylabel('Odds Ratio (log scale)', fontsize=13, fontweight='bold')
ax.set_title('Distribution Comparison of Edge Weights (OR > 1.5)',
             fontsize=15, fontweight='bold', pad=20)
ax.grid(axis='y', alpha=0.3, linestyle='--')

plt.tight_layout()
plt.savefig(FIG_DIR / 'edge_weight_boxplots.png', dpi=300, bbox_inches='tight')
print("  ✓ Saved: edge_weight_boxplots.png")
plt.close()

# ============================================================================
# STEP 4: SUMMARY STATISTICS
# ============================================================================

print()
print("=" * 80)
print("SUMMARY STATISTICS")
print("=" * 80)
print()

def print_stats(name, weights):
    if len(weights) > 0:
        print(f"{name}:")
        print(f"  Count:   {len(weights)}")
        print(f"  Mean:    {np.mean(weights):.3f}")
        print(f"  Median:  {np.median(weights):.3f}")
        print(f"  Std:     {np.std(weights):.3f}")
        print(f"  Min:     {np.min(weights):.3f}")
        print(f"  Max:     {np.max(weights):.3f}")
        print(f"  Q25:     {np.percentile(weights, 25):.3f}")
        print(f"  Q75:     {np.percentile(weights, 75):.3f}")
        print()
    else:
        print(f"{name}: No data")
        print()

print_stats("Outliers (edges connected to high-degree outliers)", outlier_weights)
print_stats("Sinks (edges connected to high-mortality sinks)", sink_weights)
print_stats("Bridges (high-mortality bridge edges)", bridge_weights)

# Save statistics to CSV
stats_data = []
for name, weights in [('Outliers', outlier_weights), ('Sinks', sink_weights), ('Bridges', bridge_weights)]:
    if len(weights) > 0:
        stats_data.append({
            'Category': name,
            'Count': len(weights),
            'Mean': np.mean(weights),
            'Median': np.median(weights),
            'Std': np.std(weights),
            'Min': np.min(weights),
            'Max': np.max(weights),
            'Q25': np.percentile(weights, 25),
            'Q75': np.percentile(weights, 75)
        })

stats_df = pd.DataFrame(stats_data)
stats_df.to_csv(OUTPUT_DIR / 'edge_weight_statistics.csv', index=False)
print("✓ Saved statistics to: edge_weight_statistics.csv")

print()
print("=" * 80)
print("✓ ANALYSIS COMPLETE")
print("=" * 80)
print()
print("Output files:")
print(f"  Figures:    {FIG_DIR}/*.png")
print(f"  Statistics: {OUTPUT_DIR}/edge_weight_statistics.csv")
print()

EDGE WEIGHT DISTRIBUTION ANALYSIS

Step 1: Loading critical disease identifications...
  ✓ Outliers: 785
  ✓ Sinks: 432
  ✓ Bridges: 115

Step 2: Extracting edge weights...
  Processing outliers...
    Collected 31534 edge weights
  Processing sinks...
    Collected 11728 edge weights
  Processing bridges...
    Collected 115 edge weights

Step 3: Creating visualizations...
  ✓ Saved: edge_weight_histograms_separate.png
  ✓ Saved: edge_weight_histograms_overlaid.png
  ✓ Saved: edge_weight_boxplots.png

SUMMARY STATISTICS

Outliers (edges connected to high-degree outliers):
  Count:   31534
  Mean:    8.865
  Median:  3.106
  Std:     34.319
  Min:     1.500
  Max:     1238.994
  Q25:     2.192
  Q75:     5.566

Sinks (edges connected to high-mortality sinks):
  Count:   11728
  Mean:    9.404
  Median:  3.398
  Std:     19.856
  Min:     1.502
  Max:     341.887
  Q25:     2.162
  Q75:     7.643

Bridges (high-mortality bridge edges):
  Count:   115
  Mean:    9.537
  Median:  2.724
  