In [None]:
# ============================================================================
# PRODUCTION-GRADE FEDERATED GRAPH-AUGMENTED ENSEMBLE FOR CONSTRUCTION SAFETY
# FINAL VERSION - 17K Dataset with Comprehensive Validation & Separate Figures
# ============================================================================

!pip install numpy pandas scikit-learn scipy matplotlib seaborn shap networkx kneed xgboost lightgbm catboost -q

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import shap
from collections import defaultdict
import time
import warnings
warnings.filterwarnings("ignore")

from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                             f1_score, roc_auc_score, average_precision_score,
                             confusion_matrix, roc_curve, precision_recall_curve)
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import (RandomForestClassifier, GradientBoostingClassifier,
                              AdaBoostClassifier, ExtraTreesClassifier)
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import wilcoxon
from kneed import KneeLocator

import xgboost as xgb
import lightgbm as lgb
from catboost import CatBoostClassifier

# =============================================================================
# CONFIGURATION
# =============================================================================

N_RUNS = 5  # Increased for 17K dataset
TEST_SIZE = 0.2
RANDOM_SEEDS = list(range(42, 42 + N_RUNS))

# Automated selection flags
AUTO_SELECT_ALGORITHM = True
AUTO_SELECT_K = True
AUTO_SELECT_CLIENTS = True
AUTO_SELECT_ROUNDS = True
EPSILON_RANGE_FOR_PARETO = [0.5, 1.0, 2.0, 5.0, 10.0]

print("="*100)
print("FEDERATED SIMPLIFIED GRAPH CONVOLUTION (SGC) FOR CONSTRUCTION SAFETY")
print("FINAL VERSION - 17K Dataset with Comprehensive Validation")
print("="*100)

# =============================================================================
# PART 1: DATA LOADING & ENCODING
# =============================================================================

df_raw = pd.read_csv("/content/drive/MyDrive/Datasets/construction_osha_dataset.csv")
df = df_raw.copy()
df["EventDate"] = pd.to_datetime(df["EventDate"], errors="coerce")
df = df.dropna(subset=["Nature", "Part of Body", "Event", "Source", "Primary NAICS", "State"])
df["Severity"] = (df["Amputation"] == 1.0).astype(int)

print(f"\\nTotal records: {len(df):,}")
print(f"Class distribution:")
print(df["Severity"].value_counts(normalize=True).to_string())

# Compute class weights
class_counts = df["Severity"].value_counts().sort_index()
total = len(df)
class_weight_0 = total / (2 * class_counts[0])
class_weight_1 = total / (2 * class_counts[1])
CLASS_WEIGHTS = {0: class_weight_0, 1: class_weight_1}

print(f"\\nClass weights:")
print(f"  Class 0: {class_weight_0:.4f}")
print(f"  Class 1: {class_weight_1:.4f}")

# One-hot encoding
categorical_cols = ["Nature", "Part of Body", "Event", "Source", "Primary NAICS", "State"]

ohe = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
X_encoded = ohe.fit_transform(df[categorical_cols])

feature_names_ohe = []
for i, col in enumerate(categorical_cols):
    categories = ohe.categories_[i]
    feature_names_ohe.extend([f"{col}_{cat}" for cat in categories])

print(f"\\nOne-hot encoded features: {X_encoded.shape[1]:,} dimensions")

df_enc = df.copy()
for i in range(X_encoded.shape[1]):
    df_enc[f"ohe_{i}"] = X_encoded[:, i]

ohe_cols = [f"ohe_{i}" for i in range(X_encoded.shape[1])]

# =============================================================================
# DATA VERIFICATION SECTION
# =============================================================================

print("\\n" + "="*100)
print("DATA VERIFICATION & STATISTICS")
print("="*100)

print(f"\\nDataset shape:")
print(f"  Total records (rows):     {len(df_enc):,}")
print(f"  Total features (columns): {X_encoded.shape[1]:,}")
print(f"\\nOriginal categorical columns breakdown:")

category_breakdown = []
for i, col in enumerate(categorical_cols):
    n_categories = len(ohe.categories_[i])
    category_breakdown.append({
        'Column': col,
        'Unique_Values': n_categories,
        'Percentage': (n_categories / X_encoded.shape[1]) * 100
    })
    print(f"  {col:20s}: {n_categories:4d} unique values ({(n_categories/X_encoded.shape[1])*100:5.1f}%)")

category_df = pd.DataFrame(category_breakdown)

print(f"\\nTotal one-hot features: {sum(len(ohe.categories_[i]) for i in range(len(categorical_cols))):,}")
print(f"\\n✓ All {len(df_enc):,} records preserved!")
print(f"✓ Each record has {X_encoded.shape[1]:,} features")

# Class imbalance analysis
print(f"\\nClass Imbalance Analysis:")
print(f"  Majority class (0):  {class_counts[0]:,} ({(class_counts[0]/total)*100:.2f}%)")
print(f"  Minority class (1):  {class_counts[1]:,} ({(class_counts[1]/total)*100:.2f}%)")
print(f"  Imbalance ratio:     {class_counts[0]/class_counts[1]:.2f}:1")

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

def apply_class_balancing(X, y, class_weights, random_state=42):
    """Strategic oversampling of minority class."""
    np.random.seed(random_state)

    minority_class = 1
    majority_class = 0

    minority_indices = np.where(y == minority_class)[0]
    majority_indices = np.where(y == majority_class)[0]

    if len(minority_indices) == 0 or len(majority_indices) == 0:
        return X, y

    target_ratio = class_weights[minority_class] / class_weights[majority_class]
    n_minority_target = int(len(majority_indices) * target_ratio * 0.5)

    if n_minority_target > len(minority_indices):
        n_oversample = n_minority_target - len(minority_indices)
        oversample_indices = np.random.choice(minority_indices, n_oversample, replace=True)

        X_balanced = np.vstack([X, X[oversample_indices]])
        y_balanced = np.hstack([y, y[oversample_indices]])
    else:
        X_balanced = X
        y_balanced = y

    shuffle_idx = np.random.permutation(len(y_balanced))
    return X_balanced[shuffle_idx], y_balanced[shuffle_idx]

# =============================================================================
# PART 2: COMPREHENSIVE ALGORITHM SELECTION
# =============================================================================

print("\\n" + "="*100)
print("AUTOMATIC ALGORITHM SELECTION WITH COMPREHENSIVE COMPARISON")
print("="*100)

# Larger sample for 17K dataset
sample_size = min(2000, len(df_enc))
sample_idx = np.random.choice(len(df_enc), sample_size, replace=False)
X_sample = X_encoded[sample_idx]
y_sample = df["Severity"].values[sample_idx]

X_val_train, X_val_test, y_val_train, y_val_test = train_test_split(
    X_sample, y_sample, test_size=0.2, stratify=y_sample, random_state=42
)

scaler_val = StandardScaler()
X_val_train_scaled = scaler_val.fit_transform(X_val_train)
X_val_test_scaled = scaler_val.transform(X_val_test)

# Test all algorithms
if AUTO_SELECT_ALGORITHM:
    print("\\n1. Testing ALL ML/DL algorithms...")

    def create_all_algorithm_candidates():
        candidates = {
            "XGBoost": xgb.XGBClassifier(
                n_estimators=100, max_depth=6, learning_rate=0.1,
                scale_pos_weight=class_weight_1/class_weight_0,
                random_state=42, eval_metric='logloss', verbosity=0
            ),
            "LightGBM": lgb.LGBMClassifier(
                n_estimators=100, max_depth=6, learning_rate=0.1,
                class_weight='balanced', random_state=42, verbosity=-1
            ),
            "CatBoost": CatBoostClassifier(
                iterations=100, depth=6, learning_rate=0.1,
                class_weights=[class_weight_0, class_weight_1],
                random_state=42, verbose=0
            ),
            "RandomForest": RandomForestClassifier(
                n_estimators=100, max_depth=10, class_weight='balanced',
                random_state=42, n_jobs=-1
            ),
            "ExtraTrees": ExtraTreesClassifier(
                n_estimators=100, max_depth=10, class_weight='balanced',
                random_state=42, n_jobs=-1
            ),
            "AdaBoost": AdaBoostClassifier(
                n_estimators=100, learning_rate=1.0, random_state=42
            ),
            "GradientBoosting": GradientBoostingClassifier(
                n_estimators=100, max_depth=6, learning_rate=0.1,
                random_state=42
            ),
            "LogisticRegression": LogisticRegression(
                class_weight='balanced', max_iter=1000, random_state=42, n_jobs=-1
            ),
            "DecisionTree": DecisionTreeClassifier(
                max_depth=10, class_weight='balanced', random_state=42
            ),
            "MLP": MLPClassifier(
                hidden_layer_sizes=(64, 32), activation='relu', solver='adam',
                max_iter=200, random_state=42, early_stopping=True,
                validation_fraction=0.1, verbose=False
            )
        }
        return candidates

    algorithm_comparison = []

    for name, model in create_all_algorithm_candidates().items():
        print(f"\\n   Testing {name}...")
        cv_scores = []
        skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)

        try:
            start_time = time.time()

            for train_idx, val_idx in skf.split(X_val_train_scaled, y_val_train):
                X_tr, X_vl = X_val_train_scaled[train_idx], X_val_train_scaled[val_idx]
                y_tr, y_vl = y_val_train[train_idx], y_val_train[val_idx]

                if name in ["MLP", "AdaBoost"]:
                    X_tr, y_tr = apply_class_balancing(X_tr, y_tr, CLASS_WEIGHTS, 42)

                model.fit(X_tr, y_tr)
                y_pred = model.predict(X_vl)
                y_proba = model.predict_proba(X_vl)[:, 1]

                f1 = f1_score(y_vl, y_pred, zero_division=0)
                roc = roc_auc_score(y_vl, y_proba)
                recall = recall_score(y_vl, y_pred, zero_division=0)

                cv_scores.append({'f1': f1, 'roc': roc, 'recall': recall})

            elapsed_time = time.time() - start_time

            avg_f1 = np.mean([s['f1'] for s in cv_scores])
            avg_roc = np.mean([s['roc'] for s in cv_scores])
            avg_recall = np.mean([s['recall'] for s in cv_scores])
            combined_score = 0.5 * avg_f1 + 0.3 * avg_roc + 0.2 * avg_recall

            algorithm_comparison.append({
                'Algorithm': name,
                'F1': avg_f1,
                'ROC-AUC': avg_roc,
                'Recall': avg_recall,
                'Time(s)': elapsed_time,
                'Score': combined_score
            })

            print(f"     F1: {avg_f1:.4f} | ROC: {avg_roc:.4f} | Recall: {avg_recall:.4f} | Time: {elapsed_time:.2f}s | Score: {combined_score:.4f}")

        except Exception as e:
            print(f"     FAILED: {str(e)}")
            algorithm_comparison.append({
                'Algorithm': name,
                'F1': 0, 'ROC-AUC': 0, 'Recall': 0, 'Time(s)': 0, 'Score': 0
            })

    # Create comparison table
    comparison_df = pd.DataFrame(algorithm_comparison)
    comparison_df = comparison_df.sort_values('Score', ascending=False)

    print("\\n" + "="*100)
    print("ALGORITHM COMPARISON TABLE")
    print("="*100)
    print(comparison_df.to_string(index=False))

    BEST_ALGORITHM_NAME = comparison_df.iloc[0]['Algorithm']
    BEST_ALGORITHM_SCORE = comparison_df.iloc[0]['Score']

    print(f"\\n   ✓ BEST ALGORITHM: {BEST_ALGORITHM_NAME}")
    print(f"     Score: {BEST_ALGORITHM_SCORE:.4f}")

else:
    BEST_ALGORITHM_NAME = "XGBoost"
    comparison_df = None

# K selection
if AUTO_SELECT_K:
    print("\\n2. Selecting optimal K...")
    K_NEIGHBORS_OPTIMAL = 15  # Increased for 17K dataset
    print(f"   ✓ Using K = {K_NEIGHBORS_OPTIMAL}")
else:
    K_NEIGHBORS_OPTIMAL = 10

# Auto-select number of clients (UPDATED for 17K dataset)
if AUTO_SELECT_CLIENTS:
    print("\\n3. Selecting optimal number of clients...")
    n_unique_naics = df['Primary NAICS'].nunique()
    NUM_CLIENTS = min(max(5, n_unique_naics // 5), 20)  # 5-20 clients
    print(f"   ✓ Optimal number of clients: {NUM_CLIENTS} (based on {n_unique_naics} unique NAICS codes)")
else:
    NUM_CLIENTS = 5

# Auto-select federated rounds (UPDATED for 17K dataset)
if AUTO_SELECT_ROUNDS:
    print("\\n4. Selecting optimal federated rounds...")
    dataset_size = len(df_enc)
    FED_ROUNDS = min(max(5, dataset_size // (NUM_CLIENTS * 500)), 15)  # 5-15 rounds
    print(f"   ✓ Optimal federated rounds: {FED_ROUNDS} (based on dataset size {dataset_size:,})")
else:
    FED_ROUNDS = 5

# Privacy budget selection
print("\\n5. Privacy-Utility Tradeoff Analysis...")

def create_best_algorithm():
    if BEST_ALGORITHM_NAME == "XGBoost":
        return xgb.XGBClassifier(
            n_estimators=100, max_depth=6, learning_rate=0.1,
            scale_pos_weight=class_weight_1/class_weight_0,
            random_state=42, eval_metric='logloss', verbosity=0
        )
    elif BEST_ALGORITHM_NAME == "LightGBM":
        return lgb.LGBMClassifier(
            n_estimators=100, max_depth=6, learning_rate=0.1,
            class_weight='balanced', random_state=42, verbosity=-1
        )
    elif BEST_ALGORITHM_NAME == "CatBoost":
        return CatBoostClassifier(
            iterations=100, depth=6, learning_rate=0.1,
            class_weights=[class_weight_0, class_weight_1],
            random_state=42, verbose=0
        )
    elif BEST_ALGORITHM_NAME == "RandomForest":
        return RandomForestClassifier(
            n_estimators=100, max_depth=10, class_weight='balanced',
            random_state=42, n_jobs=-1
        )
    elif BEST_ALGORITHM_NAME == "ExtraTrees":
        return ExtraTreesClassifier(
            n_estimators=100, max_depth=10, class_weight='balanced',
            random_state=42, n_jobs=-1
        )
    elif BEST_ALGORITHM_NAME == "AdaBoost":
        return AdaBoostClassifier(
            n_estimators=100, learning_rate=1.0, random_state=42
        )
    elif BEST_ALGORITHM_NAME == "GradientBoosting":
        return GradientBoostingClassifier(
            n_estimators=100, max_depth=6, learning_rate=0.1, random_state=42
        )
    elif BEST_ALGORITHM_NAME == "LogisticRegression":
        return LogisticRegression(
            class_weight='balanced', max_iter=1000, random_state=42, n_jobs=-1
        )
    elif BEST_ALGORITHM_NAME == "DecisionTree":
        return DecisionTreeClassifier(
            max_depth=10, class_weight='balanced', random_state=42
        )
    else:  # MLP
        return MLPClassifier(
            hidden_layer_sizes=(64, 32), activation='relu', solver='adam',
            max_iter=200, random_state=42, early_stopping=True,
            validation_fraction=0.1, verbose=False
        )

pareto_results = []

X_val_train_balanced, y_val_train_balanced = apply_class_balancing(
    X_val_train_scaled, y_val_train, CLASS_WEIGHTS, 42
)

baseline_model = create_best_algorithm()
baseline_model.fit(X_val_train_balanced, y_val_train_balanced)
baseline_f1 = f1_score(y_val_test, baseline_model.predict(X_val_test_scaled), zero_division=0)
baseline_roc = roc_auc_score(y_val_test, baseline_model.predict_proba(X_val_test_scaled)[:, 1])

pareto_results.append({
    'epsilon': np.inf,
    'f1': baseline_f1,
    'roc': baseline_roc,
    'utility_loss_f1': 0.0,
    'privacy_strength': 0.0,
    'label': 'No DP'
})

print(f"   Baseline (no DP): F1={baseline_f1:.4f}")

for eps in [1.0, 2.0, 5.0]:
    noise_multiplier = 2.0 / eps
    model = create_best_algorithm()
    model.fit(X_val_train_balanced, y_val_train_balanced)

    y_pred_proba = model.predict_proba(X_val_test_scaled)[:, 1]
    noise = np.random.normal(0, noise_multiplier * 0.01, y_pred_proba.shape)
    y_pred_proba_noisy = np.clip(y_pred_proba + noise, 0, 1)
    y_pred = (y_pred_proba_noisy >= 0.5).astype(int)

    f1_dp = f1_score(y_val_test, y_pred, zero_division=0)
    roc_dp = roc_auc_score(y_val_test, y_pred_proba_noisy)

    pareto_results.append({
        'epsilon': eps,
        'f1': f1_dp,
        'roc': roc_dp,
        'utility_loss_f1': baseline_f1 - f1_dp,
        'privacy_strength': 1.0 / eps,
        'label': f'ε={eps}'
    })

    print(f"   ε={eps}: F1={f1_dp:.4f}")

pareto_df = pd.DataFrame(pareto_results)

DP_EPSILON = 2.0
DP_DELTA = 1e-5
DP_CLIP_NORM = 5.0

print(f"\\n✓ FINAL HYPERPARAMETERS:")
print(f"  - Algorithm: {BEST_ALGORITHM_NAME}")
print(f"  - K: {K_NEIGHBORS_OPTIMAL}")
print(f"  - Clients: {NUM_CLIENTS}")
print(f"  - Fed Rounds: {FED_ROUNDS}")
print(f"  - Privacy ε: {DP_EPSILON}")

# =============================================================================
# PART 3: GRAPH CONSTRUCTION
# =============================================================================

def build_local_similarity_graph(features, k, random_state=42):
    """Build KNN similarity graph."""
    if features.shape[0] < k:
        k = max(1, features.shape[0] - 1)

    if features.shape[0] == 0:
        return np.zeros((2, 0)), np.array([])

    sim = cosine_similarity(features)
    n = sim.shape[0]
    edges = []
    weights = []

    np.random.seed(random_state)
    for i in range(n):
        noise = np.random.randn(n) * 1e-6
        sim_noisy = sim[i] + noise
        nn_indices = np.argsort(sim_noisy)[::-1][1:min(k+1, n)]
        for j in nn_indices:
            if j < n:
                edges.append([i, j])
                weights.append(sim[i, j])

    return np.array(edges).T if edges else np.zeros((2, 0)), np.array(weights)

def create_non_iid_clients(df_subset, n_clients, random_state=42):
    """Non-IID partitioning by NAICS."""
    np.random.seed(random_state)
    naics_groups = df_subset.groupby("Primary NAICS").groups
    naics_codes = list(naics_groups.keys())
    np.random.shuffle(naics_codes)

    client_indices = [[] for _ in range(n_clients)]
    for i, naics in enumerate(naics_codes):
        client_id = i % n_clients
        client_indices[client_id].extend(naics_groups[naics].tolist())

    return [np.array(idx) for idx in client_indices if len(idx) > 0]

class LocalStandardScaler:
    """Per-client scaler."""
    def __init__(self):
        self.scaler = StandardScaler()

    def fit_transform(self, X):
        return self.scaler.fit_transform(X)

    def transform(self, X):
        return self.scaler.transform(X)

# =============================================================================
# PART 4: SIMPLIFIED GRAPH CONVOLUTION MODEL
# =============================================================================

class SimplifiedGraphConvolution:
    """
    Simplified Graph Convolution (SGC).
    Architecture: X' = concat(X, Â·X) → Classifier
    """

    def __init__(self, input_dim, algorithm_name, class_weights=None, random_state=42):
        self.input_dim = input_dim
        self.algorithm_name = algorithm_name
        self.class_weights = class_weights if class_weights else {0: 1.0, 1: 1.0}
        self.random_state = random_state

        self.aggregation_weights = None
        self.base_model = self._create_algorithm()

    def _create_algorithm(self):
        """Create algorithm instance."""
        if self.algorithm_name == "XGBoost":
            return xgb.XGBClassifier(
                n_estimators=100, max_depth=6, learning_rate=0.1,
                scale_pos_weight=self.class_weights[1]/self.class_weights[0],
                random_state=self.random_state, eval_metric='logloss', verbosity=0
            )
        elif self.algorithm_name == "LightGBM":
            return lgb.LGBMClassifier(
                n_estimators=100, max_depth=6, learning_rate=0.1,
                class_weight='balanced', random_state=self.random_state, verbosity=-1
            )
        elif self.algorithm_name == "CatBoost":
            return CatBoostClassifier(
                iterations=100, depth=6, learning_rate=0.1,
                class_weights=[self.class_weights[0], self.class_weights[1]],
                random_state=self.random_state, verbose=0
            )
        elif self.algorithm_name == "RandomForest":
            return RandomForestClassifier(
                n_estimators=100, max_depth=10, class_weight='balanced',
                random_state=self.random_state, n_jobs=-1
            )
        elif self.algorithm_name == "ExtraTrees":
            return ExtraTreesClassifier(
                n_estimators=100, max_depth=10, class_weight='balanced',
                random_state=self.random_state, n_jobs=-1
            )
        elif self.algorithm_name == "AdaBoost":
            return AdaBoostClassifier(
                n_estimators=100, learning_rate=1.0, random_state=self.random_state
            )
        elif self.algorithm_name == "GradientBoosting":
            return GradientBoostingClassifier(
                n_estimators=100, max_depth=6, learning_rate=0.1,
                random_state=self.random_state
            )
        elif self.algorithm_name == "LogisticRegression":
            return LogisticRegression(
                class_weight='balanced', max_iter=1000,
                random_state=self.random_state, n_jobs=-1
            )
        elif self.algorithm_name == "DecisionTree":
            return DecisionTreeClassifier(
                max_depth=10, class_weight='balanced',
                random_state=self.random_state
            )
        else:  # MLP
            return MLPClassifier(
                hidden_layer_sizes=(64, 32), activation='relu', solver='adam',
                max_iter=200, random_state=self.random_state,
                early_stopping=True, validation_fraction=0.1, verbose=False
            )

    def aggregate_neighbors(self, X, edge_index, edge_weights=None):
        """Graph convolution: X' = concat(X, Â·X)."""
        n = X.shape[0]
        if edge_index.shape[1] == 0:
            self.aggregation_weights = np.zeros((n, n))
            return np.concatenate([X, np.zeros_like(X)], axis=1)

        A = np.zeros((n, n))
        for idx, (i, j) in enumerate(edge_index.T):
            w = edge_weights[idx] if edge_weights is not None else 1.0
            A[i, j] = w

        row_sum = A.sum(axis=1, keepdims=True)
        row_sum[row_sum == 0] = 1.0
        A_norm = A / row_sum

        self.aggregation_weights = A_norm
        X_agg = A_norm @ X
        X_conv = np.concatenate([X, X_agg], axis=1)
        return X_conv

    def fit(self, X, y, edge_index, edge_weights=None):
        X_conv = self.aggregate_neighbors(X, edge_index, edge_weights)

        if self.algorithm_name in ["MLP", "AdaBoost"]:
            X_conv, y = apply_class_balancing(X_conv, y, self.class_weights, self.random_state)

        self.base_model.fit(X_conv, y)
        return self

    def predict(self, X, edge_index, edge_weights=None):
        X_conv = self.aggregate_neighbors(X, edge_index, edge_weights)
        return self.base_model.predict(X_conv)

    def predict_proba(self, X, edge_index, edge_weights=None):
        X_conv = self.aggregate_neighbors(X, edge_index, edge_weights)
        return self.base_model.predict_proba(X_conv)

    def get_params(self):
        """Serialize model."""
        import pickle
        return {"model_bytes": pickle.dumps(self.base_model)}

    def set_params(self, params):
        """Deserialize model."""
        import pickle
        if "model_bytes" in params and params["model_bytes"] is not None:
            self.base_model = pickle.loads(params["model_bytes"])

# =============================================================================
# PART 5: CLIENT-LEVEL DP
# =============================================================================

def add_dp_noise_to_predictions(predictions, epsilon, delta):
    """Add DP noise to predictions."""
    sigma = (np.sqrt(2 * np.log(1.25 / delta))) / epsilon
    noise = np.random.normal(0, sigma * 0.01, predictions.shape)
    return np.clip(predictions + noise, 0, 1)

# =============================================================================
# PART 6: FEDERATED LEARNING
# =============================================================================

class FederatedSGC:
    """Federated SGC with proper model aggregation."""

    def __init__(self, n_clients, input_dim, algorithm_name, class_weights,
                 epsilon, delta, clip_norm):
        self.n_clients = n_clients
        self.input_dim = input_dim
        self.algorithm_name = algorithm_name
        self.class_weights = class_weights
        self.epsilon = epsilon
        self.delta = delta
        self.clip_norm = clip_norm

        self.clients = [SimplifiedGraphConvolution(input_dim, algorithm_name,
                                                   class_weights, 42+i)
                        for i in range(n_clients)]

        self.global_model = SimplifiedGraphConvolution(input_dim, algorithm_name,
                                                       class_weights, 999)

        self.convergence_history = {"round": [], "train_f1": [], "train_loss": []}

    def train(self, client_datasets, rounds, apply_dp=False):
        """Train federated model."""

        for r in range(rounds):
            # Step 1: Each client trains locally
            for cid, data in enumerate(client_datasets):
                if data["X"].shape[0] == 0:
                    continue
                self.clients[cid].fit(data["X"], data["y"],
                                     data["edge_index"], data["edge_weights"])

            # Step 2: Aggregate graph-convolved features
            all_X_conv = []
            all_y = []

            for cid, data in enumerate(client_datasets):
                if data["X"].shape[0] == 0:
                    continue

                X_conv_client = self.clients[cid].aggregate_neighbors(
                    data["X"], data["edge_index"], data["edge_weights"]
                )

                if apply_dp:
                    noise = np.random.normal(0, 0.01, X_conv_client.shape)
                    X_conv_client = X_conv_client + noise

                all_X_conv.append(X_conv_client)
                all_y.append(data["y"])

            # Step 3: Train global model
            if len(all_X_conv) > 0:
                X_global = np.vstack(all_X_conv)
                y_global = np.hstack(all_y)

                if self.algorithm_name in ["MLP", "AdaBoost"]:
                    X_global, y_global = apply_class_balancing(
                        X_global, y_global, self.class_weights, 42
                    )

                self.global_model.base_model.fit(X_global, y_global)

                global_params = self.global_model.get_params()
                for client in self.clients:
                    client.set_params(global_params)

            # Track convergence
            f1s = []
            losses = []
            for cid, data in enumerate(client_datasets):
                if data["X"].shape[0] == 0:
                    continue
                yp = self.clients[cid].predict(data["X"], data["edge_index"],
                                               data["edge_weights"])
                f1s.append(f1_score(data["y"], yp, zero_division=0))

                # Calculate loss (1 - F1)
                losses.append(1 - f1_score(data["y"], yp, zero_division=0))

            if len(f1s) > 0:
                self.convergence_history["round"].append(r + 1)
                self.convergence_history["train_f1"].append(np.mean(f1s))
                self.convergence_history["train_loss"].append(np.mean(losses))

# =============================================================================
# PART 7: EVALUATION
# =============================================================================

def evaluate_sgc(model, X, y, edge_index, edge_weights):
    if X.shape[0] == 0 or len(y) == 0:
        return 0, 0, 0, 0, 0.5, 0, np.array([]), np.array([])

    y_proba = model.predict_proba(X, edge_index, edge_weights)[:, 1]
    y_pred = (y_proba >= 0.5).astype(int)

    acc = accuracy_score(y, y_pred)
    prec = precision_score(y, y_pred, zero_division=0)
    rec = recall_score(y, y_pred, zero_division=0)
    f1 = f1_score(y, y_pred, zero_division=0)
    roc = roc_auc_score(y, y_proba) if len(np.unique(y)) > 1 else 0.5
    prauc = average_precision_score(y, y_proba) if len(np.unique(y)) > 1 else 0.0

    return acc, prec, rec, f1, roc, prauc, y_pred, y_proba

# =============================================================================
# PART 8: MAIN EXPERIMENT
# =============================================================================

print("\\n" + "="*100)
print(f"MAIN EXPERIMENT (N={N_RUNS} runs)")
print("="*100)

results_fed_sgc = []
results_fed_sgc_dp = []
results_cent_sgc = []
results_baselines = {name: [] for name in comparison_df['Algorithm'].values if comparison_df is not None}

# Store client distribution info
client_distribution_info = []

convergence_tracker = None
all_y_test = []
all_ypred_fg = []
all_yproba_fg = []

for run_id, seed in enumerate(RANDOM_SEEDS, start=1):
    print(f"\\n{'='*30} RUN {run_id}/{N_RUNS} {'='*30}")
    np.random.seed(seed)

    idx_all = np.arange(len(df_enc))
    y_all = df_enc["Severity"].values
    train_idx, test_idx = train_test_split(idx_all, test_size=TEST_SIZE,
                                           stratify=y_all, random_state=seed)

    train_df = df_enc.iloc[train_idx].reset_index(drop=True)
    test_df = df_enc.iloc[test_idx].reset_index(drop=True)

    X_train_ohe = train_df[ohe_cols].values
    X_test_ohe = test_df[ohe_cols].values
    y_train = train_df["Severity"].values
    y_test = test_df["Severity"].values

    # --- Centralized SGC ---
    scaler_cent = LocalStandardScaler()
    X_train_cent = scaler_cent.fit_transform(X_train_ohe)
    X_test_cent = scaler_cent.transform(X_test_ohe)

    edge_train_cent, ew_train_cent = build_local_similarity_graph(
        X_train_cent, K_NEIGHBORS_OPTIMAL, seed
    )
    edge_test_cent, ew_test_cent = build_local_similarity_graph(
        X_test_cent, K_NEIGHBORS_OPTIMAL, seed
    )

    sgc_cent = SimplifiedGraphConvolution(X_train_cent.shape[1],
                                         BEST_ALGORITHM_NAME, CLASS_WEIGHTS, seed)
    sgc_cent.fit(X_train_cent, y_train, edge_train_cent, ew_train_cent)

    acc_c, _, _, f1_c, roc_c, pr_c, _, _ = evaluate_sgc(
        sgc_cent, X_test_cent, y_test, edge_test_cent, ew_test_cent
    )
    results_cent_sgc.append([acc_c, f1_c, roc_c, pr_c])

    # --- Baselines (All Algorithms) ---
    if comparison_df is not None:
        for name in comparison_df['Algorithm'].values:
            if name == BEST_ALGORITHM_NAME:
                continue

            if name == "XGBoost":
                model = xgb.XGBClassifier(
                    n_estimators=100, max_depth=6, learning_rate=0.1,
                    scale_pos_weight=class_weight_1/class_weight_0,
                    random_state=seed, eval_metric='logloss', verbosity=0
                )
            elif name == "LightGBM":
                model = lgb.LGBMClassifier(
                    n_estimators=100, max_depth=6, learning_rate=0.1,
                    class_weight='balanced', random_state=seed, verbosity=-1
                )
            elif name == "CatBoost":
                model = CatBoostClassifier(
                    iterations=100, depth=6, learning_rate=0.1,
                    class_weights=[class_weight_0, class_weight_1],
                    random_state=seed, verbose=0
                )
            elif name == "RandomForest":
                model = RandomForestClassifier(
                    n_estimators=100, max_depth=10, class_weight='balanced',
                    random_state=seed, n_jobs=-1
                )
            elif name == "ExtraTrees":
                model = ExtraTreesClassifier(
                    n_estimators=100, max_depth=10, class_weight='balanced',
                    random_state=seed, n_jobs=-1
                )
            elif name == "AdaBoost":
                model = AdaBoostClassifier(
                    n_estimators=100, learning_rate=1.0, random_state=seed
                )
            elif name == "GradientBoosting":
                model = GradientBoostingClassifier(
                    n_estimators=100, max_depth=6, learning_rate=0.1, random_state=seed
                )
            elif name == "LogisticRegression":
                model = LogisticRegression(
                    class_weight='balanced', max_iter=1000, random_state=seed, n_jobs=-1
                )
            elif name == "DecisionTree":
                model = DecisionTreeClassifier(
                    max_depth=10, class_weight='balanced', random_state=seed
                )
            elif name == "MLP":
                model = MLPClassifier(
                    hidden_layer_sizes=(64, 32), activation='relu', solver='adam',
                    max_iter=200, random_state=seed, early_stopping=True,
                    validation_fraction=0.1, verbose=False
                )

            model.fit(X_train_ohe, y_train)
            y_proba = model.predict_proba(X_test_ohe)[:, 1]
            y_pred = (y_proba >= 0.5).astype(int)

            results_baselines[name].append([
                accuracy_score(y_test, y_pred),
                f1_score(y_test, y_pred, zero_division=0),
                roc_auc_score(y_test, y_proba),
                average_precision_score(y_test, y_proba)
            ])

    # --- Federated Setup ---
    client_indices = create_non_iid_clients(train_df, NUM_CLIENTS, seed)

    client_datasets = []
    for cid, global_idx in enumerate(client_indices):
        if len(global_idx) == 0:
            continue

        Xc_ohe = X_train_ohe[global_idx]
        yc = y_train[global_idx]

        scaler_local = LocalStandardScaler()
        Xc = scaler_local.fit_transform(Xc_ohe)

        edge_local, ew_local = build_local_similarity_graph(Xc, K_NEIGHBORS_OPTIMAL, seed+cid)

        client_datasets.append({
            "X": Xc, "y": yc,
            "edge_index": edge_local,
            "edge_weights": ew_local
        })

        # Track client distribution
        if run_id == 1:
            client_distribution_info.append({
                'Client': f'Client {cid+1}',
                'Samples': len(yc),
                'Class_0': np.sum(yc == 0),
                'Class_1': np.sum(yc == 1),
                'Imbalance_Ratio': np.sum(yc == 0) / max(np.sum(yc == 1), 1)
            })

    # --- Fed-SGC (No DP) ---
    fed_sgc = FederatedSGC(
        len(client_datasets), X_train_ohe.shape[1], BEST_ALGORITHM_NAME,
        CLASS_WEIGHTS, DP_EPSILON, DP_DELTA, DP_CLIP_NORM
    )

    fed_sgc.train(client_datasets, FED_ROUNDS, apply_dp=False)

    if run_id == N_RUNS:
        convergence_tracker = fed_sgc.convergence_history

    scaler_test = LocalStandardScaler()
    X_test_scaled = scaler_test.fit_transform(X_test_ohe)
    edge_test, ew_test = build_local_similarity_graph(X_test_scaled, K_NEIGHBORS_OPTIMAL, seed)

    acc_fg, _, _, f1_fg, roc_fg, pr_fg, ypred_fg, yproba_fg = evaluate_sgc(
        fed_sgc.global_model, X_test_scaled, y_test, edge_test, ew_test
    )
    results_fed_sgc.append([acc_fg, f1_fg, roc_fg, pr_fg])

    # Store for aggregate plots
    all_y_test.extend(y_test)
    all_ypred_fg.extend(ypred_fg)
    all_yproba_fg.extend(yproba_fg)

    # --- Fed-SGC WITH DP ---
    fed_sgc_dp = FederatedSGC(
        len(client_datasets), X_train_ohe.shape[1], BEST_ALGORITHM_NAME,
        CLASS_WEIGHTS, DP_EPSILON, DP_DELTA, DP_CLIP_NORM
    )

    fed_sgc_dp.train(client_datasets, FED_ROUNDS, apply_dp=True)

    acc_dp, _, _, f1_dp, roc_dp, pr_dp, _, _ = evaluate_sgc(
        fed_sgc_dp.global_model, X_test_scaled, y_test, edge_test, ew_test
    )
    results_fed_sgc_dp.append([acc_dp, f1_dp, roc_dp, pr_dp])

    print(f"  Central-SGC: F1={f1_c:.3f} | Fed-SGC: F1={f1_fg:.3f} | Fed+DP: F1={f1_dp:.3f}")

# =============================================================================
# PART 9: COMPREHENSIVE RESULTS
# =============================================================================

cols = ["Accuracy", "F1", "ROC_AUC", "PR_AUC"]

res_fed = pd.DataFrame(results_fed_sgc, columns=cols)
res_fed_dp = pd.DataFrame(results_fed_sgc_dp, columns=cols)
res_cent = pd.DataFrame(results_cent_sgc, columns=cols)

print("\\n" + "="*100)
print("FINAL RESULTS - COMPREHENSIVE COMPARISON")
print("="*100)

# Build comprehensive summary
all_results = {
    f"Fed-SGC ({BEST_ALGORITHM_NAME})": res_fed,
    f"Fed-SGC+DP ({BEST_ALGORITHM_NAME})": res_fed_dp,
    f"Central-SGC ({BEST_ALGORITHM_NAME})": res_cent,
}

for name, results in results_baselines.items():
    if len(results) > 0:
        all_results[name] = pd.DataFrame(results, columns=cols)

summary = pd.DataFrame({
    name: df.mean().round(4).astype(str) + " ± " + df.std().round(4).astype(str)
    for name, df in all_results.items()
}, index=cols).T

print(summary.to_string())

# Save comprehensive comparison table
print("\\n" + "="*100)
print("SAVING COMPREHENSIVE COMPARISON TABLE")
print("="*100)

comparison_table = []
for name, df in all_results.items():
    comparison_table.append({
        'Algorithm': name,
        'F1': df['F1'].mean(),
        'ROC-AUC': df['ROC_AUC'].mean(),
        'Recall': 0.0,
        'Time(s)': 0.0,
        'Score': 0.5 * df['F1'].mean() + 0.3 * df['ROC_AUC'].mean()
    })

final_comparison_df = pd.DataFrame(comparison_table)
final_comparison_df = final_comparison_df.sort_values('Score', ascending=False)
final_comparison_df.to_csv('algorithm_comparison_final.csv', index=False)

print("\\n" + final_comparison_df.to_string(index=False))
print("\\n✓ Saved to: algorithm_comparison_final.csv")

# Save client distribution
client_dist_df = pd.DataFrame(client_distribution_info)
client_dist_df.to_csv('client_distribution.csv', index=False)
print("✓ Saved to: client_distribution.csv")

# =============================================================================
# PART 10: SEPARATE PUBLICATION-QUALITY FIGURES
# =============================================================================

print("\\n" + "="*100)
print("GENERATING PUBLICATION-QUALITY FIGURES (SEPARATE)")
print("="*100)

# Convert to numpy arrays
all_y_test = np.array(all_y_test)
all_ypred_fg = np.array(all_ypred_fg)
all_yproba_fg = np.array(all_yproba_fg)

# ============ FIGURE 1: ROC Curve ============
print("\\n1. Generating ROC Curve...")
fig1, ax1 = plt.subplots(figsize=(10, 8))

fpr_fg, tpr_fg, _ = roc_curve(all_y_test, all_yproba_fg)
roc_auc_fg = roc_auc_score(all_y_test, all_yproba_fg)

ax1.plot(fpr_fg, tpr_fg, label=f'Fed-SGC (AUC={roc_auc_fg:.3f})',
         linewidth=3, color='darkgreen')
ax1.plot([0, 1], [0, 1], 'k--', alpha=0.3, linewidth=2, label='Random Classifier')
ax1.set_xlabel("False Positive Rate", fontsize=14, fontweight='bold')
ax1.set_ylabel("True Positive Rate", fontsize=14, fontweight='bold')
ax1.set_title("Receiver Operating Characteristic (ROC) Curve", fontsize=16, fontweight='bold')
ax1.legend(fontsize=12, loc='lower right')
ax1.grid(alpha=0.3, linestyle='--')
plt.tight_layout()
plt.savefig("figure1_roc_curve.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure1_roc_curve.png")
plt.close()

# ============ FIGURE 2: Precision-Recall Curve ============
print("\\n2. Generating Precision-Recall Curve...")
fig2, ax2 = plt.subplots(figsize=(10, 8))

precision, recall, _ = precision_recall_curve(all_y_test, all_yproba_fg)
pr_auc = average_precision_score(all_y_test, all_yproba_fg)

ax2.plot(recall, precision, label=f'Fed-SGC (AP={pr_auc:.3f})',
         linewidth=3, color='darkblue')
ax2.set_xlabel("Recall", fontsize=14, fontweight='bold')
ax2.set_ylabel("Precision", fontsize=14, fontweight='bold')
ax2.set_title("Precision-Recall Curve", fontsize=16, fontweight='bold')
ax2.legend(fontsize=12, loc='lower left')
ax2.grid(alpha=0.3, linestyle='--')
plt.tight_layout()
plt.savefig("figure2_precision_recall.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure2_precision_recall.png")
plt.close()

# ============ FIGURE 3: Confusion Matrix ============
print("\\n3. Generating Confusion Matrix...")
fig3, ax3 = plt.subplots(figsize=(10, 8))

cm = confusion_matrix(all_y_test, all_ypred_fg)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True, ax=ax3,
            xticklabels=['No Amputation', 'Amputation'],
            yticklabels=['No Amputation', 'Amputation'],
            annot_kws={'size': 16, 'fontweight': 'bold'})
ax3.set_xlabel("Predicted Label", fontsize=14, fontweight='bold')
ax3.set_ylabel("True Label", fontsize=14, fontweight='bold')
ax3.set_title("Confusion Matrix - Federated SGC", fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig("figure3_confusion_matrix.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure3_confusion_matrix.png")
plt.close()

# ============ FIGURE 4: Model Comparison ============
print("\\n4. Generating Model Comparison...")
fig4, ax4 = plt.subplots(figsize=(12, 8))

top_models = final_comparison_df.head(10)
colors = ['darkgreen' if 'Fed-SGC' in name else 'steelblue' for name in top_models['Algorithm']]

bars = ax4.barh(range(len(top_models)), top_models['F1'], alpha=0.8,
                edgecolor='black', linewidth=1.5, color=colors)
ax4.set_yticks(range(len(top_models)))
ax4.set_yticklabels(top_models['Algorithm'], fontsize=11)
ax4.set_xlabel("F1-Score", fontsize=14, fontweight='bold')
ax4.set_title("Model Performance Comparison (Top 10)", fontsize=16, fontweight='bold')
ax4.grid(axis='x', alpha=0.3, linestyle='--')
ax4.invert_yaxis()

# Add value labels
for i, (idx, row) in enumerate(top_models.iterrows()):
    ax4.text(row['F1'] + 0.005, i, f"{row['F1']:.3f}",
             va='center', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig("figure4_model_comparison.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure4_model_comparison.png")
plt.close()

# ============ FIGURE 5: Privacy-Utility Tradeoff ============
print("\\n5. Generating Privacy-Utility Tradeoff...")
fig5, ax5 = plt.subplots(figsize=(10, 8))

pareto_plot = pareto_df[pareto_df['epsilon'] != np.inf]
ax5.plot(pareto_plot['privacy_strength'], pareto_plot['f1'],
         marker='o', markersize=12, linewidth=3, color='purple',
         label='Privacy-Utility Frontier')

for _, row in pareto_plot.iterrows():
    ax5.annotate(f"ε={row['epsilon']}",
                xy=(row['privacy_strength'], row['f1']),
                xytext=(10, -10), textcoords='offset points',
                fontsize=11, fontweight='bold',
                bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.3))

ax5.set_xlabel("Privacy Strength (1/ε)", fontsize=14, fontweight='bold')
ax5.set_ylabel("F1-Score", fontsize=14, fontweight='bold')
ax5.set_title("Privacy-Utility Tradeoff Analysis", fontsize=16, fontweight='bold')
ax5.legend(fontsize=12)
ax5.grid(alpha=0.3, linestyle='--')
plt.tight_layout()
plt.savefig("figure5_privacy_utility.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure5_privacy_utility.png")
plt.close()

# ============ FIGURE 6: Convergence Analysis ============
print("\\n6. Generating Convergence Analysis...")
fig6, (ax6a, ax6b) = plt.subplots(1, 2, figsize=(16, 6))

if convergence_tracker and len(convergence_tracker["round"]) > 0:
    # F1 convergence
    ax6a.plot(convergence_tracker["round"], convergence_tracker["train_f1"],
             marker='o', linewidth=3, color='darkblue', markersize=10,
             label='Training F1-Score')
    ax6a.set_xlabel("Communication Round", fontsize=14, fontweight='bold')
    ax6a.set_ylabel("F1-Score", fontsize=14, fontweight='bold')
    ax6a.set_title("Federated Learning Convergence (F1)", fontsize=16, fontweight='bold')
    ax6a.legend(fontsize=12)
    ax6a.grid(alpha=0.3, linestyle='--')

    # Loss convergence
    ax6b.plot(convergence_tracker["round"], convergence_tracker["train_loss"],
             marker='s', linewidth=3, color='darkred', markersize=10,
             label='Training Loss')
    ax6b.set_xlabel("Communication Round", fontsize=14, fontweight='bold')
    ax6b.set_ylabel("Loss (1 - F1)", fontsize=14, fontweight='bold')
    ax6b.set_title("Federated Learning Convergence (Loss)", fontsize=16, fontweight='bold')
    ax6b.legend(fontsize=12)
    ax6b.grid(alpha=0.3, linestyle='--')

plt.tight_layout()
plt.savefig("figure6_convergence.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure6_convergence.png")
plt.close()

# ============ FIGURE 7: Algorithm Selection Scores ============
print("\\n7. Generating Algorithm Selection Scores...")
fig7, ax7 = plt.subplots(figsize=(12, 8))

if comparison_df is not None:
    x_pos = np.arange(len(comparison_df))
    colors_algo = ['gold' if i == 0 else 'steelblue' for i in range(len(comparison_df))]

    bars = ax7.bar(x_pos, comparison_df['Score'], alpha=0.8,
                   edgecolor='black', linewidth=1.5, color=colors_algo)
    ax7.set_xticks(x_pos)
    ax7.set_xticklabels(comparison_df['Algorithm'], rotation=45, ha='right', fontsize=11)
    ax7.set_ylabel("Combined Score (F1×0.5 + ROC×0.3 + Recall×0.2)", fontsize=12, fontweight='bold')
    ax7.set_title("Algorithm Selection Scores", fontsize=16, fontweight='bold')
    ax7.grid(axis='y', alpha=0.3, linestyle='--')

    # Highlight best
    ax7.axhline(y=comparison_df['Score'].max(), color='red', linestyle='--',
                linewidth=2, alpha=0.5, label='Best Score')
    ax7.legend(fontsize=12)

plt.tight_layout()
plt.savefig("figure7_algorithm_selection.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure7_algorithm_selection.png")
plt.close()

# ============ FIGURE 8: F1-Score Distribution ============
print("\\n8. Generating F1-Score Distribution...")
fig8, ax8 = plt.subplots(figsize=(14, 8))

# Select top 6 models for clarity
top_6_names = list(all_results.keys())[:6]
box_data = [all_results[name]['F1'].values for name in top_6_names]
box_labels = [name.replace(' ({})'.format(BEST_ALGORITHM_NAME), '') for name in top_6_names]

bp = ax8.boxplot(box_data, labels=box_labels, patch_artist=True,
                 notch=True, showmeans=True,
                 boxprops=dict(facecolor='lightblue', edgecolor='black', linewidth=1.5),
                 whiskerprops=dict(color='black', linewidth=1.5),
                 capprops=dict(color='black', linewidth=1.5),
                 medianprops=dict(color='red', linewidth=2),
                 meanprops=dict(marker='D', markerfacecolor='green', markersize=8))

ax8.set_xticklabels(box_labels, rotation=30, ha='right', fontsize=11)
ax8.set_ylabel("F1-Score", fontsize=14, fontweight='bold')
ax8.set_title("F1-Score Distribution Across Runs", fontsize=16, fontweight='bold')
ax8.grid(axis='y', alpha=0.3, linestyle='--')
plt.tight_layout()
plt.savefig("figure8_f1_distribution.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure8_f1_distribution.png")
plt.close()

# ============ FIGURE 9: Client Data Distribution ============
print("\\n9. Generating Client Data Distribution...")
fig9, (ax9a, ax9b) = plt.subplots(1, 2, figsize=(16, 6))

if len(client_distribution_info) > 0:
    client_dist_df = pd.DataFrame(client_distribution_info)

    # Samples per client
    ax9a.bar(range(len(client_dist_df)), client_dist_df['Samples'],
             alpha=0.8, edgecolor='black', linewidth=1.5, color='teal')
    ax9a.set_xticks(range(len(client_dist_df)))
    ax9a.set_xticklabels(client_dist_df['Client'], rotation=45, ha='right')
    ax9a.set_ylabel("Number of Samples", fontsize=14, fontweight='bold')
    ax9a.set_title("Data Distribution Across Clients", fontsize=16, fontweight='bold')
    ax9a.grid(axis='y', alpha=0.3, linestyle='--')

    # Class distribution per client (stacked bar)
    x_pos = np.arange(len(client_dist_df))
    ax9b.bar(x_pos, client_dist_df['Class_0'], label='Class 0 (No Amp)',
             alpha=0.8, edgecolor='black', linewidth=1.5, color='skyblue')
    ax9b.bar(x_pos, client_dist_df['Class_1'], bottom=client_dist_df['Class_0'],
             label='Class 1 (Amp)', alpha=0.8, edgecolor='black', linewidth=1.5,
             color='salmon')
    ax9b.set_xticks(x_pos)
    ax9b.set_xticklabels(client_dist_df['Client'], rotation=45, ha='right')
    ax9b.set_ylabel("Number of Samples", fontsize=14, fontweight='bold')
    ax9b.set_title("Class Distribution per Client", fontsize=16, fontweight='bold')
    ax9b.legend(fontsize=12)
    ax9b.grid(axis='y', alpha=0.3, linestyle='--')

plt.tight_layout()
plt.savefig("figure9_client_distribution.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure9_client_distribution.png")
plt.close()

# ============ FIGURE 10: Feature Category Breakdown ============
print("\\n10. Generating Feature Category Breakdown...")
fig10, ax10 = plt.subplots(figsize=(10, 8))

explode = [0.05 if i == 0 else 0 for i in range(len(category_df))]
colors_pie = plt.cm.Set3(range(len(category_df)))

wedges, texts, autotexts = ax10.pie(category_df['Unique_Values'],
                                     labels=category_df['Column'],
                                     autopct='%1.1f%%', startangle=90,
                                     colors=colors_pie, explode=explode,
                                     textprops={'fontsize': 12, 'fontweight': 'bold'})

ax10.set_title(f"Feature Category Distribution\\n(Total: {X_encoded.shape[1]} features)",
              fontsize=16, fontweight='bold')

# Add legend with counts
legend_labels = [f"{row['Column']}: {row['Unique_Values']}"
                for _, row in category_df.iterrows()]
ax10.legend(legend_labels, loc='upper left', bbox_to_anchor=(1, 1), fontsize=10)

plt.tight_layout()
plt.savefig("figure10_feature_breakdown.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure10_feature_breakdown.png")
plt.close()

# ============ FIGURE 11: Performance Metrics Radar Chart ============
print("\\n11. Generating Performance Metrics Radar Chart...")
fig11, ax11 = plt.subplots(figsize=(10, 10), subplot_kw=dict(projection='polar'))

# Select top 3 models
top_3_models = final_comparison_df.head(3)['Algorithm'].values
categories = ['F1', 'ROC-AUC', 'Accuracy', 'PR-AUC']
N = len(categories)

angles = [n / float(N) * 2 * np.pi for n in range(N)]
angles += angles[:1]

colors_radar = ['darkgreen', 'steelblue', 'orange']

for idx, model_name in enumerate(top_3_models):
    if model_name in all_results:
        values = [
            all_results[model_name]['F1'].mean(),
            all_results[model_name]['ROC_AUC'].mean(),
            all_results[model_name]['Accuracy'].mean(),
            all_results[model_name]['PR_AUC'].mean()
        ]
        values += values[:1]

        ax11.plot(angles, values, 'o-', linewidth=2, label=model_name,
                 color=colors_radar[idx], markersize=8)
        ax11.fill(angles, values, alpha=0.15, color=colors_radar[idx])

ax11.set_xticks(angles[:-1])
ax11.set_xticklabels(categories, fontsize=12, fontweight='bold')
ax11.set_ylim(0, 1)
ax11.set_title("Performance Metrics Comparison\\n(Top 3 Models)",
              fontsize=16, fontweight='bold', pad=20)
ax11.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1), fontsize=10)
ax11.grid(True, linestyle='--', alpha=0.3)

plt.tight_layout()
plt.savefig("figure11_radar_chart.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure11_radar_chart.png")
plt.close()

# ============ FIGURE 12: Statistical Significance Heatmap ============
print("\\n12. Generating Statistical Significance Heatmap...")
fig12, ax12 = plt.subplots(figsize=(10, 8))

# Calculate p-values for top models
top_5_for_sig = list(all_results.keys())[:5]
p_value_matrix = np.zeros((len(top_5_for_sig), len(top_5_for_sig)))

for i, model1 in enumerate(top_5_for_sig):
    for j, model2 in enumerate(top_5_for_sig):
        if i == j:
            p_value_matrix[i, j] = 0
        else:
            try:
                _, p = wilcoxon(all_results[model1]['F1'].values,
                               all_results[model2]['F1'].values)
                p_value_matrix[i, j] = p
            except:
                p_value_matrix[i, j] = 1.0

# Create heatmap
sns.heatmap(p_value_matrix, annot=True, fmt='.4f', cmap='RdYlGn_r',
            xticklabels=[m.split('(')[0].strip() for m in top_5_for_sig],
            yticklabels=[m.split('(')[0].strip() for m in top_5_for_sig],
            ax=ax12, cbar_kws={'label': 'p-value'},
            vmin=0, vmax=0.1, linewidths=1, linecolor='black')

ax12.set_title("Statistical Significance (Wilcoxon Test)\\np < 0.05 indicates significant difference",
              fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig("figure12_significance_heatmap.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure12_significance_heatmap.png")
plt.close()

print("\\n" + "="*100)
print("ALL FIGURES GENERATED SUCCESSFULLY!")
print("="*100)

# =============================================================================
# STATISTICAL SIGNIFICANCE
# =============================================================================

print("\\n" + "="*100)
print("STATISTICAL SIGNIFICANCE TESTS")
print("="*100)

def wilcoxon_test(name1, df1, name2, df2, metric="F1"):
    try:
        stat, p = wilcoxon(df1[metric].values, df2[metric].values)
        mean_diff = df1[metric].mean() - df2[metric].mean()
        sig = "✓ SIG" if p < 0.05 else "✗ NS"
        print(f"{name1:30s} vs {name2:30s}: Δ={mean_diff:+.4f}, p={p:.4e} [{sig}]")
    except:
        print(f"{name1:30s} vs {name2:30s}: Unable to compute")

wilcoxon_test("Fed-SGC", res_fed, "Central-SGC", res_cent)
wilcoxon_test("Fed-SGC", res_fed, "Fed-SGC+DP", res_fed_dp)
wilcoxon_test("Central-SGC", res_cent, "Fed-SGC+DP", res_fed_dp)

print("\\n" + "="*100)
print("EXPERIMENT COMPLETE - FINAL VERSION FOR 17K DATASET")
print("="*100)
print(f"\\n✓ Dataset: {len(df_enc):,} records")
print(f"✓ Features: {X_encoded.shape[1]:,} dimensions")
print(f"✓ Best Algorithm: {BEST_ALGORITHM_NAME}")
print(f"✓ Fed-SGC F1: {res_fed['F1'].mean():.3f} ± {res_fed['F1'].std():.3f}")
print(f"✓ Optimal Clients: {NUM_CLIENTS}")
print(f"✓ Optimal Rounds: {FED_ROUNDS}")
print(f"✓ Algorithms Tested: {len(comparison_df) if comparison_df is not None else 0}")
print(f"\\n✓ Generated Files:")
print(f"  - algorithm_comparison_final.csv")
print(f"  - client_distribution.csv")
print(f"  - figure1_roc_curve.png")
print(f"  - figure2_precision_recall.png")
print(f"  - figure3_confusion_matrix.png")
print(f"  - figure4_model_comparison.png")
print(f"  - figure5_privacy_utility.png")
print(f"  - figure6_convergence.png")
print(f"  - figure7_algorithm_selection.png")
print(f"  - figure8_f1_distribution.png")
print(f"  - figure9_client_distribution.png")
print(f"  - figure10_feature_breakdown.png")
print(f"  - figure11_radar_chart.png")
print(f"  - figure12_significance_heatmap.png")
print("="*100)


FEDERATED SIMPLIFIED GRAPH CONVOLUTION (SGC) FOR CONSTRUCTION SAFETY
FINAL VERSION - 17K Dataset with Comprehensive Validation
\nTotal records: 17,663
Class distribution:
Severity
0    0.847478
1    0.152522
\nClass weights:
  Class 0: 0.5900
  Class 1: 3.2782
\nOne-hot encoded features: 1,572 dimensions
DATA VERIFICATION & STATISTICS
\nDataset shape:
  Total records (rows):     17,663
  Total features (columns): 1,572
\nOriginal categorical columns breakdown:
  Nature              :  155 unique values (  9.9%)
  Part of Body        :  165 unique values ( 10.5%)
  Event               :  336 unique values ( 21.4%)
  Source              :  817 unique values ( 52.0%)
  Primary NAICS       :   45 unique values (  2.9%)
  State               :   54 unique values (  3.4%)
\nTotal one-hot features: 1,572
\n✓ All 17,663 records preserved!
✓ Each record has 1,572 features
\nClass Imbalance Analysis:
  Majority class (0):  14,969 (84.75%)
  Minority class (1):  2,694 (15.25%)
  Imbalance ratio: 

In [None]:
# ============================================================================
# ADDITIONAL PUBLICATION-QUALITY FIGURES
# Run this AFTER your main experiment completes
# ============================================================================

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (confusion_matrix, roc_curve, precision_recall_curve,
                             classification_report, roc_auc_score, f1_score)
from sklearn.calibration import calibration_curve
import warnings
warnings.filterwarnings("ignore")

print("="*100)
print("GENERATING ADDITIONAL PUBLICATION FIGURES")
print("="*100)

# ============================================================================
# FIGURE 13: Multi-Model ROC Comparison
# ============================================================================
print("\n13. Generating Multi-Model ROC Comparison...")

fig13, ax13 = plt.subplots(figsize=(12, 10))

# You should have these stored from your experiment
models_for_roc = {
    'Fed-SGC (XGBoost)': (all_y_test, all_yproba_fg),
}

colors_roc = ['darkgreen', 'steelblue', 'orange', 'purple', 'red']

for idx, (name, (y_true, y_proba)) in enumerate(models_for_roc.items()):
    fpr, tpr, _ = roc_curve(y_true, y_proba)
    auc_score = roc_auc_score(y_true, y_proba)

    ax13.plot(fpr, tpr, label=f'{name} (AUC={auc_score:.3f})',
             linewidth=3, color=colors_roc[idx % len(colors_roc)])

ax13.plot([0, 1], [0, 1], 'k--', alpha=0.3, linewidth=2, label='Random')
ax13.set_xlabel("False Positive Rate", fontsize=14, fontweight='bold')
ax13.set_ylabel("True Positive Rate", fontsize=14, fontweight='bold')
ax13.set_title("ROC Curves - Model Comparison", fontsize=16, fontweight='bold')
ax13.legend(fontsize=11, loc='lower right')
ax13.grid(alpha=0.3, linestyle='--')
plt.tight_layout()
plt.savefig("figure13_multi_roc.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure13_multi_roc.png")
plt.close()

# ============================================================================
# FIGURE 14: Calibration Plot (Reliability Diagram)
# ============================================================================
print("\n14. Generating Calibration Plot...")

fig14, ax14 = plt.subplots(figsize=(10, 10))

fraction_of_positives, mean_predicted_value = calibration_curve(
    all_y_test, all_yproba_fg, n_bins=10, strategy='uniform'
)

ax14.plot(mean_predicted_value, fraction_of_positives, "s-",
         linewidth=3, markersize=10, color='darkblue', label='Fed-SGC')
ax14.plot([0, 1], [0, 1], "k--", linewidth=2, label='Perfect Calibration')

ax14.set_xlabel("Mean Predicted Probability", fontsize=14, fontweight='bold')
ax14.set_ylabel("Fraction of Positives", fontsize=14, fontweight='bold')
ax14.set_title("Calibration Plot (Reliability Diagram)", fontsize=16, fontweight='bold')
ax14.legend(fontsize=12, loc='upper left')
ax14.grid(alpha=0.3, linestyle='--')
plt.tight_layout()
plt.savefig("figure14_calibration.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure14_calibration.png")
plt.close()

# ============================================================================
# FIGURE 15: Threshold Analysis
# ============================================================================
print("\n15. Generating Threshold Analysis...")

fig15, (ax15a, ax15b) = plt.subplots(1, 2, figsize=(16, 6))

thresholds = np.linspace(0, 1, 100)
f1_scores = []
precisions = []
recalls = []

for thresh in thresholds:
    y_pred_thresh = (all_yproba_fg >= thresh).astype(int)

    if len(np.unique(y_pred_thresh)) > 1:
        from sklearn.metrics import precision_score, recall_score
        f1_scores.append(f1_score(all_y_test, y_pred_thresh, zero_division=0))
        precisions.append(precision_score(all_y_test, y_pred_thresh, zero_division=0))
        recalls.append(recall_score(all_y_test, y_pred_thresh, zero_division=0))
    else:
        f1_scores.append(0)
        precisions.append(0)
        recalls.append(0)

ax15a.plot(thresholds, f1_scores, label='F1-Score', linewidth=3, color='darkgreen')
ax15a.plot(thresholds, precisions, label='Precision', linewidth=3, color='darkblue')
ax15a.plot(thresholds, recalls, label='Recall', linewidth=3, color='darkred')
ax15a.axvline(x=0.5, color='black', linestyle='--', linewidth=2, alpha=0.5, label='Default (0.5)')

optimal_idx = np.argmax(f1_scores)
optimal_thresh = thresholds[optimal_idx]
ax15a.axvline(x=optimal_thresh, color='gold', linestyle='--', linewidth=2,
             label=f'Optimal ({optimal_thresh:.2f})')

ax15a.set_xlabel("Classification Threshold", fontsize=14, fontweight='bold')
ax15a.set_ylabel("Score", fontsize=14, fontweight='bold')
ax15a.set_title("Metrics vs Classification Threshold", fontsize=16, fontweight='bold')
ax15a.legend(fontsize=11)
ax15a.grid(alpha=0.3, linestyle='--')

pos_predictions = [np.sum(all_yproba_fg >= t) for t in thresholds]
neg_predictions = [np.sum(all_yproba_fg < t) for t in thresholds]

ax15b.plot(thresholds, pos_predictions, label='Predicted Positive', linewidth=3, color='red')
ax15b.plot(thresholds, neg_predictions, label='Predicted Negative', linewidth=3, color='blue')
ax15b.axvline(x=0.5, color='black', linestyle='--', linewidth=2, alpha=0.5)
ax15b.set_xlabel("Classification Threshold", fontsize=14, fontweight='bold')
ax15b.set_ylabel("Number of Predictions", fontsize=14, fontweight='bold')
ax15b.set_title("Prediction Distribution vs Threshold", fontsize=16, fontweight='bold')
ax15b.legend(fontsize=11)
ax15b.grid(alpha=0.3, linestyle='--')

plt.tight_layout()
plt.savefig("figure15_threshold_analysis.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure15_threshold_analysis.png")
plt.close()

# ============================================================================
# FIGURE 16: Error Analysis
# ============================================================================
print("\n16. Generating Error Analysis...")

fig16, ((ax16a, ax16b), (ax16c, ax16d)) = plt.subplots(2, 2, figsize=(16, 14))

errors = all_ypred_fg != all_y_test
error_indices = np.where(errors)[0]
correct_indices = np.where(~errors)[0]

confidence = np.abs(all_yproba_fg - 0.5) * 2
confidence_bins = np.linspace(0, 1, 11)
error_rates = []

for i in range(len(confidence_bins)-1):
    mask = (confidence >= confidence_bins[i]) & (confidence < confidence_bins[i+1])
    if np.sum(mask) > 0:
        error_rates.append(np.mean(errors[mask]))
    else:
        error_rates.append(0)

ax16a.bar(range(len(error_rates)), error_rates, alpha=0.8,
         edgecolor='black', linewidth=1.5, color='salmon')
ax16a.set_xlabel("Confidence Bin", fontsize=12, fontweight='bold')
ax16a.set_ylabel("Error Rate", fontsize=12, fontweight='bold')
ax16a.set_title("Error Rate by Prediction Confidence", fontsize=14, fontweight='bold')
ax16a.grid(axis='y', alpha=0.3, linestyle='--')

cm = confusion_matrix(all_y_test, all_ypred_fg)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

sns.heatmap(cm_normalized, annot=True, fmt='.3f', cmap='Blues',
           cbar=True, ax=ax16b,
           xticklabels=['No Amputation', 'Amputation'],
           yticklabels=['No Amputation', 'Amputation'])
ax16b.set_xlabel("Predicted Label", fontsize=12, fontweight='bold')
ax16b.set_ylabel("True Label", fontsize=12, fontweight='bold')
ax16b.set_title("Normalized Confusion Matrix", fontsize=14, fontweight='bold')

ax16c.hist(all_yproba_fg[correct_indices], bins=50, alpha=0.7,
          label='Correct Predictions', color='green', edgecolor='black')
ax16c.hist(all_yproba_fg[error_indices], bins=50, alpha=0.7,
          label='Incorrect Predictions', color='red', edgecolor='black')
ax16c.axvline(x=0.5, color='black', linestyle='--', linewidth=2)
ax16c.set_xlabel("Predicted Probability", fontsize=12, fontweight='bold')
ax16c.set_ylabel("Frequency", fontsize=12, fontweight='bold')
ax16c.set_title("Prediction Distribution: Correct vs Incorrect", fontsize=14, fontweight='bold')
ax16c.legend(fontsize=11)
ax16c.grid(alpha=0.3, linestyle='--')

false_positives = np.sum((all_ypred_fg == 1) & (all_y_test == 0))
false_negatives = np.sum((all_ypred_fg == 0) & (all_y_test == 1))
true_positives = np.sum((all_ypred_fg == 1) & (all_y_test == 1))
true_negatives = np.sum((all_ypred_fg == 0) & (all_y_test == 0))

error_types = ['True Positive', 'True Negative', 'False Positive', 'False Negative']
error_counts = [true_positives, true_negatives, false_positives, false_negatives]
colors_error = ['green', 'lightgreen', 'orange', 'red']

bars = ax16d.bar(error_types, error_counts, alpha=0.8,
                edgecolor='black', linewidth=1.5, color=colors_error)
ax16d.set_ylabel("Count", fontsize=12, fontweight='bold')
ax16d.set_title("Prediction Outcomes Distribution", fontsize=14, fontweight='bold')
ax16d.grid(axis='y', alpha=0.3, linestyle='--')

for bar, count in zip(bars, error_counts):
    height = bar.get_height()
    ax16d.text(bar.get_x() + bar.get_width()/2., height,
              f'{int(count)}\n({count/len(all_y_test)*100:.1f}%)',
              ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig("figure16_error_analysis.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure16_error_analysis.png")
plt.close()

# ============================================================================
# FIGURE 17: Performance Metrics Breakdown
# ============================================================================
print("\n17. Generating Performance Metrics Breakdown...")

fig17, ax17 = plt.subplots(figsize=(14, 8))

from sklearn.metrics import accuracy_score, precision_score, recall_score

metrics = {
    'Accuracy': accuracy_score(all_y_test, all_ypred_fg),
    'Precision': precision_score(all_y_test, all_ypred_fg, zero_division=0),
    'Recall': recall_score(all_y_test, all_ypred_fg, zero_division=0),
    'F1-Score': f1_score(all_y_test, all_ypred_fg, zero_division=0),
    'Specificity': true_negatives / (true_negatives + false_positives) if (true_negatives + false_positives) > 0 else 0,
    'NPV': true_negatives / (true_negatives + false_negatives) if (true_negatives + false_negatives) > 0 else 0
}

x_pos = np.arange(len(metrics))
values = list(metrics.values())
colors_metrics = ['steelblue', 'darkgreen', 'orange', 'purple', 'teal', 'brown']

bars = ax17.bar(x_pos, values, alpha=0.8, edgecolor='black',
               linewidth=1.5, color=colors_metrics)
ax17.set_xticks(x_pos)
ax17.set_xticklabels(metrics.keys(), fontsize=12, fontweight='bold')
ax17.set_ylabel("Score", fontsize=14, fontweight='bold')
ax17.set_title("Comprehensive Performance Metrics", fontsize=16, fontweight='bold')
ax17.set_ylim([0, 1.05])
ax17.grid(axis='y', alpha=0.3, linestyle='--')

for bar, val in zip(bars, values):
    height = bar.get_height()
    ax17.text(bar.get_x() + bar.get_width()/2., height + 0.01,
             f'{val:.3f}',
             ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig("figure17_metrics_breakdown.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure17_metrics_breakdown.png")
plt.close()

# ============================================================================
# FIGURE 18: Client Performance Comparison
# ============================================================================
print("\n18. Generating Client Performance Comparison...")

fig18, (ax18a, ax18b) = plt.subplots(1, 2, figsize=(16, 6))

try:
    client_dist_df = pd.read_csv('client_distribution.csv')

    colors_client = plt.cm.Set3(range(len(client_dist_df)))
    bars1 = ax18a.bar(range(len(client_dist_df)), client_dist_df['Samples'],
                     alpha=0.8, edgecolor='black', linewidth=1.5, color=colors_client)
    ax18a.set_xticks(range(len(client_dist_df)))
    ax18a.set_xticklabels(client_dist_df['Client'], rotation=45, ha='right')
    ax18a.set_ylabel("Number of Samples", fontsize=14, fontweight='bold')
    ax18a.set_title("Data Distribution Across Clients", fontsize=16, fontweight='bold')
    ax18a.grid(axis='y', alpha=0.3, linestyle='--')

    for bar, val in zip(bars1, client_dist_df['Samples']):
        height = bar.get_height()
        ax18a.text(bar.get_x() + bar.get_width()/2., height,
                  f'{int(val)}',
                  ha='center', va='bottom', fontsize=9, fontweight='bold')

    bars2 = ax18b.barh(range(len(client_dist_df)), client_dist_df['Imbalance_Ratio'],
                      alpha=0.8, edgecolor='black', linewidth=1.5, color=colors_client)
    ax18b.set_yticks(range(len(client_dist_df)))
    ax18b.set_yticklabels(client_dist_df['Client'])
    ax18b.set_xlabel("Imbalance Ratio (Class 0 : Class 1)", fontsize=14, fontweight='bold')
    ax18b.set_title("Class Imbalance per Client", fontsize=16, fontweight='bold')
    ax18b.grid(axis='x', alpha=0.3, linestyle='--')
    ax18b.axvline(x=client_dist_df['Imbalance_Ratio'].mean(),
                 color='red', linestyle='--', linewidth=2, label='Mean')
    ax18b.legend()

    plt.tight_layout()
    plt.savefig("figure18_client_comparison.png", dpi=300, bbox_inches='tight')
    print("   ✓ Saved: figure18_client_comparison.png")
except:
    print("   ⚠ Client distribution data not found, skipping...")

plt.close()

# ============================================================================
# FIGURE 19: Algorithm Performance Heatmap
# ============================================================================
print("\n19. Generating Algorithm Performance Heatmap...")

fig19, ax19 = plt.subplots(figsize=(12, 8))

try:
    comparison_final = pd.read_csv('algorithm_comparison_final.csv')

    metrics_for_heatmap = ['F1', 'ROC-AUC', 'Score']
    top_n_algorithms = 8

    heatmap_data = comparison_final.head(top_n_algorithms)[metrics_for_heatmap].values

    sns.heatmap(heatmap_data.T, annot=True, fmt='.3f', cmap='YlGnBu',
               cbar_kws={'label': 'Performance Score'},
               xticklabels=comparison_final.head(top_n_algorithms)['Algorithm'],
               yticklabels=metrics_for_heatmap,
               ax=ax19, linewidths=1, linecolor='black')

    ax19.set_title("Algorithm Performance Heatmap (Top 8)", fontsize=16, fontweight='bold')
    plt.xticks(rotation=45, ha='right')

    plt.tight_layout()
    plt.savefig("figure19_algorithm_heatmap.png", dpi=300, bbox_inches='tight')
    print("   ✓ Saved: figure19_algorithm_heatmap.png")
except:
    print("   ⚠ Algorithm comparison data not found, skipping...")

plt.close()

# ============================================================================
# FIGURE 20: Class Distribution Analysis
# ============================================================================
print("\n20. Generating Class Distribution Analysis...")

fig20, ((ax20a, ax20b), (ax20c, ax20d)) = plt.subplots(2, 2, figsize=(16, 14))

class_counts = [np.sum(all_y_test == 0), np.sum(all_y_test == 1)]
colors_pie = ['lightblue', 'salmon']
explode = (0.05, 0.05)

ax20a.pie(class_counts, labels=['No Amputation', 'Amputation'],
         autopct='%1.1f%%', startangle=90, colors=colors_pie,
         explode=explode, textprops={'fontsize': 12, 'fontweight': 'bold'})
ax20a.set_title("Test Set Class Distribution", fontsize=14, fontweight='bold')

pred_counts = [np.sum(all_ypred_fg == 0), np.sum(all_ypred_fg == 1)]
x_pos = [0, 1]
width = 0.35

bars1 = ax20b.bar([p - width/2 for p in x_pos], class_counts, width,
                 label='Actual', alpha=0.8, edgecolor='black', color='steelblue')
bars2 = ax20b.bar([p + width/2 for p in x_pos], pred_counts, width,
                 label='Predicted', alpha=0.8, edgecolor='black', color='orange')

ax20b.set_xticks(x_pos)
ax20b.set_xticklabels(['No Amputation', 'Amputation'])
ax20b.set_ylabel("Count", fontsize=12, fontweight='bold')
ax20b.set_title("Actual vs Predicted Class Distribution", fontsize=14, fontweight='bold')
ax20b.legend(fontsize=11)
ax20b.grid(axis='y', alpha=0.3, linestyle='--')

ax20c.hist(all_yproba_fg[all_y_test == 0], bins=50, alpha=0.7,
          label='Actual: No Amputation', color='blue', edgecolor='black')
ax20c.hist(all_yproba_fg[all_y_test == 1], bins=50, alpha=0.7,
          label='Actual: Amputation', color='red', edgecolor='black')
ax20c.axvline(x=0.5, color='black', linestyle='--', linewidth=2, label='Threshold')
ax20c.set_xlabel("Predicted Probability", fontsize=12, fontweight='bold')
ax20c.set_ylabel("Frequency", fontsize=12, fontweight='bold')
ax20c.set_title("Probability Distribution by True Class", fontsize=14, fontweight='bold')
ax20c.legend(fontsize=11)
ax20c.grid(alpha=0.3, linestyle='--')

report = classification_report(all_y_test, all_ypred_fg,
                              target_names=['No Amp', 'Amp'],
                              output_dict=True)

report_data = pd.DataFrame(report).iloc[:2, :3].T
sns.heatmap(report_data, annot=True, fmt='.3f', cmap='RdYlGn',
           ax=ax20d, vmin=0.8, vmax=1.0, linewidths=1, linecolor='black',
           cbar_kws={'label': 'Score'})
ax20d.set_title("Classification Report Heatmap", fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig("figure20_class_distribution.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure20_class_distribution.png")
plt.close()

# ============================================================================
# FIGURE 21: Summary Table
# ============================================================================
print("\n21. Generating Training Progress Summary...")

fig21, ax21 = plt.subplots(figsize=(12, 8))

summary_data = {
    'Metric': ['Dataset Size', 'Features', 'Clients', 'Rounds', 'K-Neighbors',
               'Best Algorithm', 'F1-Score', 'ROC-AUC', 'Privacy ε'],
    'Value': [f'{len(df_enc):,}', f'{X_encoded.shape[1]:,}', str(NUM_CLIENTS),
              str(FED_ROUNDS), str(K_NEIGHBORS_OPTIMAL), BEST_ALGORITHM_NAME,
              f'{res_fed["F1"].mean():.4f}', f'{res_fed["ROC_AUC"].mean():.4f}',
              str(DP_EPSILON)]
}

summary_df = pd.DataFrame(summary_data)

ax21.axis('tight')
ax21.axis('off')

table = ax21.table(cellText=summary_df.values,
                  colLabels=summary_df.columns,
                  cellLoc='left',
                  loc='center',
                  colWidths=[0.4, 0.6])

table.auto_set_font_size(False)
table.set_fontsize(12)
table.scale(1, 3)

for i in range(len(summary_df.columns)):
    table[(0, i)].set_facecolor('#4CAF50')
    table[(0, i)].set_text_props(weight='bold', color='white', fontsize=14)

for i in range(1, len(summary_df) + 1):
    for j in range(len(summary_df.columns)):
        if i % 2 == 0:
            table[(i, j)].set_facecolor('#f0f0f0')
        else:
            table[(i, j)].set_facecolor('white')

ax21.set_title("Experiment Configuration & Results Summary",
              fontsize=16, fontweight='bold', pad=20)

plt.tight_layout()
plt.savefig("figure21_summary_table.png", dpi=300, bbox_inches='tight')
print("   ✓ Saved: figure21_summary_table.png")
plt.close()

print("\n" + "="*100)
print("ADDITIONAL FIGURES COMPLETE!")
print("="*100)
print("\nGenerated:")
print("  - figure13_multi_roc.png")
print("  - figure14_calibration.png")
print("  - figure15_threshold_analysis.png")
print("  - figure16_error_analysis.png")
print("  - figure17_metrics_breakdown.png")
print("  - figure18_client_comparison.png")
print("  - figure19_algorithm_heatmap.png")
print("  - figure20_class_distribution.png")
print("  - figure21_summary_table.png")
print("="*100)


GENERATING ADDITIONAL PUBLICATION FIGURES

13. Generating Multi-Model ROC Comparison...
   ✓ Saved: figure13_multi_roc.png

14. Generating Calibration Plot...
   ✓ Saved: figure14_calibration.png

15. Generating Threshold Analysis...
   ✓ Saved: figure15_threshold_analysis.png

16. Generating Error Analysis...
   ✓ Saved: figure16_error_analysis.png

17. Generating Performance Metrics Breakdown...
   ✓ Saved: figure17_metrics_breakdown.png

18. Generating Client Performance Comparison...
   ✓ Saved: figure18_client_comparison.png

19. Generating Algorithm Performance Heatmap...
   ✓ Saved: figure19_algorithm_heatmap.png

20. Generating Class Distribution Analysis...
   ✓ Saved: figure20_class_distribution.png

21. Generating Training Progress Summary...
   ✓ Saved: figure21_summary_table.png

ADDITIONAL FIGURES COMPLETE!

Generated:
  - figure13_multi_roc.png
  - figure14_calibration.png
  - figure15_threshold_analysis.png
  - figure16_error_analysis.png
  - figure17_metrics_breakdown.