# Anomaly Detection in Bipartite Graphs - Model Comparison

This notebook compares different models for anomaly detection in bipartite graphs.

## 0. Reload

In [None]:
# Load the extension
%load_ext autoreload
# Configure it to reload all modules before each cell execution
%autoreload 2

In [None]:
import sys
import os
sys.path.append('..')

# --- MPS Fallback ---
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
print(f"PYTORCH_ENABLE_MPS_FALLBACK set to: {os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK')}")

import torch
import numpy as np
import matplotlib.pyplot as plt

import torch_geometric.transforms as T
import torch_geometric
from sklearn.metrics import precision_recall_curve

# Import custom modules
from src.data.dataloader import load_member_features, load_provider_features, load_claims_data, load_claims_data_with_splitting, prepare_hetero_data, prepare_hetero_data_with_splitting
from src.data.anomaly_injection import  * 
from src.models.main_model import *
from src.models.baseline_models import MLPAutoencoder, SklearnBaseline, GCNAutoencoder, GATAutoencoder, SAGEAutoencoder
from src.utils.vizualize import *
from src.utils.train_utils import *
from src.utils.eval_utils import *
from src.utils.stat_utils import *

## 1. Load and Prepare Data

In [None]:
df_member_features, members_dataset = load_member_features("../data/final_members_df.pickle")
df_provider_features, providers_dataset = load_provider_features("../data/final_df.pickle")
df_edges = load_claims_data("../data/df_descriptions.pickle", members_dataset=members_dataset, providers_dataset=providers_dataset)
df_edges, train_edges, val_edges, test_edges = load_claims_data_with_splitting("../data/df_descriptions.pickle", members_dataset=members_dataset, providers_dataset=providers_dataset)


import pickle
with open("../data/df_descriptions.pickle", 'rb') as pickle_file:
        df = pickle.load(pickle_file)

print(f"Members: {len(members_dataset)}")
print(f"Providers: {len(providers_dataset)}")
print(f"Edges: {len(df_edges)}")

print("\nMember features:")
display(df_member_features.head())
print("\nProvider features:")
display(df_provider_features.head())
print("\nEdges:")
display(df_edges.head())


In [None]:
# Create HeteroData object
data = prepare_hetero_data(df_member_features, df_provider_features, df_edges)
# Split the data into training, validation and testing sets based on temporal split 
train_data, val_data, test_data = prepare_hetero_data_with_splitting(df_member_features, df_provider_features, train_edges, val_edges, test_edges)
print(data)
print(f"Member features shape: {data['member'].x.shape}")
print(f"Provider features shape: {data['provider'].x.shape}")
print(f"Number of edges: {data['provider', 'to', 'member'].edge_index.shape[1]}")

## 2. Injecting Anomalies

In [None]:

from src.data.new_anomaly_injection import *


modified_graph, node_labels, edge_labels, tracking = inject_scenario_anomalies(
        data,
        p_node_anomalies=0.05,   # Target 10% of nodes per type
        p_edge_anomalies=0.1,  # Target 5% of edges to be anomalous
        lambda_structural=0.5, # 50% structural, 50% attribute focus
        seed=42
    )

    
summarize_injected_anomalies(
      modified_graph,
      node_labels,
      edge_labels,
      tracking
 )

In [None]:
summarize_injected_anomalies(
      modified_graph,
      node_labels,
      edge_labels,
      tracking
 )

## 3. Splitting into training, validation and testing set

### 3.1 Inductive Setting

In [28]:
from src.utils.data_utils import *

In [7]:
# Split the graph into train, val, and test sets (using temporal split)
train_g, val_g, test_g = train_data, val_data, test_data

In [None]:
# New version 
from src.data.new_anomaly_injection import *

# Split the graph into train, val, and test sets (using temporal split)
train_g, val_g, test_g = train_data, val_data, test_data

p_node_anomalies = 0.05
p_edge_anomalies = 0.1
lambda_structural = 0.5
seed = 42

# Injecting anomalies in the training subgraph using the new version
train_g, gt_node_labels_train, gt_edge_labels_dict_train, final_anomaly_tracking_train = inject_scenario_anomalies(
        train_g,
        p_node_anomalies=p_node_anomalies,   # Target 5% of nodes per type
        p_edge_anomalies=p_edge_anomalies,  # Target 5% of edges to be anomalous
        lambda_structural=lambda_structural, # 50% structural, 50% attribute focus
        seed=seed
    )
print("\n--- Training Graph Anomaly Summary ---")
summarize_injected_anomalies(
      train_g   ,
      gt_node_labels_train,
      gt_edge_labels_dict_train,
      final_anomaly_tracking_train
 )
 
 # Injecting anomalies in the validation subgraph using the new version
val_g, gt_node_labels_val, gt_edge_labels_dict_val, final_anomaly_tracking_val = inject_scenario_anomalies(
    val_g,
    p_node_anomalies=p_node_anomalies,   # Target 10% of nodes per type
    p_edge_anomalies=p_edge_anomalies,  # Target 5% of edges to be anomalous
    lambda_structural=lambda_structural, # 50% structural, 50% attribute focus
    seed=seed 
)
print("\n--- Validation Graph Anomaly Summary ---")
summarize_injected_anomalies(
      val_g,
      gt_node_labels_val,
      gt_edge_labels_dict_val,
      final_anomaly_tracking_val
)  

# Injecting anomalies in the testing subgraph using the new version
test_g, gt_node_labels_test, gt_edge_labels_dict_test, final_anomaly_tracking_test = inject_scenario_anomalies(
    test_g,
    p_node_anomalies=p_node_anomalies,   # Target 10% of nodes per type
    p_edge_anomalies=p_edge_anomalies,  # Target 5% of edges to be anomalous
    lambda_structural=lambda_structural, # 50% structural, 50% attribute focus
    seed=seed )
print("\n--- Testing Graph Anomaly Summary ---")
summarize_injected_anomalies(
      test_g,
      gt_node_labels_test,
      gt_edge_labels_dict_test,
      final_anomaly_tracking_test
 )


comparison_input = {
        'Train': (train_g, gt_node_labels_train, gt_edge_labels_dict_train, final_anomaly_tracking_train),
        'Validation': (val_g, gt_node_labels_val, gt_edge_labels_dict_val, final_anomaly_tracking_val),
        'Test': (test_g, gt_node_labels_test, gt_edge_labels_dict_test, final_anomaly_tracking_test)
    }

# Store ground truth node labels and anomaly tracking for each split
GT_NODE_LABELS = {
    "train": gt_node_labels_train,
    "val": gt_node_labels_val,
    "test": gt_node_labels_test
}

GT_EDGE_LABELS = {
    "train": gt_edge_labels_dict_train,
    "val": gt_edge_labels_dict_val,
    "test": gt_edge_labels_dict_test
}

ANOMALY_TRACKING = {
    "train": final_anomaly_tracking_train,
    "val": final_anomaly_tracking_val,
    "test": final_anomaly_tracking_test
}

# --- Run the Comparison ---
compare_anomaly_splits(comparison_input)



In [None]:
print(tuning_df)

In [None]:
# Vizualize Anomalies 
visualize_anomaly_sample(
        modified_graph,
        node_labels,
        edge_labels,
        tracking,
        num_instances_per_scenario=2, # Sample 2 nodes/instances per scenario
        num_normal_nodes_per_type=30, # Increased normal nodes
        neighborhood_hops=1,
        provider_node_type='provider',
        member_node_type='member',
        output_filename="anomaly_sample_vis_shapes_full_struct.html",
        notebook=False
    )

## 4. Models' comparison

In [None]:
# Check if MPS is available
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

### 4.1 Main model 

In [None]:
# Move data to device for all splits
print(f"Moving data to device: {device}")
data_on_device = train_g.to(device)
train_g_on_device = train_g.to(device)
val_g_on_device = val_g.to(device)
test_g_on_device = test_g.to(device)

# Also move ground truth labels (needed for validation during training) for all splits
gt_node_labels_on_device = {k: v.to(device) for k, v in gt_node_labels_train.items()}
gt_node_labels_train_on_device = {k: v.to(device) for k, v in gt_node_labels_train.items()}
gt_node_labels_val_on_device = {k: v.to(device) for k, v in gt_node_labels_val.items()}
gt_node_labels_test_on_device = {k: v.to(device) for k, v in gt_node_labels_test.items()}

# --- Model Hyperparameters ---
# Set dimensions based on data
in_dim_member = data_on_device['member'].x.size(1)
in_dim_provider = data_on_device['provider'].x.size(1)
hidden_dim = 64 
latent_dim = 32 

# Check if edge_attr exists and get its dimension
target_edge_type_for_dim = ('provider', 'to', 'member')
if hasattr(data_on_device[target_edge_type_for_dim], 'edge_attr') and data_on_device[target_edge_type_for_dim].edge_attr is not None:
    edge_dim = data_on_device[target_edge_type_for_dim].edge_attr.size(1)
    print(f"Edge dimension (edge_dim) detected as: {edge_dim}")
else:
    edge_dim = 0 # Or handle as error if edge attributes are expected
    print("Warning: No edge_attr found for target edge type. Setting edge_dim=0.")

num_conv_layers = 2
num_dec_layers = 2
dropout_rate = 0.5 

# --- Instantiate the Main Model ---
# This will be the single model instance we train and evaluate
bgae_model = BipartiteGraphAutoEncoder_ReportBased(
    in_dim_member=in_dim_member,
    in_dim_provider=in_dim_provider,
    edge_dim=edge_dim,
    hidden_dim=hidden_dim,
    latent_dim=latent_dim,
    num_conv_layers=num_conv_layers,
    num_dec_layers=num_dec_layers,
    dropout=dropout_rate
).to(device)

model = torch_geometric.compile(bgae_model)

print("\nBipartite GAE Model Instantiated:")
print(bgae_model)

# --- Optimizer Definition ---
learning_rate = 5e-4 
weight_decay = 1e-5  
optimizer = torch.optim.Adam(bgae_model.parameters(), lr=learning_rate, weight_decay=weight_decay)

print("\nOptimizer Defined.")

In [30]:
from src.utils.train_utils import *

In [None]:
# --- Define training hyperparameters ---
num_epochs = 200      # Use the intended number of epochs
lambda_attr = 0.5      # Weight for attribute loss (both train and val scoring)
lambda_struct = 0.5    # Weight for structure loss (both train and val scoring)
k_neg_train = 5        # Negative samples for TRAINING loss calculation
val_k_neg_score = 0    # Negative samples for VALIDATION scoring (set to 0 based on user's val_k_neg)
eval_freq = 10         # Evaluate on validation set every 10 epochs
target_edge_type = ('provider', 'to', 'member') # Primary edge type
node_k_list_eval = [50, 100, 200] # K values for node P@K/R@K calculation during validation

# --- Define Early Stopping & Saving Parameters ---
early_stop_metric_choice = 'AP'         # Metric to monitor ('AUROC', 'AP', 'Best F1')
early_stop_element_choice = 'Avg AP'  # Element to monitor ('provider', 'member', or str(target_edge_type))
early_stop_patience_checks = 50         # How many validation CHECKS to wait for improvement
best_model_save_path = "./saved_models/best_bgae_model.pth" # CORRECTED: Full file path


print(f"\n--- Starting Training Run ---")
print(f"Epochs: {num_epochs}, Val Freq: {eval_freq}, Patience: {early_stop_patience_checks} checks")
print(f"Early Stopping: Monitor '{early_stop_element_choice}' '{early_stop_metric_choice}'")
print(f"Saving best model to: {best_model_save_path}")

# Call the enhanced training function
trained_model, history, saved_best_model_filepath = train_model_inductive_with_metrics(
    model=bgae_model,
    train_graph=train_g_on_device,
    num_epochs=num_epochs,
    optimizer=optimizer,
    lambda_attr=lambda_attr,             # Used for training loss
    lambda_struct=lambda_struct,           # Used for training loss
    k_neg_samples=k_neg_train,           # Negative samples during training loss calc
    target_edge_type=target_edge_type,
    device=device,
    log_freq=eval_freq,                 # How often to print logs

    # --- Validation Arguments ---
    val_graph=val_g_on_device,
    gt_node_labels_val=GT_NODE_LABELS["val"], # Pass node labels for validation metrics
    gt_edge_labels_val=GT_EDGE_LABELS["val"], # Pass edge labels for validation metrics
    val_log_freq=eval_freq,              # How often validation logic runs

    # --- Validation Scoring Arguments ---
    val_lambda_attr=lambda_attr,         # Use same lambda for consistency in validation scoring
    val_lambda_struct=lambda_struct,       # Use same lambda for consistency in validation scoring
    val_k_neg_samples_score=val_k_neg_score, # K for negative sampling during validation *score* calc (set to 0 by user)

    # --- Node Metrics Argument ---
    node_k_list=node_k_list_eval,        # K values for P@K, R@K calculation

    # --- Early Stopping Arguments ---
    early_stopping_metric=early_stop_metric_choice,
    early_stopping_element=early_stop_element_choice,
    patience=early_stop_patience_checks,
    save_best_model_path=best_model_save_path # Pass the file path
)

print("\n--- Training Call Finished ---")
if saved_best_model_filepath:
    print(f"Best model artifact saved at: {saved_best_model_filepath}")
    # The 'trained_model' variable returned should now hold the state of the best model.
else:
    print("No best model was saved (check logs for early stopping/improvement status).")


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from typing import Dict, Optional, Tuple, List
import os

# Ensure necessary functions are available (if using PR curves later)
# from sklearn.metrics import average_precision_score, precision_recall_curve, PrecisionRecallDisplay

# Set a visually appealing style for the plots
sns.set_theme(style="whitegrid", palette="deep", font_scale=1.1)

# Define check_val_data helper function (outside main function)
def check_val_data(history: Dict, key: str, val_epochs: List) -> bool:
    """Checks if validation data for a given key exists and matches epoch length."""
    data = history.get(key, [])
    # Ensure data exists, is iterable, not empty, and matches val_epochs length
    return data is not None and hasattr(data, '__len__') and len(data) > 0 and len(val_epochs) > 0 and len(data) == len(val_epochs)

# --- Main Plotting Function ---
def plot_training_validation_performance_report_combined_loss(
    history: Dict,
    # lambda_attr: float, # REMOVED - No longer needed
    # lambda_struct: float, # REMOVED - No longer needed
    # val_lambda_attr: float, # REMOVED - No longer needed
    # val_lambda_struct: float, # REMOVED - No longer needed
    target_edge_type: tuple = ('provider', 'to', 'member'),
    figsize_loss: Tuple[int, int] = (10, 6),
    figsize_metrics: Tuple[int, int] = (10, 12),
    save_dir: Optional[str] = None
    ):
    """
    Generates publication-quality plots for training/validation loss (combined components)
    and validation metrics (separate plots per element type).
    """
    print("Generating Report Plots (Combined Losses and Validation Metrics)...")

    # --- Data Extraction and Validation ---
    train_epochs = np.arange(1, len(history.get('train_loss', [])) + 1)
    val_epochs_raw = history.get('epochs_validated', [])
    val_epochs = list(val_epochs_raw) if isinstance(val_epochs_raw, np.ndarray) else val_epochs_raw
    best_epoch = history.get('best_epoch', -1)

    # --- Loss Plotting Section (Already Modified) ---

    # Check which loss components exist
    has_train_attr = history.get('train_loss_attr') is not None and len(history.get('train_loss_attr', [])) > 0
    has_train_struct = history.get('train_loss_struct') is not None and len(history.get('train_loss_struct', [])) > 0
    has_val_attr = check_val_data(history, 'val_loss_attr', val_epochs) # Use helper
    has_val_struct = check_val_data(history, 'val_loss_struct', val_epochs) # Use helper

    has_any_components = has_train_attr or has_train_struct or has_val_attr or has_val_struct
    has_total_loss = (history.get('train_loss') is not None and len(history.get('train_loss', [])) > 0) or \
                     check_val_data(history, 'val_loss', val_epochs) # Use helper

    num_plots_loss = (1 if has_total_loss else 0) + (1 if has_any_components else 0)

    if num_plots_loss > 0:
        fig_height = figsize_loss[1] if num_plots_loss == 2 else figsize_loss[1] * 0.7 # Adjusted height factor
        fig_loss, axes_loss = plt.subplots(num_plots_loss, 1, figsize=(figsize_loss[0], fig_height), sharex=True, squeeze=False)
        axes_loss = axes_loss.flatten()

        fig_loss.suptitle('BGAE Model Loss Trajectories', fontsize=16, y=1.02)
        plot_idx_loss = 0
        grid_style = {'linestyle': ':', 'alpha': 0.7}
        line_width = 1.8

        # Plot 1: Total Loss
        if has_total_loss:
            ax = axes_loss[plot_idx_loss]
            plotted_total = False
            if history.get('train_loss') is not None and len(history.get('train_loss', [])) > 0:
                ax.plot(train_epochs, history['train_loss'], label='Train Total Loss', color='royalblue', linewidth=line_width, alpha=0.9)
                plotted_total = True
            if check_val_data(history, 'val_loss', val_epochs): # Use helper
                ax.plot(val_epochs, history['val_loss'], label='Validation Total Loss', color='darkorange', marker='.', linestyle='--', linewidth=line_width*0.9, markersize=5)
                plotted_total = True

            if plotted_total:
                ax.set_ylabel('Total Loss')
                # Add title only if it's not the only plot (to avoid redundancy with suptitle)
                if num_plots_loss > 1: ax.set_title('Total Training and Validation Loss')
                if best_epoch != -1:
                    best_epoch_label = f'Best Epoch ({best_epoch})' if plot_idx_loss == 0 else "_nolegend_"
                    ax.axvline(x=best_epoch, color='crimson', linestyle=':', linewidth=1.5, label=best_epoch_label)
                ax.legend(loc='best')
                ax.grid(True, **grid_style)
            else: ax.set_visible(False)
            plot_idx_loss += 1

        # Plot 2: Combined Loss Components
        if has_any_components:
            ax = axes_loss[plot_idx_loss]
            plotted_component = False
            if has_train_attr:
                ax.plot(train_epochs, history['train_loss_attr'], label='Train Attr Comp.', color='forestgreen', linestyle='-', linewidth=line_width*0.8, alpha=0.9)
                plotted_component = True
            if has_train_struct:
                ax.plot(train_epochs, history['train_loss_struct'], label='Train Struct Comp.', color='mediumpurple', linestyle='--', linewidth=line_width*0.8, alpha=0.9)
                plotted_component = True
            if has_val_attr:
                ax.plot(val_epochs, history['val_loss_attr'], label='Val Attr Comp.', color='limegreen', linestyle='-', marker='x', markersize=5, linewidth=line_width*0.7, alpha=0.8)
                plotted_component = True
            if has_val_struct:
                ax.plot(val_epochs, history['val_loss_struct'], label='Val Struct Comp.', color='blueviolet', linestyle='--', marker='+', markersize=6, linewidth=line_width*0.7, alpha=0.8)
                plotted_component = True

            if plotted_component:
                 ax.set_ylabel('Loss Component Value')
                 # Add title only if it's not the only plot
                 if num_plots_loss > 1: ax.set_title('Loss Components (Pre-Lambda Weights)')
                 if best_epoch != -1:
                     ax.axvline(x=best_epoch, color='crimson', linestyle=':', linewidth=1.5, label="_nolegend_")
                 ax.legend(loc='best')
                 ax.grid(True, **grid_style)
                 ax.set_xlabel('Epoch') # Add X label only to the last plot
            else: ax.set_visible(False)

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        if save_dir:
            try:
                os.makedirs(save_dir, exist_ok=True)
                save_loss_path = os.path.join(save_dir, 'report_loss_curves_combined.png')
                fig_loss.savefig(save_loss_path, dpi=300, bbox_inches='tight')
                print(f"Combined loss plot saved to {save_loss_path}")
            except Exception as e:
                print(f"Error saving combined loss plot: {e}")
        plt.show()
    else:
        print("No loss data found to plot.")


    # --- Figure 2: Validation AP Metrics (Single Plot with Average) ---
    print("--- Validation AP Metrics Plotting ---")
    metric_to_plot = 'AP'
    element_types = ['provider', 'member', 'edge']
    num_plots_metrics = 1

    has_any_val_ap = any(
        check_val_data(history, f'val_{elem}_{metric_to_plot}', val_epochs)
        for elem in element_types
    )

    if has_any_val_ap and len(val_epochs) > 0:
        fig_metrics, ax = plt.subplots(1, 1, figsize=figsize_metrics)

        # --- Set ONLY the main figure title ---
        fig_metrics.suptitle(f'BGAE Model Validation {metric_to_plot} vs. Epoch', fontsize=16, y=1.0) # Adjusted y slightly higher maybe

        element_colors = {'provider': 'mediumblue', 'member': 'darkgreen', 'edge': 'firebrick', 'average': 'black'} # Changed edge color slightly
        element_markers = {'provider': '.', 'member': '^', 'edge': 's', 'average': ''}
        element_linestyles = {'provider': '-', 'member': '--', 'edge': ':', 'average': '-.'}
        marker_size = 5
        line_width_metrics = 1.7
        best_epoch_style = {'color': 'crimson', 'linestyle': ':', 'linewidth': 1.5}
        # grid_style defined above is reused

        ap_values_for_avg = []
        elements_with_ap = []
        for element_type in element_types:
             key = f'val_{element_type}_{metric_to_plot}'
             if check_val_data(history, key, val_epochs):
                 ap_values_for_avg.append(np.array(history[key]))
                 elements_with_ap.append(element_type)
        average_ap = []
        if ap_values_for_avg:
            stacked_ap = np.vstack(ap_values_for_avg)
            with warnings.catch_warnings():
                 warnings.simplefilter("ignore", category=RuntimeWarning)
                 average_ap = np.nanmean(stacked_ap, axis=0)

        plotted_anything_on_axis = False
        for element_type in element_types:
            key = f'val_{element_type}_{metric_to_plot}'
            if check_val_data(history, key, val_epochs):
                label = 'Edge' if element_type == 'edge' else f'{element_type.capitalize()} Node'
                ax.plot(val_epochs, history[key], label=label,
                        color=element_colors[element_type], marker=element_markers[element_type],
                        linestyle=element_linestyles[element_type], markersize=marker_size,
                        linewidth=line_width_metrics)
                plotted_anything_on_axis = True

        if len(average_ap) > 0:
             ax.plot(val_epochs, average_ap, label='Average AP',
                     color=element_colors['average'], linestyle=element_linestyles['average'],
                     linewidth=line_width_metrics + 0.3)
             plotted_anything_on_axis = True

        # Configure the subplot axes and grid
        # *** REMOVED SUBPLOT TITLE ***
        # ax.set_title(f'Validation {metric_to_plot} Metric by Element Type')
        ax.set_ylabel(f'{metric_to_plot} Score')
        ax.set_xlabel('Epoch')
        ax.set_ylim(bottom=-0.05, top=0.7)
        ax.grid(True, **grid_style)

        if plotted_anything_on_axis:
             if best_epoch != -1:
                 ax.axvline(x=best_epoch, label=f'Best Epoch ({best_epoch})', **best_epoch_style)
             # *** ADDED fontsize parameter to legend ***
             ax.legend(title='Element Type', loc='best', fontsize='small') # Or 'x-small', 8, etc.

        plt.tight_layout(rect=[0, 0.03, 1, 0.96]) # Adjust rect to fit suptitle

        if save_dir:
            try:
                os.makedirs(save_dir, exist_ok=True)
                save_metrics_path = os.path.join(save_dir, f'report_validation_{metric_to_plot}_metric.png')
                fig_metrics.savefig(save_metrics_path, dpi=300, bbox_inches='tight')
                print(f"AP metrics plot saved to {save_metrics_path}")
            except Exception as e:
                print(f"Error saving {metric_to_plot} metrics plot: {e}")
        plt.show()
    else:
        print(f"No validation {metric_to_plot} data (or validation epochs) found to plot.")



# Create the example directory if it doesn't exist
save_directory = "report_plots_example"
if not os.path.exists(save_directory):
    os.makedirs(save_directory)

# Call the plotting function
plot_training_validation_performance_report_combined_loss(
     history=history,
     target_edge_type=('provider', 'to', 'member'),
     # *** Corrected save_dir to be a directory ***
     figsize_metrics=(10, 4),
     save_dir=save_directory
 )

In [49]:
def run_hyperparameter_tuning(
    num_trials: int,
    param_space: dict,
    train_g_dev: HeteroData,
    val_g_dev: HeteroData,
    gt_node_labels_val_dev: dict,
    gt_edge_labels_val_dev: dict,
    in_dim_member: int,
    in_dim_provider: int,
    edge_dim: int,
    device: str,
    # --- Assume BipartiteGraphAutoEncoder_ReportBased class is available ---
    num_epochs_per_trial: int = 200,
    early_stopping_metric: str = 'AP',
    early_stopping_element: str = 'Avg AP',
    patience: int = 10,
    val_freq: int = 10,
    target_edge: tuple = ('provider', 'to', 'member')
    ) -> Tuple[Optional[dict], pd.DataFrame]: # Return type includes Optional for best_params
    """
    Performs hyperparameter tuning using either random search or exhaustive
    grid search if the total number of valid combinations is less than num_trials.

    Args:
        

    Returns:
        Tuple[Optional[dict], pd.DataFrame]: Best parameters found (or None) and DataFrame of all trial results.
    """
    tuning_results_list = []
    best_overall_metric = -np.inf
    best_overall_params = None

    # --- 1. Generate and Filter All Valid Combinations ---
    print("--- Calculating All Valid Hyperparameter Combinations ---")
    param_names = list(param_space.keys())
    param_value_lists = [param_space[name] for name in param_names]

    all_possible_combinations = itertools.product(*param_value_lists)
    valid_combinations = []

    for combo_values in all_possible_combinations:
        params = dict(zip(param_names, combo_values))

        # Apply Constraints
        # a) lambda_attr + lambda_struct = 1
        #    (This is handled by how 'lambda_struct_only' is used below)
        # b) latent_dim <= hidden_dim
        if params['latent_dim'] > params['hidden_dim']:
            continue # Skip invalid combination
        # c) num_dec_layers = num_conv_layers (for symmetry)
        params['num_dec_layers'] = params['num_conv_layers']

        # d) Handle lambda derivation
        lambda_struct_sampled = params.pop('lambda_struct_only') # Remove temp key
        params['lambda_struct'] = lambda_struct_sampled          # Add final key
        # No need to store lambda_attr explicitly in params, it's derived during use

        valid_combinations.append(params)

    total_valid_combinations = len(valid_combinations)
    print(f"Total valid combinations found: {total_valid_combinations}")

    # --- 2. Decide Search Strategy and Prepare Trials ---
    if total_valid_combinations <= num_trials:
        print(f"\n--- Starting EXHAUSTIVE Grid Search ({total_valid_combinations} Trials) ---")
        search_mode = "Exhaustive"
        trials_to_run = valid_combinations # List of parameter dicts
        actual_num_trials = total_valid_combinations
    else:
        print(f"\n--- Starting RANDOM Search ({num_trials} Trials out of {total_valid_combinations} possible) ---")
        search_mode = "Random"
        trials_to_run = range(num_trials) # Just use the number for loop iterations
        actual_num_trials = num_trials

    print(f"  Objective: Maximize Validation '{early_stopping_element}' '{early_stopping_metric}'")

    # --- 3. Run Tuning Loop ---
    for i, trial_input in enumerate(trials_to_run):
        trial_num = i + 1
        print(f"\n--- Trial {trial_num}/{actual_num_trials} ({search_mode}) ---")

        # a) Get/Sample Hyperparameters for this trial
        if search_mode == "Exhaustive":
            params = trial_input # Directly use the dict from the list
        else: # Random Search
            # Sample randomly (could reuse the generation/filtering logic,
            # but sampling directly is simpler for pure random)
            while True: # Keep sampling until a valid combo is found
                params = {name: random.choice(values) for name, values in param_space.items()}
                if params['latent_dim'] <= params['hidden_dim']:
                    # Apply lambda constraint and symmetry
                    lambda_struct_sampled = params.pop('lambda_struct_only')
                    params['lambda_struct'] = lambda_struct_sampled
                    params['num_dec_layers'] = params['num_conv_layers']
                    break # Valid combo found

        # Calculate derived lambda_attr for use
        lambda_attr_calc = 1.0 - params['lambda_struct']
        print(f"  Params: {params} (Derived lambda_attr = {lambda_attr_calc:.2f})")

        # b) Instantiate Model and Optimizer
        try:
            model = BipartiteGraphAutoEncoder_ReportBased( # Use the passed model class
                in_dim_member=in_dim_member,
                in_dim_provider=in_dim_provider,
                edge_dim=edge_dim,
                hidden_dim=params['hidden_dim'],
                latent_dim=params['latent_dim'],
                num_conv_layers=params['num_conv_layers'],
                num_dec_layers=params['num_dec_layers'],
                dropout=params['dropout']
            ).to(device)
        except Exception as e:
             print(f"ERROR: Failed to instantiate model '{model_class.__name__}': {e}. Skipping trial.")
             continue

        # --- IMPORTANT: DO NOT USE torch.compile here due to previous errors ---
        # If compile is needed, debug it separately first.
        # model = torch.compile(model)
        # -------------------------------------------------------------------

        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=params['learning_rate'],
            weight_decay=params.get('weight_decay', 1e-5) # Assuming not tuned here
        )

        # c) Train Model
        print(f"  Starting training for trial {trial_num}...")
        try:
            # Make sure the training function is available
            _, history, _ = train_model_inductive_with_metrics(
                model=model,
                train_graph=train_g_dev,
                num_epochs=num_epochs_per_trial,
                optimizer=optimizer,
                lambda_attr=lambda_attr_calc, # Use calculated attr lambda for train loss
                lambda_struct=params['lambda_struct'], # Use sampled struct lambda for train loss
                k_neg_samples=5,
                target_edge_type=target_edge,
                device=device,
                log_freq=val_freq * 5, # Log even less often during tuning
                val_graph=val_g_dev,
                gt_node_labels_val=gt_node_labels_val_dev,
                gt_edge_labels_val=gt_edge_labels_val_dev,
                val_log_freq=val_freq,
                val_lambda_attr=lambda_attr_calc, # Use same derived lambdas for validation loss
                val_lambda_struct=params['lambda_struct'],
                val_k_neg_samples_score=1,
                node_k_list=[50, 100, 200],
                early_stopping_metric=early_stopping_metric,
                early_stopping_element=early_stopping_element,
                patience=patience,
                save_best_model_path=None # No saving during tuning
            )
        except Exception as e:
             print(f"ERROR during training for trial {trial_num}: {e}")
             history = {'best_val_metric_value': -np.inf, 'best_epoch': -1}


        # d) Record Results
        best_metric_in_trial = history.get('best_val_metric_value', -np.inf)
        best_epoch_in_trial = history.get('best_epoch', -1)
        print(f"  Trial {trial_num} Result: Best Val '{early_stopping_element}' '{early_stopping_metric}' = {best_metric_in_trial:.4f} (at epoch {best_epoch_in_trial})")

        trial_data = {'trial': trial_num, 'search_mode': search_mode}
        trial_data.update(params) # Store hyperparameters used
        trial_data['best_val_metric'] = best_metric_in_trial
        trial_data['best_epoch'] = best_epoch_in_trial
        tuning_results_list.append(trial_data)

        # Update overall best
        # Use >= to favor later trials slightly in case of exact ties
        if best_metric_in_trial >= best_overall_metric:
            # Add small tolerance check to avoid fluctuations due to floating point noise
            if best_metric_in_trial > best_overall_metric + 1e-7:
                 print(f"  ** New Best Overall Found! **")
            best_overall_metric = best_metric_in_trial
            # Need to store the params dict *before* lambda_struct_only was popped
            best_overall_params = next(
                 (p for p in valid_combinations if p == params), # Find the original dict in valid_combinations
                 params # Fallback just in case (shouldn't happen)
            )

    print("\n--- Hyperparameter Tuning Finished ---")

    # 4. Report Final Results
    results_df = pd.DataFrame(tuning_results_list)

    if best_overall_params:
        print("\n--- Best Hyperparameters Found ---")
        print(f"  Search Mode Run: {search_mode}")
        print(f"  Best Validation '{early_stopping_element}' '{early_stopping_metric}': {best_overall_metric:.4f}")
        print("  Optimal Parameters:")
        # Recalculate the derived lambda_attr for the final print
        best_lambda_struct = best_overall_params['lambda_struct']
        best_lambda_attr = 1.0 - best_lambda_struct
        print(f"    lambda_attr: {best_lambda_attr:.2f} (derived from constraint)")
        # Print other params from the best dict
        for param, value in best_overall_params.items():
             print(f"    {param}: {value}")

        # Add derived lambda_attr to the best_overall_params for completeness if needed later
        best_overall_params_final = best_overall_params.copy()
        best_overall_params_final['lambda_attr'] = best_lambda_attr

    else:
        best_overall_params_final = None # Ensure None is returned if no best params found
        print("No successful trials completed or no improvement found.")

    print("\n--- Tuning Results Summary (Top 5 by best_val_metric) ---")
    print(results_df.sort_values(by='best_val_metric', ascending=False).head().to_string()) # Use to_string for better console format

    return best_overall_params_final, results_df

In [None]:
# --- Define Search Space ---
# Reduced for quick example run - expand as needed
tuning_param_space = {
    'hidden_dim': [32, 128],
    'latent_dim': [8, 32], # Will be capped by hidden_dim
    'num_conv_layers': [1, 3],
    'learning_rate': [5e-4, 1e-3],
    'dropout': [0.3, 0.7],
    'lambda_struct_only': [0.25, 0.75] # Tune this, attr = 1 - this
}



best_params, tuning_df = run_hyperparameter_tuning_repeated(
    num_trials=10, 
    n_iter=3,
    param_space=tuning_param_space,
    train_g_dev=train_g_on_device,
    val_g_dev=val_g_on_device,
    gt_node_labels_val_dev=GT_NODE_LABELS["val"],
    gt_edge_labels_val_dev=GT_EDGE_LABELS["val"], # Pass edge labels
    in_dim_member=in_dim_member,
    in_dim_provider=in_dim_provider,
    edge_dim=edge_dim,
    device=device,
    num_epochs_per_trial=200, # Low epochs for quick example
    early_stopping_metric='AP',
    early_stopping_element='Avg AP', # Monitor average AP
    patience=15,  
    val_freq=10
)

print("\n--- Final Best Parameters from Tuning ---")
print(best_params)


In [154]:
def run_hyperparameter_tuning_repeated(
    # --- Inputs ---
    num_trials: int,
    param_space: dict,
    train_g_dev: HeteroData,
    val_g_dev: HeteroData,
    gt_node_labels_val_dev: dict,
    gt_edge_labels_val_dev: dict,
    in_dim_member: int,
    in_dim_provider: int,
    edge_dim: int,
    device: str,
    # --- NEW: Repetition Parameter ---
    n_iter: int = 1, # Number of times to repeat each trial
    # --- Fixed Training/Eval Parameters ---
    num_epochs_per_trial: int = 200,
    early_stopping_metric: str = 'AP',
    early_stopping_element: str = 'Avg AP',
    patience: int = 10,
    val_freq: int = 10,
    k_neg_train: int = 5,
    k_neg_val_score: int = 1,
    target_edge: tuple = ('provider', 'to', 'member')
    ) -> Tuple[Optional[dict], pd.DataFrame]:
    """
    Performs hyperparameter tuning (random or exhaustive) by repeating each
    trial `n_iter` times to estimate mean performance and confidence.

    Args:
        model_class: The GNN model class to instantiate (e.g., BipartiteGraphAutoEncoder_ReportBased).
        n_iter (int): Number of repetitions for each hyperparameter combination.
        (Other args similar to previous function)

    Returns:
        Tuple[Optional[dict], pd.DataFrame]:
            - Best parameters found based on mean performance (or None).
            - DataFrame summarizing all trials, including mean, std dev, and CI of the validation metric.
    """
    tuning_results_list = [] # Stores aggregated results per parameter combo
    best_overall_mean_metric = -np.inf
    best_overall_params = None
    # --- Input Validation ---
    if not isinstance(n_iter, int) or n_iter < 1:
        print("Warning: n_iter must be a positive integer. Setting n_iter=1.")
        n_iter = 1

    # --- 1. Generate and Filter All Valid Combinations ---
    print("--- Calculating All Valid Hyperparameter Combinations ---")
    param_names = list(param_space.keys())
    param_value_lists = [param_space[name] for name in param_names]
    all_possible_combinations = itertools.product(*param_value_lists)
    valid_combinations = []

    for combo_values in all_possible_combinations:
        params = dict(zip(param_names, combo_values))
        if params.get('latent_dim', 0) > params.get('hidden_dim', float('inf')): continue
        # Add other constraints if necessary
        # We handle lambda and layer symmetry within the loop now
        valid_combinations.append(params) # Store original sampled dict with lambda_struct_only

    total_valid_combinations = len(valid_combinations)
    print(f"Total valid combinations found (before constraints): {total_valid_combinations}") # Note: size constraint checked later

    # --- 2. Decide Search Strategy and Prepare Trials ---
    if total_valid_combinations <= num_trials:
        print(f"\n--- Starting EXHAUSTIVE Grid Search ({total_valid_combinations} combinations, {n_iter} runs each) ---")
        search_mode = "Exhaustive"
        # Use the pre-generated combinations (still need to filter based on constraints inside loop)
        param_combinations_to_run = list(itertools.product(*param_value_lists))
        actual_num_unique_configs = total_valid_combinations # May be lower after filtering
    else:
        print(f"\n--- Starting RANDOM Search ({num_trials} combinations, {n_iter} runs each, from {total_valid_combinations} possible) ---")
        search_mode = "Random"
        param_combinations_to_run = range(num_trials) # Generate random ones inside the loop
        actual_num_unique_configs = num_trials

    print(f"  Objective: Maximize Mean Validation '{early_stopping_element}' '{early_stopping_metric}' over {n_iter} runs")

    # --- 3. Run Tuning Loop (Iterating through Unique Configurations) ---
    processed_configs = set() # To avoid re-running same config in random search if sampled twice

    for i in range(actual_num_unique_configs):
        # --- a) Get/Sample Parameters for this Configuration ---
        if search_mode == "Exhaustive":
            # Get the i-th potential combo and check constraints
            combo_values = param_combinations_to_run[i]
            params_sampled = dict(zip(param_names, combo_values))
            # Check constraints
            if params_sampled.get('latent_dim', 0) > params_sampled.get('hidden_dim', float('inf')):
                continue # Skip this invalid combination
        else: # Random Search
            attempts = 0
            while attempts < 100: # Prevent infinite loop
                params_sampled = {name: random.choice(values) for name, values in param_space.items()}
                # Check constraints
                if params_sampled.get('latent_dim', 0) <= params_sampled.get('hidden_dim', float('inf')):
                     # Avoid duplicates in random search
                     params_tuple_key = tuple(sorted(params_sampled.items()))
                     if params_tuple_key not in processed_configs:
                         processed_configs.add(params_tuple_key)
                         break # Valid and unique combo found
                attempts += 1
            if attempts == 100:
                 print(f"Warning: Could not find a unique valid random combo after 100 attempts for trial {i+1}. Skipping.")
                 continue


        # Apply fixed constraints/derivations
        params = params_sampled.copy() # Work with a copy
        lambda_struct_sampled = params.pop('lambda_struct_only')
        params['lambda_struct'] = lambda_struct_sampled
        params['num_dec_layers'] = params['num_conv_layers']
        lambda_attr_calc = 1.0 - params['lambda_struct']

        print(f"\n--- Configuration {i + 1}/{actual_num_unique_configs} ({search_mode}) ---")
        print(f"  Params: {params} (Derived lambda_attr = {lambda_attr_calc:.2f})")
        print(f"  Running {n_iter} times...")

        run_metrics = [] # Store metrics for each run of this config
        run_epochs = []  # Store best epochs for each run

        # --- b) Inner Loop: Repeat n_iter times ---
        for run in range(n_iter):
            print(f"    Run {run + 1}/{n_iter}: Starting training...")
            # **CRITICAL: Instantiate NEW model and optimizer for each run**
            try:
                # Set seeds for this specific run for reproducibility *within* the run
                # This allows stochasticity between runs but consistency if a single run needs debugging
                run_seed = i * n_iter + run # Simple way to get a unique seed per run
                random.seed(run_seed)
                np.random.seed(run_seed)
                torch.manual_seed(run_seed)
                if device == 'cuda': torch.cuda.manual_seed_all(run_seed)

                model = BipartiteGraphAutoEncoder_ReportBased( # Use the passed model class
                    in_dim_member=in_dim_member,
                    in_dim_provider=in_dim_provider,
                    edge_dim=edge_dim,
                    hidden_dim=params['hidden_dim'],
                    latent_dim=params['latent_dim'],
                    num_conv_layers=params['num_conv_layers'],
                    num_dec_layers=params['num_dec_layers'],
                    dropout=params['dropout']
                ).to(device)
            except Exception as e:
                 print(f"    ERROR: Failed to instantiate model '{model_class.__name__}' on run {run+1}: {e}. Skipping run.")
                 run_metrics.append(np.nan) # Record failure
                 run_epochs.append(-1)
                 continue

            optimizer = torch.optim.Adam(
                model.parameters(),
                lr=params['learning_rate'],
                # weight_decay=params.get('weight_decay', 1e-5)
            )

            try:
                _, history, _ = train_model_inductive_with_metrics(
                    model=model,
                    train_graph=train_g_dev,
                    num_epochs=num_epochs_per_trial,
                    optimizer=optimizer,
                    lambda_attr=lambda_attr_calc,
                    lambda_struct=params['lambda_struct'],
                    k_neg_samples=k_neg_train,
                    target_edge_type=target_edge,
                    device=device,
                    log_freq=9999, # Suppress logs within the inner loop usually
                    val_graph=val_g_dev,
                    gt_node_labels_val=gt_node_labels_val_dev,
                    gt_edge_labels_val=gt_edge_labels_val_dev,
                    val_log_freq=val_freq,
                    val_lambda_attr=lambda_attr_calc,
                    val_lambda_struct=params['lambda_struct'],
                    val_k_neg_samples_score=k_neg_val_score,
                    node_k_list=[50], # Only need base metrics for early stopping
                    early_stopping_metric=early_stopping_metric,
                    early_stopping_element=early_stopping_element,
                    patience=patience,
                    save_best_model_path=None # No saving during tuning
                )
                # Store the best metric achieved in this specific run
                metric_value = history.get('best_val_metric_value', -np.inf)
                run_metrics.append(metric_value if metric_value > -np.inf else np.nan) # Store NaN if no improvement
                run_epochs.append(history.get('best_epoch', -1))

            except Exception as e:
                 print(f"    ERROR during training for config {i+1}, run {run+1}: {e}")
                 run_metrics.append(np.nan) # Record failure
                 run_epochs.append(-1)
            finally:
                 # Clean up GPU memory if applicable
                 del model
                 del optimizer
                 if torch.cuda.is_available(): torch.cuda.empty_cache()
                 elif torch.backends.mps.is_available(): torch.mps.empty_cache()


        # --- c) Aggregate Results for this Configuration ---
        valid_run_metrics = [m for m in run_metrics if not np.isnan(m)]
        if valid_run_metrics:
            mean_metric = np.mean(valid_run_metrics)
            std_metric = np.std(valid_run_metrics) if len(valid_run_metrics) > 1 else 0.0
            ci_radius = 1.96 * std_metric / np.sqrt(len(valid_run_metrics)) if len(valid_run_metrics) > 1 else 0.0
            mean_best_epoch = np.mean([e for e in run_epochs if e != -1]) if any(e != -1 for e in run_epochs) else -1
            n_successful_runs = len(valid_run_metrics)
        else:
            mean_metric = -np.inf
            std_metric = np.nan
            ci_radius = np.nan
            mean_best_epoch = -1
            n_successful_runs = 0

        print(f"  Config {i + 1} Aggregated Result ({n_successful_runs}/{n_iter} succ. runs): Mean '{early_stopping_element}' '{early_stopping_metric}' = {mean_metric:.4f} ± {std_metric:.4f}")

        # Store aggregated data
        trial_data = {'config_idx': i + 1, 'search_mode': search_mode, 'n_successful_runs': n_successful_runs}
        trial_data.update(params) # Store hyperparameters used
        trial_data[f'mean_{early_stopping_element}_{early_stopping_metric}'] = mean_metric
        trial_data[f'std_{early_stopping_element}_{early_stopping_metric}'] = std_metric
        trial_data['ci_95_radius'] = ci_radius
        trial_data['mean_best_epoch'] = mean_best_epoch
        tuning_results_list.append(trial_data)

        # Update overall best based on mean metric
        if mean_metric >= best_overall_mean_metric:
             # Add small tolerance check
            if mean_metric > best_overall_mean_metric + 1e-7:
                 print(f"    ** New Best Overall Mean Metric Found! **")
            # Check std dev as tie-breaker (lower is better)
            elif mean_metric == best_overall_mean_metric and std_metric < (tuning_results_list[np.argmax([t[f'mean_{early_stopping_element}_{early_stopping_metric}'] for t in tuning_results_list])][f'std_{early_stopping_element}_{early_stopping_metric}'] if best_overall_params else float('inf')):
                  print(f"    ** Same Mean Metric, Lower Std Dev Found! **")
            else: # Same mean, same or higher std dev -> keep existing best
                  continue

            best_overall_mean_metric = mean_metric
            # Store the original sampled params before lambda derivation etc.
            best_overall_params = params_sampled.copy() # Store the version with lambda_struct_only

    print("\n--- Hyperparameter Tuning Finished ---")

    # --- 4. Report Final Results ---
    results_df = pd.DataFrame(tuning_results_list)
    results_df = results_df.sort_values(by=f'mean_{early_stopping_element}_{early_stopping_metric}', ascending=False)


    if best_overall_params:
        print("\n--- Best Hyperparameters Found ---")
        print(f"  Search Mode Run: {search_mode}")
        print(f"  Best Mean Validation '{early_stopping_element}' '{early_stopping_metric}': {best_overall_mean_metric:.4f}")
        best_row = results_df.iloc[0] # Get row corresponding to best mean metric
        print(f"  Std Dev: {best_row[f'std_{early_stopping_element}_{early_stopping_metric}']:.4f} | 95% CI Radius: {best_row['ci_95_radius']:.4f} (over {best_row['n_successful_runs']} runs)")
        print("  Optimal Parameters (sampled values):")
        # Apply constraints again to print final derived params
        final_params_print = best_overall_params.copy()
        final_lambda_struct = final_params_print.pop('lambda_struct_only')
        final_params_print['lambda_struct'] = final_lambda_struct
        final_params_print['num_dec_layers'] = final_params_print['num_conv_layers']
        final_lambda_attr = 1.0 - final_lambda_struct
        print(f"    lambda_attr: {final_lambda_attr:.2f} (derived)")
        for param, value in final_params_print.items():
             print(f"    {param}: {value}")
        # Prepare the final parameter dict to return (matching structure of original best_overall_params)
        final_best_params_to_return = final_params_print # Return derived version
    else:
        final_best_params_to_return = None
        print("No successful trials completed or no improvement found.")

    print("\n--- Tuning Results Summary (Top 5 by Mean Metric) ---")
    print(results_df.head().to_string())

    # Save full results if needed
    # results_df.to_csv("full_tuning_results.csv", index=False)

    return final_best_params_to_return, results_df

In [None]:
# --- Define Search Space ---
# Reduced for quick example run - expand as needed
tuning_param_space_2 = {
    'hidden_dim': [32, 64],
    'latent_dim': [32], # Will be capped by hidden_dim
    'num_conv_layers': [1, 2],
    'learning_rate': [5e-4, 1e-3],
    'dropout': [0.7],
    'lambda_struct_only': [0.75] # Tune this, attr = 1 - this
}



best_params_test, tuning_df_test = run_hyperparameter_tuning_repeated(
    num_trials=2, 
    param_space=tuning_param_space_2,
    train_g_dev=train_g_on_device,
    val_g_dev=val_g_on_device,
    gt_node_labels_val_dev=GT_NODE_LABELS["val"],
    gt_edge_labels_val_dev=GT_EDGE_LABELS["val"], # Pass edge labels
    in_dim_member=in_dim_member,
    n_iter=2,
    in_dim_provider=in_dim_provider,
    edge_dim=edge_dim,
    device=device,
    num_epochs_per_trial=20, # Low epochs for quick example
    early_stopping_metric='AP',
    early_stopping_element='Avg AP', # Monitor average AP
    patience=30,  
    val_freq=10
)

print("\n--- Final Best Parameters from Tuning ---")
print(best_params_test)


In [None]:
import itertools # Needed for product to generate combinations
# --- Define Search Space ---
# Reduced for quick example run - expand as needed
tuning_param_space_3 = {
    'hidden_dim': [64],
    'latent_dim': [32], # Will be capped by hidden_dim
    'num_conv_layers': [1],
    'learning_rate': [5e-4],
    'dropout': [0.5, 0.7],
    'lambda_struct_only': [0.5, 0.75] # Tune this, attr = 1 - this
}



best_params_3, tuning_df_3 = run_hyperparameter_tuning(
    num_trials=5, 
    param_space=tuning_param_space_3,
    train_g_dev=train_g_on_device,
    val_g_dev=val_g_on_device,
    gt_node_labels_val_dev=GT_NODE_LABELS["val"],
    gt_edge_labels_val_dev=GT_EDGE_LABELS["val"], # Pass edge labels
    in_dim_member=in_dim_member,
    in_dim_provider=in_dim_provider,
    edge_dim=edge_dim,
    device=device,
    num_epochs_per_trial=300, # Low epochs for quick example
    early_stopping_metric='AP',
    early_stopping_element='Avg AP', # Monitor average AP
    patience=30,  
    val_freq=10
)

print("\n--- Final Best Parameters from Tuning ---")
print(best_params_3)


In [None]:
import itertools # Needed for product to generate combinations
# --- Define Search Space ---
# Reduced for quick example run - expand as needed
tuning_param_space_4 = {
    'hidden_dim': [64],
    'latent_dim': [32], # Will be capped by hidden_dim
    'num_conv_layers': [1],
    'learning_rate': [5e-4],
    'dropout': [0.5],
    'lambda_struct_only': [0.1, 0.25, 0.5, 0.6, 0.75, 0.9] # Tune this, attr = 1 - this
}



best_params_4, tuning_df_4 = run_hyperparameter_tuning_repeated(
    num_trials=6, 
    param_space=tuning_param_space_4,
    train_g_dev=train_g_on_device,
    val_g_dev=val_g_on_device,
    gt_node_labels_val_dev=GT_NODE_LABELS["val"],
    gt_edge_labels_val_dev=GT_EDGE_LABELS["val"], # Pass edge labels
    in_dim_member=in_dim_member,
    in_dim_provider=in_dim_provider,
    edge_dim=edge_dim,
    device=device,
    num_epochs_per_trial=300, # Low epochs for quick example
    early_stopping_metric='AP',
    early_stopping_element='Avg AP', # Monitor average AP
    patience=30,  
    val_freq=10
)

print("\n--- Final Best Parameters from Tuning ---")
print(best_params_4)


In [None]:
import itertools # Needed for product to generate combinations
# --- Define Search Space ---
# Reduced for quick example run - expand as needed
tuning_param_space_5 = {
    'hidden_dim': [64],
    'latent_dim': [32], # Will be capped by hidden_dim
    'num_conv_layers': [1],
    'learning_rate': [5e-4],
    'dropout': [0.5],
    'lambda_struct_only': [0.1, 0.25, 0.5, 0.6, 0.75, 0.9] # Tune this, attr = 1 - this
}



best_params_5, tuning_df_5 = run_hyperparameter_tuning_repeated(
    num_trials=6, 
    param_space=tuning_param_space_5,
    train_g_dev=train_g_on_device,
    val_g_dev=val_g_on_device,
    n_iter=3,
    gt_node_labels_val_dev=GT_NODE_LABELS["val"],
    gt_edge_labels_val_dev=GT_EDGE_LABELS["val"], # Pass edge labels
    in_dim_member=in_dim_member,
    in_dim_provider=in_dim_provider,
    edge_dim=edge_dim,
    device=device,
    num_epochs_per_trial=200, # Low epochs for quick example
    early_stopping_metric='AP',
    early_stopping_element='Avg AP', # Monitor average AP
    patience=15,  
    val_freq=50
)

print("\n--- Final Best Parameters from Tuning ---")
print(best_params_5)


In [None]:
best_params = best_params_5.copy()

# ---  Train Final Model ---
if best_params:
    print("\n--- Training Final Model with Best Parameters ---")
    # Combine train and validation for final training (optional, common practice)
    # final_train_g = combine_graphs(train_g_dev, val_g_dev) # Requires a helper function
    # final_gt_nodes = merge_labels(gt_node_labels_train_dev, gt_node_labels_val_dev) # Helper needed
    # Or just train longer on training set

    final_lambda_struct = best_params['lambda_struct']
    final_lambda_attr = 1.0 - final_lambda_struct
    num_epochs = 1000

    final_model = BipartiteGraphAutoEncoder_ReportBased(
        in_dim_member=in_dim_member,
        in_dim_provider=in_dim_provider,
        edge_dim=edge_dim,
        hidden_dim=best_params['hidden_dim'],
        latent_dim=best_params['latent_dim'],
        num_conv_layers=best_params['num_conv_layers'],
        num_dec_layers=best_params['num_dec_layers'],
        dropout=best_params['dropout']
    ).to(device)

    final_optimizer = torch.optim.Adam(final_model.parameters(), lr=best_params['learning_rate'], weight_decay=weight_decay)

    # Train for potentially more epochs or until convergence on train/val combined
    final_model, final_history, saved_best_final_model_filepath = train_model_inductive_with_metrics(
        model=final_model,
        train_graph=train_g_on_device,
        num_epochs=num_epochs,
        optimizer=final_optimizer,
        lambda_attr=final_lambda_attr,             # Used for training loss
        lambda_struct=final_lambda_struct,           # Used for training loss
        k_neg_samples=k_neg_train,           # Negative samples during training loss calc
        target_edge_type=target_edge_type,
        device=device,
        log_freq=eval_freq,                 # How often to print logs

        # --- Validation Arguments ---
        val_graph=val_g_on_device,
        gt_node_labels_val=GT_NODE_LABELS["val"], # Pass node labels for validation metrics
        gt_edge_labels_val=GT_EDGE_LABELS["val"], # Pass edge labels for validation metrics
        val_log_freq=eval_freq,              # How often validation logic runs

        # --- Validation Scoring Arguments ---
        val_lambda_attr=lambda_attr,         # Use same lambda for consistency in validation scoring
        val_lambda_struct=lambda_struct,       # Use same lambda for consistency in validation scoring
        val_k_neg_samples_score=val_k_neg_score, # K for negative sampling during validation *score* calc (set to 0 by user)

        # --- Node Metrics Argument ---
        node_k_list=node_k_list_eval,        # K values for P@K, R@K calculation

        # --- Early Stopping Arguments ---
        early_stopping_metric=early_stop_metric_choice,
        early_stopping_element=early_stop_element_choice,
        #patience=early_stop_patience_checks,
        patience=100,
        save_best_model_path=best_model_save_path # Pass the file path
    )
    # Then evaluate this final_model on the *test set* using evaluate_model_inductively

In [None]:
best_params

In [None]:
if not os.path.exists("report_plots_example"):
    os.makedirs("report_plots_example")
plot_training_validation_performance_report_combined_loss(
     history=final_history, # The dictionary returned from training
     target_edge_type=('provider', 'to', 'member'), # The edge type monitored
     figsize_metrics=(10,4),
     save_dir='training_validation_curves.png' # Optional save path
 )

In [169]:
best_params["lambda_attr"] = 1 - best_params["lambda_struct"]

In [None]:
# --- Calculate Anomaly Scores on Training and Validation Graphs ---
print("\n--- Calculating Anomaly Scores on Training and Validation Graphs ---")
# Use the same lambdas for scoring as for training loss weighting
node_scores_train, edge_scores_train = calculate_anomaly_scores(
    trained_model=final_model,
    eval_graph_data=train_g_on_device,
    lambda_attr=best_params["lambda_attr"],
    lambda_struct=best_params["lambda_struct"],
    target_edge_type=target_edge_type
)

node_scores_val, edge_scores_val = calculate_anomaly_scores(
    trained_model=final_model,
    eval_graph_data=val_g_on_device,
    lambda_attr=best_params["lambda_attr"],
    lambda_struct=best_params["lambda_struct"],
    target_edge_type=target_edge_type
)
print("--- Score Calculation Finished ---")

# --- Evaluate Performance on Training and Validation Graphs ---
print("\n--- Evaluating Performance on Training and Validation Graphs ---")
k_values_for_eval = [50, 100, 200, 500]

# Evaluate on Training Graph
results_train = evaluate_performance_inductive(
    node_scores=node_scores_train,
    edge_scores=edge_scores_train,
    gt_node_labels_eval=gt_node_labels_train,
    gt_edge_labels_eval=gt_edge_labels_dict_train,
    k_list=k_values_for_eval
)

# Evaluate on Validation Graph
results_val = evaluate_performance_inductive(
    node_scores=node_scores_val,
    edge_scores=edge_scores_val,
    gt_node_labels_eval=gt_node_labels_val,
    gt_edge_labels_eval=gt_edge_labels_dict_val,
    k_list=k_values_for_eval
)

# --- Display Results ---
print("\n--- Inductive Evaluation Results ---")

print("\nTraining Graph Node Results:")
df_node_results_train = pd.DataFrame(results_train['nodes']).T
display(df_node_results_train)

print("\nTraining Graph Edge Results:")
edge_results_train_str_keys = {str(k): v for k, v in results_train['edges'].items()}
df_edge_results_train = pd.DataFrame(edge_results_train_str_keys).T
display(df_edge_results_train)

print("\nValidation Graph Node Results:")
df_node_results_val = pd.DataFrame(results_val['nodes']).T
display(df_node_results_val)

print("\nValidation Graph Edge Results:")
edge_results_val_str_keys = {str(k): v for k, v in results_val['edges'].items()}
df_edge_results_val = pd.DataFrame(edge_results_val_str_keys).T
display(df_edge_results_val)

print("\n--- Analysis Complete ---")

In [None]:
# --- Calculate Anomaly Scores on Test Graph ---
print("\n--- Calculating Anomaly Scores on Test Graph ---")
# Use the same lambdas for scoring as for training loss weighting
node_scores_test, edge_scores_test = calculate_anomaly_scores(
    trained_model=final_model,
    eval_graph_data=test_g_on_device,
    lambda_attr=best_params["lambda_attr"],
    lambda_struct=best_params["lambda_struct"],
    target_edge_type=target_edge_type
)
print("--- Score Calculation Finished ---")

# --- Evaluate Performance on Test Graph ---
print("\n--- Evaluating Performance on Test Graph ---")
k_values_for_eval = [50, 100, 200, 500]

# Evaluate on Test Graph
results_test = evaluate_performance_inductive(
    node_scores=node_scores_test,
    edge_scores=edge_scores_test,
    gt_node_labels_eval=gt_node_labels_test,
    gt_edge_labels_eval=gt_edge_labels_dict_test,
    k_list=k_values_for_eval
)

# --- Display Results ---
print("\nTest Graph Node Results:")
df_node_results_test = pd.DataFrame(results_test['nodes']).T
display(df_node_results_test)

print("\nTest Graph Edge Results:")
edge_results_test_str_keys = {str(k): v for k, v in results_test['edges'].items()}
df_edge_results_test = pd.DataFrame(edge_results_test_str_keys).T
display(df_edge_results_test)

print("\n--- Test Graph Analysis Complete ---")

In [67]:
from src.utils.train_utils import calculate_anomaly_scores

In [None]:
from src.utils.eval_utils import *
k_neg = k_values_for_eval.copy()

ev_params = {
    "k_list": [50, 100, 200, 500],
    "lambda_attr" : best_params["lambda_attr"],
    "lambda_struct" : best_params["lambda_struct"],
    "k_neg_samples": k_neg,
}
all_scores, summary_df, anomaly_type_df = evaluate_model_inductively(
    trained_model=final_model,
    train_graph=train_g,
    val_graph=val_g,
    test_graph=test_g,
    gt_node_labels=GT_NODE_LABELS,
    gt_edge_labels=GT_EDGE_LABELS,
    anomaly_tracking_all=ANOMALY_TRACKING, 
    device=device,
    eval_params=ev_params,
    target_edge_type=target_edge_type,
    plot=True,
    verbose=True
)

print("\n--- Overall Metrics Summary ---")
display(summary_df)

print("\n--- Anomaly Type Performance Summary ---")
display(anomaly_type_df)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, Tuple, Optional, List
import os
from collections import defaultdict

# Set a suitable style for report plots
sns.set_theme(style="whitegrid", palette="viridis", font_scale=1.1) # Using viridis palette

def extract_main_category(tag):
    """Helper function to extract the main category from the tag."""
    if pd.isna(tag):
        return 'Unknown'
    if isinstance(tag, str):
        # Handle edge type strings explicitly if they appear in Anomaly Tag column
        if tag.startswith("('"):
            return 'Edge' # Classify edge types simply as 'Edge' for grouping
        if '/' in tag:
            category = tag.split('/')[0]
            # Standardize capitalization
            if category.lower() == 'structural': return 'Structural'
            if category.lower() == 'attribute': return 'Attribute'
            return category # Keep original if not standard
        elif tag.lower() == 'combined':
            return 'Combined'
        elif tag.lower() == 'unknown':
             return 'Unknown'
        else:
             # Basic keyword check for tags without '/'
             if 'attribute' in tag.lower(): return 'Attribute'
             if 'structural' in tag.lower(): return 'Structural'
             if 'combined' in tag.lower(): return 'Combined'
             return 'Other' # Fallback category
    return 'Unknown' # Default for non-string types or unparseable strings


def analyze_anomaly_types(
    anomaly_type_df: pd.DataFrame,
    split_name: str = 'test',
    sort_metric: str = 'AP',
    metrics_to_analyze: List[str] = ['AUROC', 'AP', 'Best F1'], # Focus on performance metrics
    plot_metric_comparison: str = 'AP', # Metric for the first bar plot
    plot_score_distributions: bool = True, # Flag to generate score plot
    plot_figsize_comparison: Tuple[int, int] = (10, 6),
    plot_figsize_distribution: Tuple[int, int] = (12, 6),
    save_dir: Optional[str] = None
    ) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame]]:
    """
    Analyzes performance breakdown by anomaly type for a specific data split,
    generating essential tables and plots for reporting.

    Args:
        anomaly_type_df (pd.DataFrame): DataFrame with per-tag metrics vs normals
                                        (expected columns: 'Split', 'Node Type', 'Node Type',
                                        'Anomaly Tag', 'Count', 'Mean Score', 'Median Score', 'AUROC', 'AP', 'Best F1').
        all_scores (Dict): The scores dictionary output by evaluate_model_inductively,
                           structured as {'split': {'nodes': {type: scores}, 'edges': {type: scores}}}.
                           Needed for plotting score distributions.
        split_name (str): The split to analyze ('train', 'val', or 'test').
        sort_metric (str): Metric used to sort the detailed tag results ('AP', 'AUROC', 'Best F1').
        metrics_to_analyze (List[str]): List of performance metric columns for summaries.
        plot_metric_comparison (str): The metric to visualize in the category comparison bar plot.
        plot_score_distributions (bool): Whether to generate the score distribution violin plot.
        plot_figsize_comparison: Figure size for the metric comparison plot.
        plot_figsize_distribution: Figure size for the score distribution plot.
        save_dir (Optional[str]): Directory to save the output tables and plots.

    Returns:
        Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame]]:
            - df_category_summary: Aggregated metrics by Main Category and Node Type.
            - df_detailed_sorted: Detailed metrics per Anomaly Tag, sorted.
            (Returns None, None if analysis fails)
    """
    print(f"\n--- Analyzing Anomaly Type Performance for Split: '{split_name}' ---")

    if anomaly_type_df is None or anomaly_type_df.empty:
        print("Input DataFrame `anomaly_type_df` is empty. Cannot perform analysis.")
        return None, None

    # --- 1. Preprocessing ---
    df_split = anomaly_type_df[anomaly_type_df['Split'] == split_name].copy()
    if df_split.empty:
        print(f"No data found for split '{split_name}' in `anomaly_type_df`.")
        return None, None

    # Check required metric columns
    required_metrics = metrics_to_analyze + ([plot_metric_comparison] if plot_metric_comparison not in metrics_to_analyze else [])
    required_cols_base = ['Node Type', 'Anomaly Tag', 'Count']
    missing_cols = [col for col in required_cols_base + required_metrics if col not in df_split.columns]
    if missing_cols:
        print(f"Error: Missing required columns in DataFrame: {missing_cols}")
        return None, None

    # Add Main Category
    df_split['Main Category'] = df_split['Anomaly Tag'].apply(extract_main_category)

    # Calculate Total Anomalies per Node Type for proportion calculation
    # Ensure 'Node Type' exists before grouping
    if 'Node Type' not in df_split.columns:
         print("Error: 'Node Type' column missing, cannot calculate proportions.")
         return None, None
    total_anomalies_per_element = df_split.groupby('Node Type')['Count'].sum()

    def get_proportion(row):
        total = total_anomalies_per_element.get(row['Node Type'], 0)
        return (row['Count'] / total * 100) if total > 0 else 0
    df_split['Proportion (%)'] = df_split.apply(get_proportion, axis=1)

    # --- 2. Table 1: Performance Summary by Main Category & Node Type ---
    print("\n--- Table 1: Performance Summary by Main Anomaly Category & Node Type ---")
    df_category_summary = None
    try:
        category_group = df_split.groupby(['Main Category', 'Node Type'])
        df_category_summary = category_group[metrics_to_analyze].mean() # Calculate mean of specified metrics

        # Add counts and proportions
        df_category_summary['Total Anomalies'] = category_group['Count'].sum()
        df_category_summary['Category Proportion (%)'] = category_group['Proportion (%)'].sum()

        # Reorder and round
        cols_order = ['Total Anomalies', 'Category Proportion (%)'] + metrics_to_analyze
        df_category_summary = df_category_summary.reindex(columns=cols_order, fill_value=np.nan).round(3)

        print(df_category_summary.to_string())
    except Exception as e:
        print(f"Error creating category summary table: {e}")


    # --- 3. Table 2: Detailed Performance by Specific Anomaly Tag (Sorted) ---
    print(f"\n--- Table 2: Detailed Performance by Specific Anomaly Tag (Sorted by {sort_metric}) ---")
    df_detailed_sorted = None
    try:
        # Include score stats in detailed table if present
        score_stats = [col for col in ['Mean Score', 'Median Score'] if col in df_split.columns]
        cols_detailed = ['Node Type', 'Main Category', 'Anomaly Tag', 'Count', 'Proportion (%)'] + metrics_to_analyze + score_stats
        cols_detailed = [col for col in cols_detailed if col in df_split.columns] # Ensure all selected columns exist

        df_detailed_sorted = df_split.sort_values(
            by=['Node Type', 'Main Category', sort_metric],
            ascending=[True, True, False] # Sort metric descending
        )[cols_detailed].round(3)

        print(df_detailed_sorted.to_string(index=False, max_rows=50)) # Print more rows
    except KeyError:
         print(f"Error: Sort metric '{sort_metric}' not found in DataFrame columns.")
    except Exception as e:
        print(f"Error creating detailed sorted table: {e}")


    # --- 4. Plot 1: Metric Comparison by Main Category ---
    print(f"\n--- Plot 1: Comparison of Average {plot_metric_comparison} by Main Category ---")
    if df_category_summary is not None and plot_metric_comparison in df_category_summary.columns:
        try:
            plot_data_comp = df_category_summary.reset_index()
            plt.figure(figsize=plot_figsize_comparison)

            ax = sns.barplot(
                data=plot_data_comp,
                x='Main Category',
                y=plot_metric_comparison,
                hue='Node Type', # Group by node/edge type
                palette='viridis',
                edgecolor='grey',
                linewidth=0.75
            )
            plt.title(f'Average {plot_metric_comparison} by Main Anomaly Category ({split_name.capitalize()} Set)')
            plt.xlabel('Main Anomaly Category')
            plt.ylabel(f'Average {plot_metric_comparison}')
            plt.xticks(rotation=0)
            plt.legend(title='Node Type', bbox_to_anchor=(1.02, 1), loc='upper left')
            ax.grid(True, axis='y', linestyle=':', alpha=0.7)
            # Add value labels on bars
            for container in ax.containers:
                 ax.bar_label(container, fmt='%.2f', label_type='edge', padding=2, fontsize=9)

            plt.ylim(bottom=0) # Start y-axis at 0 for bar plots
            plt.tight_layout(rect=[0, 0, 0.88, 0.97]) # Adjust right margin for legend

            # Save plot if directory specified
            if save_dir:
                plot_path_comp = os.path.join(save_dir, f'plot_category_comparison_{plot_metric_comparison}_{split_name}.png')
                try:
                    plt.savefig(plot_path_comp, dpi=300, bbox_inches='tight')
                    print(f"Comparison plot saved to {plot_path_comp}")
                except Exception as e: print(f"Error saving comparison plot: {e}")
            plt.show()

        except Exception as e: print(f"Error generating comparison plot: {e}")
    else: print(f"Cannot generate comparison plot: Summary data missing or plot metric '{plot_metric_comparison}' not found.")


    # --- 5. Plot 2: Score Distribution by Main Category ---
    print("\n--- Plot 2: Anomaly Score Distributions by Main Category ---")
    if plot_score_distributions:
        # Check if Mean Score and Count columns exist for the workaround
        if 'Mean Score' not in df_split.columns or 'Count' not in df_split.columns:
            print("Cannot generate score distribution plot: 'Mean Score' or 'Count' column missing in DataFrame.")
        else:
            try:
                dist_data = []
                # Corrected iteration: Removed 'index' column which likely doesn't exist
                # We only need columns already present in df_split for the workaround
                required_cols_for_dist = ['Node Type', 'Main Category', 'Anomaly Tag', 'Mean Score', 'Count']
                if not all(col in df_split.columns for col in required_cols_for_dist):
                    print(f"Cannot generate score distribution plot: Missing one of {required_cols_for_dist}")

                else:
                    # Iterate directly over the necessary columns
                    for element_type, main_category, tag, mean_score, count in df_split[required_cols_for_dist].itertuples(index=False, name=None):
                        # WORKAROUND: Add multiple points based on mean score and count
                        # This approximates the distribution for visualization purposes
                        # Note: Using median_score might be another option if available
                        if not pd.isna(mean_score) and not pd.isna(count) and count > 0:
                            # Add points based on count, using the mean score as the value
                            dist_data.extend([{'Node Type': element_type, 'Main Category': main_category, 'Score': mean_score}] * int(count))

                    if dist_data:
                        df_dist = pd.DataFrame(dist_data)
                        plt.figure(figsize=plot_figsize_distribution)
                        ax = sns.violinplot(
                            data=df_dist,
                            x='Main Category',
                            y='Score',
                            hue='Node Type',
                            palette='viridis',
                            cut=0,
                            inner='quartile',
                            linewidth=1.0, # Slightly thinner lines for violin
                            split=False,
                            scale='width' # Scale violins to have approx same width for comparison
                        )
                        plt.title(f'Approximated Score Distributions by Main Category ({split_name.capitalize()} Set)')
                        plt.xlabel('Main Anomaly Category')
                        plt.ylabel('Anomaly Score (Approximated: based on Mean Score per Tag)')
                        plt.xticks(rotation=0)
                        plt.legend(title='Node Type', bbox_to_anchor=(1.02, 1), loc='upper left')
                        ax.grid(True, axis='y', linestyle=':', alpha=0.7)
                        plt.tight_layout(rect=[0, 0, 0.88, 0.97])

                        # Save plot
                        if save_dir:
                            plot_path_dist = os.path.join(save_dir, f'plot_category_score_distribution_{split_name}.png')
                            try:
                                plt.savefig(plot_path_dist, dpi=300, bbox_inches='tight')
                                print(f"Score distribution plot saved to {plot_path_dist}")
                            except Exception as e: print(f"Error saving score distribution plot: {e}")
                        plt.show()

                    else:
                        print("Could not generate score distribution plot: No valid data points extracted.")

            except Exception as e:
                print(f"Error generating score distribution plot: {e}")
                # import traceback # Uncomment for detailed debugging
                # traceback.print_exc()

    else:
        print("Skipping score distribution plot as per request.")



    # --- Save DataFrames ---
    if save_dir:
        if df_category_summary is not None:
            cat_path = os.path.join(save_dir, f'summary_category_{split_name}.csv')
            try: df_category_summary.to_csv(cat_path)
            except Exception as e: print(f"Error saving category summary: {e}")
        if df_detailed_sorted is not None:
             det_path = os.path.join(save_dir, f'summary_detailed_tags_{split_name}.csv')
             try: df_detailed_sorted.to_csv(det_path, index=False)
             except Exception as e: print(f"Error saving detailed summary: {e}")
        if df_category_summary is not None or df_detailed_sorted is not None:
             print(f"Summary tables saved to {save_dir}")

    return df_category_summary, df_detailed_sorted



category_summary, detailed_summary = analyze_anomaly_types(
    anomaly_type_df=anomaly_type_df,
    #all_scores=all_scores, # Pass the raw scores here
    split_name='train',
    sort_metric='AP',
    metrics_to_analyze=['AUROC', 'AP', 'Best F1'], # Focus on performance metrics
    plot_metric_comparison='AP',
    plot_score_distributions=True,
    #save_dir='analysis_results_v2'
)


### 4.2 Baseline Models

#### 4.2.1 Feature-based baseline models

In [None]:
# Add necessary imports (should already be there from previous step)
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import IsolationForest
from sklearn.svm import OneClassSVM
from sklearn.metrics import roc_auc_score, average_precision_score
from src.utils.baseline_utils import *


train_graph = train_g.clone()
val_graph = val_g.clone()
test_graph = test_g.clone()

# --- 1. Augment Features for Each Split  ---
print("Augmenting features for train, val, and test splits...")
augmented_features_train = augment_features_for_sklearn(train_graph)
augmented_features_val = augment_features_for_sklearn(val_graph)
augmented_features_test = augment_features_for_sklearn(test_graph)
print("Feature augmentation complete.")

# --- Initialize results storage ---
results_baselines_separate = {} # Stores scores per model, split, type
eval_summary_separate = []      # Stores evaluation metrics
k_list_baselines = [50, 100, 200] # Or define elsewhere

# --- 2. Loop through Node Types for Separate Training and Evaluation ---
for node_type in ['provider', 'member']:
    print(f"\n--- Processing Node Type: {node_type} ---")

    # --- 2a. Extract Features and Labels for this Node Type ---
    print(f"  Extracting {node_type} features and labels...")
    try:
        X_train_nt = augmented_features_train[node_type]
        y_train_nt = gt_node_labels_train[node_type].cpu().numpy()

        X_val_nt = augmented_features_val[node_type]
        y_val_nt = gt_node_labels_val[node_type].cpu().numpy()

        X_test_nt = augmented_features_test[node_type]
        y_test_nt = gt_node_labels_test[node_type].cpu().numpy()

        # Basic validation
        if not (X_train_nt.shape[0] == y_train_nt.shape[0] and
                X_val_nt.shape[0] == y_val_nt.shape[0] and
                X_test_nt.shape[0] == y_test_nt.shape[0]):
            raise ValueError("Feature/label shape mismatch for a split.")

        if X_train_nt.shape[0] == 0:
             print(f"  Skipping {node_type}: No training data found.")
             continue

    except (KeyError, ValueError) as e:
        print(f"  Skipping {node_type}: Error extracting data - {e}")
        continue

    print(f"    Train shape: {X_train_nt.shape}")
    print(f"    Val shape: {X_val_nt.shape}")
    print(f"    Test shape: {X_test_nt.shape}")

    # --- 2b. Preprocess (Scale) Features ---
    print(f"  Scaling {node_type} features...")
    scaler_nt = StandardScaler()
    #X_train_scaled_nt = scaler_nt.fit_transform(X_train_nt)
    # Handle cases where val/test might be empty for this node type
    #X_val_scaled_nt = scaler_nt.transform(X_val_nt) if X_val_nt.shape[0] > 0 else X_val_nt
    # X_test_scaled_nt = scaler_nt.transform(X_test_nt) if X_test_nt.shape[0] > 0 else X_test_nt
    X_train_scaled_nt = X_train_nt.copy()
    X_val_scaled_nt = X_val_nt.copy()
    X_test_scaled_nt = X_test_nt.copy()
    print("    Scaling complete.")

    # --- 2c. Define and Train Baseline Models ---
    print(f"  Training baseline models for {node_type}...")
    contamination_est_nt = y_train_nt.mean()
    print(f"    Estimated contamination ({node_type}): {contamination_est_nt:.4f}")

    models_nt = {}
    # Isolation Forest
    iforest_nt = IsolationForest(contamination=contamination_est_nt, random_state=42)
    iforest_nt.fit(X_train_scaled_nt)
    models_nt['IsolationForest'] = iforest_nt
    print("    Isolation Forest trained.")

    # One-Class SVM
    ocsvm_nt = OneClassSVM(nu=max(0.01, min(0.99, contamination_est_nt)), kernel='rbf', gamma='auto') # Ensure nu in (0, 1)
    ocsvm_nt.fit(X_train_scaled_nt)
    models_nt['OneClassSVM'] = ocsvm_nt
    print("    One-Class SVM trained.")
    print(f"    Baseline training for {node_type} complete.")

    # --- 2d. Predict/Score Anomalies ---
    print(f"  Generating anomaly scores for {node_type}...")
    if node_type not in results_baselines_separate:
        results_baselines_separate[node_type] = {}

    for model_name, model in models_nt.items():
        if model_name not in results_baselines_separate[node_type]:
            results_baselines_separate[node_type][model_name] = {}

        scores_train_nt = model.decision_function(X_train_scaled_nt)
        # Handle potentially empty val/test splits
        scores_val_nt = model.decision_function(X_val_scaled_nt) if X_val_scaled_nt.shape[0] > 0 else np.array([])
        scores_test_nt = model.decision_function(X_test_scaled_nt) if X_test_scaled_nt.shape[0] > 0 else np.array([])

        # Negate scores
        scores_train_nt = -scores_train_nt
        scores_val_nt = -scores_val_nt
        scores_test_nt = -scores_test_nt

        results_baselines_separate[node_type][model_name]['train'] = {'scores': scores_train_nt, 'labels': y_train_nt}
        results_baselines_separate[node_type][model_name]['val'] = {'scores': scores_val_nt, 'labels': y_val_nt}
        results_baselines_separate[node_type][model_name]['test'] = {'scores': scores_test_nt, 'labels': y_test_nt}
    print(f"    Scoring for {node_type} complete.")

    # --- 2e. Evaluate ---
    print(f"  Evaluating baseline models for {node_type}...")
    for model_name, split_results in results_baselines_separate[node_type].items():
        print(f"    Model: {model_name}")
        for split_name, data in split_results.items():
            scores = data['scores']
            labels = data['labels']

            if len(scores) == 0:
                print(f"      Split: {split_name} - No data to evaluate.")
                continue

            print(f"      Split: {split_name} - Items: {len(scores)}, Anomalies: {int(np.sum(labels))}")
            metrics = compute_evaluation_metrics(scores, labels, k_list=k_list_baselines)

            summary_row = {
                'Model': model_name,
                'Split': split_name,
                'Element': f'Node ({node_type})', # Store node type info
                'Num Items': len(scores),
                'Num Anomalies': int(np.sum(labels)),
                '% Anomalies': (np.sum(labels) / len(scores) * 100) if len(scores) > 0 else 0
            }
            summary_row.update(metrics)
            eval_summary_separate.append(summary_row)

            # Print key metrics
            print(f"        AUROC: {metrics.get('AUROC', 0.0):.4f}, AP: {metrics.get('AP', 0.0):.4f}, Best F1: {metrics.get('Best F1', 0.0):.4f}")

# --- 3. Display Combined Results ---
all_model_summaries = defaultdict({})
baseline_summary_df_separate = pd.DataFrame(eval_summary_separate)
# Reorder columns for clarity
ordered_cols = ['Model', 'Split', 'Element', 'Num Items', 'Num Anomalies', '% Anomalies',
               'AUROC', 'AP', 'Best F1', 'Best F1 Threshold'] + \
               [f'{p}@{k}' for k in k_list_baselines for p in ['Precision', 'Recall']]
# Ensure only existing columns are used for reindexing
existing_cols_ordered = [col for col in ordered_cols if col in baseline_summary_df_separate.columns]
baseline_summary_df_separate = baseline_summary_df_separate.reindex(columns=existing_cols_ordered, fill_value=np.nan)


print("\n--- Baseline Evaluation Summary (Separate Models) ---")
# Increase display options for pandas DataFrame
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 50)
pd.set_option('display.width', 1000)
print(baseline_summary_df_separate.to_string())
print("Baseline evaluation complete.")

# Reset display options if desired
# pd.reset_option('display.max_rows')
# pd.reset_option('display.max_columns')
# pd.reset_option('display.width')


# Store the final summary DataFrame
all_model_summaries['FeatureBasedSeparate'] = baseline_summary_df_separate

In [None]:
# --- 3. Display Combined Results ---
all_model_summaries = defaultdict()
baseline_summary_df_separate = pd.DataFrame(eval_summary_separate)
# Reorder columns for clarity
ordered_cols = ['Model', 'Split', 'Element', 'Num Items', 'Num Anomalies', '% Anomalies',
               'AUROC', 'AP', 'Best F1', 'Best F1 Threshold'] + \
               [f'{p}@{k}' for k in k_list_baselines for p in ['Precision', 'Recall']]
# Ensure only existing columns are used for reindexing
existing_cols_ordered = [col for col in ordered_cols if col in baseline_summary_df_separate.columns]
baseline_summary_df_separate = baseline_summary_df_separate.reindex(columns=existing_cols_ordered, fill_value=np.nan)


print("\n--- Baseline Evaluation Summary (Separate Models) ---")
# Increase display options for pandas DataFrame
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 50)
pd.set_option('display.width', 1000)
print(baseline_summary_df_separate.to_string())
print("Baseline evaluation complete.")

# Reset display options if desired
# pd.reset_option('display.max_rows')
# pd.reset_option('display.max_columns')
# pd.reset_option('display.width')


# Store the final summary DataFrame
all_model_summaries['FeatureBasedSeparate'] = baseline_summary_df_separate

#### 4.2.2 Structure-based Baseline models (GAE and OddBall)

##### ODDBALL

In [208]:
from src.external_repos.OddBall.egonet_extractor import EgonetFeatureExtractor
from src.external_repos.OddBall.utils import select_group_nodes, get_node_property_list
from src.external_repos.OddBall.anomaly_detection import (
        StarCliqueAnomalyDetection, HeavyVicinityAnomalyDetection, DominantEdgeAnomalyDetection,
        StarCliqueLOFAnomalyDetection, HeavyVicinityLOFAnomalyDetection, DominantEdgeLOFAnomalyDetection
    )
from src.external_repos.OddBall.oddball_runner import *


In [None]:
# --- OddBall Baseline Evaluation ---
print("\n--- Starting OddBall Baseline Evaluation ---")

# Define OddBall parameters
oddball_anomaly_type = 'sc' # Options: 'sc', 'hv', 'de'
oddball_use_lof = False     # Keep False for simplicity first

oddball_results_list = []
k_list_baselines = [50, 100, 200] # Use same K as other baselines

graphs_dict = {'train': train_g, 'val': val_g, 'test': test_g} # Assuming these are your splits

for split_name, graph_split in graphs_dict.items():
    print(f"\n===== Processing Split: {split_name} =====")
    gt_labels_split = GT_NODE_LABELS[split_name] # Get the GT labels for this split

    for node_type in ['provider', 'member']:
         # Check if node type exists in this split
         if node_type not in graph_split.node_types or graph_split[node_type].num_nodes == 0:
             print(f"Skipping Oddball for {node_type} in {split_name} split (no nodes).")
             continue

         # Run and evaluate OddBall for this node type
         metrics = evaluate_oddball_inductive(
             graph_split=graph_split,
             gt_node_labels_split=gt_labels_split,
             anomaly_type=oddball_anomaly_type,
             node_type_to_eval=node_type,
             use_lof=oddball_use_lof,
             k_list=k_list_baselines
         )

         if metrics: # Check if evaluation was successful
             num_items = graph_split[node_type].num_nodes
             num_anomalies = int(gt_labels_split.get(node_type, torch.tensor([])).sum().item())
             perc = (num_anomalies / num_items * 100) if num_items > 0 else 0
             summary_row = {
                 'Model': f'OddBall ({oddball_anomaly_type}, LOF={oddball_use_lof})',
                 'Split': split_name,
                 'Element': f'Node ({node_type})',
                 'Num Items': num_items,
                 'Num Anomalies': num_anomalies,
                 '% Anomalies': perc
             }
             summary_row.update(metrics)
             oddball_results_list.append(summary_row)
         else:
              print(f"OddBall evaluation failed for {node_type} in {split_name} split.")


# --- Display OddBall Results ---
if oddball_results_list:
    oddball_summary_df = pd.DataFrame(oddball_results_list)
    # Reorder columns for clarity (similar to other baselines)
    ordered_cols = ['Model', 'Split', 'Element', 'Num Items', 'Num Anomalies', '% Anomalies',
                   'AUROC', 'AP', 'Best F1', 'Best F1 Threshold'] + \
                   [f'{p}@{k}' for k in k_list_baselines for p in ['Precision', 'Recall']]
    existing_cols_ordered = [col for col in ordered_cols if col in oddball_summary_df.columns]
    oddball_summary_df = oddball_summary_df.reindex(columns=existing_cols_ordered, fill_value=np.nan)

    print("\n--- OddBall Evaluation Summary ---")
    # Increase display options for pandas DataFrame
    pd.set_option('display.max_rows', 500)
    pd.set_option('display.max_columns', 50)
    pd.set_option('display.width', 1000)
    print(oddball_summary_df.to_string())
else:
    print("\n--- No OddBall results were generated. ---")

all_model_summaries['StructureBased'] = oddball_summary_df
combined_summary = pd.concat([baseline_summary_df_separate, oddball_summary_df], ignore_index=True)

#### 4.2.3 Hybrid methods (DOMINANT)

In [None]:
import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected
from pygod.detector import DOMINANT, GAE
from src.utils.eval_utils import compute_evaluation_metrics # Ensure this is imported
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm # For progress bars
import time

# --- Configuration ---
k_list_baselines = [50, 100, 200, 500]
# <<< Using CPU for consistency and avoiding potential MPS issues with baselines >>>
baseline_device = torch.device("cpu")
print(f"Using device for PyGOD graph baselines: {baseline_device}")
target_edge_type = ('provider', 'to', 'member')

# Store results
graph_baseline_results = {} # Renamed for clarity
all_baseline_scores = {} # Stores raw scores (can reuse)

# --- Helper Function to Prepare Homogeneous Data (Keep as before) ---
def prepare_homogeneous_data(hetero_graph: HeteroData, target_device: torch.device, make_undirected: bool = True) -> Optional[Data]:
    """Converts HeteroData to a homogeneous Data object on the target device."""
    try:
        hetero_graph_cpu = hetero_graph.cpu()
        homo_data = hetero_graph_cpu.to_homogeneous(node_attrs=['x'])

        if not hasattr(homo_data, 'edge_index') or homo_data.edge_index is None:
            print("Warning: No 'edge_index' found after to_homogeneous. Creating empty edge_index.")
            homo_data.edge_index = torch.empty((2, 0), dtype=torch.long)

        if make_undirected and homo_data.edge_index.shape[1] > 0:
             # Ensure edge indices are valid before making undirected
             if homo_data.num_nodes is None:
                 if homo_data.edge_index.numel() > 0:
                      homo_data.num_nodes = int(homo_data.edge_index.max().item()) + 1
                 else:
                      homo_data.num_nodes = homo_data.x.shape[0] if hasattr(homo_data, 'x') and homo_data.x is not None else 0

             if homo_data.num_nodes is not None and homo_data.edge_index.max().item() < homo_data.num_nodes:
                 homo_data.edge_index = to_undirected(homo_data.edge_index, num_nodes=homo_data.num_nodes)
             else:
                 print(f"Warning: Cannot make graph undirected. Max edge index {homo_data.edge_index.max().item() if homo_data.edge_index.numel() > 0 else 'N/A'} >= num_nodes {homo_data.num_nodes}. Keeping directed.")


        if not hasattr(homo_data, 'num_nodes') or homo_data.num_nodes is None:
             if hasattr(homo_data, 'x') and homo_data.x is not None:
                 homo_data.num_nodes = homo_data.x.shape[0]
             elif homo_data.edge_index.numel() > 0:
                 homo_data.num_nodes = int(homo_data.edge_index.max().item()) + 1
             else:
                 homo_data.num_nodes = 0
                 print(f"Warning: Manually set num_nodes to 0.")

        return homo_data.to(target_device)
    except Exception as e:
        print(f"Error during to_homogeneous conversion: {e}")
        return None

# --- Prepare Homogeneous Data Splits (Keep as before) ---
print("Preparing homogeneous data splits for graph baselines...")
train_g_homo = prepare_homogeneous_data(train_g, target_device=baseline_device, make_undirected=True)
val_g_homo = prepare_homogeneous_data(val_g, target_device=baseline_device, make_undirected=True)
test_g_homo = prepare_homogeneous_data(test_g, target_device=baseline_device, make_undirected=True)

if train_g_homo is not None and val_g_homo is not None and test_g_homo is not None:
    print("Homogeneous data prepared.")
else:
    print("Failed to prepare homogeneous data. Skipping graph baselines.")
    train_g_homo, val_g_homo, test_g_homo = None, None, None

def map_homo_results(homo_data: Data, homo_scores: np.ndarray, original_hetero_graph: HeteroData, gt_node_labels_split: Dict[str, torch.Tensor]) -> Dict[str, Dict]:
    """Maps scores from homogeneous graph back and aligns with original GT labels."""
    results = {}
    if not hasattr(homo_data, 'node_type'):
        print("Error: Homogeneous data is missing 'node_type'. Cannot map results.")
        return {}
    node_type_tensor = homo_data.node_type.cpu()

    current_gt_labels = {}
    start_idx = 0
    for i, node_type in enumerate(original_hetero_graph.node_types):
        num_nodes_in_split = (node_type_tensor == i).sum().item()
        if node_type in gt_node_labels_split:
            if len(gt_node_labels_split[node_type]) != num_nodes_in_split:
                print(f"Warning: Label mapping mismatch for {node_type}. Expected {num_nodes_in_split} labels in split GT, found {len(gt_node_labels_split[node_type])}.")
                current_gt_labels[node_type] = gt_node_labels_split[node_type][:num_nodes_in_split].cpu().numpy()
            else:
                current_gt_labels[node_type] = gt_node_labels_split[node_type].cpu().numpy()
        else:
            current_gt_labels[node_type] = np.zeros(num_nodes_in_split)


    for i, node_type in enumerate(original_hetero_graph.node_types): # Iterate in original order
        mask = (node_type_tensor == i)
        num_nodes_homo_type = mask.sum().item()

        if num_nodes_homo_type == 0: continue

        type_scores = homo_scores[mask]
        type_labels = current_gt_labels.get(node_type, np.zeros(num_nodes_homo_type)) # Get aligned labels

        if len(type_scores) != len(type_labels):
             print(f"CRITICAL WARNING: Final score/label length mismatch for {node_type} after mapping. Scores: {len(type_scores)}, Labels: {len(type_labels)}. Skipping evaluation for this type.")
             continue

        results[node_type] = {'scores': type_scores, 'labels': type_labels}
    return results

# --- Define Baseline Models ---
pygod_baselines = {}
if train_g_homo is not None:
    # Estimate contamination (use same estimate for both)
    total_train_nodes = train_g_homo.num_nodes
    total_train_anomalies = sum(GT_NODE_LABELS['train'][nt].sum().item() for nt in GT_NODE_LABELS['train'] if nt in train_g.node_types)
    contamination_est = total_train_anomalies / total_train_nodes if total_train_nodes > 0 else 0.1
    contamination_est = max(0.001, min(0.5, contamination_est)) # Ensure in valid range (0, 0.5]

    # 1. DOMINANT 
    # DOMINANT uses features, check if they exist
    if hasattr(train_g_homo, 'x') and train_g_homo.x is not None:
         print(f"Defining DOMINANT with contamination={contamination_est:.4f}")
         pygod_baselines['DOMINANT'] = DOMINANT(
             epoch=50, # Adjust epochs as needed
             contamination=contamination_est,
             gpu=-1, # Force CPU
             verbose=0
         )
    else:
        print("Skipping DOMINANT baseline definition: Missing 'x' features in homogeneous graph.")

    # 2. GAE (focused on structure) - Keep this, but requires dependency install
    if hasattr(train_g_homo, 'x') and train_g_homo.x is not None:
        print(f"Defining GAE_Structure with contamination={contamination_est:.4f}")
        pygod_baselines['GAE_Structure'] = GAE(
            epoch=50,
            contamination=contamination_est,
            recon_s=True, # Focus on structural reconstruction
            gpu=-1, # Force CPU
            verbose=0
        )
    else:
         print("Skipping GAE_Structure baseline definition: Missing 'x' features.")


# --- Run and Evaluate PyGOD Baselines (Keep the loop mostly the same) ---
if train_g_homo and pygod_baselines:
    # ... (The loop structure from the previous answer can be kept) ...
    # ... (Just ensure the models being iterated are now DOMINANT and GAE_Structure) ...
    for model_name, model in tqdm(pygod_baselines.items(), desc="Running PyGOD Graph Baselines"): # Updated progress bar desc
        print(f"\n--- Processing Baseline: {model_name} ---")
        model_scores = {'train': {}, 'val': {}, 'test': {}}

        try:
            # --- Fit on Training Data ---
            print(f"Fitting {model_name} on training data...")
            fit_data = train_g_homo # Already on CPU

            if fit_data is None or fit_data.num_nodes == 0:
                print(f"Skipping fit for {model_name}: Invalid or empty training data.")
                continue

            # GAE needs 'x', DOMINANT needs 'x' and 'edge_index'
            if not hasattr(fit_data, 'x') or fit_data.x is None:
                 print(f"Skipping fit for {model_name}: Missing 'x' features required by model.")
                 continue
            if not hasattr(fit_data, 'edge_index') or fit_data.edge_index is None or fit_data.edge_index.numel() == 0:
                 # DOMINANT/GAE might behave unexpectedly without edges
                 print(f"Warning: Fitting {model_name} on graph with no edges.")


            start_fit = time.time()
            model.fit(fit_data)
            end_fit = time.time()
            print(f"Fit completed in {end_fit - start_fit:.2f}s.")

            # --- Get Scores (Decision Function) ---
            print(f"Calculating scores for {model_name}...")
            for split_name, hetero_graph, homo_graph in [('train', train_g, train_g_homo),
                                                        ('val', val_g, val_g_homo),
                                                        ('test', test_g, test_g_homo)]:
                if homo_graph is None:
                    print(f"  Skipping scoring for {split_name}: Homogeneous graph is None.")
                    continue

                try:
                    eval_data = homo_graph # Already on CPU

                    if eval_data.num_nodes == 0:
                         print(f"  Skipping scoring for {split_name}: No nodes in homogeneous graph.")
                         homo_scores_split = np.array([])
                    # Check for necessary inputs for the specific model
                    elif (not hasattr(eval_data, 'x') or eval_data.x is None):
                         print(f"  Skipping scoring for {split_name} with {model_name}: Missing 'x' features.")
                         homo_scores_split = np.zeros(eval_data.num_nodes) # Placeholder
                    elif (not hasattr(eval_data, 'edge_index') or eval_data.edge_index is None or eval_data.edge_index.numel()==0):
                         print(f"  Warning: Scoring {split_name} with {model_name} on graph with no edges.")
                         # Models might still produce scores based on features alone
                         homo_scores_split = model.decision_function(eval_data)
                    else:
                         homo_scores_split = model.decision_function(eval_data)


                    gt_labels_split = GT_NODE_LABELS.get(split_name, {})
                    mapped_scores = map_homo_results(homo_graph, homo_scores_split, hetero_graph, gt_labels_split)
                    model_scores[split_name] = mapped_scores

                except Exception as e_score:
                    print(f"Error scoring {split_name} data with {model_name}: {e_score}")
                    model_scores[split_name] = {}

            all_baseline_scores[model_name] = model_scores

            # --- Evaluate Mapped Scores ---
            # In the Baseline Evaluation Loop:

            # --- Evaluate Mapped Scores ---
            print(f"Evaluating {model_name}...")
            for split_name, split_scores_mapped in model_scores.items():
                print(f"  Split: {split_name}")
                if not split_scores_mapped:
                    print(f"    Skipping evaluation for {split_name}: No mapped scores found.")
                    continue
                for node_type, type_data in split_scores_mapped.items():
                    # <<< FIX: Ensure scores and labels are NumPy arrays >>>
                    scores_raw = type_data.get('scores')
                    labels_raw = type_data.get('labels')

                    # Check if they exist and convert if necessary
                    if scores_raw is not None and isinstance(scores_raw, torch.Tensor):
                        scores_np = scores_raw.cpu().numpy()
                    elif scores_raw is not None: # Assume it's already numpy or list-like
                        scores_np = np.asarray(scores_raw)
                    else:
                        scores_np = np.array([])

                    if labels_raw is not None and isinstance(labels_raw, torch.Tensor):
                        labels_np = labels_raw.cpu().numpy()
                    elif labels_raw is not None: # Assume it's already numpy or list-like
                        labels_np = np.asarray(labels_raw)
                    else:
                        labels_np = np.array([])

                    # Proceed with evaluation using NumPy arrays
                    if len(scores_np) > 0 and len(labels_np) > 0 and len(scores_np) == len(labels_np):
                        # Pass the guaranteed NumPy arrays to the metrics function
                        metrics = compute_evaluation_metrics(scores_np, labels_np, k_list=k_list_baselines)

                        # Store metrics (rest of the storing logic remains the same)
                        if model_name not in graph_baseline_results:
                            graph_baseline_results[model_name] = {}
                        if split_name not in graph_baseline_results[model_name]:
                            graph_baseline_results[model_name][split_name] = {}
                        graph_baseline_results[model_name][split_name][node_type] = metrics
                        print(f"    {node_type}: AUROC={metrics.get('AUROC', 0):.4f}, AP={metrics.get('AP', 0):.4f}, Best F1={metrics.get('Best F1', 0):.4f}")
                    elif len(scores_np) != len(labels_np):
                        print(f"    {node_type}: Skipping evaluation due to score/label length mismatch ({len(scores_np)} vs {len(labels_np)}).")
                    else:
                        print(f"    {node_type}: No scores/labels found for evaluation.")

        except ImportError as ie:
             if model_name == 'GAE_Structure' and ('torch_sparse' in str(ie) or 'NeighborSampler' in str(ie) or 'pyg-lib' in str(ie)):
                  print(f"\n****** MISSING DEPENDENCY for {model_name} ******")
                  print(f"Error: {ie}")
                  print(f"Please install 'torch-sparse' or 'pyg-lib'.")
                  print(f"  pip install torch-sparse  OR  pip install pyg-lib")
                  print(f"*******************************\n")
             else:
                  print(f"ImportError running {model_name}: {ie}")
        except AttributeError as ae:
             print(f"AttributeError running {model_name}: {ae}. ")
             print("  Ensure homogeneous graph has 'x' and 'edge_index'.")
        except Exception as e:
            print(f"Error running baseline {model_name}: {e}")


# --- Display Structural Baseline Results ---
print("\n--- Graph-Based Baseline Evaluation Summary ---") # Updated title
summary_graph_list = []
# ... (rest of the summary list creation and DataFrame display - use graph_baseline_results) ...
for model_name, splits in graph_baseline_results.items(): # Use renamed dict
    for split_name, types in splits.items():
        for node_type, metrics in types.items():
             num_items = len(GT_NODE_LABELS.get(split_name, {}).get(node_type, []))
             num_anomalies = int(GT_NODE_LABELS.get(split_name, {}).get(node_type, torch.tensor([])).sum().item())
             perc = (num_anomalies / num_items * 100) if num_items > 0 else 0
             row = {
                 'Model': model_name,
                 'Split': split_name,
                 'Element': f'Node ({node_type})',
                 'Num Items': num_items,
                 'Num Anomalies': num_anomalies,
                 '% Anomalies': perc
             }
             row.update(metrics)
             summary_graph_list.append(row)

graph_summary_df = pd.DataFrame(summary_graph_list) # Use new name
if not graph_summary_df.empty:
    ordered_cols_struct = ['Model', 'Split', 'Element', 'Num Items', 'Num Anomalies', '% Anomalies',
                           'AUROC', 'AP', 'Best F1', 'Best F1 Threshold'] + \
                          [f'{p}@{k}' for k in k_list_baselines for p in ['Precision', 'Recall']]
    existing_cols_struct = [col for col in ordered_cols_struct if col in graph_summary_df.columns]
    graph_summary_df = graph_summary_df.reindex(columns=existing_cols_struct, fill_value=np.nan)

print(graph_summary_df.to_string())
print("Graph-based baseline evaluation complete.")

# Store the summary
all_model_summaries['GraphBased'] = graph_summary_df

In [214]:
all_model_summaries['GraphAutoEncoderFramemwork'] = graph_summary_df
combined_summary = pd.concat([graph_summary_df, combined_summary], ignore_index=True)

## 5. Compare Results

In [241]:
model_name_map = {
    "Oddball": "Oddball (Akoglu, 2010)", # Placeholder reference/description if no specific paper
    "GAE_Structure": "GAE (Kipf & Welling, 2016)", # Ref: Variational Graph Auto-Encoders
    "DOMINANT": "DOMINANT (Ding et al., 2019)",    # Ref: Deep Anomaly Detection on Attributed Networks
    "OneClassSVM": "One-Class SVM (Schölkopf et al., 2001)", # Ref: Estimating the support of a high-dimensional distribution
    "IsolationForest": "Isolation Forest (Liu et al., 2008)"  # Ref: Isolation forest
}
conditions_rename = [combined_summary["Model"] == original_name for original_name in model_name_map.keys()]
choices_rename = list(model_name_map.values())


combined_summary["Model"] = np.select(
    conditions_rename,
    choices_rename,
    default=combined_summary["Model"] # Keep the original value if no condition matches
)

conditions_category = [
    combined_summary["Model"].isin([
        "Isolation Forest (Liu et al., 2008)",
        "One-Class SVM (Schölkopf et al., 2001)"
    ]),
    combined_summary["Model"] == "Oddball (Structural Context)",
    combined_summary["Model"].isin([
        "DOMINANT (Ding et al., 2019)",
        "GAE (Kipf & Welling, 2016)"
    ])
]

choices_category = [
    "feature-based baseline",
    "structure-based baseline",
    "graph_autoencoder-based baseline"
]

combined_summary["Category"] = np.select(
    conditions_category,
    choices_category,
    default="Other Category" # Assign a default category if none of the above match
)



In [252]:
models_to_keep = ["Isolation Forest (Liu et al., 2008)", "Oddball (Akoglu, 2010)", "DOMINANT (Ding et al., 2019)"]
combined_summary = combined_summary_old[combined_summary_old.Model.isin(models_to_keep)]

In [None]:
# Main model summary : 
summary_main_model = summary_df.copy()
summary_main_model["Model"] = "BGAE (Proposed model)"
summary_main_model["Category"] = "bipartite_graph_autoencoder-based main model"
summary_all = pd.concat([summary_main_model, combined_summary], axis=0)

print("Average AP for all models (nodes only) :")
display(
    summary_all.loc[((summary_all.Split == "test") & (~summary_all.Element.str.startswith("E")) )]  # Filter for test split
    .groupby(["Model", "Category"])                       # Group by BOTH Model and Category
    .agg({"AP": "mean"})                                  # Calculate mean AP for each group
    .rename(columns={"AP": "Avg AP"})                     # Rename the aggregated column
    .sort_values(by=["Model","Avg AP"])                              # Sort the results (by Model level of index)
    .reset_index()
    .head(10)                                             # Display the top 10 rows
)
print("Metrics for all models evaluated on the test set :")
columns_metrics = ["Model", "Element", "AUROC", "AP", "Best F1", 
                   'Precision@50', 'Recall@50', 'Precision@100', 'Recall@100', 'Precision@200', 'Recall@200']
display(summary_all.loc[summary_all.Split=="test"][columns_metrics].sort_values(by="Model"))





In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import io # For creating the sample dataframes
from IPython.display import display # For displaying styled table in notebooks


combined_summary = summary_all.copy()



model_category_map = {
    "BGAE (Proposed model)": "Our Model (BGAE)", # Renamed for clarity
    "DOMINANT (Ding et al., 2019)": "DOMINANT",
    "Isolation Forest (Liu et al., 2008)": "Isolation Forest",
    "Oddball (Akoglu, 2010)": "Oddball"
    # Mapping for categories themselves - will be used in plot legend
}
category_display_map = {
    "bipartite_graph_autoencoder-based main model": "Our Approach",
    "graph_autoencoder-based baseline": "Graph Autoencoder Baseline",
    "feature-based baseline": "Feature-based Baseline",
    "structure-based baseline": "Structure-based Baseline"
}

# Apply the model name mapping
combined_summary["Model"] = combined_summary["Model"].map(model_category_map).fillna(combined_summary["Model"])

# Add the Category column based on the NEW model names (if needed) or original logic
# Assuming the category mapping is based on the original model names or types:
# Re-apply or ensure category is correct based on the logic you used previously
# For this sample, I'll create a simplified category mapping based on the new names
simplified_category_map = {
    "Our Model (BGAE)": "Our Approach",
    "DOMINANT": "Graph Autoencoder Baseline",
    "Isolation Forest": "Feature-based Baseline",
    "Oddball": "Structure-based Baseline"
}
combined_summary["Category"] = combined_summary["Model"].map(simplified_category_map).fillna("Unknown Category")


# --- Configuration for plots ---
sns.set_theme(style="whitegrid") # Use a nice theme
plt.rcParams['figure.dpi'] = 300 # High resolution for academic figures
plt.rcParams['savefig.dpi'] = 300

# Define colors for models for consistency across plots
# Ensure enough colors for all unique models AFTER mapping
unique_models = combined_summary['Model'].unique()
model_colors = sns.color_palette('tab10', n_colors=len(unique_models))
model_color_map = dict(zip(unique_models, model_colors))

# Define colors for categories for Figure 1 legend
unique_categories = combined_summary['Category'].unique()
category_colors = sns.color_palette('viridis', n_colors=len(unique_categories))
category_color_map = dict(zip(unique_categories, category_colors))


# --- Filter out 'Edge' element early for plots ---
# This ensures plots only include Node anomalies
combined_summary_nodes_only = combined_summary.loc[combined_summary.Element.isin(["Node (member)", "Node (provider)"])].copy()


# --- 2. Create the Main Results Table ---
# This table shows key metrics per Model and Element Type on the test set.
# Select relevant columns, sort, and format.

# Filter for test split and select columns
main_table_data = combined_summary[combined_summary.Split == "test"][
    ["Model", "Category", "Element", "AUROC", "AP", "Best F1",
     "Precision@50", "Recall@50", "Precision@100", "Recall@100",
     "Precision@200", "Recall@200"]
].copy() # Use .copy() to avoid SettingWithCopyWarning

# Sort for consistent table order
main_table_data = main_table_data.sort_values(by=["Model", "Element"])

# Define formatting for decimal places
format_dict = {
    "AUROC": "{:.3f}",
    "AP": "{:.3f}",
    "Best F1": "{:.3f}",
    "Precision@50": "{:.2f}",
    "Recall@50": "{:.3f}",
    "Precision@100": "{:.2f}",
    "Recall@100": "{:.3f}",
    "Precision@200": "{:.2f}",
    "Recall@200": "{:.3f}",
}

# Apply formatting using Styler
styled_table = main_table_data.style.format(format_dict)


# --- 2. Generate LaTeX code from the styled table ---

latex_string = styled_table.to_latex(
    caption="Performance metrics of anomaly detection models on the test set, by element type.", # Your table caption
    label="tab:main_results", # LaTeX label for cross-referencing
    column_format="lll" + "c" * (main_table_data.shape[1] - 3), # lll for first 3 columns, c for the rest
    position="!htbp", # Standard position specifier
    hrules=True, # Use booktabs horizontal rules (\toprule, \midrule, \bottomrule)
    multicol_align="c", # Alignment for multicolumns (if any, usually 'c')
    environment="tabular", # Use 'tabular' environment (often wrapped in 'table')
    convert_css=False, # Don't convert CSS styles (handled by format())
    # index=False is the default for Styler.to_latex, but explicit is fine
)

# Print the LaTeX code
print("--- LaTeX Table Code ---")
print(latex_string)

# Display the styled table (in environments that support it, like Jupyter)
# In a report, you'd typically save this as a LaTeX table or similar.
print("--- Main Results Table (Test Set) ---")
# To print the raw data frame (less pretty formatting but works anywhere)
# print(main_table_data.to_string(index=False, formatters=format_dict))
# To display the styled version (better for notebooks/HTML)
display(styled_table)

# --- 3. Create Figure 1: Overall Average Node AP Comparison ---
# This figure shows the average AP for Node anomalies across all models.

# Use the combined_summary_nodes_only data
node_avg_ap = combined_summary_nodes_only[
    combined_summary_nodes_only.Split == "test"
].groupby(["Model", "Category"])["AP"].mean().reset_index()

# Sort by Average AP descending to easily see the best model
node_avg_ap = node_avg_ap.sort_values(by="AP", ascending=False)

# Get the sorted order of models for the x-axis
model_order_fig1 = node_avg_ap["Model"].tolist()

# Adjust figure size (smaller)
plt.figure(figsize=(7, 5))
ax = sns.barplot( # Get the axes object
    x="Model",
    y="AP",
    hue="Category",
    data=node_avg_ap,
    palette=category_color_map, # Use category colors for hue
    order=model_order_fig1 # Set the order based on sorted data
)

plt.title("Average Precision (AP) for Node Anomaly Detection on Test Set", fontsize=14)
plt.xlabel("Model", fontsize=12)
plt.ylabel("Average Precision (AP)", fontsize=12)
plt.xticks(rotation=45, ha="right")

# Adjust legend title and position
# Use handles and labels from the axes to ensure correctness with hue
handles, labels = ax.get_legend_handles_labels()
# Map original labels to friendlier display names
friendly_labels = [category_display_map.get(label, label) for label in labels]

ax.legend(handles, friendly_labels, title="Model Type", bbox_to_anchor=(1.05, 1), loc='upper left')


# Add AP values on top of bars (mean only)
# Iterate through bars to add labels
# bar_label might be tricky with hue - manual text is often more reliable here
# Or try bar_label on each container
for container in ax.containers:
    ax.bar_label(container, fmt='%.3f', label_type='center', padding=3)


plt.tight_layout()
plt.savefig("avg_node_ap_barplot.png")
plt.show()


# --- 4. Create Figure 2: Key Metric Comparison by Anomaly Type (AP and Best F1) ---
# Exclude AUROC and Edge

# Filter test data, nodes only, select AP and Best F1
metrics_to_plot_fig2 = ["AP", "Best F1"]
detailed_metrics = combined_summary_nodes_only[combined_summary_nodes_only.Split == "test"][
    ["Model", "Element"] + metrics_to_plot_fig2
]

# To sort bars within facets, we need to define a consistent order.
# Let's sort the data by Element, then Metric, then Score (descending)
detailed_metrics_melted = detailed_metrics.melt(
    id_vars=["Model", "Element"],
    value_vars=metrics_to_plot_fig2,
    var_name="Metric",
    value_name="Score"
).sort_values(by=["Element", "Metric", "Score"], ascending=[True, True, False]) # Sort by score descending

# Get the order of models for the x-axis within each facet based on the sorted data
# This is still tricky with catplot, it tends to default to internal sorting.
# A simpler approach for consistent order is to define it explicitly based on avg performance (like fig 1)
# Let's reuse model_order_fig1 for consistency across plots where possible.

# Adjust size (smaller), use row="Element", col="Metric"
g = sns.catplot(
    x="Model",
    y="Score",
    hue="Model", # Color bars by model
    col="Element",    # Columns for Element types
    row="Metric",     # Rows for Metric types
    data=detailed_metrics_melted,
    kind="bar",
    palette=model_color_map, # Use model colors for hue
    sharey=False,
    height=3.5,       # Reduced height
    aspect=1,         # Adjusted aspect ratio
    order=model_order_fig1, # Set the order for x-axis
    legend_out=True # Explicitly keep legend outside
)

# Improve plot appearance
g.fig.suptitle("Anomaly Detection Metrics by Node Type on Test Set", y=1.02, fontsize=16)
g.set_axis_labels("Model", "Score")
g.set_xticklabels(rotation=45, ha="right")
# Adjust titles for each subplot row/column
g.set_titles(row_template="{row_name}", col_template="{col_name} Anomalies") # e.g., "AP" and "Node (member) Anomalies"

# Add a single legend for models
# The legend is automatically added by catplot with hue=, just need to position it
g.add_legend(title="Model", bbox_to_anchor=(1.02, 0.5), loc='center left')

# Omit values on bars as discussed.

plt.tight_layout(rect=[0, 0, 0.85, 1.02]) # Adjust layout to make space for legend on the right
plt.savefig("detailed_metrics_by_node_type.png")
plt.show()


# --- 5. Create Figure 3: Precision and Recall at Top-K Anomalies ---
# Exclude Edge

# Use the combined_summary_nodes_only data
pr_k_cols = ["Precision@50", "Recall@50", "Precision@100", "Recall@100", "Precision@200", "Recall@200"]
pr_k_data = combined_summary_nodes_only[combined_summary_nodes_only.Split == "test"][["Model", "Element"] + pr_k_cols]

# Melt the DataFrame
pr_k_melted = pr_k_data.melt(
    id_vars=["Model", "Element"],
    value_vars=pr_k_cols,
    var_name="Metric_k",
    value_name="Score"
)

# Extract Metric Type and k value
pr_k_melted[['Metric_Type', 'k']] = pr_k_melted['Metric_k'].str.split('@', expand=True)
pr_k_melted['k'] = pr_k_melted['k'].astype(int)

# Sort the melted data to potentially influence line drawing order (optional but can make consistent)
pr_k_melted = pr_k_melted.sort_values(by=["Element", "Metric_Type", "k", "Score"], ascending=[True, True, True, False])



# Example Placeholder Color Map (replace with your actual map)
model_color_map = {
    'Our Model (BGAE)': sns.color_palette('tab10')[0],
    'DOMINANT': sns.color_palette('tab10')[1],
    'Isolation Forest': sns.color_palette('tab10')[2],
    'Oddball': sns.color_palette('tab10')[3]
}

# --- Figure 3: Precision and Recall at Top-K Anomalies ---

# Adjust size (smaller)
# Let relplot create the legend automatically (default legend_out=True)
g = sns.relplot(
    x="k",
    y="Score",
    hue="Model",         # Color lines by model
    style="Metric_Type", # Use different line styles for Precision and Recall
    col="Element",       # Create a column of plots for each Element type
    data=pr_k_melted,
    kind="line",
    palette=model_color_map,
    facet_kws={'sharey': False, 'sharex': True}, # Share x-axis (k), but not y-axis (score range might differ)
    height=4,            # Adjusted height
    aspect=1.2           # Adjusted aspect ratio
    # No explicit legend=True/False or add_legend here
)

# Improve plot appearance
g.fig.suptitle("Precision and Recall at Top-K Node Anomalies on Test Set", y=1.02, fontsize=16)
g.set_axis_labels("Number of Top Anomalies (k)", "Score")
g.set_titles("{col_name} Anomalies") # Titles for each subplot (e.g., "Node (member) Anomalies")

# Set x-axis ticks to correspond to your k values (50, 100, 200)
g.set(xticks=[50, 100, 200], xlim=(40, 210)) # Adjust xlim slightly to give space

# --- Adjust the LEGEND created automatically by relplot ---
# Access the legend object created by relplot when using hue/style/size
# It contains entries for both hue ('Model') and style ('Metric_Type')
legend = g.legend

# Set its title
# You can set a single combined title, or remove the default ones if they cause overlap
legend.set_title("Legend") # Simple combined title

# Set its position relative to the figure
# Use a tuple (x, y) where (0,0) is bottom-left and (1,1) is top-right of the figure
# (1.02, 0.5) places the anchor point just outside the right edge, vertically centered
legend.set_bbox_to_anchor((1.02, 0.5))

# Set the location of the legend box relative to its anchor point
# 'center left' means the center-left corner of the legend box is placed at (1.02, 0.5)
legend.loc = 'center left'

# Optional: Further adjustments if the internal titles/entries still overlap
# You might need to iterate through legend.get_texts() or legend.get_lines()
# This is more advanced and often unnecessary if positioning is correct.
# For instance, to hide default titles:
# for text in legend.get_texts():
#     if text.get_text() in ['Model', 'Metric_Type']:
#         text.set_visible(False)


# Use tight_layout with rect to make space for the legend on the right
# The rect tuple [left, bottom, right, top] defines the area for the grid.
# We make the 'right' edge less than 1.0 to leave space on the right for the legend.
# Adjust the 'right' value (e.g., 0.85) until the legend fits comfortably.
plt.tight_layout(rect=[0, 0, 0.85, 1.02]) # Adjust rect as needed (e.g., 0.85 leaves ~15% space on the right)


plt.savefig("pr_at_k_node_lines.png")
plt.show()

## 6. Robust evaluation : Several Runs 

In [335]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from collections import defaultdict
import time
import random
import os
from typing import Dict, Tuple, Optional, List

In [None]:
random_seeds = list(range(5))

# Use a standard dictionary for this structure
DATA_SPLITS_ITERATIONS = {}

# --- Load Base Data (Once) ---
print("Loading base data...")
df_member_features, members_dataset = load_member_features("../data/final_members_df.pickle")
df_provider_features, providers_dataset = load_provider_features("../data/final_df.pickle")
_, train_edges, val_edges, test_edges = load_claims_data_with_splitting("../data/df_descriptions.pickle", members_dataset=members_dataset, providers_dataset=providers_dataset)

# Create base HeteroData splits
base_train_g, base_val_g, base_test_g = prepare_hetero_data_with_splitting(df_member_features, df_provider_features, train_edges, val_edges, test_edges)
print("Base data loaded and split.")

# --- Anomaly Injection Params ---
p_node_anomalies = 0.05
p_edge_anomalies = 0.1
lambda_structural_injection = 0.5

# --- Loop through seeds to generate and store data variations ---
print(f"Generating {len(random_seeds)} data iterations...")
for random_seed in random_seeds:
    print(f"  Processing seed: {random_seed}")
    # Use .clone() to avoid modifying base graphs if inject_scenario_anomalies works in-place
    current_train_g = base_train_g.clone()
    current_val_g = base_val_g.clone()
    current_test_g = base_test_g.clone()

    # Inject anomalies for each split using the current seed
    train_g_inj, gt_n_tr, gt_e_tr, track_tr = inject_scenario_anomalies(
        current_train_g, p_node_anomalies=p_node_anomalies, p_edge_anomalies=p_edge_anomalies, lambda_structural=lambda_structural_injection, seed=random_seed
    )
    val_g_inj, gt_n_val, gt_e_val, track_val = inject_scenario_anomalies(
        current_val_g, p_node_anomalies=p_node_anomalies, p_edge_anomalies=p_edge_anomalies, lambda_structural=lambda_structural_injection, seed=current_seed
    )
    test_g_inj, gt_n_test, gt_e_test, track_test = inject_scenario_anomalies(
        current_test_g, p_node_anomalies=p_node_anomalies, p_edge_anomalies=p_edge_anomalies, lambda_structural=lambda_structural_injection, seed=current_seed
    )

    # Store the data for this iteration, keyed by the seed
    DATA_SPLITS_ITERATIONS[random_seed] = {
        'train_graph': train_g_inj,
        'val_graph': val_g_inj,
        'test_graph': test_g_inj,
        'gt_node_labels': {
            "train": gt_n_tr,
            "val": gt_n_val,
            "test": gt_n_test
        },
        'gt_edge_labels': {
            "train": gt_e_tr,
            "val": gt_e_val,
            "test": gt_e_test
        },
        'anomaly_tracking': {
            "train": track_tr,
            "val": track_val,
            "test": track_test
        }
    }

print(f"Stored data and labels for {len(DATA_SPLITS_ITERATIONS)} iterations.")

In [337]:
def run_bgae_iteration(
    iteration_data: Dict,
    best_params: Dict,
    device: torch.device,
    seed: int
    ) -> Tuple[pd.DataFrame, pd.DataFrame, Dict]:
    """
    Trains and evaluates the BGAE model for a single iteration (seed).

    Args:
        iteration_data (Dict): Dictionary containing 'train_graph', 'val_graph', 'test_graph',
                               'gt_node_labels', 'gt_edge_labels', 'anomaly_tracking'.
        best_params (Dict): Dictionary of the selected best hyperparameters for BGAE.
        device: The torch device to use.
        seed: The random seed used for this iteration (for reproducibility).

    Returns:
        Tuple[pd.DataFrame, pd.DataFrame, Dict]:
            - summary_df: Overall performance metrics for this iteration.
            - anomaly_type_df: Per-anomaly-type metrics for this iteration.
            - all_scores: Raw anomaly scores for nodes and edges for this iteration.
    """
    # 1. Set Seed for reproducibility of this specific run
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if device.type == 'cuda': torch.cuda.manual_seed_all(seed)
    elif device.type == 'mps': torch.mps.manual_seed(seed)

    # 2. Prepare Data for this Iteration
    train_g = iteration_data['train_graph']
    val_g = iteration_data['val_graph']
    test_g = iteration_data['test_graph']
    gt_node_labels = iteration_data['gt_node_labels']
    gt_edge_labels = iteration_data['gt_edge_labels']
    anomaly_tracking = iteration_data['anomaly_tracking']

    train_g_dev = train_g.to(device)
    val_g_dev = val_g.to(device)
    test_g_dev = test_g.to(device)
    gt_node_labels_val_dev = {k: v.to(device) for k, v in gt_node_labels["val"].items()}
    gt_edge_labels_val_dev = {k: v.to(device) for k, v in gt_edge_labels["val"].items()}

    # 3. Get Model Params & Instantiate Model/Optimizer
    iter_params = best_params.copy()
    final_lambda_struct = iter_params['lambda_struct']
    final_lambda_attr = 1.0 - final_lambda_struct

    # Assuming dimensions are consistent across iterations or obtained elsewhere
    # Replace these with actual dimension retrievals if necessary
    in_dim_member = train_g['member'].x.size(1)
    in_dim_provider = train_g['provider'].x.size(1)
    edge_attr = train_g.get(('provider', 'to', 'member'), {}).get('edge_attr', None)
    edge_dim = edge_attr.size(1) if edge_attr is not None else 0


    model = BipartiteGraphAutoEncoder_ReportBased(
        in_dim_member=in_dim_member,
        in_dim_provider=in_dim_provider,
        edge_dim=edge_dim,
        hidden_dim=iter_params['hidden_dim'],
        latent_dim=iter_params['latent_dim'],
        num_conv_layers=iter_params['num_conv_layers'],
        num_dec_layers=iter_params['num_dec_layers'],
        dropout=iter_params['dropout']
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=iter_params['learning_rate'],
                                 weight_decay=iter_params.get('weight_decay', 1e-5))

    # 4. Train Model
    num_epochs_final = 1000
    k_neg_train = 5
    val_k_neg_score = 0
    eval_freq = 10 # Can make this larger if not needed for early stopping analysis here
    target_edge_type = ('provider', 'to', 'member')
    early_stop_patience_final = 100 # Patience used during final training

    # Note: train_model_inductive_with_metrics is assumed available
    trained_model_iter, _, _ = train_model_inductive_with_metrics(
        model=model,
        train_graph=train_g_dev,
        num_epochs=num_epochs_final,
        optimizer=optimizer,
        lambda_attr=final_lambda_attr,
        lambda_struct=final_lambda_struct,
        k_neg_samples=k_neg_train,
        target_edge_type=target_edge_type,
        device=device,
        log_freq=9999, # Suppress logs during repeated runs
        # Validation Args
        val_graph=val_g_dev,
        gt_node_labels_val=gt_node_labels_val_dev,
        gt_edge_labels_val=gt_edge_labels_val_dev,
        val_log_freq=eval_freq,
        # Validation Scoring Args
        val_lambda_attr=final_lambda_attr,
        val_lambda_struct=final_lambda_struct,
        val_k_neg_samples_score=val_k_neg_score,
        node_k_list=[50], # Minimal list sufficient for early stopping logic
        # Early Stopping Args
        early_stopping_metric='AP',
        early_stopping_element='Avg AP',
        patience=early_stop_patience_final,
        save_best_model_path=None # No saving during these runs
    )

    # 5. Evaluate Model on All Splits
    eval_params = {
        "k_list": [50, 100, 200, 500], # Use desired Ks for final reporting
        "lambda_attr": final_lambda_attr,
        "lambda_struct": final_lambda_struct,
        "k_neg_samples_score": val_k_neg_score # Use k=0 for eval scoring
    }
    # Note: evaluate_model_inductively is assumed available
    all_scores, summary_df, anomaly_type_df = evaluate_model_inductively(
        trained_model=trained_model_iter,
        train_graph=train_g_dev,
        val_graph=val_g_dev,
        test_graph=test_g_dev,
        gt_node_labels=gt_node_labels, # Pass full dict for this iter
        gt_edge_labels=gt_edge_labels, # Pass full dict for this iter
        anomaly_tracking_all=anomaly_tracking, # Pass full dict for this iter
        device=device,
        eval_params=eval_params,
        target_edge_type=target_edge_type,
        plot=False,
        verbose=False
    )

    # 6. Clean up (Optional but good practice)
    del trained_model_iter
    del optimizer
    if device.type == 'cuda': torch.cuda.empty_cache()
    elif device.type == 'mps': torch.mps.empty_cache()

    return summary_df, anomaly_type_df, all_scores

In [None]:
all_summary_dfs = []
all_anomaly_type_dfs = []
all_run_scores = {} # Store scores per iteration if needed

for seed in random_seeds:
    print(f"\n===== Running BGAE Iteration for Seed: {seed} =====")
    iteration_data = DATA_SPLITS_ITERATIONS.get(seed)
    if iteration_data:
        try:
            summary_df, anomaly_type_df, scores = run_bgae_iteration(
                iteration_data=iteration_data,
                best_params=best_params, # Your optimal hyperparameters
                device=device,
                seed=seed
            )
            # Store results
            summary_df['seed'] = seed
            anomaly_type_df['seed'] = seed
            all_summary_dfs.append(summary_df)
            all_anomaly_type_dfs.append(anomaly_type_df)
            all_run_scores[seed] = scores
            print(f"--- Finished Iteration {seed} ---")
        except Exception as e:
            print(f"--- ERROR in Iteration {seed}: {e} ---")
            # Store error markers if needed
            all_summary_dfs.append(pd.DataFrame([{'seed': seed, 'error': str(e)}]))
            all_anomaly_type_dfs.append(pd.DataFrame([{'seed': seed, 'error': str(e)}]))
    else:
        print(f"Warning: Data for seed {seed} not found in DATA_SPLITS_ITERATIONS.")

# --- Aggregate results after the loop ---
if all_summary_dfs:
    final_summary_df = pd.concat(all_summary_dfs, ignore_index=True)
    print("\n--- Aggregated Overall Results ---")
if all_anomaly_type_dfs:
     final_anomaly_type_df = pd.concat(all_anomaly_type_dfs, ignore_index=True)
     print("\n--- Aggregated Anomaly Type Results ---")


In [None]:
all_run_scores.get(0)

In [None]:
# --- Aggregation & Final Reporting ---
print("\n--- Aggregating Results Across All Iterations ---")

# --- Aggregate Overall Summary ---
if all_summary_dfs:
    full_summary_results_df = pd.concat(all_summary_dfs, ignore_index=True)
    metrics_to_check_agg = ['AUROC', 'AP', 'Best F1'] # Core metrics to check for validity
    cols_exist = all(m in full_summary_results_df.columns for m in metrics_to_check_agg)

    if cols_exist:
        # Filter out rows where core metrics might be NaN (indicating a failed run segment)
        successful_summary_df = full_summary_results_df.dropna(subset=metrics_to_check_agg, how='any')
    else:
        print("Warning: Core metric columns (AUROC, AP, Best F1) missing in overall summary. Aggregation might be inaccurate.")
        successful_summary_df = full_summary_results_df # Proceed cautiously

    if not successful_summary_df.empty:
        print(f"Aggregating results from {successful_summary_df['seed'].nunique()} successful runs (overall summary)...")
        # Define metrics to aggregate
        metrics_to_aggregate = metrics_to_check_agg + \
                               [f'{p}@{k}' for k in node_k_list_eval for p in ['Precision', 'Recall'] if f'{p}@{k}' in successful_summary_df.columns]

        # Calculate mean and std dev
        aggregated_summary = successful_summary_df.groupby(['Split', 'Element'])[metrics_to_aggregate].agg(['mean', 'std'])
        run_counts_summary = successful_summary_df.groupby(['Split', 'Element'])['seed'].nunique().rename('n_runs')
        aggregated_summary = pd.concat([run_counts_summary, aggregated_summary], axis=1)

        print("\n--- Aggregated Performance Summary (Mean ± Std Dev) ---")
        # Format for printing
        aggregated_summary_print = aggregated_summary.copy()
        # Flatten MultiIndex (handle potential errors if agg resulted in single level)
        if isinstance(aggregated_summary_print.columns, pd.MultiIndex):
             aggregated_summary_print.columns = ['_'.join(map(str, col)).strip('_') for col in aggregated_summary_print.columns.values]
        else:
            aggregated_summary_print.columns = aggregated_summary_print.columns # Keep as is if already flat

        for col in metrics_to_aggregate:
            mean_col = f'{col}_mean'
            std_col = f'{col}_std'
            # Check if columns exist after potential flattening/aggregation quirks
            if mean_col in aggregated_summary_print.columns and std_col in aggregated_summary_print.columns:
                 aggregated_summary_print[f'{col}'] = aggregated_summary_print[mean_col].map('{:.4f}'.format) + \
                                                      ' ± ' + \
                                                      aggregated_summary_print[std_col].map('{:.4f}'.format)
                 # Drop only if both exist to avoid errors
                 aggregated_summary_print = aggregated_summary_print.drop(columns=[mean_col, std_col], errors='ignore')
            elif mean_col in aggregated_summary_print.columns:
                 aggregated_summary_print[f'{col}'] = aggregated_summary_print[mean_col].map('{:.4f}'.format) + ' ± nan'
                 aggregated_summary_print = aggregated_summary_print.drop(columns=[mean_col], errors='ignore')

        print(aggregated_summary_print.to_string())
        final_summary_df = aggregated_summary # Store numerical version

    else:
        print("No successful runs found in overall summary results to aggregate.")
else:
    print("No overall summary DataFrames found in all_summary_dfs to aggregate.")



In [None]:
# --- Aggregate Anomaly Type Metrics ---

if all_anomaly_type_dfs:
    full_anomaly_type_results_df = pd.concat(all_anomaly_type_dfs, ignore_index=True)
    print("\n--- Debug: Columns in full_anomaly_type_results_df ---")
    print(full_anomaly_type_results_df.columns)
    # print(full_anomaly_type_results_df.head()) # Uncomment to see first few rows

    # Filter out failed runs based on a key metric column existence or specific error marker
    metrics_to_check_atype = ['AUROC', 'AP', 'Best F1']
    # Check which core metrics *actually exist* in the concatenated DataFrame
    metrics_present_check = [m for m in metrics_to_check_atype if m in full_anomaly_type_results_df.columns]

    if not metrics_present_check:
         print("Warning: Core metric columns (AUROC, AP, Best F1) missing entirely from anomaly type results. Cannot filter or aggregate metrics.")
         successful_anomaly_df = full_anomaly_type_results_df # Proceed with caution, only stats/counts will be aggregated
    else:
        # Filter rows where *at least one* of the essential metrics is valid (not NaN)
        successful_anomaly_df = full_anomaly_type_results_df.dropna(subset=metrics_present_check, how='all').copy() # Use .copy() to avoid SettingWithCopyWarning
        # Alternatively, use 'any' if you want rows where *all* essential metrics are non-NaN:
        # successful_anomaly_df = full_anomaly_type_results_df.dropna(subset=metrics_present_check, how='any').copy()

    if not successful_anomaly_df.empty:
        n_runs_represented = successful_anomaly_df['seed'].nunique()
        print(f"\nAggregating results from {n_runs_represented} successful runs (anomaly type summary)...")

        # Define metrics and stats columns to aggregate AGAIN based on the successful DataFrame
        metrics_to_aggregate_atype = ['AUROC', 'AP', 'Best F1']
        score_stats_to_aggregate = ['Mean Score', 'Median Score']
        metrics_present_atype = [m for m in metrics_to_aggregate_atype if m in successful_anomaly_df.columns]
        score_stats_present = [s for s in score_stats_to_aggregate if s in successful_anomaly_df.columns]

        print(f"Metrics available for aggregation: {metrics_present_atype}")
        print(f"Score stats available for aggregation: {score_stats_present}")

        # Define grouping columns - ensure they exist
        base_grouping_cols = ['Split', 'Anomaly Tag']
        # Determine optional grouping cols based on what's actually in the successful_anomaly_df
        optional_grouping_cols = []
        if 'Element Type' in successful_anomaly_df.columns:
            optional_grouping_cols.append('Element Type')
        elif 'Node Type' in successful_anomaly_df.columns: # Fallback if Element Type isn't there but Node Type is
             optional_grouping_cols.append('Node Type')
        grouping_cols_atype = base_grouping_cols + optional_grouping_cols

        if not all(col in successful_anomaly_df.columns for col in grouping_cols_atype):
            print(f"Error: Cannot aggregate. Missing one or more grouping columns {grouping_cols_atype} in filtered anomaly type results.")
        else:
            # Aggregate performance metrics if any are present
            if metrics_present_atype:
                aggregated_metrics_atype = successful_anomaly_df.groupby(
                    grouping_cols_atype, observed=True # Use observed=True for stability with categorical data
                    )[metrics_present_atype].agg(['mean', 'std'])
                # Flatten metrics multi-index
                if isinstance(aggregated_metrics_atype.columns, pd.MultiIndex):
                     aggregated_metrics_atype.columns = ['_'.join(map(str, col)).strip('_') for col in aggregated_metrics_atype.columns.values]

            else:
                aggregated_metrics_atype = pd.DataFrame() # Create empty df if no metrics to aggregate

            # Aggregate score stats and counts separately
            stats_count_cols_to_agg = score_stats_present + ['Count', 'Proportion (%)', 'seed']
            stats_count_cols_present = [c for c in stats_count_cols_to_agg if c in successful_anomaly_df.columns]

            if stats_count_cols_present:
                agg_dict = {f'{col}_mean': pd.NamedAgg(column=col, aggfunc='mean') for col in stats_count_cols_present if col != 'seed'}
                agg_dict.update({f'{col}_std': pd.NamedAgg(column=col, aggfunc='std') for col in stats_count_cols_present if col != 'seed'})
                agg_dict['n_runs'] = pd.NamedAgg(column='seed', aggfunc='nunique')

                aggregated_stats_counts = successful_anomaly_df.groupby(
                        grouping_cols_atype, observed=True
                        ).agg(**agg_dict) # Use dictionary unpacking for NamedAgg
            else:
                 aggregated_stats_counts = pd.DataFrame() # Create empty if no stats/counts columns

            # Combine aggregated parts
            # Check if both have data before concatenating
            if not aggregated_stats_counts.empty and not aggregated_metrics_atype.empty:
                 aggregated_anomaly_type = pd.concat([aggregated_stats_counts, aggregated_metrics_atype], axis=1)
            elif not aggregated_stats_counts.empty:
                 aggregated_anomaly_type = aggregated_stats_counts
            elif not aggregated_metrics_atype.empty:
                  aggregated_anomaly_type = aggregated_metrics_atype
            else:
                  aggregated_anomaly_type = pd.DataFrame() # Both were empty


            if not aggregated_anomaly_type.empty:
                print("\n--- Aggregated Anomaly Type Performance (Mean ± Std Dev) ---")
                # Select and format key columns for concise printing
                final_anomaly_type_print = aggregated_anomaly_type.copy()
                cols_to_print_atype = [] # Build dynamically

                # Add run count if available
                if 'n_runs' in final_anomaly_type_print.columns: cols_to_print_atype.append('n_runs')
                # Add mean count if available
                if 'Count_mean' in final_anomaly_type_print.columns: cols_to_print_atype.append('Count_mean')
                # Add mean proportion if available
                prop_mean_col = 'Proportion (%)_mean'; prop_std_col = 'Proportion (%)_std'
                if prop_mean_col in final_anomaly_type_print.columns:
                     if prop_std_col in final_anomaly_type_print.columns:
                         final_anomaly_type_print['Proportion (%)'] = final_anomaly_type_print[prop_mean_col].map('{:.1f}'.format) + \
                                                                       ' ± ' + \
                                                                       final_anomaly_type_print[prop_std_col].map('{:.1f}'.format)
                     else:
                         final_anomaly_type_print['Proportion (%)'] = final_anomaly_type_print[prop_mean_col].map('{:.1f}'.format) + ' ± nan'
                     cols_to_print_atype.append('Proportion (%)')


                # Format performance metrics
                for m in metrics_present_atype:
                    mean_col = f'{m}_mean'; std_col = f'{m}_std'
                    col_name_formatted = f'{m}' # Display name
                    if mean_col in final_anomaly_type_print.columns and std_col in final_anomaly_type_print.columns:
                        final_anomaly_type_print[col_name_formatted] = final_anomaly_type_print[mean_col].map('{:.3f}'.format) + \
                                                            ' ± ' + \
                                                            final_anomaly_type_print[std_col].map('{:.3f}'.format)
                        cols_to_print_atype.append(col_name_formatted)
                    elif mean_col in final_anomaly_type_print.columns:
                        final_anomaly_type_print[col_name_formatted] = final_anomaly_type_print[mean_col].map('{:.3f}'.format) + ' ± nan'
                        cols_to_print_atype.append(col_name_formatted)

                # Format score statistics (optional inclusion)
                for s in score_stats_present:
                        mean_col = f'{s}_mean'; std_col = f'{s}_std'
                        col_name_formatted = f'{s}' # Display name
                        if mean_col in final_anomaly_type_print.columns and std_col in final_anomaly_type_print.columns:
                            final_anomaly_type_print[col_name_formatted] = final_anomaly_type_print[mean_col].map('{:.3f}'.format) + \
                                                                ' ± ' + \
                                                                final_anomaly_type_print[std_col].map('{:.3f}'.format)
                            # cols_to_print_atype.append(col_name_formatted) # Uncomment to add to printout
                        elif mean_col in final_anomaly_type_print.columns:
                            final_anomaly_type_print[col_name_formatted] = final_anomaly_type_print[mean_col].map('{:.3f}'.format) + ' ± nan'
                            # cols_to_print_atype.append(col_name_formatted) # Uncomment to add to printout

                # Reorder and reset index for printing
                final_anomaly_type_print = final_anomaly_type_print.reset_index()
                # Ensure columns exist before selecting for final print
                final_print_cols = grouping_cols_atype + [col for col in cols_to_print_atype if col in final_anomaly_type_print.columns]
                print(final_anomaly_type_print[final_print_cols].round(3).to_string(max_rows=100))

                final_anomaly_type_df = aggregated_anomaly_type # Store numerical version

            else:
                print("Aggregated anomaly type DataFrame is empty after processing.")

    else:
        print("\nNo successful anomaly type results found to aggregate.")

else:
    print("No anomaly type DataFrames found in all_anomaly_type_dfs to aggregate.")

In [339]:
import torch
import numpy as np
import pandas as pd
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler # Assuming you might still want scaling
from typing import Dict, Tuple, Optional, List

def run_if_iteration(
    iteration_data: Dict,
    device: torch.device, # Keep device arg for consistency, though IF runs on CPU
    seed: int,
    ) -> Optional[pd.DataFrame]:
    """
    Trains and evaluates the Isolation Forest model for a single iteration's data.

    Args:
        iteration_data (Dict): Dictionary containing 'train_graph', 'val_graph', 'test_graph',
                               'gt_node_labels', 'gt_edge_labels', 'anomaly_tracking'.
        device: The torch device (passed but likely unused by IF).
        seed (int): The random seed used for this iteration (for IF's random_state).

    Returns:
        Optional[pd.DataFrame]: DataFrame containing evaluation metrics for IF on this iteration,
                                or None if an error occurs.
    """
    try:
        # 1. Extract graphs and labels for this iteration
        train_graph = iteration_data['train_graph']
        val_graph = iteration_data['val_graph']
        test_graph = iteration_data['test_graph']
        gt_node_labels = iteration_data['gt_node_labels'] # Contains train, val, test labels

        # 2. Augment Features (Run this for each split needed)
        # Assumes augment_features_for_sklearn returns a dict: {'provider': np.array, 'member': np.array}
        augmented_features_train = augment_features_for_sklearn(train_graph)
        augmented_features_val = augment_features_for_sklearn(val_graph)
        augmented_features_test = augment_features_for_sklearn(test_graph)

        # 3. Initialize results storage for this iteration
        eval_summary_iter = []
        k_list_baselines = [50, 100, 200, 500] # K values for baseline metrics

        # 4. Loop through Node Types
        for node_type in ['provider', 'member']:
            # Extract features and labels for this node type and all splits
            X_train_nt = augmented_features_train.get(node_type)
            y_train_nt_t = gt_node_labels["train"].get(node_type)
            X_val_nt = augmented_features_val.get(node_type)
            y_val_nt_t = gt_node_labels["val"].get(node_type)
            X_test_nt = augmented_features_test.get(node_type)
            y_test_nt_t = gt_node_labels["test"].get(node_type)

            # Basic checks
            if X_train_nt is None or y_train_nt_t is None or X_train_nt.shape[0] == 0:
                print(f"  Skipping IF {node_type}: Missing or empty training data for iteration {seed}.")
                continue
            # Convert labels to numpy
            y_train_nt = y_train_nt_t.cpu().numpy()
            y_val_nt = y_val_nt_t.cpu().numpy() if y_val_nt_t is not None else np.array([])
            y_test_nt = y_test_nt_t.cpu().numpy() if y_test_nt_t is not None else np.array([])


            # 5. Optional: Scale Features (IF is less sensitive than SVM, but can still help)
            # scaler_nt = StandardScaler()
            # X_train_scaled_nt = scaler_nt.fit_transform(X_train_nt)
            # X_val_scaled_nt = scaler_nt.transform(X_val_nt) if X_val_nt is not None and X_val_nt.shape[0] > 0 else X_val_nt
            # X_test_scaled_nt = scaler_nt.transform(X_test_nt) if X_test_nt is not None and X_test_nt.shape[0] > 0 else X_test_nt
            # Using unscaled for simplicity here, add scaling if needed
            X_train_scaled_nt = X_train_nt
            X_val_scaled_nt = X_val_nt
            X_test_scaled_nt = X_test_nt


            # 6. Train Isolation Forest
            # Estimate contamination from the training labels of this iteration
            contamination_est_nt = y_train_nt.mean()
            # Ensure contamination is within valid range (0, 0.5] for IF
            contamination_if = max(0.001, min(0.5, contamination_est_nt))
            if contamination_if == 0.5: contamination_if = 0.499 # Avoid edge case = 0.5

            iforest_nt = IsolationForest(contamination=contamination_if, random_state=seed)
            iforest_nt.fit(X_train_scaled_nt)

            # 7. Predict/Score Anomalies for all splits
            split_data = {
                'train': {'X': X_train_scaled_nt, 'y': y_train_nt},
                'val': {'X': X_val_scaled_nt, 'y': y_val_nt},
                'test': {'X': X_test_scaled_nt, 'y': y_test_nt}
            }

            for split_name, data in split_data.items():
                X_split = data['X']
                y_split = data['y']

                if X_split is None or y_split is None or X_split.shape[0] == 0 or y_split.shape[0] == 0:
                    # print(f"  Skipping evaluation for IF {node_type} on {split_name}: No data.")
                    continue

                if X_split.shape[0] != y_split.shape[0]:
                     print(f"  Warning: Shape mismatch for IF {node_type} on {split_name}. Scores: {X_split.shape[0]}, Labels: {y_split.shape[0]}. Skipping eval.")
                     continue

                scores_split_nt = iforest_nt.decision_function(X_split)
                # Negate scores: IF decision_function gives lower scores for anomalies,
                # but compute_evaluation_metrics expects higher scores for anomalies.
                scores_split_nt = -scores_split_nt

                # 8. Evaluate
                metrics = compute_evaluation_metrics(scores_split_nt, y_split, k_list=k_list_baselines)

                summary_row = {
                    'Model': 'IsolationForest', # Fixed model name
                    'Split': split_name,
                    'Element': f'Node ({node_type})',
                    'Num Items': len(scores_split_nt),
                    'Num Anomalies': int(np.sum(y_split)),
                    '% Anomalies': (np.sum(y_split) / len(scores_split_nt) * 100) if len(scores_split_nt) > 0 else 0,
                    'seed': seed, # Add seed info
                    'iteration': seed # Assuming seed acts as iteration identifier
                }
                summary_row.update(metrics)
                eval_summary_iter.append(summary_row)

    except Exception as e:
        print(f"  ERROR processing Isolation Forest for iteration {seed}: {e}")
        # Return None or an empty DataFrame to indicate failure for this iteration
        return None

    # Convert results for this iteration to DataFrame
    if eval_summary_iter:
        iter_results_df = pd.DataFrame(eval_summary_iter)
        # Reorder columns for clarity if needed (optional)
        ordered_cols = ['Model', 'Split', 'Element', 'seed', 'Num Items', 'Num Anomalies', '% Anomalies',
                       'AUROC', 'AP', 'Best F1', 'Best F1 Threshold'] + \
                       [f'{p}@{k}' for k in k_list_baselines for p in ['Precision', 'Recall']]
        existing_cols_ordered = [col for col in ordered_cols if col in iter_results_df.columns]
        return iter_results_df.reindex(columns=existing_cols_ordered, fill_value=np.nan)
    else:
        return pd.DataFrame() # Return empty dataframe if no nodes were processed



In [None]:
all_if_summaries = []
for seed, iter_data in DATA_SPLITS_ITERATIONS.items():
    print(f"\n===== Running IF Iteration for Seed: {seed} =====")
    if_summary_df_iter = run_if_iteration(
            iteration_data=iter_data,
            device=device, # Pass device even if unused
            seed=seed
        )
    if if_summary_df_iter is not None:
        all_if_summaries.append(if_summary_df_iter)

In [None]:

def run_oddball_iteration(
    iteration_data: Dict,
    seed: int, 
    oddball_anomaly_type: str = 'sc', # 'sc', 'hv', 'de'
    oddball_use_lof: bool = False,
    k_list_eval: List[int] = [50, 100, 200] # Use K list matching your previous code
    ) -> Optional[pd.DataFrame]:
    """
    Evaluates the OddBall model for a single iteration's data splits using
    the provided evaluate_oddball_inductive function.

    Args:
        iteration_data (Dict): Dictionary containing 'train_graph', 'val_graph', 'test_graph',
                               'gt_node_labels' for one seed.
        seed (int): The random seed for this iteration (for logging).
        oddball_anomaly_type (str): Type parameter for OddBall.
        oddball_use_lof (bool): LOF parameter for OddBall.
        k_list_eval (List[int]): K values for P@K, R@K metrics passed to evaluate_oddball_inductive.

    Returns:
        Optional[pd.DataFrame]: DataFrame containing evaluation metrics for OddBall on this iteration,
                                or None if a critical error occurs.
    """
    try:
        graphs_dict = {
            'train': iteration_data['train_graph'],
            'val': iteration_data['val_graph'],
            'test': iteration_data['test_graph']
        }
        gt_node_labels_all_splits = iteration_data['gt_node_labels']
        eval_summary_iter = []

        # Loop through Splits and Node Types
        for split_name, graph_split in graphs_dict.items():
            gt_labels_split = gt_node_labels_all_splits.get(split_name, {})

            for node_type in ['provider', 'member']:
                # Check if node type exists and has labels for this split
                if node_type not in graph_split.node_types or \
                   graph_split[node_type].num_nodes == 0 or \
                   node_type not in gt_labels_split or \
                   gt_labels_split[node_type] is None:
                    # print(f"  Skipping Oddball {node_type} in {split_name} (no nodes or labels).")
                    continue

                # Call your existing evaluation function
                # Ensure evaluate_oddball_inductive is defined/imported
                metrics = evaluate_oddball_inductive(
                    graph_split=graph_split,
                    gt_node_labels_split=gt_labels_split, # Pass dict for this split
                    anomaly_type=oddball_anomaly_type,
                    node_type_to_eval=node_type,
                    use_lof=oddball_use_lof,
                    k_list=k_list_eval
                )

                if metrics: # Check if evaluation returned metrics
                    num_items = graph_split[node_type].num_nodes
                    # Safely get anomaly count
                    label_tensor = gt_labels_split.get(node_type, torch.tensor([]))
                    num_anomalies = int(label_tensor.sum().item())
                    perc = (num_anomalies / num_items * 100) if num_items > 0 else 0

                    summary_row = {
                        'Model': f'OddBall', # Simplified name
                        'Split': split_name,
                        'Element': f'Node ({node_type})',
                        'Num Items': num_items,
                        'Num Anomalies': num_anomalies,
                        '% Anomalies': perc,
                        'seed': seed,
                        'iteration': seed
                    }
                    summary_row.update(metrics) # Add calculated metrics
                    eval_summary_iter.append(summary_row)
                else:
                     print(f"OddBall evaluation failed or returned None for {node_type} in {split_name} split.")

    except Exception as e:
        print(f"  ERROR processing OddBall for iteration {seed}: {e}")
        return None 

    # Convert results for this iteration to DataFrame
    if eval_summary_iter:
        iter_results_df = pd.DataFrame(eval_summary_iter)
        # Optional: Reorder columns if needed for consistency
        # ordered_cols = [...]
        # existing_cols_ordered = [col for col in ordered_cols if col in iter_results_df.columns]
        # return iter_results_df.reindex(columns=existing_cols_ordered, fill_value=np.nan)
        return iter_results_df
    else:
        # Return empty DataFrame if no results were generated (e.g., all node types skipped)
        return pd.DataFrame()

# --- Example Usage within the main loop ---
all_oddball_summaries = []
for seed, iter_data in DATA_SPLITS_ITERATIONS.items():
     print(f"\n===== Running OddBall Iteration for Seed: {seed} =====")
     oddball_summary_df_iter = run_oddball_iteration(
         iteration_data=iter_data,
         seed=seed,
         oddball_anomaly_type='sc', # Or 'hv', 'de'
         oddball_use_lof=False,
         k_list_eval=[50, 100, 200, 500] # Use consistent K list
     )
     if oddball_summary_df_iter is not None:
         all_oddball_summaries.append(oddball_summary_df_iter)
#
# # Aggregate OddBall results after the loop
# if all_oddball_summaries:
#     final_oddball_summary_df = pd.concat(all_oddball_summaries, ignore_index=True)
#     # ... (Perform groupby/aggregation similar to BGAE results) ...
#     print("\n--- Aggregated OddBall Performance ---")
#     # print(aggregated_oddball_print)

In [None]:
import torch
import numpy as np
import pandas as pd
from typing import Dict, Tuple, Optional, List



def run_dominant_iteration(
    iteration_data: Dict,
    device: torch.device, # Keep for consistency, PyGOD uses 'gpu' param
    seed: int,
    # --- DOMINANT Specific Parameters ---
    dominant_epochs: int = 200,
    # Add other DOMINANT params if needed (e.g., num_layers, weight_decay)
    k_list_eval: List[int] = [50, 100, 200, 500]
    ) -> Optional[pd.DataFrame]:
    """
    Trains and evaluates the DOMINANT model for a single iteration's data splits.

    Args:
        iteration_data (Dict): Contains graphs and labels for one seed.
        device: Target device (PyGOD uses gpu=-1 for CPU, 0 for GPU 0, etc.).
        seed (int): Random seed for logging/consistency (DOMINANT doesn't have random_state).
        dominant_epochs (int): Number of training epochs for DOMINANT.
        k_list_eval (List[int]): K values for metric calculation.

    Returns:
        Optional[pd.DataFrame]: DataFrame with evaluation metrics for DOMINANT, or None if error.
    """
    try:
        # 1. Extract graphs and labels
        train_graph_hetero = iteration_data['train_graph']
        val_graph_hetero = iteration_data['val_graph']
        test_graph_hetero = iteration_data['test_graph']
        gt_node_labels_all_splits = iteration_data['gt_node_labels']

        eval_summary_iter = []
        pygod_device_id = 0 if device.type == 'cuda' else (-1 if device.type == 'cpu' else -1) # Map to PyGOD convention

        # 2. Prepare Homogeneous Data
        # Important: PyGOD models often expect data on CPU for fitting/predicting
        # but can use GPU internally if specified. Let's prepare on CPU.
        train_g_homo = prepare_homogeneous_data(train_graph_hetero, target_device=torch.device('cpu'))
        val_g_homo = prepare_homogeneous_data(val_graph_hetero, target_device=torch.device('cpu'))
        test_g_homo = prepare_homogeneous_data(test_graph_hetero, target_device=torch.device('cpu'))

        if train_g_homo is None or not hasattr(train_g_homo, 'x') or train_g_homo.x is None:
            print(f"  ERROR: Failed to prepare valid homogeneous training data for DOMINANT (seed {seed}). Skipping iteration.")
            return None

        # 3. Estimate Contamination
        total_train_nodes = train_g_homo.num_nodes
        total_train_anomalies = sum(gt_node_labels_all_splits['train'][nt].sum().item() for nt in gt_node_labels_all_splits['train'] if nt in train_graph_hetero.node_types)
        contamination_est = total_train_anomalies / total_train_nodes if total_train_nodes > 0 else 0.1
        contamination_est = max(0.001, min(0.5, contamination_est)) # Ensure in valid range

        # 4. Instantiate and Train DOMINANT
        print(f"  Training DOMINANT for iteration {seed} (Contamination: {contamination_est:.4f})...")
        model_dominant = DOMINANT(
            epoch=dominant_epochs,
            contamination=contamination_est,
            gpu=pygod_device_id, # Pass mapped device ID
            verbose=0 # Keep logs clean
        )
        start_fit = time.time()
        model_dominant.fit(train_g_homo)
        end_fit = time.time()
        print(f"    Fit completed in {end_fit - start_fit:.2f}s.")

        # 5. Evaluate on all splits
        print(f"  Evaluating DOMINANT for iteration {seed}...")
        for split_name, hetero_graph, homo_graph in [('train', train_graph_hetero, train_g_homo),
                                                    ('val', val_graph_hetero, val_g_homo),
                                                    ('test', test_graph_hetero, test_g_homo)]:
            if homo_graph is None or homo_graph.num_nodes == 0:
                print(f"    Skipping evaluation for {split_name}: Homogeneous graph missing or empty.")
                continue
            if not hasattr(homo_graph, 'x') or homo_graph.x is None:
                 print(f"    Skipping evaluation for {split_name}: Homogeneous graph missing features 'x'.")
                 continue
            # DOMINANT also needs edge_index, check if it exists
            if not hasattr(homo_graph, 'edge_index') or homo_graph.edge_index is None:
                print(f"    Warning: Evaluating DOMINANT on {split_name} graph with no 'edge_index'. May be unreliable.")
                # Create empty edge_index if missing and nodes exist, needed by PyGOD
                if homo_graph.num_nodes > 0:
                     homo_graph.edge_index = torch.empty((2,0), dtype=torch.long, device=homo_graph.x.device)
                else: continue # Skip if no nodes either


            # Get scores (higher score = more anomalous in PyGOD)
            homo_scores_split = model_dominant.decision_function(homo_graph)

            # Map scores back to original node types
            gt_labels_split = gt_node_labels_all_splits.get(split_name, {})
            mapped_results = map_homo_results(homo_graph, homo_scores_split, hetero_graph, gt_labels_split)

            # Calculate metrics per node type
            for node_type, type_data in mapped_results.items():
                scores_raw = type_data.get('scores')
                labels_raw = type_data.get('labels')
                
                if scores_raw is not None and isinstance(scores_raw, torch.Tensor):
                    scores_np = scores_raw.cpu().numpy()
                elif scores_raw is not None: # Assume it's already numpy or list-like
                    scores_np = np.asarray(scores_raw)
                else:
                    scores_np = np.array([])

                if labels_raw is not None and isinstance(labels_raw, torch.Tensor):
                    labels_np = labels_raw.cpu().numpy()
                elif labels_raw is not None: # Assume it's already numpy or list-like
                    labels_np = np.asarray(labels_raw)
                else:
                    labels_np = np.array([])

                if scores_np is not None and labels_np is not None and \
                   len(scores_np) > 0 and len(labels_np) > 0 and \
                   len(scores_np) == len(labels_np):

                    metrics = compute_evaluation_metrics(scores_np, labels_np, k_list=k_list_eval)

                    summary_row = {
                        'Model': 'DOMINANT',
                        'Split': split_name,
                        'Element': f'Node ({node_type})',
                        'Num Items': len(scores_np),
                        'Num Anomalies': int(np.sum(labels_np)),
                        '% Anomalies': (np.sum(labels_np) / len(scores_np) * 100) if len(scores_np) > 0 else 0,
                        'seed': seed,
                        'iteration': seed
                    }
                    summary_row.update(metrics)
                    eval_summary_iter.append(summary_row)
                # else:
                #     print(f"    Skipping metric calculation for {node_type} in {split_name} (invalid scores/labels).")

    except ImportError:
        print("Error: pygod library not found. Please install it (`pip install pygod`) to run DOMINANT.")
        return None
    except Exception as e:
        print(f"  ERROR processing DOMINANT for iteration {seed}: {e}")
        return None # Indicate failure for this iteration

    # Convert results for this iteration to DataFrame
    if eval_summary_iter:
        iter_results_df = pd.DataFrame(eval_summary_iter)
        return iter_results_df
    else:
        return pd.DataFrame() # Return empty if no results


all_dominant_summaries = []
for seed, iter_data in DATA_SPLITS_ITERATIONS.items():
    print(f"\n===== Running DOMINANT Iteration for Seed: {seed} =====")
    dominant_summary_df_iter = run_dominant_iteration(
         iteration_data=iter_data,
         device=device, # Pass target device
         seed=seed,
         dominant_epochs=50, # Example epoch count
         k_list_eval=[50, 100, 200, 500]
     )
    if dominant_summary_df_iter is not None:
        all_dominant_summaries.append(dominant_summary_df_iter)
#
# # Aggregate DOMINANT results after the loop
# if all_dominant_summaries:
#     final_dominant_summary_df = pd.concat(all_dominant_summaries, ignore_index=True)
#     # ... (Perform groupby/aggregation similar to BGAE/IF results) ...
#     print("\n--- Aggregated DOMINANT Performance ---")
#     # print(aggregated_dominant_print)

In [None]:
all_dominant_summaries

In [None]:
import pandas as pd
import numpy as np
from typing import Dict, Tuple, Optional, List


def aggregate_and_compare_model_results(
    bgae_results_list: List[pd.DataFrame],
    if_results_list: List[pd.DataFrame],
    oddball_results_list: List[pd.DataFrame],
    dominant_results_list: List[pd.DataFrame],
    metrics_to_compare: List[str] = ['AUROC', 'AP', 'Best F1'],
    k_list_eval: List[int] = [50, 100, 200, 500], # K values used during evaluation
    splits_to_compare: List[str] = ['test'] # Focus on test set by default
    ) -> Optional[pd.DataFrame]:
    """
    Aggregates results from multiple runs for different models and creates
    a comparison DataFrame showing mean ± std dev for specified metrics.

    Args:
        bgae_results_list: List of summary DataFrames from BGAE runs.
        if_results_list: List of summary DataFrames from Isolation Forest runs.
        oddball_results_list: List of summary DataFrames from OddBall runs.
        dominant_results_list: List of summary DataFrames from DOMINANT runs.
        metrics_to_compare: List of primary performance metrics to show in the table.
        k_list_eval: List of K values used, to include relevant P@K/R@K metrics.
        splits_to_compare: List of data splits to include in the comparison (e.g., ['test'], ['val', 'test']).

    Returns:
        Optional[pd.DataFrame]: A DataFrame comparing models, indexed by Split and Element,
                                 with columns for each model showing 'mean ± std' metrics,
                                 or None if aggregation fails.
    """
    print("\n--- Aggregating and Comparing Model Performances ---")

    all_results_list = []
    model_data_map = {
        "BGAE": bgae_results_list,
        "IsolationForest": if_results_list,
        "OddBall": oddball_results_list,
        "DOMINANT": dominant_results_list
    }

    # --- 1. Concatenate and Preprocess Results for Each Model ---
    processed_model_dfs = {}
    metrics_to_check = ['AUROC', 'AP', 'Best F1'] # Ensure these core metrics exist

    for model_name, results_list in model_data_map.items():
        if not results_list:
            print(f"Warning: No results found for model '{model_name}'. Skipping.")
            continue
        try:
            full_df = pd.concat(results_list, ignore_index=True)
            # Filter for relevant splits
            filtered_df = full_df[full_df['Split'].isin(splits_to_compare)].copy()
            # Drop rows where essential metrics are NaN (likely failed runs)
            cols_exist = all(m in filtered_df.columns for m in metrics_to_check)
            if cols_exist:
                 successful_df = filtered_df.dropna(subset=metrics_to_check, how='any')
            else:
                 print(f"Warning: Core metrics missing for {model_name}. Proceeding with available data.")
                 successful_df = filtered_df

            if successful_df.empty:
                 print(f"Warning: No successful runs found for model '{model_name}' in specified splits. Skipping.")
                 continue

            processed_model_dfs[model_name] = successful_df
            print(f"Processed {model_name}: Found {successful_df['seed'].nunique()} successful runs in specified splits.")

        except Exception as e:
            print(f"Error processing results for model '{model_name}': {e}")

    if not processed_model_dfs:
        print("Error: No processed results available for any model. Cannot create comparison.")
        return None

    # --- 2. Aggregate Metrics for Each Model ---
    aggregated_data_list = []
    full_metric_list = metrics_to_compare + \
                       [f'{p}@{k}' for k in k_list_eval for p in ['Precision', 'Recall']]

    for model_name, df in processed_model_dfs.items():
        # Ensure metrics to aggregate actually exist in this df
        metrics_present = [m for m in full_metric_list if m in df.columns]
        if not metrics_present:
            print(f"Warning: No specified metrics found for model '{model_name}'. Skipping aggregation.")
            continue

        try:
            # Group by Split and Element, calculate mean and std
            agg_df = df.groupby(['Split', 'Element'])[metrics_present].agg(['mean', 'std'])
            # Add run count
            run_counts = df.groupby(['Split', 'Element'])['seed'].nunique().rename('n_runs')
            agg_df = pd.concat([run_counts, agg_df], axis=1)
            agg_df['Model'] = model_name # Add model identifier
            aggregated_data_list.append(agg_df.reset_index()) # Reset index for easier merging later
        except Exception as e:
            print(f"Error aggregating data for model '{model_name}': {e}")

    if not aggregated_data_list:
        print("Error: Aggregation failed for all models.")
        return None

    # --- 3. Combine and Format the Comparison Table ---
    comparison_df = pd.concat(aggregated_data_list, ignore_index=True)

    # Pivot table for comparison view
    try:
        pivot_df = comparison_df.pivot_table(
            index=['Split', 'Element'],
            columns='Model'
            # Values will be selected below
        )
    except Exception as e:
         print(f"Error pivoting comparison table: {e}")
         return None # Cannot proceed if pivot fails


    # Format columns as 'Metric (Mean ± Std)'
    final_comparison_cols = {} # Store formatted columns keyed by original metric name
    n_runs_cols = {} # Store n_runs separately

    # Ensure metrics_present considers columns available across *all* processed models if needed,
    # or handle missing columns per model during formatting. Let's handle per model.
    all_metrics_in_pivot = set(lvl0 for lvl0, lvl1 in pivot_df.columns if lvl1=='mean') & set(metrics_to_compare + [f'{p}@{k}' for k in k_list_eval for p in ['Precision', 'Recall']])


    for metric in all_metrics_in_pivot:
        metric_mean_cols = [col for col in pivot_df.columns if col[0] == metric and col[1] == 'mean']
        metric_std_cols = [col for col in pivot_df.columns if col[0] == metric and col[1] == 'std']

        # Create the formatted column for this metric
        formatted_series_list = []
        models_in_pivot = pivot_df.columns.get_level_values('Model').unique() # Models actually present

        for model_name in models_in_pivot:
             mean_col = (metric, 'mean', model_name)
             std_col = (metric, 'std', model_name)

             if mean_col in pivot_df.columns and std_col in pivot_df.columns:
                 # Format only if both mean and std are present and not NaN
                 formatted = pivot_df[mean_col].map('{:.4f}'.format).str.cat(
                               pivot_df[std_col].map('{:.4f}'.format), sep=' ± '
                           ).where(pivot_df[mean_col].notna() & pivot_df[std_col].notna(), "N/A") # Handle NaN after agg
             elif mean_col in pivot_df.columns:
                  # Format mean only if std is missing/NaN
                  formatted = pivot_df[mean_col].map('{:.4f}'.format).where(pivot_df[mean_col].notna(), "N/A") + ' ± nan'
             else:
                  formatted = pd.Series("N/A", index=pivot_df.index) # Metric not found for this model

             formatted.name = (metric, model_name) # Assign multi-level name
             formatted_series_list.append(formatted)

        # Combine formatted series for this metric across all models
        if formatted_series_list:
             final_comparison_cols[metric] = pd.concat(formatted_series_list, axis=1)


    # Get n_runs separately
    n_runs_cols_present = [col for col in pivot_df.columns if col[0] == 'n_runs' and col[1] == ''] # n_runs has empty second level
    if n_runs_cols_present:
        n_runs_df = pivot_df[n_runs_cols_present]
        n_runs_df.columns = n_runs_df.columns.droplevel(1) # Drop the empty level ''
        n_runs_cols = {'n_runs': n_runs_df}


    # Assemble the final table
    if not final_comparison_cols:
        print("Error: No metric columns could be formatted.")
        return None

    # Create MultiIndex columns for the final DataFrame
    metric_order = metrics_to_compare + sorted([m for m in all_metrics_in_pivot if m not in metrics_to_compare]) # Order metrics
    models_order = sorted(processed_model_dfs.keys()) # Sort model names alphabetically

    final_df_cols = pd.MultiIndex.from_product([metric_order, models_order], names=['Metric', 'Model'])
    final_comparison_df = pd.DataFrame(index=pivot_df.index, columns=final_df_cols)

    # Fill the DataFrame with formatted values
    for metric in metric_order:
         if metric in final_comparison_cols:
             metric_data = final_comparison_cols[metric]
             # Ensure columns align - reindex metric_data if necessary
             final_comparison_df[metric] = metric_data.reindex(columns=final_df_cols.get_level_values('Model').unique(), level='Model')

    # Add n_runs as the first level
    if n_runs_cols:
        final_comparison_df = pd.concat(n_runs_cols, axis=1, keys=['Info']).join(final_comparison_df)
        # Rename 'n_runs' column under 'Info'
        final_comparison_df.rename(columns={'n_runs': 'Runs'}, level=1, inplace=True)


    print("\n--- Final Model Comparison (Mean ± Std Dev) ---")
    print(final_comparison_df.to_string())

    return final_comparison_df


import pandas as pd
import numpy as np
from typing import Dict, Tuple, Optional, List

# Assume compute_evaluation_metrics is defined elsewhere or imported

def aggregate_and_compare_model_results(
    bgae_results_list: List[pd.DataFrame],
    if_results_list: List[pd.DataFrame],
    oddball_results_list: List[pd.DataFrame],
    dominant_results_list: List[pd.DataFrame],
    metrics_to_compare: List[str] = ['AUROC', 'AP', 'Best F1'],
    k_list_eval: List[int] = [50, 100, 200, 500], # K values used during evaluation
    splits_to_compare: List[str] = ['test'], # Focus on test set by default
    sort_elements_by: Optional[List[str]] = None # Optional order for elements in index
    ) -> Optional[pd.DataFrame]:
    """
    Aggregates results from multiple runs for different models and creates
    a comparison DataFrame showing mean ± std dev for specified metrics.

    Args:
        bgae_results_list: List of summary DataFrames from BGAE runs.
        if_results_list: List of summary DataFrames from Isolation Forest runs.
        oddball_results_list: List of summary DataFrames from OddBall runs.
        dominant_results_list: List of summary DataFrames from DOMINANT runs.
        metrics_to_compare: List of primary performance metrics to show in the table.
        k_list_eval: List of K values used, to include relevant P@K/R@K metrics.
        splits_to_compare: List of data splits to include in the comparison (e.g., ['test'], ['val', 'test']).
        sort_elements_by (Optional[List[str]]): Specific order for rows based on 'Element' column.

    Returns:
        Optional[pd.DataFrame]: A DataFrame comparing models, indexed by Split and Element,
                                 with columns for each model showing 'mean ± std' metrics,
                                 or None if aggregation fails.
    """
    print("\n--- Aggregating and Comparing Model Performances ---")

    model_data_map = {
        "BGAE": bgae_results_list,
        "IsolationForest": if_results_list,
        "OddBall": oddball_results_list,
        "DOMINANT": dominant_results_list
    }

    # --- 1. Combine all results into a single DataFrame ---
    all_results_list = []
    for model_name, results_list in model_data_map.items():
        if results_list:
            try:
                model_df = pd.concat(results_list, ignore_index=True)
                # Ensure essential columns exist, add 'Model' column
                if not model_df.empty:
                    model_df['Model'] = model_name
                    all_results_list.append(model_df)
            except Exception as e:
                print(f"Error concatenating results for {model_name}: {e}")

    if not all_results_list:
        print("Error: No valid result lists provided or concatenation failed.")
        return None

    combined_df = pd.concat(all_results_list, ignore_index=True)

    # --- 2. Filter and Clean ---
    filtered_df = combined_df[combined_df['Split'].isin(splits_to_compare)].copy()

    metrics_to_check = ['AUROC', 'AP', 'Best F1'] # Core metrics for filtering failed runs
    cols_exist = all(m in filtered_df.columns for m in metrics_to_check)
    if cols_exist:
        successful_df = filtered_df.dropna(subset=metrics_to_check, how='any')
    else:
        print("Warning: Core metric columns missing. Filtering might be incomplete.")
        successful_df = filtered_df

    if successful_df.empty:
        print(f"Warning: No successful runs found for models in specified splits: {splits_to_compare}.")
        return pd.DataFrame() # Return empty DataFrame

    print(f"Aggregating results from {successful_df['seed'].nunique()} successful runs across models...")

    # --- 3. Calculate Aggregates (Mean, Std, Count) ---
    # Define all metrics we might want to aggregate
    all_metrics_possible = metrics_to_compare + \
                           [f'{p}@{k}' for k in k_list_eval for p in ['Precision', 'Recall']]
    # Filter list to only those metrics actually present in the successful data
    metrics_present = [m for m in all_metrics_possible if m in successful_df.columns]

    if not metrics_present:
         print("Error: None of the specified metrics_to_compare or P@K/R@K metrics found in the data.")
         return None

    try:
        # Group by Split, Element, and Model then aggregate
        grouped = successful_df.groupby(['Split', 'Element', 'Model'])
        aggregated_means = grouped[metrics_present].mean()
        aggregated_stds = grouped[metrics_present].std()
        aggregated_counts = grouped['seed'].nunique().rename('n_runs') # Count distinct runs per group

        # Combine mean, std, and counts
        aggregated_df = aggregated_means.join(aggregated_stds, lsuffix='_mean', rsuffix='_std')
        aggregated_df = aggregated_df.join(aggregated_counts)

    except Exception as e:
        print(f"Error during aggregation: {e}")
        return None

    # --- 4. Format the Comparison Table ---
    final_df_list = []
    # Use index from aggregated_df to handle multi-level index easily
    for idx, row in aggregated_df.iterrows():
        split, element, model = idx # Unpack the index levels
        formatted_row = {'Split': split, 'Element': element, 'Model': model, 'Runs': int(row['n_runs'])}
        for metric in metrics_present:
            mean_val = row.get(f'{metric}_mean', np.nan)
            std_val = row.get(f'{metric}_std', np.nan)

            if pd.notna(mean_val) and pd.notna(std_val):
                formatted_row[metric] = f"{mean_val:.4f} ± {std_val:.4f}"
            elif pd.notna(mean_val):
                formatted_row[metric] = f"{mean_val:.4f} ± nan"
            else:
                formatted_row[metric] = "N/A" # Or np.nan if preferred
        final_df_list.append(formatted_row)

    if not final_df_list:
        print("Error: No data after formatting.")
        return pd.DataFrame()

    final_comparison_df_long = pd.DataFrame(final_df_list)

    # Pivot to get models as columns
    try:
        # Define the order of metrics for columns
        column_order = ['Runs'] + metrics_present
        pivot_comparison_df = final_comparison_df_long.pivot_table(
            index=['Split', 'Element'],
            columns='Model',
            values=column_order # Specify columns to pivot as values
        )
        # Reorder model columns alphabetically for consistency
        pivot_comparison_df = pivot_comparison_df.reindex(columns=sorted(pivot_comparison_df.columns.levels[1]), level='Model')
        # Reorder metric level for readability
        pivot_comparison_df = pivot_comparison_df.reindex(columns=column_order, level=0)


    except Exception as e:
        print(f"Error pivoting final table: {e}")
        print("Returning table in long format instead.")
        # Set index for long format return
        final_comparison_df_long = final_comparison_df_long.set_index(['Split', 'Element', 'Model'])
        return final_comparison_df_long # Return long format as fallback


    # --- Optional: Sort index by custom element order ---
    if sort_elements_by and isinstance(pivot_comparison_df.index, pd.MultiIndex):
         try:
             # Sort by Split first, then by custom Element order
             pivot_comparison_df = pivot_comparison_df.sort_index(
                 level='Element',
                 key=lambda index: index.map({elem: i for i, elem in enumerate(sort_elements_by)}),
                 sort_remaining=True # Sorts by Split automatically after sorting by Element key
             )
         except Exception as e:
              print(f"Warning: Could not sort by custom element order: {e}")


    print("\n--- Final Model Comparison (Mean ± Std Dev) ---")
    # Configure pandas display options for better table printing
    pd.set_option('display.max_rows', 500)
    pd.set_option('display.max_columns', len(pivot_comparison_df.columns) + 1) # Adjust based on columns
    pd.set_option('display.width', 200) # Adjust width as needed
    pd.set_option('display.precision', 4) # Set default float precision

    print(pivot_comparison_df.to_string())

    # Reset display options if desired
    # pd.reset_option('all')

    return pivot_comparison_df

# --- Example Usage ---
# Assume all_summary_dfs, all_if_summaries, all_oddball_summaries, all_dominant_summaries
# are lists of DataFrames populated from the previous iteration loops.

# Example element order (adjust based on your actual 'Element' column values)
# element_order = [
#     'Node (provider)',
#     'Node (member)',
#     "Edge ('provider', 'to', 'member')" # Example edge element name
# ]

# comparison_table = aggregate_and_compare_model_results(
#     bgae_results_list = all_summary_dfs,
#     if_results_list = all_if_summaries,
#     oddball_results_list = all_oddball_summaries,
#     dominant_results_list = all_dominant_summaries,
#     metrics_to_compare = ['AUROC', 'AP', 'Best F1'], # Primary metrics
#     k_list_eval = [50, 100, 200, 500], # K values used
#     splits_to_compare = ['test'], # Focus on test set
#     sort_elements_by = element_order # Optional sorting
# )

# if comparison_table is not None:
#      print("\nComparison Table Created:")
#      # print(comparison_table.to_markdown()) # For easy viewing if needed

#      # Optional: Save to CSV
#      # save_dir = "final_evaluation_results_iter"
#      # if save_dir:
#      #     os.makedirs(save_dir, exist_ok=True)
#      #     comparison_table.to_csv(os.path.join(save_dir, "model_comparison_summary.csv"))

comparison_table = aggregate_and_compare_model_results(
     bgae_results_list = all_summary_dfs,
     if_results_list = all_if_summaries,
     oddball_results_list = all_oddball_summaries,
     dominant_results_list = all_dominant_summaries,
     metrics_to_compare = ['AUROC', 'AP', 'Best F1'], # Choose primary metrics
     k_list_eval = [50, 100, 200, 500], # K values used
     splits_to_compare = ['test'] # Focus on test set
 )

# if comparison_table is not None:
#      print("\nComparison Table Created:")
#      print(comparison_table.to_markdown()) # Print markdown for easy viewing

#      # Optional: Save to CSV
#      # save_dir = "final_evaluation_results_iter"
#      # if save_dir:
#      #     os.makedirs(save_dir, exist_ok=True)
#      #     comparison_table.to_csv(os.path.join(save_dir, "model_comparison_summary.csv"))

In [None]:
comparison_table

In [None]:
import pandas as pd
import os
import pickle
from typing import Dict, List, Optional

def save_iteration_results(
    save_dir: str,
    bgae_results_list: List[pd.DataFrame],
    if_results_list: List[pd.DataFrame],
    oddball_results_list: List[pd.DataFrame],
    dominant_results_list: List[pd.DataFrame],
    aggregated_comparison_df: Optional[pd.DataFrame] = None, # Optional final aggregated table
    aggregated_bgae_summary: Optional[pd.DataFrame] = None, # Optional aggregated specific model tables
    aggregated_if_summary: Optional[pd.DataFrame] = None,
    aggregated_oddball_summary: Optional[pd.DataFrame] = None,
    aggregated_dominant_summary: Optional[pd.DataFrame] = None,
    ) -> None:
    """
    Saves the raw iteration results and optionally aggregated results as pickle files.

    Args:
        save_dir (str): The directory path where the pickle files will be saved.
        bgae_results_list: List of summary DataFrames from BGAE runs.
        if_results_list: List of summary DataFrames from Isolation Forest runs.
        oddball_results_list: List of summary DataFrames from OddBall runs.
        dominant_results_list: List of summary DataFrames from DOMINANT runs.
        aggregated_comparison_df (Optional[pd.DataFrame]): The final comparison table.
        aggregated_bgae_summary (Optional[pd.DataFrame]): Aggregated results for BGAE.
        aggregated_if_summary (Optional[pd.DataFrame]): Aggregated results for IF.
        aggregated_oddball_summary (Optional[pd.DataFrame]): Aggregated results for OddBall.
        aggregated_dominant_summary (Optional[pd.DataFrame]): Aggregated results for DOMINANT.
    """
    print(f"\n--- Saving Iteration Results to Directory: {save_dir} ---")

    # --- Ensure Save Directory Exists ---
    try:
        os.makedirs(save_dir, exist_ok=True)
    except OSError as e:
        print(f"Error: Could not create save directory '{save_dir}'. Cannot save results. Error: {e}")
        return

    # --- Data to Save (Map model name to list of DataFrames) ---
    data_to_save = {
        "bgae_iterations": bgae_results_list,
        "if_iterations": if_results_list,
        "oddball_iterations": oddball_results_list,
        "dominant_iterations": dominant_results_list,
        "aggregated_comparison": aggregated_comparison_df, # Can be None
        "aggregated_bgae": aggregated_bgae_summary,         # Can be None
        "aggregated_if": aggregated_if_summary,           # Can be None
        "aggregated_oddball": aggregated_oddball_summary,   # Can be None
        "aggregated_dominant": aggregated_dominant_summary # Can be None
    }

    # --- Save Each Item as a Pickle File ---
    for filename_base, data_object in data_to_save.items():
        if data_object is None or (isinstance(data_object, list) and not data_object):
            # print(f"Skipping saving '{filename_base}.pkl': No data provided.")
            continue # Skip if data is None or an empty list

        file_path = os.path.join(save_dir, f"{filename_base}.pkl")
        try:
            with open(file_path, 'wb') as f:
                pickle.dump(data_object, f)
            print(f"Successfully saved: {file_path}")
        except Exception as e:
            print(f"Error saving '{file_path}': {e}")

    print("--- Results Saving Attempt Finished ---")


# --- Example Usage ---

# Assume you have run the iterations and aggregation, resulting in:
# all_summary_dfs (list of BGAE DFs)
# all_if_summaries (list of IF DFs)
# all_oddball_summaries (list of OddBall DFs)
# all_dominant_summaries (list of DOMINANT DFs)
# comparison_table (final aggregated comparison DF from previous function)
# aggregated_summary (example: aggregated results for BGAE only, if calculated separately)

# Example data (replace with your actual variables)
# all_summary_dfs = [pd.DataFrame({'A': [1,2]}), pd.DataFrame({'A': [3,4]})]
# all_if_summaries = [pd.DataFrame({'B': [5,6]})]
# all_oddball_summaries = [] # Example empty list
# all_dominant_summaries = [pd.DataFrame({'D': [7,8]})]
# comparison_table = pd.DataFrame({'Model': ['BGAE', 'IF'], 'Metric': [0.8, 0.7]}) # Example aggregated
# aggregated_bgae_summary = pd.DataFrame({'Metric_mean': [0.8], 'Metric_std': [0.05]}) # Example

#Define the directory to save files
save_directory = "model_evaluation_results_pickle"

#Call the save function
save_iteration_results(
    save_dir=save_directory,
    bgae_results_list=all_summary_dfs,
    if_results_list=all_if_summaries,
    oddball_results_list=all_oddball_summaries,
    dominant_results_list=all_dominant_summaries,
    aggregated_comparison_df=comparison_table,
    aggregated_bgae_summary=all_anomaly_type_dfs # Pass other aggregated DFs if you have them
    # Pass None for aggregated DFs you haven't calculated or don't need to save separately
)

save_directory = "model_evaluation_results_pickle"
file_to_load = os.path.join(save_directory, "bgae_iterations.pkl")
if os.path.exists(file_to_load):
    with open(file_to_load, 'rb') as f:
        loaded_bgae_iterations = pickle.load(f)
    print(f"\nLoaded {len(loaded_bgae_iterations)} BGAE iteration results.")

In [None]:
loaded_bgae_iterations

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, Tuple, Optional, List
import os
import math # For ceiling function

# Set a suitable style for report plots
sns.set_theme(style="whitegrid", palette="deep", font_scale=1.1)

def plot_metric_comparison_with_std(
    aggregated_summary_df: pd.DataFrame, # DataFrame with MultiIndex columns ('Metric', 'Stat', 'Model') or flattened ('Metric_Stat')
    models_to_plot: List[str], # List of models to include (e.g., ['BGAE', 'IsolationForest'])
    metrics_to_plot: List[str] = ['AUROC', 'AP', 'Best F1'],
    splits_to_plot: List[str] = ['test', 'train', 'val'], # Order matters for bars
    elements_to_plot: Optional[List[str]] = None, # Specific elements/order for x-axis
    plot_figsize: Tuple[int, int] = (20, 5.5), # Adjusted width for 3 plots
    save_path: Optional[str] = None
    ) -> None:
    """
    Generates a grouped bar plot comparing specified metrics across models, splits,
    and element types, including standard deviation as error bars.
    Replicates the structure of the user-provided plot example.

    Args:
        aggregated_summary_df (pd.DataFrame): DataFrame containing aggregated mean/std metrics.
                                              Expects multi-index ('Split', 'Element') and columns
                                              structured like ('Metric', 'Stat', 'Model') or flattened ('Metric_Stat').
        models_to_plot (List[str]): List of model names (column level 'Model') to include.
        metrics_to_plot (List[str]): List of metrics (column level 'Metric') to create subplots for.
        splits_to_plot (List[str]): List of splits to group bars by (order determines bar order).
        elements_to_plot (Optional[List[str]]): Ordered list of elements for the x-axis.
                                                 If None, uses unique elements found in the data.
        plot_figsize (Tuple[int, int]): Overall figure size.
        save_path (Optional[str]): If provided, saves the figure to this path.
    """
    print("\n--- Generating Metric Comparison Plot with Std Dev ---")

    if aggregated_summary_df is None or aggregated_summary_df.empty:
        print("Error: Aggregated summary DataFrame is empty. Cannot generate plot.")
        return

    # --- Data Preparation ---
    # Ensure index is ['Split', 'Element'] if not already
    if not isinstance(aggregated_summary_df.index, pd.MultiIndex) or \
       aggregated_summary_df.index.names != ['Split', 'Element']:
        try:
            aggregated_summary_df = aggregated_summary_df.set_index(['Split', 'Element'])
        except KeyError:
            print("Error: DataFrame must contain 'Split' and 'Element' columns to set as index.")
            return

    # Define element order for x-axis
    if elements_to_plot is None:
        elements_order = sorted(aggregated_summary_df.index.get_level_values('Element').unique())
    else:
        # Filter data to only include specified elements and maintain order
        elements_order = [e for e in elements_to_plot if e in aggregated_summary_df.index.get_level_values('Element')]
        if not elements_order:
            print(f"Error: None of the specified elements_to_plot found in the data.")
            return
        # Keep only rows matching the desired elements
        aggregated_summary_df = aggregated_summary_df[aggregated_summary_df.index.get_level_values('Element').isin(elements_order)]
        # Reindex to enforce the desired order
        current_splits = aggregated_summary_df.index.get_level_values('Split').unique()
        new_index = pd.MultiIndex.from_product([current_splits, elements_order], names=['Split', 'Element'])
        aggregated_summary_df = aggregated_summary_df.reindex(new_index)


    # Define split order for bars within groups
    splits_order = [s for s in splits_to_plot if s in aggregated_summary_df.index.get_level_values('Split')]
    if not splits_order:
        print(f"Error: None of the specified splits_to_plot found in the data.")
        return


    # --- Plotting Setup ---
    n_metrics = len(metrics_to_plot)
    n_splits = len(splits_order)
    n_elements = len(elements_order)

    fig, axes = plt.subplots(1, n_metrics, figsize=plot_figsize, sharey=True)
    if n_metrics == 1: axes = [axes] # Ensure iterable

    bar_width = 0.8 / n_splits # Adjust bar width based on number of splits
    group_positions = np.arange(n_elements) # Center positions for each element group

    # Use a predefined palette or generate one
    palette = sns.color_palette("deep", n_splits)
    split_color_map = dict(zip(splits_order, palette))

    # --- Create Plots ---
    for i, metric in enumerate(metrics_to_plot):
        ax = axes[i]
        ax.set_title(metric)
        ax.set_ylabel("Score" if i == 0 else "") # Label only first y-axis
        ax.set_xticks(group_positions)
        ax.set_xticklabels(elements_order, rotation=0) # Use the determined element order
        ax.grid(True, axis='y', linestyle=':', alpha=0.7)
        ax.axhline(0, color='grey', linewidth=0.8) # Line at y=0

        max_y = 0 # Track max y for ylim adjustment

        # Loop through splits to plot grouped bars
        for j, split in enumerate(splits_order):
            # Calculate offset for this split's bar within the group
            offset = (j - (n_splits - 1) / 2) * bar_width

            # Extract mean and std for this metric, split, and model(s)
            means = []
            stds = []
            bar_positions = group_positions + offset

            for element in elements_order:
                 mean_val = np.nan
                 std_val = np.nan
                 # Check if the specific index exists
                 if (split, element) in aggregated_summary_df.index:
                      # Iterate through models to find the metric
                      for model_name in models_to_plot:
                           mean_col = (metric, 'mean', model_name) if isinstance(aggregated_summary_df.columns, pd.MultiIndex) else f'{metric}_mean'
                           std_col = (metric, 'std', model_name) if isinstance(aggregated_summary_df.columns, pd.MultiIndex) else f'{metric}_std'

                           # Need to handle cases where stats might be missing for a specific model
                           # This assumes we plot the *first* model found in models_to_plot that has data
                           # A better approach might be needed if comparing multiple models side-by-side on the *same* plot
                           if mean_col in aggregated_summary_df.columns:
                                mean_val = aggregated_summary_df.loc[(split, element), mean_col]
                                if std_col in aggregated_summary_df.columns:
                                     std_val = aggregated_summary_df.loc[(split, element), std_col]
                                else: std_val = 0 # Default std if missing
                                break # Found data for this metric/element/split for *a* model

                 means.append(mean_val if pd.notna(mean_val) else 0) # Plot 0 if NaN
                 stds.append(std_val if pd.notna(std_val) else 0)   # Plot 0 error if NaN

            # Plot bars and error bars for this split
            ax.bar(bar_positions, means, bar_width, label=split, color=split_color_map[split],
                   edgecolor='grey', linewidth=0.5)
            ax.errorbar(bar_positions, means, yerr=stds, fmt='none', ecolor='black',
                        capsize=3, elinewidth=1, capthick=1)

            # Update max y value encountered
            current_max = np.nanmax(np.array(means) + np.array(stds))
            if pd.notna(current_max) and current_max > max_y:
                max_y = current_max

        ax.legend(title="Split")

    # --- Final Touches ---
    # Adjust ylim for all subplots consistently
    common_ylim_top = math.ceil(max_y * 11) / 10.0 # Add 10% buffer and ceil
    for ax in axes:
        ax.set_ylim(bottom=-0.05, top=min(1.05, common_ylim_top)) # Cap at 1.05 for typical metrics

    fig.suptitle("Metric Comparison Across Splits", fontsize=16, y=1.05) # Adjust title position
    plt.tight_layout(rect=[0, 0, 1, 1]) # Adjust layout

    # Save plot if directory specified
    if save_path:
        try:
            save_dir = os.path.dirname(save_path)
            if save_dir: os.makedirs(save_dir, exist_ok=True)
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Comparison plot saved to {save_path}")
        except Exception as e:
            print(f"Error saving comparison plot: {e}")

    plt.show()


# --- Example Usage ---
# Assume 'aggregated_summary' is the DataFrame returned by the previous aggregation code
# It should have ('Split', 'Element') as index and MultiIndex columns like ('AUROC','mean','BGAE'), ('AUROC','std','BGAE'), etc.

# Example DataFrame structure (replace with your actual aggregated_summary)
# idx = pd.MultiIndex.from_product([['test', 'train', 'val'],
#                                   ["Edge ('provider', 'to', 'member')", "Node (member)", "Node (provider)"]],
#                                  names=['Split', 'Element'])
# cols = pd.MultiIndex.from_product([['AUROC', 'AP', 'Best F1'], ['mean', 'std'], ['BGAE']], names=['Metric', 'Stat', 'Model'])
# aggregated_summary_example = pd.DataFrame(np.random.rand(9, 6) * 0.6 + 0.2, index=idx, columns=cols)
# aggregated_summary_example[('n_runs', '', 'BGAE')] = 3 # Add n_runs column

# Define the exact element names as they appear in your DataFrame index
element_order_example = [
     "Edge ('provider', 'to', 'member')", # Match exact string from your data
     "Node (member)",
     "Node (provider)"
]

plot_metric_comparison_with_std(
     aggregated_summary_df = final_summary_df, # Use your actual aggregated DataFrame
     models_to_plot = ['BGAE'], # Plot only BGAE as in the example figure
     metrics_to_plot = ['AUROC', 'AP', 'Best F1'],
     splits_to_plot = ['test', 'train', 'val'], # Order bars: test, train, val
     elements_to_plot = element_order_example, # Use the defined order for x-axis
     save_path = "report_plots/metric_comparison_splits_with_std.png" # Example save path
)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, Tuple, Optional, List
import os
import math

# Set a suitable style for report plots
sns.set_theme(style="whitegrid", palette="deep", font_scale=1.1)

def plot_metric_comparison_with_std(
    aggregated_summary_df: pd.DataFrame, # DataFrame with MultiIndex columns ('Metric', 'Stat') or flattened ('Metric_Stat')
    metrics_to_plot: List[str] = ['AUROC', 'AP', 'Best F1'],
    splits_to_plot: List[str] = ['test', 'train', 'val'], # Order matters for bars
    elements_to_plot: Optional[List[str]] = None, # Specific elements/order for x-axis
    plot_figsize: Tuple[int, int] = (20, 5.5),
    save_path: Optional[str] = None
    ) -> None:
    """
    Generates a grouped bar plot comparing specified metrics across splits
    and element types, including standard deviation as error bars.
    Assumes input DataFrame index=['Split', 'Element'] and columns are MultiIndex
    like ('Metric', 'Stat') or flattened like 'Metric_Stat'.

    Args:
        aggregated_summary_df (pd.DataFrame): Aggregated mean/std metrics.
        metrics_to_plot (List[str]): Metrics to create subplots for.
        splits_to_plot (List[str]): Splits to group bars by (order determines bar order).
        elements_to_plot (Optional[List[str]]): Ordered list of elements for the x-axis.
        plot_figsize (Tuple[int, int]): Overall figure size.
        save_path (Optional[str]): If provided, saves the figure to this path.
    """
    print("\n--- Generating Metric Comparison Plot with Std Dev ---")

    if aggregated_summary_df is None or aggregated_summary_df.empty:
        print("Error: Aggregated summary DataFrame is empty. Cannot generate plot.")
        return

    # --- Data Preparation ---
    # Ensure index is ['Split', 'Element']
    if not isinstance(aggregated_summary_df.index, pd.MultiIndex) or \
       aggregated_summary_df.index.names != ['Split', 'Element']:
        try:
            # Check if 'Split' and 'Element' exist as columns first
            if 'Split' in aggregated_summary_df.columns and 'Element' in aggregated_summary_df.columns:
                 aggregated_summary_df = aggregated_summary_df.set_index(['Split', 'Element'])
            else:
                 # Try resetting index if they might already be the index but unnamed
                 temp_df = aggregated_summary_df.reset_index()
                 if 'Split' in temp_df.columns and 'Element' in temp_df.columns:
                      aggregated_summary_df = temp_df.set_index(['Split', 'Element'])
                 else:
                      raise ValueError("DataFrame must have 'Split' and 'Element' as index or columns.")
        except Exception as e:
            print(f"Error setting index: {e}")
            return

    # Define element order for x-axis
    available_elements = aggregated_summary_df.index.get_level_values('Element').unique()
    if elements_to_plot is None:
        elements_order = sorted(available_elements)
    else:
        elements_order = [e for e in elements_to_plot if e in available_elements]
        if not elements_order:
            print(f"Error: None of the specified elements_to_plot found in the data: {available_elements}")
            return
        # Filter and reindex to enforce order
        current_splits = aggregated_summary_df.index.get_level_values('Split').unique()
        new_index = pd.MultiIndex.from_product([current_splits, elements_order], names=['Split', 'Element'])
        aggregated_summary_df = aggregated_summary_df.reindex(new_index)

    # Define split order for bars
    available_splits = aggregated_summary_df.index.get_level_values('Split').unique()
    splits_order = [s for s in splits_to_plot if s in available_splits]
    if not splits_order:
        print(f"Error: None of the specified splits_to_plot found in the data: {available_splits}")
        return

    # --- Plotting Setup ---
    n_metrics = len(metrics_to_plot)
    n_splits = len(splits_order)
    n_elements = len(elements_order)

    fig, axes = plt.subplots(1, n_metrics, figsize=plot_figsize, sharey=True)
    if n_metrics == 1: axes = [axes] # Ensure iterable

    bar_width = 0.8 / n_splits
    group_positions = np.arange(n_elements)
    palette = sns.color_palette("deep", n_splits)
    split_color_map = dict(zip(splits_order, palette))

    # --- Create Plots ---
    max_y_overall = 0 # Track max y across all plots for consistent ylim

    for i, metric in enumerate(metrics_to_plot):
        ax = axes[i]
        ax.set_title(metric)
        ax.set_ylabel("Score" if i == 0 else "")
        ax.set_xticks(group_positions)
        ax.set_xticklabels(elements_order, rotation=0, ha='center') # Center align x-labels
        ax.grid(True, axis='y', linestyle=':', alpha=0.7)
        ax.axhline(0, color='grey', linewidth=0.8)

        plot_max_y = 0 # Track max y for this subplot

        # Check if metric columns exist (either flattened or MultiIndex)
        mean_col_flat = f'{metric}_mean'
        std_col_flat = f'{metric}_std'
        mean_col_multi = (metric, 'mean')
        std_col_multi = (metric, 'std')

        has_flat_cols = mean_col_flat in aggregated_summary_df.columns
        has_multi_cols = mean_col_multi in aggregated_summary_df.columns

        if not has_flat_cols and not has_multi_cols:
             print(f"Warning: Metric '{metric}' mean column not found. Skipping subplot.")
             ax.text(0.5, 0.5, f"Data for\n'{metric}'\nnot found", ha='center', va='center', transform=ax.transAxes)
             continue


        for j, split in enumerate(splits_order):
            offset = (j - (n_splits - 1) / 2) * bar_width
            means = []
            stds = []
            bar_positions = group_positions + offset

            for element in elements_order:
                mean_val = np.nan
                std_val = 0.0 # Default std dev to 0 if not found or NaN

                idx = (split, element)
                if idx in aggregated_summary_df.index:
                    row = aggregated_summary_df.loc[idx]
                    # Get mean value
                    if has_multi_cols and mean_col_multi in row.index:
                        mean_val = row[mean_col_multi]
                    elif has_flat_cols and mean_col_flat in row.index:
                        mean_val = row[mean_col_flat]

                    # Get std value
                    if has_multi_cols and std_col_multi in row.index:
                        std_val_raw = row[std_col_multi]
                        if pd.notna(std_val_raw): std_val = std_val_raw
                    elif has_flat_cols and std_col_flat in row.index:
                         std_val_raw = row[std_col_flat]
                         if pd.notna(std_val_raw): std_val = std_val_raw

                means.append(mean_val if pd.notna(mean_val) else 0)
                stds.append(std_val) # Already defaulted to 0


            # Plot bars and error bars
            valid_means = np.array(means)
            valid_stds = np.array(stds)
            ax.bar(bar_positions, valid_means, bar_width, label=split, color=split_color_map[split],
                   edgecolor='grey', linewidth=0.5)
            ax.errorbar(bar_positions, valid_means, yerr=valid_stds, fmt='none', ecolor='black',
                        capsize=3, elinewidth=1, capthick=1)

            # Update max y for this subplot
            current_max = np.nanmax(valid_means + valid_stds)
            if pd.notna(current_max) and current_max > plot_max_y:
                plot_max_y = current_max

        ax.legend(title="Split", loc='best') # Adjust legend location
        if plot_max_y > max_y_overall: # Update overall max
             max_y_overall = plot_max_y

    # --- Final Touches ---
    common_ylim_top = math.ceil(max_y_overall * 11) / 10.0 # Add 10% buffer and ceil
    for ax in axes:
        ax.set_ylim(bottom=min(-0.05, ax.get_ylim()[0]), # Keep existing bottom if lower than -0.05
                    top=min(1.05, common_ylim_top)) # Cap at 1.05

    fig.suptitle("Metric Comparison Across Splits", fontsize=16, y=1.03) # Adjust title y position
    plt.tight_layout(rect=[0, 0, 1, 1])

    # Save plot
    if save_path:
        try:
            save_dir = os.path.dirname(save_path)
            if save_dir: os.makedirs(save_dir, exist_ok=True)
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Comparison plot saved to {save_path}")
        except Exception as e:
            print(f"Error saving comparison plot: {e}")

    plt.show()


# Or create a dummy DataFrame with flattened columns
# aggregated_summary_flat = aggregated_summary_example.copy()
# aggregated_summary_flat.columns = ['_'.join(map(str, col)).strip('_') for col in aggregated_summary_flat.columns.values]
# aggregated_summary_flat = aggregated_summary_flat.reset_index()

# Define the exact element names as they appear in your DataFrame index
element_order_example = [
     "Edge", # Match exact string
     "Node (member)",
     "Node (provider)"
]

plot_metric_comparison_with_std(
     aggregated_summary_df = df_split, # Use your actual aggregated DataFrame
     #models_to_plot=['BGAE'], # We removed model level from columns now
     metrics_to_plot = ['AUROC', 'AP', 'Best F1'],
     splits_to_plot = ['test', 'train', 'val'],
     elements_to_plot = element_order_example,
     save_path = "report_plots/metric_comparison_splits_with_std_v2.png"
)

In [372]:
element_rename = {
     "Edge ('provider', 'to', 'member')": "Edge"
}
df_split = final_summary_df.rename(index=element_rename)

In [None]:
df_split


In [None]:
import pandas as pd
import numpy as np
from typing import Dict, Tuple, Optional, List

def calculate_average_metric_across_elements(
    results_list: List[pd.DataFrame],
    metric: str = 'AP',
    elements_to_average: List[str] = ['Node (member)', 'Node (provider)', 'Edge'], # Specify elements to average
    split: str = 'test'
    ) -> Tuple[Optional[float], Optional[float]]:
    """
    Calculates the mean and standard deviation of a metric averaged across
    specified element types over multiple iteration runs.

    Args:
        results_list (List[pd.DataFrame]): List where each DataFrame contains the results
                                            (including 'Split', 'Element', metric column, 'seed'/'iteration')
                                            from one iteration run for a specific model.
        metric (str): The performance metric column to average (e.g., 'AP', 'AUROC').
        elements_to_average (List[str]): A list of the 'Element' names to include in the average.
                                         Exact strings must match those in the DataFrame.
        split (str): The data split ('train', 'val', 'test') to focus on.

    Returns:
        Tuple[Optional[float], Optional[float]]:
            - Mean of the average metric across runs (or None if calculation fails).
            - Standard deviation of the average metric across runs (or None if calculation fails).
    """
    if not results_list:
        print("Warning: Input results_list is empty.")
        return None, None

    all_iterations_df = pd.concat(results_list, ignore_index=True)

    # --- Data Validation ---
    required_cols = ['Split', 'Element', metric, 'seed'] # Assume 'seed' identifies runs
    if not all(col in all_iterations_df.columns for col in required_cols):
        print(f"Error: Input DataFrames missing one or more required columns: {required_cols}")
        return None, None

    # Filter for the specified split and elements
    filtered_df = all_iterations_df[
        (all_iterations_df['Split'] == split) &
        (all_iterations_df['Element'].isin(elements_to_average))
    ].copy()

    if filtered_df.empty:
        print(f"Warning: No data found for split '{split}' and elements '{elements_to_average}'.")
        return None, None

    # Check if all desired elements were actually found for each seed
    elements_found = filtered_df['Element'].unique()
    if not all(elem in elements_found for elem in elements_to_average):
         print(f"Warning: Not all requested elements ({elements_to_average}) were found in the filtered data for split '{split}'. Found: {elements_found}")
         # Proceeding with available elements

    # Check if the metric column has valid data
    if filtered_df[metric].isnull().all():
        print(f"Warning: Metric column '{metric}' contains only NaNs for the selected split/elements.")
        return None, None

    # --- Calculation ---
    try:
        # Calculate the average metric *per iteration* across the specified elements
        # Group by 'seed' (or 'iteration'), then calculate the mean of the metric for that seed's elements
        average_metric_per_iteration = filtered_df.groupby('seed')[metric].mean()

        if average_metric_per_iteration.empty:
             print("Warning: Could not calculate average metric per iteration (maybe grouping failed?).")
             return None, None

        # Calculate the overall mean and std dev of these *per-iteration averages*
        final_mean = average_metric_per_iteration.mean()
        # Calculate std dev only if more than one iteration's average was computed
        final_std = average_metric_per_iteration.std() if len(average_metric_per_iteration) > 1 else 0.0

        return float(final_mean), float(final_std)

    except Exception as e:
        print(f"Error during calculation of average metric across elements: {e}")
        return None, None


# --- Example Usage ---
# Assume:
# all_summary_dfs = [...] # List of BGAE summary DFs from each run
# all_if_summaries = [...] # List of IF summary DFs from each run
# all_oddball_summaries = [...]
# all_dominant_summaries = [...]

# Define elements to include in the average (use exact names from your 'Element' column)
# elements = ['Node (member)', 'Node (provider)', "Edge ('provider', 'to', 'member')"] # Example
# If 'Edge' is simply 'Edge' in your df, use that:
elements = ["Edge ('provider', 'to', 'member')"]


print("\n--- Calculating Average AP across Elements (Test Set) ---")

# For BGAE
mean_avg_ap_bgae, std_avg_ap_bgae = calculate_average_metric_across_elements(
    results_list=all_summary_dfs,
    metric='AP',
    elements_to_average=elements,
    split='test'
)
if mean_avg_ap_bgae is not None:
    print(f"BGAE Avg AP (Test): {mean_avg_ap_bgae:.4f} ± {std_avg_ap_bgae:.4f}")

# For Isolation Forest
mean_avg_ap_if, std_avg_ap_if = calculate_average_metric_across_elements(
    results_list=all_if_summaries,
    metric='AP',
    elements_to_average=elements,
    split='test'
)
if mean_avg_ap_if is not None:
    print(f"IF Avg AP (Test):   {mean_avg_ap_if:.4f} ± {std_avg_ap_if:.4f}")

# For OddBall
mean_avg_ap_oddball, std_avg_ap_oddball = calculate_average_metric_across_elements(
    results_list=all_oddball_summaries,
    metric='AP',
    elements_to_average=elements,
    split='test'
)
if mean_avg_ap_oddball is not None:
    print(f"OddBall Avg AP (Test): {mean_avg_ap_oddball:.4f} ± {std_avg_ap_oddball:.4f}")

# For DOMINANT
mean_avg_ap_dominant, std_avg_ap_dominant = calculate_average_metric_across_elements(
    results_list=all_dominant_summaries,
    metric='AP',
    elements_to_average=elements,
    split='test'
)
if mean_avg_ap_dominant is not None:
    print(f"DOMINANT Avg AP (Test): {mean_avg_ap_dominant:.4f} ± {std_avg_ap_dominant:.4f}")

In [None]:
all_if_summaries

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, Tuple, Optional, List
import os
import math
from collections import defaultdict # Added

# Set a suitable style for report plots
sns.set_theme(style="whitegrid", palette="deep", font_scale=1.1)

def plot_pk_rk_comparison_grid(
    results_dict: Dict[str, List[pd.DataFrame]], # Dict: {'ModelName': [df_run1, df_run2,...]}
    k_list_to_plot: List[int] = [50, 100, 200], # K values to plot on x-axis
    split_to_plot: str = 'test',
    figsize: Tuple[int, int] = (14, 8), # Adjusted for 2x2 grid
    save_path: Optional[str] = None,
    model_name_map: Optional[Dict[str, str]] = None, # Optional: {'BGAE': 'Our Model (BGAE)'}
    model_palette: Optional[Dict[str, str]] = None # Optional: {'BGAE': 'blue', 'IsolationForest': 'green'}
    ) -> None:
    """
    Generates a 2x2 grid of plots showing Precision@K and Recall@K separately
    for Member and Provider nodes across different models, including std dev error bars.

    Args:
        results_dict: Dictionary where keys are model names and values are lists
                      of summary DataFrames from multiple runs for that model.
        k_list_to_plot: The K values to show on the x-axis.
        split_to_plot: The data split ('test', 'val', etc.) to plot results for.
        figsize: Overall figure size.
        save_path: Optional path to save the figure.
        model_name_map: Dictionary to map internal model names to legend labels.
        model_palette: Dictionary to map model names to specific plot colors.
    """
    print(f"\n--- Generating 2x2 Precision/Recall@K Comparison Plot ({split_to_plot} Set) ---")

    # --- 1. Aggregate Mean/Std for P@K and R@K for each model ---
    aggregated_data = defaultdict(lambda: defaultdict(dict))
    models_found = list(results_dict.keys())

    for model_name, results_list in results_dict.items():
        if not results_list: continue
        full_df = pd.concat(results_list, ignore_index=True)
        split_df = full_df[full_df['Split'] == split_to_plot].copy()
        if split_df.empty: continue

        metrics_to_agg = [f'{p}@{k}' for k in k_list_to_plot for p in ['Precision', 'Recall']]
        metrics_present = [m for m in metrics_to_agg if m in split_df.columns]
        if not metrics_present: continue

        try:
            grouped = split_df.groupby('Element')[metrics_present]
            means = grouped.mean()
            stds = grouped.std()

            for element in means.index: # Elements like 'Node (member)', 'Node (provider)'
                 aggregated_data[model_name][element] = {}
                 for metric in metrics_present:
                      mean_val = means.loc[element, metric]
                      std_val = stds.loc[element, metric]
                      aggregated_data[model_name][element][metric] = (
                          mean_val if pd.notna(mean_val) else 0,
                          std_val if pd.notna(std_val) else 0
                      )
        except Exception as e:
             print(f"Warning: Could not aggregate P@K/R@K metrics for model '{model_name}': {e}")

    if not aggregated_data:
        print("Error: No aggregated P@K/R@K data available to plot.")
        return

    # --- 2. Plotting Setup ---
    fig, axes = plt.subplots(2, 2, figsize=figsize, sharex=True, sharey=False)
    fig.suptitle(f'Precision and Recall at Top-K Node Anomalies on {split_to_plot.capitalize()} Set', fontsize=16, y=1.03)

    element_map = {'Node (member)': 0, 'Node (provider)': 1}
    metric_map = {'Precision': 0, 'Recall': 1}

    # --- **FIXED COLOR ASSIGNMENT** ---
    # Assign colors definitively *before* the loop
    default_colors = sns.color_palette("deep", len(models_found))
    color_assignment = {}
    if model_palette is None:
        model_palette = {} # Ensure it's a dict

    # Assign colors: Use provided palette first, then defaults
    color_idx = 0
    for model_name in models_found:
        if model_name in model_palette:
            color_assignment[model_name] = model_palette[model_name]
        elif color_idx < len(default_colors):
            color_assignment[model_name] = default_colors[color_idx]
            color_idx += 1
        else:
            color_assignment[model_name] = 'grey' # Fallback grey if more models than default colors

    if model_name_map is None:
        model_name_map = {model: model for model in models_found} # Default to original names
    # --- **END FIXED COLOR ASSIGNMENT** ---

    line_width = 2.0
    grid_style = {'linestyle': ':', 'alpha': 0.6}

    # --- 3. Create Subplots ---
    plotted_models = set() # Track models plotted for legend

    for model_name in models_found: # Iterate through models expected based on input dict
        if model_name not in aggregated_data: continue # Skip if no data aggregated for this model

        # --- **Get Assigned Color** ---
        plot_color = color_assignment.get(model_name, 'grey') # Get pre-assigned color
        legend_label = model_name_map.get(model_name, model_name)
        first_plot_for_model = True # Flag to add model label only once

        for element_name, col_idx in element_map.items():
            if element_name not in aggregated_data.get(model_name, {}): continue # Skip if no data for this element

            element_data = aggregated_data[model_name][element_name]

            for metric_type, row_idx in metric_map.items():
                ax = axes[row_idx, col_idx]

                # Extract means and stds
                means = [element_data.get(f'{metric_type}@{k}', (0, 0))[0] for k in k_list_to_plot]
                stds = [element_data.get(f'{metric_type}@{k}', (0, 0))[1] for k in k_list_to_plot]

                # Plot line with error bars
                current_label = legend_label if first_plot_for_model else '_nolegend_'
                ax.errorbar(k_list_to_plot, means, yerr=stds, label=current_label,
                            color=plot_color, linestyle='-', # Use solid lines now
                            linewidth=line_width, marker='.', markersize=0,
                            capsize=4, elinewidth=1, capthick=1)
                first_plot_for_model = False
                plotted_models.add(model_name)


    # --- 4. Final Touches for Each Axes ---
    max_y_all = 0
    for r in range(2):
        for c in range(2):
            ax = axes[r, c]
            element_name = list(element_map.keys())[list(element_map.values()).index(c)] # Get element name back
            metric_type = list(metric_map.keys())[list(metric_map.values()).index(r)] # Get metric name back

            ax.set_title(f"{metric_type} - {element_name} Anomalies")
            ax.grid(True, **grid_style)
            ax.set_xticks(k_list_to_plot)

            # Set labels only on outer axes
            if c == 0: ax.set_ylabel(f"{metric_type} Score")
            if r == 1: ax.set_xlabel("Number of Top Anomalies (k)")

            # Adjust Y limits (find max across all plots for consistency, or per plot)
            current_ylim = ax.get_ylim()
            max_y_all = max(max_y_all, current_ylim[1])
            ax.set_ylim(bottom=-0.05) # Set bottom limit

    # Apply consistent top y-limit
    common_ylim_top = min(1.05, math.ceil(max_y_all * 11) / 10.0) # Add buffer, cap at 1.05
    for r in range(2):
        for c in range(2):
            axes[r, c].set_ylim(top=common_ylim_top)
    # --- 5. Create Shared Legend ---
    from matplotlib.lines import Line2D
    # --- **Use color_assignment for Legend** ---
    legend_handles = [Line2D([0], [0], color=color_assignment.get(model, 'grey'), lw=line_width,
                            label=model_name_map.get(model, model))
                    for model in models_found if model in plotted_models] # Only include models actually plotted

    # Place legend outside the top-right plot
    axes[0, 1].legend(handles=legend_handles, title="Model",
                    bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)


    # --- 6. Layout & Save ---
    # (Keep this section as it was)
    plt.tight_layout(rect=[0, 0, 0.88, 0.95])
    # ... (saving logic) ...
    plt.show()


# --- Example Usage ---
#Assume results_dict is populated as before:
results_for_plot = {
    "Our Model (BGAE)": all_summary_dfs,
    "Isolation Forest": all_if_summaries,
    "Oddball": all_oddball_summaries,
    "DOMINANT": all_dominant_summaries
}
palette_map = { ... } # Optional color map

plot_pk_rk_comparison_grid(
    results_dict=results_for_plot,
    k_list_to_plot=[50, 100, 200, 500], # K values present in your data
    split_to_plot='test',
    model_palette=palette_map # Optional
    # save_path="report_plots/pk_rk_comparison_grid.png"
)

In [None]:
oddball_summary_dfba

In [None]:
all_anomaly_type_dfs

In [None]:
all_anomaly_type_dfs

In [None]:
import pandas as pd
import numpy as np
from typing import Dict, Tuple, Optional, List
from collections import defaultdict # Added for scatter plot data prep

# --- Assumed: Helper function from previous steps ---
def extract_main_category(tag):
    """Helper function to extract the main category from the tag."""
    if pd.isna(tag): return 'Unknown'
    if isinstance(tag, str):
        if tag.startswith("('"): return 'Edge' # Classify edge types
        if '/' in tag:
            category = tag.split('/')[0]
            if category.lower() == 'structural': return 'Structural'
            if category.lower() == 'attribute': return 'Attribute'
            return category
        elif tag.lower() == 'combined': return 'Combined'
        elif tag.lower() == 'unknown': return 'Unknown'
        else:
             if 'attribute' in tag.lower(): return 'Attribute'
             if 'structural' in tag.lower(): return 'Structural'
             if 'combined' in tag.lower(): return 'Combined'
             return 'Other'
    return 'Unknown'

def format_mean_std(mean_val, std_val, precision=3):
    """Formats mean and std dev into 'mean ± std' string."""
    mean_val_num = pd.to_numeric(mean_val, errors='coerce')
    std_val_num = pd.to_numeric(std_val, errors='coerce')
    if pd.notna(mean_val_num) and pd.notna(std_val_num):
        return f"{mean_val_num:.{precision}f} ± {std_val_num:.{precision}f}"
    elif pd.notna(mean_val_num):
        return f"{mean_val_num:.{precision}f} ± nan"
    else: return "N/A"

# --- Main Analysis Function ---

def generate_anomaly_type_report_tables(
    full_anomaly_type_results_df: pd.DataFrame, # DataFrame with results from ALL runs
    split_name: str = 'test',
    sort_metric: str = 'AP', # Metric for ranking detailed tags
    metrics_to_show: List[str] = ['AUROC', 'AP', 'Best F1'], # Metrics in tables
    n_top_bottom: int = 3 # Number of top/bottom tags to show per element
    ) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], Optional[pd.DataFrame]]:
    """
    Generates DataFrames for report tables summarizing performance by anomaly type.

    Args:
        full_anomaly_type_results_df (pd.DataFrame): DataFrame with per-tag metrics from ALL iterations
                                                     (must include 'Split', 'Element Type', 'Node Type',
                                                     'Anomaly Tag', 'Count', 'Mean Score', 'Median Score',
                                                     'AUROC', 'AP', 'Best F1', 'seed').
        split_name (str): The split to analyze ('train', 'val', or 'test').
        sort_metric (str): Metric used to rank specific tags ('AP', 'AUROC', 'Best F1').
        metrics_to_show (List[str]): List of performance metric columns for summaries.
        n_top_bottom (int): Number of top and bottom performing tags to show per element type.

    Returns:
        Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], Optional[pd.DataFrame]]:
            - df_category_summary: Aggregated metrics (Mean ± Std) by Main Category and Element Type.
            - df_top_bottom_tags: Detailed metrics (Mean ± Std) for top/bottom specific tags.
            - df_for_scatter: Data prepared for the Score vs Performance scatter plot.
            (Returns None, None, None if analysis fails)
    """
    print(f"\n--- Generating Anomaly Type Report Tables for Split: '{split_name}' ---")

    if full_anomaly_type_results_df is None or full_anomaly_type_results_df.empty:
        print("Input DataFrame is empty.")
        return None, None, None

    # --- 1. Preprocessing & Filtering ---
    df = full_anomaly_type_results_df[full_anomaly_type_results_df['Split'] == split_name].copy()
    if df.empty:
        print(f"No data found for split '{split_name}'.")
        return None, None, None

    required_cols = ['Node Type', 'Anomaly Tag', 'Count', 'seed'] + metrics_to_show
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        print(f"Error: Missing required columns in DataFrame: {missing_cols}")
        return None, None, None

    # Add Main Category
    df['Main Category'] = df['Anomaly Tag'].apply(extract_main_category)

    # --- 2. Aggregate Mean & Std per Tag/Element/Node Type ---
    # Group by everything except the run seed/iteration to get per-tag aggregates
    grouping_cols = ['Split', 'Node Type', 'Anomaly Tag', 'Main Category']
    # Define columns to aggregate (metrics + stats needed later)
    cols_to_agg = metrics_to_show + ['Mean Score', 'Median Score', 'Count', 'Proportion (%)']
    # Ensure columns exist before aggregating
    cols_present = [c for c in cols_to_agg if c in df.columns]

    try:
        agg_funcs = {col: ['mean', 'std'] for col in cols_present}
        agg_funcs['seed'] = ['nunique'] # Count runs

        df_agg = df.groupby(grouping_cols, observed=True).agg(agg_funcs)

        # Flatten MultiIndex columns
        df_agg.columns = ['_'.join(map(str, col)).strip('_') for col in df_agg.columns.values]
        df_agg = df_agg.rename(columns={'seed_nunique': 'n_runs',
                                        'Count_mean': 'Avg Count',
                                        'Proportion (%)_mean': 'Avg Proportion (%)'})
        df_agg = df_agg.reset_index() # Make grouping cols regular columns

    except Exception as e:
         print(f"Error during initial aggregation: {e}")
         return None, None, None

    # --- 3. Table 1: Summary by Main Category & Element Type ---
    print("\n--- Generating Table 1: Summary by Main Category & Element ---")
    df_category_summary = None
    try:
        # Group the already aggregated data further by Main Category and Element Type
        category_group = df_agg.groupby(['Main Category', 'Node Type'], observed=True)

        # Sum counts and proportions
        summary_counts = category_group[['Avg Count', 'Avg Proportion (%)']].sum()
        summary_counts = summary_counts.rename(columns={'Avg Count': 'Total Anomalies',
                                                         'Avg Proportion (%)': 'Category Proportion (%)'})

        # Average the performance metrics (mean values)
        metric_mean_cols = [f'{m}_mean' for m in metrics_to_show if f'{m}_mean' in df_agg.columns]
        summary_metrics_mean = category_group[metric_mean_cols].mean()

        # Combine counts and averaged means
        df_category_summary = summary_counts.join(summary_metrics_mean)

        # Format metrics as Mean ± Std for display table
        df_category_summary_display = df_category_summary.copy()
        for metric in metrics_to_show:
             mean_col = f'{metric}_mean'
             std_col = f'{metric}_std'
             # Calculate the mean of the std devs within the category group for an *indicative* variability
             std_mean = category_group[std_col].mean() if std_col in category_group.mean().columns else pd.Series(index=df_category_summary.index, dtype=float)
             # Create formatted string
             df_category_summary_display[metric] = [
                 format_mean_std(df_category_summary_display.loc[idx, mean_col], std_mean.get(idx, np.nan))
                 for idx in df_category_summary_display.index
             ]
             # Drop original mean col after formatting
             if mean_col in df_category_summary_display.columns:
                  df_category_summary_display = df_category_summary_display.drop(columns=[mean_col])


        # Reorder and select columns for display table
        final_cols_cat = ['Total Anomalies', 'Category Proportion (%)'] + metrics_to_show
        df_category_summary_display = df_category_summary_display.reindex(columns=final_cols_cat, fill_value='N/A').round(3)
        print(df_category_summary_display.to_string())

        # Keep numerical version for potential later use
        df_category_summary = df_category_summary.round(4)


    except Exception as e:
        print(f"Error creating category summary table: {e}")


    # --- 4. Table 2: Top & Bottom Performing Specific Tags ---
    print(f"\n--- Generating Table 2: Top/Bottom {n_top_bottom} Tags by {sort_metric} ---")
    df_top_bottom_tags = None
    try:
        if sort_metric not in metrics_to_show:
             print(f"Warning: Sort metric '{sort_metric}' not in metrics_to_show list. Cannot sort detailed table.")
        else:
            sort_col_mean = f'{sort_metric}_mean'
            if sort_col_mean not in df_agg.columns:
                 print(f"Warning: Mean column for sort metric '{sort_metric}' ({sort_col_mean}) not found.")
            else:
                top_bottom_list = []
                # Get top/bottom for each element type separately
                for element in df_agg['Element Type'].unique():
                    df_element = df_agg[df_agg['Element Type'] == element].copy()
                    df_element_sorted = df_element.sort_values(by=sort_col_mean, ascending=False)
                    top_n = df_element_sorted.head(n_top_bottom)
                    bottom_n = df_element_sorted.tail(n_top_bottom)
                    top_bottom_list.extend([top_n, bottom_n])

                if top_bottom_list:
                    df_top_bottom_tags_agg = pd.concat(top_bottom_list).drop_duplicates()

                    # Format metrics as Mean ± Std for display
                    df_top_bottom_tags_display = df_top_bottom_tags_agg.copy()
                    cols_detailed_display = ['Element Type', 'Main Category', 'Anomaly Tag', 'Avg Count', 'Avg Proportion (%)']
                    for metric in metrics_to_show:
                         mean_col = f'{metric}_mean'; std_col = f'{metric}_std'
                         if mean_col in df_top_bottom_tags_display.columns:
                              df_top_bottom_tags_display[metric] = df_top_bottom_tags_display.apply(
                                   lambda row: format_mean_std(row.get(mean_col), row.get(std_col)), axis=1
                              )
                              cols_detailed_display.append(metric)
                              # Drop original mean/std columns after formatting
                              df_top_bottom_tags_display = df_top_bottom_tags_display.drop(columns=[mean_col, f'{metric}_std'], errors='ignore')
                         else:
                              df_top_bottom_tags_display[metric] = "N/A" # Add metric column even if mean was missing
                              cols_detailed_display.append(metric)


                    df_top_bottom_tags_display = df_top_bottom_tags_display[cols_detailed_display].round(3)
                    print(df_top_bottom_tags_display.to_string(index=False))
                    # Store numerical version
                    df_top_bottom_tags = df_top_bottom_tags_agg

                else:
                     print("No data found for top/bottom tags.")


    except Exception as e:
        print(f"Error creating top/bottom tags table: {e}")


    # --- 5. Prepare Data for Scatter Plot (Optional: Figure 2) ---
    print("\n--- Preparing Data for Score vs Performance Scatter Plot ---")
    df_for_scatter = None
    try:
        # Select necessary columns from the aggregated-per-tag DataFrame
        scatter_cols = ['Element Type', 'Main Category', 'Anomaly Tag', 'Avg Count', 'Avg Proportion (%)']
        metrics_for_scatter = ['AP', 'Best F1', 'Mean Score', 'Median Score'] # Choose metrics/stats for axes
        scatter_metrics_present = []

        for metric in metrics_for_scatter:
             mean_col = f'{metric}_mean'
             if mean_col in df_agg.columns:
                  scatter_cols.append(mean_col)
                  scatter_metrics_present.append(metric)

        # Ensure all needed columns exist before selecting
        if all(c in df_agg.columns for c in scatter_cols):
            df_for_scatter = df_agg[scatter_cols].copy()
            # Rename columns for clarity in plot
            rename_map = {f'{m}_mean': m for m in scatter_metrics_present}
            df_for_scatter = df_for_scatter.rename(columns=rename_map)
            print(f"Scatter plot data prepared with {len(df_for_scatter)} unique tags.")
            # print(df_for_scatter.head()) # Print head for verification
        else:
            missing = [c for c in scatter_cols if c not in df_agg.columns]
            print(f"Could not prepare scatter plot data: Missing columns {missing}")

    except Exception as e:
        print(f"Error preparing data for scatter plot: {e}")


    return df_category_summary, df_top_bottom_tags, df_for_scatter

# --- Example Usage ---
# Assume 'final_anomaly_type_df' is the DataFrame returned by concatenating
# results from all iterations (output of previous aggregation step).

if 'final_anomaly_type_df' in locals() and not final_anomaly_type_df.empty:
    # Specify metrics to include in the tables
    metrics = ['AUROC', 'AP', 'Best F1'] # Primary performance metrics
    # Specify metric to sort detailed table by and to plot in category comparison
    sort_display_metric = 'AP'

    cat_summary, top_bottom_summary, scatter_data = generate_anomaly_type_report_tables(
        full_anomaly_type_results_df=final_anomaly_type_df,
        split_name='test',
        sort_metric=sort_display_metric,
        metrics_to_show=metrics,
        n_top_bottom=3
    )

    # Now you can use cat_summary, top_bottom_summary for LaTeX tables
    # and scatter_data for the optional scatter plot.
else:
     print("Input DataFrame `final_anomaly_type_df` is missing or empty.")

In [422]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, Tuple, Optional, List
from collections import defaultdict
import os
import math

# Assume sns.set_theme() is called outside

def extract_main_category(tag):
    if pd.isna(tag): return 'Unknown'
    if isinstance(tag, str):
        # No need to check for edge tuples here if no 'Element Type' exists
        if '/' in tag:
            category = tag.split('/')[0]
            if category.lower() == 'structural': return 'Structural'
            if category.lower() == 'attribute': return 'Attribute'
            return category
        elif tag.lower() == 'combined': return 'Combined'
        elif tag.lower() == 'unknown': return 'Unknown'
        else:
             if 'attribute' in tag.lower(): return 'Attribute'
             if 'structural' in tag.lower(): return 'Structural'
             if 'combined' in tag.lower(): return 'Combined'
             return 'Other'
    return 'Unknown'

def format_mean_std(mean_val, std_val, precision=3):
    mean_val_num = pd.to_numeric(mean_val, errors='coerce')
    std_val_num = pd.to_numeric(std_val, errors='coerce')
    if pd.notna(mean_val_num) and pd.notna(std_val_num):
        return f"{mean_val_num:.{precision}f} ± {std_val_num:.{precision}f}"
    elif pd.notna(mean_val_num):
        return f"{mean_val_num:.{precision}f} ± nan"
    else: return "N/A"

def generate_anomaly_type_report_tables(
    full_anomaly_type_results_df: pd.DataFrame,
    split_name: str = 'test',
    sort_metric: str = 'AP',
    metrics_to_show: List[str] = ['AUROC', 'AP', 'Best F1'],
    n_top_bottom: int = 3
    ) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], Optional[pd.DataFrame]]:
    """
    Generates DataFrames for report tables summarizing performance by anomaly type,
    assuming input df uses 'Node Type' for element identification.

    Args:
        full_anomaly_type_results_df (pd.DataFrame): DataFrame with per-tag metrics from ALL iterations
                                                     (must include 'Split', 'Node Type', 'Anomaly Tag',
                                                      'Count', 'Mean Score', 'Median Score', 'AUROC',
                                                      'AP', 'Best F1', 'seed').
        split_name (str): The split to analyze.
        sort_metric (str): Metric used to rank specific tags.
        metrics_to_show (List[str]): Performance metrics for summaries.
        n_top_bottom (int): Number of top/bottom tags to show per node type.

    Returns:
        Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], Optional[pd.DataFrame]]:
            - df_category_summary: Aggregated metrics (Mean ± Std) by Main Category and Node Type.
            - df_top_bottom_tags: Detailed metrics (Mean ± Std) for top/bottom specific tags.
            - df_for_scatter: Data prepared for the Score vs Performance scatter plot.
    """
    print(f"\n--- Generating Anomaly Type Report Tables for Split: '{split_name}' ---")

    if full_anomaly_type_results_df is None or full_anomaly_type_results_df.empty:
        print("Input DataFrame is empty.")
        return None, None, None

    # --- 1. Preprocessing & Filtering ---
    df_split = full_anomaly_type_results_df[full_anomaly_type_results_df['Split'] == split_name].copy()
    if df_split.empty:
        print(f"No data found for split '{split_name}'.")
        return None, None, None

    # Check required columns (using Node Type instead of Element Type)
    required_cols = ['Node Type', 'Anomaly Tag', 'Count', 'seed'] + metrics_to_show
    score_stat_cols = ['Mean Score', 'Median Score'] # Check optional cols
    required_cols.extend([c for c in score_stat_cols if c in df_split.columns]) # Add score stats if present
    missing_cols = [col for col in required_cols if col not in df_split.columns]
    if missing_cols:
        print(f"Error: Missing required columns in DataFrame: {missing_cols}")
        return None, None, None

    # Add Main Category
    df_split['Main Category'] = df_split['Anomaly Tag'].apply(extract_main_category)

    # Calculate Total Anomalies per Node Type for proportion calculation
    total_anomalies_per_node_type = df_split.groupby('Node Type')['Count'].sum()

    def get_proportion(row):
        total = total_anomalies_per_node_type.get(row['Node Type'], 0)
        return (row['Count'] / total * 100) if total > 0 else 0
    df_split['Proportion (%)'] = df_split.apply(get_proportion, axis=1)

    # --- 2. Aggregate Mean & Std per Tag/Node Type ---
    grouping_cols = ['Split', 'Node Type', 'Anomaly Tag', 'Main Category']
    # Define columns to aggregate
    cols_to_agg = metrics_to_show + [c for c in score_stat_cols if c in df_split.columns] + ['Count', 'Proportion (%)']
    cols_present_agg = [c for c in cols_to_agg if c in df_split.columns]

    try:
        agg_funcs = {col: ['mean', 'std'] for col in cols_present_agg}
        agg_funcs['seed'] = ['nunique']
        df_agg = df_split.groupby(grouping_cols, observed=True).agg(agg_funcs)
        df_agg.columns = ['_'.join(map(str, col)).strip('_') for col in df_agg.columns.values]
        df_agg = df_agg.rename(columns={'seed_nunique': 'n_runs',
                                        'Count_mean': 'Avg Count',
                                        'Proportion (%)_mean': 'Avg Proportion (%)'})
        df_agg = df_agg.reset_index()
    except Exception as e:
         print(f"Error during initial aggregation: {e}")
         return None, None, None

    # --- 3. Table 1: Summary by Main Category & Node Type ---
    print("\n--- Generating Table 1: Summary by Main Category & Node Type ---")
    df_category_summary = None # Numerical version
    df_category_summary_display = pd.DataFrame() # Initialize display version

    # Ensure df_agg exists and is not empty from step 2
    if df_agg is None or df_agg.empty:
        print("Skipping Table 1: Initial aggregated DataFrame (df_agg) is missing or empty.")
    else:
        try:
            # Define numerical columns needed for averaging means
            metric_mean_cols_to_avg = [f'{m}_mean' for m in metrics_to_show if f'{m}_mean' in df_agg.columns]
            # Define numerical columns needed for averaging stds (for display formatting)
            metric_std_cols_to_avg = [f'{m}_std' for m in metrics_to_show if f'{m}_std' in df_agg.columns]
            # Define numerical columns needed for summing counts/proportions
            count_prop_cols_to_sum = [c for c in ['Avg Count', 'Avg Proportion (%)'] if c in df_agg.columns]

            # --- Data Conversion and Cleaning before Grouping ---
            df_agg_numeric = df_agg.copy()
            cols_to_convert = metric_mean_cols_to_avg + metric_std_cols_to_avg + count_prop_cols_to_sum

            for col in cols_to_convert:
                if col in df_agg_numeric.columns:
                    # Force conversion to numeric, making errors NaN
                    df_agg_numeric[col] = pd.to_numeric(df_agg_numeric[col], errors='coerce')
                #else: # Debugging line, remove later
                #    print(f"Debug: Column {col} not found for numeric conversion in df_agg_numeric")


            # --- Perform Aggregation on Cleaned Numeric Data ---
            category_group = df_agg_numeric.groupby(['Main Category', 'Node Type'], observed=True)

            # Sum counts and proportions (safe now due to numeric conversion)
            if not count_prop_cols_to_sum: # Check if columns exist
                summary_counts = pd.DataFrame(index=category_group.groups.keys()) # Create empty df with correct index
                summary_counts['Total Anomalies'] = 0
                summary_counts['Category Proportion (%)'] = 0
            else:
                summary_counts = category_group[count_prop_cols_to_sum].sum()
                summary_counts = summary_counts.rename(columns={'Avg Count': 'Total Anomalies',
                                                                'Avg Proportion (%)': 'Category Proportion (%)'})


            # Average the performance metric means
            if not metric_mean_cols_to_avg: # Check if columns exist
                summary_metrics_mean = pd.DataFrame(index=category_group.groups.keys())
            else:
                summary_metrics_mean = category_group[metric_mean_cols_to_avg].mean() # This should work now


            # Average the standard deviations (needed for formatting)
            if not metric_std_cols_to_avg: # Check if columns exist
                summary_metrics_std_mean = pd.DataFrame(index=category_group.groups.keys())
            else:
                summary_metrics_std_mean = category_group[metric_std_cols_to_avg].mean() # Mean of the stds


            # Combine counts and averaged metric means
            df_category_summary_num = summary_counts.join(summary_metrics_mean) # Numerical summary base

            # --- Format for Display ---
            df_category_summary_display = df_category_summary_num.copy()
            for metric in metrics_to_show:
                mean_col = f'{metric}_mean'
                std_mean_col = f'{metric}_std' # Column name in summary_metrics_std_mean

                # Get the series of averaged std deviations for this metric
                # Use .get() on the DataFrame which returns None if column doesn't exist
                std_mean_series = summary_metrics_std_mean.get(std_mean_col)

                if mean_col in df_category_summary_display.columns:
                    # Apply formatting using the numerical mean from summary and the averaged std dev series
                    df_category_summary_display[metric] = [
                        format_mean_std(
                            df_category_summary_display.loc[idx, mean_col],
                            std_mean_series.get(idx, np.nan) if std_mean_series is not None else np.nan # Safely get std for this index
                        )
                        for idx in df_category_summary_display.index
                    ]
                    df_category_summary_display = df_category_summary_display.drop(columns=[mean_col], errors='ignore')
                else:
                    # Only add metric column if it wasn't calculated (e.g. mean col missing)
                    if metric not in df_category_summary_display.columns:
                        df_category_summary_display[metric] = "N/A"

            # Reorder and select columns for display table
            final_cols_cat = ['Total Anomalies', 'Category Proportion (%)'] + metrics_to_show
            # Ensure all desired columns exist before reindexing
            final_cols_cat_present = [col for col in final_cols_cat if col in df_category_summary_display.columns]
            df_category_summary_display = df_category_summary_display.reindex(columns=final_cols_cat_present, fill_value='N/A').round(3) # Use present columns
            print(df_category_summary_display.to_string())

            # Store numerical version if needed later
            df_category_summary = df_category_summary_num.round(4)

        except Exception as e:
            print(f"Error creating category summary table: {e}")
            # import traceback; traceback.print_exc() # Uncomment for full traceback


    # --- 4. Table 2: Top & Bottom Performing Specific Tags ---
    print(f"\n--- Generating Table 2: Top/Bottom {n_top_bottom} Tags by {sort_metric} ---")
    df_top_bottom_tags = None
    df_top_bottom_tags_display = pd.DataFrame() # Initialize empty display df
    try:
        sort_col_mean = f'{sort_metric}_mean'
        if sort_col_mean not in df_agg.columns:
            print(f"Warning: Mean column for sort metric '{sort_metric}' ({sort_col_mean}) not found.")
        else:
            top_bottom_list = []
            # Get top/bottom for each NODE TYPE separately
            for node_type in df_agg['Node Type'].unique(): # Iterate through actual node types found
                df_element = df_agg[df_agg['Node Type'] == node_type].copy()
                df_element_sorted = df_element.sort_values(by=sort_col_mean, ascending=False)
                top_n = df_element_sorted.head(n_top_bottom)
                bottom_n = df_element_sorted.tail(n_top_bottom)
                top_bottom_list.extend([top_n, bottom_n])

            if top_bottom_list:
                df_top_bottom_tags_agg = pd.concat(top_bottom_list).drop_duplicates().sort_values(
                    by=['Node Type', sort_col_mean], ascending=[True, False] # Sort final table
                )

                # Format metrics as Mean ± Std for display
                df_top_bottom_tags_display = df_top_bottom_tags_agg.copy()
                cols_detailed_display = ['Node Type', 'Main Category', 'Anomaly Tag', 'Avg Count', 'Avg Proportion (%)']
                for metric in metrics_to_show:
                     mean_col = f'{metric}_mean'; std_col = f'{metric}_std'
                     if mean_col in df_top_bottom_tags_display.columns:
                          df_top_bottom_tags_display[metric] = df_top_bottom_tags_display.apply(
                               lambda row: format_mean_std(row.get(mean_col), row.get(std_col)), axis=1
                          )
                          cols_detailed_display.append(metric)
                          df_top_bottom_tags_display = df_top_bottom_tags_display.drop(columns=[mean_col, std_col], errors='ignore')
                     else: df_top_bottom_tags_display[metric] = "N/A"; cols_detailed_display.append(metric)


                df_top_bottom_tags_display = df_top_bottom_tags_display[cols_detailed_display].round(3)
                print(df_top_bottom_tags_display.to_string(index=False))
                # Store numerical version
                df_top_bottom_tags = df_top_bottom_tags_agg
            else: print("No data found for top/bottom tags.")

    except Exception as e: print(f"Error creating top/bottom tags table: {e}")


    # --- 5. Prepare Data for Scatter Plot ---
    print("\n--- Preparing Data for Score vs Performance Scatter Plot ---")
    df_for_scatter = None
    try:
        # Select necessary columns from the aggregated-per-tag DataFrame (df_agg)
        scatter_cols = ['Node Type', 'Main Category', 'Anomaly Tag', 'Avg Count', 'Avg Proportion (%)']
        metrics_for_scatter = ['AP', 'Best F1', 'Mean Score_mean', 'Median Score_mean'] # Use mean cols directly
        scatter_metrics_present = []

        for col in metrics_for_scatter:
             if col in df_agg.columns:
                  scatter_cols.append(col)
                  scatter_metrics_present.append(col.replace('_mean','')) # Store base name

        if all(c in df_agg.columns for c in scatter_cols):
            df_for_scatter = df_agg[scatter_cols].copy()
            rename_map = {f'{m}_mean': m for m in scatter_metrics_present if m+'_mean' in scatter_cols} # Rename _mean cols
            df_for_scatter = df_for_scatter.rename(columns=rename_map)
            df_for_scatter = df_for_scatter.rename(columns={'Avg Count': 'Count', 'Avg Proportion (%)': 'Proportion (%)'}) # Simpler names for plot
            print(f"Scatter plot data prepared with {len(df_for_scatter)} unique tags.")
        else:
            missing = [c for c in scatter_cols if c not in df_agg.columns]
            print(f"Could not prepare scatter plot data: Missing columns {missing}")

    except Exception as e: print(f"Error preparing data for scatter plot: {e}")

    return df_category_summary, df_top_bottom_tags_display, df_for_scatter # Return display version of top/bottom

In [None]:
import pandas as pd
# Assume necessary imports and function definitions are done above
# from your_module import generate_anomaly_type_report_tables, plot_metric_comparison_bars, plot_score_distributions_split # etc.

# --- Prerequisites ---
# Assume you have run the multi-iteration process and have these populated:
# 1. final_anomaly_type_df: A pandas DataFrame concatenated from ALL iterations,
#    containing columns like 'Split', 'Node Type', 'Anomaly Tag', 'Count',
#    'Mean Score', 'Median Score', 'AUROC', 'AP', 'Best F1', 'seed', etc.
# 2. final_summary_df: The aggregated DataFrame (mean +/- std) for overall performance
#                       (output from the aggregation step of the multi-iteration run)
# 3. all_scores: A dictionary like {'test': {'nodes': {...}, 'edges': {...}}, ...}
#                containing the raw scores from the LAST iteration (or you might need
#                to decide how to handle scores from multiple iterations if needed for plots).
#                For the score distribution plot, using scores from one representative run
#                (like the last one or one with median performance) is usually sufficient.

#  Data Initialization 
final_anomaly_type_df = full_anomaly_type_results_df.copy()
all_scores = all_run_scores.get(0) 
# Example GT labels (needed for score distribution plot - replace with actual)
gt_node_labels = DATA_SPLITS_ITERATIONS.get(0).get("gt_node_labels")
gt_edge_labels = DATA_SPLITS_ITERATIONS.get(0).get("gt_edge_labels")

# --- Analysis Parameters ---
target_split = 'test' # Analyze the test set results
primary_sort_metric = 'AP' # Rank detailed tags by Average Precision
metrics_to_include = ['AP'] # Metrics to show in tables/plots
top_n_tags = 3 # Show top/bottom 3 specific tags per node type
save_output_dir = "anomaly_type_analysis_results" # Directory to save tables/plots

# --- Run Analysis ---
if 'final_anomaly_type_df' in locals() and not final_anomaly_type_df.empty:
    category_summary_df, top_bottom_tags_df, scatter_data_df = generate_anomaly_type_report_tables(
        full_anomaly_type_results_df=final_anomaly_type_df,
        split_name=target_split,
        sort_metric=primary_sort_metric,
        metrics_to_show=metrics_to_include,
        n_top_bottom=top_n_tags
    )

    # --- Optional: Generate Plots using other functions ---
    # Note: These plotting functions might need the aggregated data (like final_summary_df)
    #       or raw scores (all_scores) and labels depending on their implementation.

    # 1. Plot Category Comparison (using data from generate_anomaly_type_report_tables)
    if category_summary_df is not None:
        try:
            # Requires plot_metric_comparison_bars function (adapted for this input if needed)
            # This plot shows AGGREGATED performance per category
            print("\n--- Generating Category Performance Plot ---")
            # Need to slightly reformat category_summary_df for the plotting function if it expects specific index/columns
            plot_data_cat = category_summary_df.reset_index()
            # You might need a dedicated plotting function for this summary table,
            # the previous plot_metric_comparison_bars was designed for the overall summary.
            # Example using seaborn directly:
            if 'Node Type' in plot_data_cat.columns and primary_sort_metric in plot_data_cat.columns:
                 plt.figure(figsize=(10, 6))
                 sns.barplot(data=plot_data_cat, x='Main Category', y=primary_sort_metric, hue='Node Type', palette='viridis', edgecolor='grey')
                 plt.title(f'Average {primary_sort_metric} by Main Anomaly Category ({target_split.capitalize()} Set)')
                 plt.ylabel(f'Average {primary_sort_metric} (Mean across Runs)')
                 plt.xlabel('Main Anomaly Category')
                 plt.xticks(rotation=0)
                 plt.legend(title='Node Type', bbox_to_anchor=(1.02, 1), loc='upper left')
                 plt.grid(True, axis='y', linestyle=':', alpha=0.7)
                 plt.tight_layout(rect=[0, 0, 0.88, 0.96])
                 if save_output_dir:
                     os.makedirs(save_output_dir, exist_ok=True)
                     plt.savefig(os.path.join(save_output_dir, f"plot_category_compare_{primary_sort_metric}_{target_split}.png"), dpi=300, bbox_inches='tight')
                 plt.show()

        except Exception as e:
            print(f"Could not generate category comparison plot: {e}")


    # 2. Plot Score Distributions (using raw scores and labels)
    # Needs the `plot_score_distributions_split` function defined previously
    # Requires `all_scores` and `gt_node_labels`, `gt_edge_labels`
    if 'all_scores' in locals() and 'gt_node_labels' in locals() and 'gt_edge_labels' in locals():
         try:
              print("\n--- Generating Score Distribution Plot ---")
              # Ensure plot_score_distributions_split is defined/imported
              plot_score_distributions_split(
                   scores_dict=all_scores,
                   gt_node_labels_dict=gt_node_labels,
                   gt_edge_labels_dict=gt_edge_labels, # Pass both label dicts
                   split_name=target_split,
                   target_edge_type = ('provider','to','member') # Specify edge type if relevant
                   # save_path=os.path.join(save_output_dir, f"plot_score_dist_{target_split}.png") # Optional save
              )
         except NameError:
              print("`plot_score_distributions_split` function not defined. Skipping plot.")
         except Exception as e:
              print(f"Could not generate score distribution plot: {e}")

    # --- Optional: Save generated tables ---
    if save_output_dir:
        os.makedirs(save_output_dir, exist_ok=True)
        if category_summary_df is not None:
            category_summary_df.to_csv(os.path.join(save_output_dir, f"table_category_summary_{target_split}.csv"))
            print(f"Category summary saved to {save_output_dir}")
        if top_bottom_tags_df is not None:
            top_bottom_tags_df.to_csv(os.path.join(save_output_dir, f"table_top_bottom_tags_{target_split}.csv"), index=False)
            print(f"Top/Bottom tags summary saved to {save_output_dir}")
        if scatter_data_df is not None:
             scatter_data_df.to_csv(os.path.join(save_output_dir, f"data_for_scatter_{target_split}.csv"), index=False)
             print(f"Scatter plot data saved to {save_output_dir}")

else:
     print("Analysis could not be performed: `final_anomaly_type_df` is missing or empty.")

In [None]:
full_anomaly_type_results_df

In [None]:
DATA_SPLITS_ITERATIONS.get(0).get("anomaly_tracking").get("test").get("node").get("member")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, Optional, List
import os
from collections import defaultdict

import re
import pandas as pd # Should be imported at the top of your script

# Helper function to format the anomaly tag for plot labels
# This function should be defined somewhere accessible by analyze_anomaly_types_iterations,
# e.g., at the same level as extract_main_category, or as a nested function if preferred.
def _format_tag_for_plot_label(original_tag_str: str, main_category_str: str) -> str:
    """
    Formats an anomaly tag string for better readability in plot labels.
    - 'Category/name_of_anomaly' -> 'Name Of Anomaly'
    - 'Combined' (if main_category is 'Combined') -> 'Combined'
    - Edge tags (if main_category is 'Edge') -> 'Edge'
    - Other simple tags -> 'Original Tag Str' (processed for underscores and capitalization)
    """
    if pd.isna(original_tag_str):
        return "Unknown"

    # Handle special main categories first
    if main_category_str == 'Combined':
        return 'Combined'
    if main_category_str == 'Edge':
        # `extract_main_category` already classified it as 'Edge'
        # The original_tag_str for edges can be complex like "('TypeA', 'rel', 'TypeB')"
        return 'Edge'

    name_to_format = original_tag_str
    # Use regex to remove the "Category/" prefix if one exists.
    # This removes everything up to and including the first '/'
    name_to_format = re.sub(r'^[^/]+/', '', name_to_format, count=1)

    # Replace underscores with spaces
    name_to_format = name_to_format.replace('_', ' ')
    
    # Capitalize the first letter of each word
    formatted_name = ' '.join(word.capitalize() for word in name_to_format.split())

    # Return the formatted name, or the original tag if formatting results in an empty string
    return formatted_name if formatted_name else original_tag_str

# Set a suitable style for report plots
sns.set_theme(style="whitegrid", palette="viridis", font_scale=1.1) # Using viridis palette

def extract_main_category(tag):
    """Helper function to extract the main category from the tag."""
    if pd.isna(tag):
        return 'Unknown'
    if isinstance(tag, str):
        # Handle edge type strings explicitly if they appear in Anomaly Tag column
        if tag.startswith("('"):
            return 'Edge' # Classify edge types simply as 'Edge' for grouping
        if '/' in tag:
            category = tag.split('/')[0]
            # Standardize capitalization
            if category.lower() == 'structural': return 'Structural'
            if category.lower() == 'attribute': return 'Attribute'
            return category # Keep original if not standard
        elif tag.lower() == 'combined':
            return 'Combined'
        elif tag.lower() == 'unknown':
             return 'Unknown'
        else:
             # Basic keyword check for tags without '/'
             if 'attribute' in tag.lower(): return 'Attribute'
             if 'structural' in tag.lower(): return 'Structural'
             if 'combined' in tag.lower(): return 'Combined'
             return 'Other' # Fallback category
    return 'Unknown' # Default for non-string types or unparseable strings


def analyze_anomaly_types_iterations(
    anomaly_type_df_iterations: pd.DataFrame,
    split_name: str = 'test',
    sort_metric: str = 'AP',
    metrics_to_analyze: List[str] = ['AUROC', 'AP', 'Best F1'],
    plot_metric_comparison: str = 'AP',
    top_k_anomalies: int = 10,
    plot_figsize_comparison: Tuple[int, int] = (12, 7),
    plot_figsize_top_k: Tuple[int, int] = (14, 8),
    save_dir: Optional[str] = None
    ) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame]]:
    """
    Analyzes performance breakdown by anomaly type for a specific data split,
    considering multiple experimental iterations (seeds).
    Generates tables with mean for counts/proportions and mean & std deviation for metrics.
    Also generates a bar plot for top-k detected anomaly types, colored by Main Category.

    Args:
        anomaly_type_df_iterations (pd.DataFrame): DataFrame with per-tag metrics vs normals
                                        (expected columns: 'Split', 'Node Type', 'Anomaly Tag',
                                        'Count', 'Mean Score', 'Median Score', 'AUROC', 'AP', 'Best F1', 'seed').
        split_name (str): The split to analyze ('train', 'val', or 'test').
        sort_metric (str): Metric used to sort the detailed tag results (e.g., 'AP').
                           The table will be sorted by the mean of this metric.
        metrics_to_analyze (List[str]): List of performance metric columns for summaries.
        plot_metric_comparison (str): The metric to visualize in the category comparison bar plot (plots its mean).
        top_k_anomalies (int): Number of top anomaly types to display in the new bar plot.
        plot_figsize_comparison: Figure size for the metric comparison plot.
        plot_figsize_top_k: Figure size for the top-k anomaly types plot.
        save_dir (Optional[str]): Directory to save the output tables and plots.

    Returns:
        Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame]]:
            - df_category_summary: Aggregated metrics by Main Category and Node Type.
            - df_detailed_sorted: Detailed metrics per Anomaly Tag, sorted.
            (Returns None, None if analysis fails)
    """
    print(f"\n--- Analyzing Anomaly Type Performance (with Iterations) for Split: '{split_name}' ---")

    if anomaly_type_df_iterations is None or anomaly_type_df_iterations.empty:
        print("Input DataFrame `anomaly_type_df_iterations` is empty. Cannot perform analysis.")
        return None, None

    # --- 1. Preprocessing ---
    df_split = anomaly_type_df_iterations[anomaly_type_df_iterations['Split'] == split_name].copy()
    if df_split.empty:
        print(f"No data found for split '{split_name}' in `anomaly_type_df_iterations`.")
        return None, None

    required_cols_base = ['Node Type', 'Anomaly Tag', 'Count', 'seed']
    metrics_for_analysis_and_plotting = list(set(metrics_to_analyze + [plot_metric_comparison, 'AP'])) # 'AP' is needed for top-k plot
    
    missing_cols = [col for col in required_cols_base + metrics_for_analysis_and_plotting if col not in df_split.columns]
    if missing_cols:
        print(f"Error: Missing required columns in DataFrame: {missing_cols}")
        return None, None

    df_split['Main Category'] = df_split['Anomaly Tag'].apply(extract_main_category)

    if 'Node Type' not in df_split.columns or 'seed' not in df_split.columns or 'Count' not in df_split.columns:
        print("Error: 'Node Type', 'seed', or 'Count' column missing, cannot calculate proportions accurately.")
        return None, None
    
    df_split['Total Anomalies in Node Type for Seed'] = df_split.groupby(['Node Type', 'seed'])['Count'].transform('sum')
    df_split['Proportion (%)'] = (df_split['Count'] / df_split['Total Anomalies in Node Type for Seed'] * 100).fillna(0)

    # --- 2. Table 1: Performance Summary by Main Category & Node Type ---
    print("\n--- Table 1: Performance Summary by Main Anomaly Category & Node Type ---")
    df_category_summary = None
    try:
        metric_aggs = {metric: ['mean', 'std'] for metric in metrics_to_analyze}
        df_category_metrics_agg = df_split.groupby(['Main Category', 'Node Type'])[metrics_to_analyze].agg(metric_aggs)
        df_category_metrics_agg.columns = ['_'.join(col).strip() for col in df_category_metrics_agg.columns.values]

        category_counts_per_seed = df_split.groupby(['Main Category', 'Node Type', 'seed'])['Count'].sum().reset_index()
        category_total_anomalies_agg = category_counts_per_seed.groupby(['Main Category', 'Node Type']).agg(
            Total_Anomalies=('Count', 'mean')
        )

        category_proportions_per_seed = df_split.groupby(['Main Category', 'Node Type', 'seed'])['Proportion (%)'].sum().reset_index()
        category_proportion_agg = category_proportions_per_seed.groupby(['Main Category', 'Node Type']).agg(
            Category_Proportion_Pct=('Proportion (%)', 'mean')
        )
        
        df_category_summary = pd.concat([category_total_anomalies_agg, category_proportion_agg, df_category_metrics_agg], axis=1).reset_index()
        
        metric_cols_ordered = []
        for m in metrics_to_analyze:
            metric_cols_ordered.extend([f'{m}_mean', f'{m}_std'])
        
        cols_order = ['Main Category', 'Node Type', 'Total_Anomalies', 'Category_Proportion_Pct'] + metric_cols_ordered
        df_category_summary = df_category_summary.reindex(columns=cols_order).round(3)
        
        for col in df_category_summary.columns:
            if col.endswith("_std"):
                df_category_summary[col] = df_category_summary[col].fillna(0)
        
        df_category_summary.rename(columns={'Total_Anomalies': 'Total Anomalies (Avg)', 
                                            'Category_Proportion_Pct': 'Category Proportion (%) (Avg)'}, inplace=True)

        print(df_category_summary.to_string(index=False))
    except Exception as e:
        print(f"Error creating category summary table: {e}")
        import traceback
        traceback.print_exc()


    # --- 3. Table 2: Detailed Performance by Specific Anomaly Tag (Sorted) ---
    sort_metric_mean_col_name = f'{sort_metric}_mean'

    print(f"\n--- Table 2: Detailed Performance by Specific Anomaly Tag (Sorted by {sort_metric_mean_col_name}) ---")
    df_detailed_sorted = None
    try:
        score_stats = [col for col in ['Mean Score', 'Median Score'] if col in df_split.columns]
        
        detailed_aggs = {}
        for m in metrics_to_analyze + score_stats:
            detailed_aggs[f"{m}_mean"] = (m, 'mean')
            detailed_aggs[f"{m}_std"] = (m, 'std')
        detailed_aggs['Count'] = ('Count', 'mean')
        detailed_aggs['Proportion (%)'] = ('Proportion (%)', 'mean')

        df_detailed_agg = df_split.groupby(['Node Type', 'Main Category', 'Anomaly Tag']).agg(**detailed_aggs).reset_index()
        
        for col in df_detailed_agg.columns:
            if col.endswith("_std"):
                df_detailed_agg[col] = df_detailed_agg[col].fillna(0)

        cols_detailed_ordered = ['Node Type', 'Main Category', 'Anomaly Tag', 'Count', 'Proportion (%)']
        for m in metrics_to_analyze + score_stats:
            cols_detailed_ordered.extend([f'{m}_mean', f'{m}_std'])
        
        cols_detailed_ordered = [col for col in cols_detailed_ordered if col in df_detailed_agg.columns]
        df_detailed_sorted = df_detailed_agg[cols_detailed_ordered]

        if sort_metric_mean_col_name not in df_detailed_sorted.columns:
            print(f"Error: Sort metric '{sort_metric_mean_col_name}' not found in detailed aggregated DataFrame. Sorting by Node Type, Main Category.")
            df_detailed_sorted = df_detailed_sorted.sort_values(by=['Node Type', 'Main Category'], ascending=[True, True])
        else:
            df_detailed_sorted = df_detailed_sorted.sort_values(
                by=['Node Type', 'Main Category', sort_metric_mean_col_name],
                ascending=[True, True, False]
            )
        
        df_detailed_sorted = df_detailed_sorted.round(3)
        print(df_detailed_sorted.to_string(index=False, max_rows=50))
    except KeyError as e:
         print(f"KeyError during detailed table creation or sorting (possibly related to '{sort_metric_mean_col_name}'): {e}")
    except Exception as e:
        print(f"Error creating detailed sorted table: {e}")
        import traceback
        traceback.print_exc()

    # --- 4. Plot 1: Metric Comparison by Main Category (with Error Bars for Std Dev) ---
    print(f"\n--- Plot 1: Comparison of Average {plot_metric_comparison} by Main Category (with Std Dev over Iterations) ---")
    if df_category_summary is not None and f'{plot_metric_comparison}_mean' in df_category_summary.columns:
        try:
            plt.figure(figsize=plot_figsize_comparison)
            ax = sns.barplot(
                data=df_split, 
                x='Main Category',
                y=plot_metric_comparison,
                hue='Node Type',
                palette='viridis',
                edgecolor='grey',
                linewidth=0.75,
                estimator=np.mean, 
                errorbar='sd'
            )
            plt.title(f'{plot_metric_comparison} by Main Anomaly Category ({split_name.capitalize()} Set)')
            plt.xlabel('Main Anomaly Category')
            plt.ylabel(f' Average Precision ({plot_metric_comparison})')
            plt.xticks(rotation=0)
            plt.legend(title='Node Type', bbox_to_anchor=(1.02, 1), loc='upper left')
            ax.grid(True, axis='y', linestyle=':', alpha=0.7)
            
            for container in ax.containers:
                 ax.bar_label(container, fmt='%.2f', label_type='edge', padding=2, fontsize=9)

            plt.ylim(bottom=0)
            plt.tight_layout(rect=[0, 0, 0.88, 0.97])

            if save_dir:
                if not os.path.exists(save_dir): os.makedirs(save_dir)
                plot_path_comp = os.path.join(save_dir, f'plot_category_comparison_iters_{plot_metric_comparison}_{split_name}.png')
                try:
                    plt.savefig(plot_path_comp, dpi=300, bbox_inches='tight')
                    print(f"Comparison plot saved to {plot_path_comp}")
                except Exception as e: print(f"Error saving comparison plot: {e}")
            plt.show()
        except Exception as e:
            print(f"Error generating comparison plot: {e}")
            import traceback
            traceback.print_exc()
    else:
        metric_mean_col = f'{plot_metric_comparison}_mean'
        print(f"Cannot generate comparison plot: Summary data missing or plot metric '{metric_mean_col}' not found.")

    # --- 5. Plot 2: Top-k Best Detected Anomaly Types (by AP_mean, colored by Main Category) ---
    print(f"\n--- Plot 2: Top {top_k_anomalies} Best Detected Anomaly Tags (by AP mean, colored by Main Category) ---")
    if df_detailed_sorted is not None and 'AP_mean' in df_detailed_sorted.columns and top_k_anomalies > 0:
        try:
            top_k_df_for_plot_info = df_detailed_sorted.sort_values(by='AP_mean', ascending=False).head(top_k_anomalies)
            
            if top_k_df_for_plot_info.empty:
                print(f"No data available for top {top_k_anomalies} plot.")
            else:
                top_k_tags_identifiers_tuples = list(zip(top_k_df_for_plot_info['Node Type'], top_k_df_for_plot_info['Anomaly Tag']))
                
                plot_data_top_k = df_split[
                    df_split.set_index(['Node Type', 'Anomaly Tag']).index.isin(top_k_tags_identifiers_tuples)
                ].copy()

                #plot_data_top_k['Tag Identifier'] = plot_data_top_k['Node Type'] + ' | ' + plot_data_top_k['Anomaly Tag']
                #ordered_tag_identifiers_for_plot = (top_k_df_for_plot_info['Node Type'] + ' | ' + top_k_df_for_plot_info['Anomaly Tag']).tolist()

                # 1. Modification for plot_data_top_k:
                # Apply the formatting to the 'Anomaly Tag' using its 'Main Category'
                plot_data_top_k['Formatted Anomaly Name'] = plot_data_top_k.apply(
                    lambda row: _format_tag_for_plot_label(row['Anomaly Tag'], row['Main Category']), axis=1
                )
                plot_data_top_k['Tag Identifier'] = plot_data_top_k['Node Type'] + ' | ' + plot_data_top_k['Formatted Anomaly Name']

                # 2. Modification for ordered_tag_identifiers_for_plot:
                # Apply the formatting similarly to get the ordered list of new identifiers
                top_k_df_for_plot_info['Formatted Anomaly Name'] = top_k_df_for_plot_info.apply(
                    lambda row: _format_tag_for_plot_label(row['Anomaly Tag'], row['Main Category']), axis=1
                )
                ordered_tag_identifiers_for_plot = (top_k_df_for_plot_info['Node Type'] + ' | ' + top_k_df_for_plot_info['Formatted Anomaly Name']).tolist()

                # Define a color palette for Main Categories.
                # If you have more than a few categories, you might want a more dynamic palette.
                # Using seaborn's default categorical palette if not specified, or you can set one:
                # E.g., unique_main_categories = plot_data_top_k['Main Category'].unique()
                # category_palette = sns.color_palette("husl", n_colors=len(unique_main_categories))
                # Or provide a dictionary mapping: main_category_colors = {'Structural': 'blue', 'Attribute': 'green', 'Combined': 'red', 'Other': 'purple', 'Unknown': 'grey', 'Edge': 'orange'}

                plt.figure(figsize=plot_figsize_top_k)
                ax_top_k = sns.barplot(
                    data=plot_data_top_k,
                    x='Tag Identifier',
                    y='AP', 
                    hue='Main Category', # Color bars by Main Category
                    order=ordered_tag_identifiers_for_plot,
                    # palette=main_category_colors, # Optional: use a predefined palette or color map
                    palette='Set2', # Using a distinct palette
                    edgecolor='grey',
                    linewidth=0.75,
                    estimator=np.mean,
                    errorbar='sd'
                )
                plt.title(f'Top {top_k_anomalies} Best Detected Anomaly Tags by AP ({split_name.capitalize()} Set)')
                plt.xlabel('Anomaly Tag (Node Type | Tag)')
                plt.ylabel('Average Precision (AP)')
                plt.xticks(rotation=45, ha='right')
                plt.legend(title='Main Category', bbox_to_anchor=(1.02, 1), loc='upper left') # Add legend for Main Category
                ax_top_k.grid(True, axis='y', linestyle=':', alpha=0.7)

                for container in ax_top_k.containers:
                    ax_top_k.bar_label(container, fmt='%.2f', label_type='edge', padding=2, fontsize=9)
                
                plt.ylim(bottom=0)
                plt.tight_layout(rect=[0, 0, 0.85, 0.96]) # Adjust layout for legend

                if save_dir:
                    plot_path_top_k = os.path.join(save_dir, f'plot_top_k_anomalies_AP_colored_{split_name}.png')
                    try:
                        plt.savefig(plot_path_top_k, dpi=300, bbox_inches='tight')
                        print(f"Top-k anomalies plot saved to {plot_path_top_k}")
                    except Exception as e: print(f"Error saving top-k anomalies plot: {e}")
                plt.show()
        except Exception as e:
            print(f"Error generating top-k anomalies plot: {e}")
            import traceback
            traceback.print_exc()
    elif top_k_anomalies <= 0:
        print("Skipping top-k anomalies plot as top_k_anomalies is not positive.")
    else:
        print("Cannot generate top-k anomalies plot: Detailed summary data or 'AP_mean' column missing.")


    # --- Save DataFrames ---
    if save_dir:
        if not os.path.exists(save_dir): os.makedirs(save_dir)
        if df_category_summary is not None:
            cat_path = os.path.join(save_dir, f'summary_category_iters_{split_name}.csv')
            try:
                df_category_summary.to_csv(cat_path, index=False)
                print(f"Category summary table saved to {cat_path}")
            except Exception as e: print(f"Error saving category summary: {e}")
        if df_detailed_sorted is not None:
             det_path = os.path.join(save_dir, f'summary_detailed_tags_iters_{split_name}.csv')
             try:
                 df_detailed_sorted.to_csv(det_path, index=False)
                 print(f"Detailed summary table saved to {det_path}")
             except Exception as e: print(f"Error saving detailed summary: {e}")

    return df_category_summary, df_detailed_sorted





example_save_dir = 'analysis_results_iterations_v2' # New save dir for this version

category_summary_iters, detailed_summary_iters = analyze_anomaly_types_iterations(
    anomaly_type_df_iterations=full_anomaly_type_results_df,
    split_name='test',
    sort_metric='AP', 
    metrics_to_analyze=['AUROC', 'AP', 'Best F1', 'Mean Score'], # Added Mean Score to metrics
    plot_metric_comparison='AP',
    top_k_anomalies=7, # Show top 7 anomalies
    save_dir=example_save_dir
)

print("\n--- Returned Category Summary (Mean for Counts/Props, Mean & Std for Metrics) ---")
if category_summary_iters is not None:
    print(category_summary_iters.to_string())

print("\n--- Returned Detailed Summary (Mean for Counts/Props, Mean & Std for Metrics) ---")
if detailed_summary_iters is not None:
    print(detailed_summary_iters.to_string(index=False, max_rows=15))




In [None]:
all_anomaly_type_dfs

In [None]:
full_anomaly_type_results_df

In [None]:
import pandas as pd
import numpy as np
from typing import Dict, Tuple, Optional, List


def aggregate_and_compare_model_results(
    bgae_results_list: List[pd.DataFrame],
    if_results_list: List[pd.DataFrame],
    oddball_results_list: List[pd.DataFrame],
    dominant_results_list: List[pd.DataFrame],
    metrics_to_compare: List[str] = ['AUROC', 'AP', 'Best F1'],
    k_list_eval: List[int] = [50, 100, 200, 500], # K values used during evaluation
    splits_to_compare: List[str] = ['test'] # Focus on test set by default
    ) -> Optional[pd.DataFrame]:
    """
    Aggregates results from multiple runs for different models and creates
    a comparison DataFrame showing mean ± std dev for specified metrics.

    Args:
        bgae_results_list: List of summary DataFrames from BGAE runs.
        if_results_list: List of summary DataFrames from Isolation Forest runs.
        oddball_results_list: List of summary DataFrames from OddBall runs.
        dominant_results_list: List of summary DataFrames from DOMINANT runs.
        metrics_to_compare: List of primary performance metrics to show in the table.
        k_list_eval: List of K values used, to include relevant P@K/R@K metrics.
        splits_to_compare: List of data splits to include in the comparison (e.g., ['test'], ['val', 'test']).

    Returns:
        Optional[pd.DataFrame]: A DataFrame comparing models, indexed by Split and Element,
                                 with columns for each model showing 'mean ± std' metrics,
                                 or None if aggregation fails.
    """
    print("\n--- Aggregating and Comparing Model Performances ---")

    all_results_list = []
    model_data_map = {
        "BGAE": bgae_results_list,
        "IsolationForest": if_results_list,
        "OddBall": oddball_results_list,
        "DOMINANT": dominant_results_list
    }

    # --- 1. Concatenate and Preprocess Results for Each Model ---
    processed_model_dfs = {}
    metrics_to_check = ['AUROC', 'AP', 'Best F1'] # Ensure these core metrics exist

    for model_name, results_list in model_data_map.items():
        if not results_list:
            print(f"Warning: No results found for model '{model_name}'. Skipping.")
            continue
        try:
            full_df = pd.concat(results_list, ignore_index=True)
            # Filter for relevant splits
            filtered_df = full_df[full_df['Split'].isin(splits_to_compare)].copy()
            # Drop rows where essential metrics are NaN (likely failed runs)
            cols_exist = all(m in filtered_df.columns for m in metrics_to_check)
            if cols_exist:
                 successful_df = filtered_df.dropna(subset=metrics_to_check, how='any')
            else:
                 print(f"Warning: Core metrics missing for {model_name}. Proceeding with available data.")
                 successful_df = filtered_df

            if successful_df.empty:
                 print(f"Warning: No successful runs found for model '{model_name}' in specified splits. Skipping.")
                 continue

            processed_model_dfs[model_name] = successful_df
            print(f"Processed {model_name}: Found {successful_df['seed'].nunique()} successful runs in specified splits.")

        except Exception as e:
            print(f"Error processing results for model '{model_name}': {e}")

    if not processed_model_dfs:
        print("Error: No processed results available for any model. Cannot create comparison.")
        return None

    # --- 2. Aggregate Metrics for Each Model ---
    aggregated_data_list = []
    full_metric_list = metrics_to_compare + \
                       [f'{p}@{k}' for k in k_list_eval for p in ['Precision', 'Recall']]

    for model_name, df in processed_model_dfs.items():
        # Ensure metrics to aggregate actually exist in this df
        metrics_present = [m for m in full_metric_list if m in df.columns]
        if not metrics_present:
            print(f"Warning: No specified metrics found for model '{model_name}'. Skipping aggregation.")
            continue

        try:
            # Group by Split and Element, calculate mean and std
            agg_df = df.groupby(['Split', 'Element'])[metrics_present].agg(['mean', 'std'])
            # Add run count
            run_counts = df.groupby(['Split', 'Element'])['seed'].nunique().rename('n_runs')
            agg_df = pd.concat([run_counts, agg_df], axis=1)
            agg_df['Model'] = model_name # Add model identifier
            aggregated_data_list.append(agg_df.reset_index()) # Reset index for easier merging later
        except Exception as e:
            print(f"Error aggregating data for model '{model_name}': {e}")

    if not aggregated_data_list:
        print("Error: Aggregation failed for all models.")
        return None

    # --- 3. Combine and Format the Comparison Table ---
    comparison_df = pd.concat(aggregated_data_list, ignore_index=True)

    # Pivot table for comparison view
    try:
        pivot_df = comparison_df.pivot_table(
            index=['Split', 'Element'],
            columns='Model'
            # Values will be selected below
        )
    except Exception as e:
         print(f"Error pivoting comparison table: {e}")
         return None # Cannot proceed if pivot fails


    # Format columns as 'Metric (Mean ± Std)'
    final_comparison_cols = {} # Store formatted columns keyed by original metric name
    n_runs_cols = {} # Store n_runs separately

    # Ensure metrics_present considers columns available across *all* processed models if needed,
    # or handle missing columns per model during formatting. Let's handle per model.
    all_metrics_in_pivot = set(lvl0 for lvl0, lvl1 in pivot_df.columns if lvl1=='mean') & set(metrics_to_compare + [f'{p}@{k}' for k in k_list_eval for p in ['Precision', 'Recall']])


    for metric in all_metrics_in_pivot:
        metric_mean_cols = [col for col in pivot_df.columns if col[0] == metric and col[1] == 'mean']
        metric_std_cols = [col for col in pivot_df.columns if col[0] == metric and col[1] == 'std']

        # Create the formatted column for this metric
        formatted_series_list = []
        models_in_pivot = pivot_df.columns.get_level_values('Model').unique() # Models actually present

        for model_name in models_in_pivot:
             mean_col = (metric, 'mean', model_name)
             std_col = (metric, 'std', model_name)

             if mean_col in pivot_df.columns and std_col in pivot_df.columns:
                 # Format only if both mean and std are present and not NaN
                 formatted = pivot_df[mean_col].map('{:.4f}'.format).str.cat(
                               pivot_df[std_col].map('{:.4f}'.format), sep=' ± '
                           ).where(pivot_df[mean_col].notna() & pivot_df[std_col].notna(), "N/A") # Handle NaN after agg
             elif mean_col in pivot_df.columns:
                  # Format mean only if std is missing/NaN
                  formatted = pivot_df[mean_col].map('{:.4f}'.format).where(pivot_df[mean_col].notna(), "N/A") + ' ± nan'
             else:
                  formatted = pd.Series("N/A", index=pivot_df.index) # Metric not found for this model

             formatted.name = (metric, model_name) # Assign multi-level name
             formatted_series_list.append(formatted)

        # Combine formatted series for this metric across all models
        if formatted_series_list:
             final_comparison_cols[metric] = pd.concat(formatted_series_list, axis=1)


    # Get n_runs separately
    n_runs_cols_present = [col for col in pivot_df.columns if col[0] == 'n_runs' and col[1] == ''] # n_runs has empty second level
    if n_runs_cols_present:
        n_runs_df = pivot_df[n_runs_cols_present]
        n_runs_df.columns = n_runs_df.columns.droplevel(1) # Drop the empty level ''
        n_runs_cols = {'n_runs': n_runs_df}


    # Assemble the final table
    if not final_comparison_cols:
        print("Error: No metric columns could be formatted.")
        return None

    # Create MultiIndex columns for the final DataFrame
    metric_order = metrics_to_compare + sorted([m for m in all_metrics_in_pivot if m not in metrics_to_compare]) # Order metrics
    models_order = sorted(processed_model_dfs.keys()) # Sort model names alphabetically

    final_df_cols = pd.MultiIndex.from_product([metric_order, models_order], names=['Metric', 'Model'])
    final_comparison_df = pd.DataFrame(index=pivot_df.index, columns=final_df_cols)

    # Fill the DataFrame with formatted values
    for metric in metric_order:
         if metric in final_comparison_cols:
             metric_data = final_comparison_cols[metric]
             # Ensure columns align - reindex metric_data if necessary
             final_comparison_df[metric] = metric_data.reindex(columns=final_df_cols.get_level_values('Model').unique(), level='Model')

    # Add n_runs as the first level
    if n_runs_cols:
        final_comparison_df = pd.concat(n_runs_cols, axis=1, keys=['Info']).join(final_comparison_df)
        # Rename 'n_runs' column under 'Info'
        final_comparison_df.rename(columns={'n_runs': 'Runs'}, level=1, inplace=True)


    print("\n--- Final Model Comparison (Mean ± Std Dev) ---")
    print(final_comparison_df.to_string())

    return final_comparison_df


import pandas as pd
import numpy as np
from typing import Dict, Tuple, Optional, List

# Assume compute_evaluation_metrics is defined elsewhere or imported

def aggregate_and_compare_model_results(
    bgae_results_list: List[pd.DataFrame],
    if_results_list: List[pd.DataFrame],
    oddball_results_list: List[pd.DataFrame],
    dominant_results_list: List[pd.DataFrame],
    metrics_to_compare: List[str] = ['AUROC', 'AP', 'Best F1'],
    k_list_eval: List[int] = [50, 100, 200, 500], # K values used during evaluation
    splits_to_compare: List[str] = ['test'], # Focus on test set by default
    sort_elements_by: Optional[List[str]] = None # Optional order for elements in index
    ) -> Optional[pd.DataFrame]:
    """
    Aggregates results from multiple runs for different models and creates
    a comparison DataFrame showing mean ± std dev for specified metrics.

    Args:
        bgae_results_list: List of summary DataFrames from BGAE runs.
        if_results_list: List of summary DataFrames from Isolation Forest runs.
        oddball_results_list: List of summary DataFrames from OddBall runs.
        dominant_results_list: List of summary DataFrames from DOMINANT runs.
        metrics_to_compare: List of primary performance metrics to show in the table.
        k_list_eval: List of K values used, to include relevant P@K/R@K metrics.
        splits_to_compare: List of data splits to include in the comparison (e.g., ['test'], ['val', 'test']).
        sort_elements_by (Optional[List[str]]): Specific order for rows based on 'Element' column.

    Returns:
        Optional[pd.DataFrame]: A DataFrame comparing models, indexed by Split and Element,
                                 with columns for each model showing 'mean ± std' metrics,
                                 or None if aggregation fails.
    """
    print("\n--- Aggregating and Comparing Model Performances ---")

    model_data_map = {
        "BGAE": bgae_results_list,
        "IsolationForest": if_results_list,
        "OddBall": oddball_results_list,
        "DOMINANT": dominant_results_list
    }

    # --- 1. Combine all results into a single DataFrame ---
    all_results_list = []
    for model_name, results_list in model_data_map.items():
        if results_list:
            try:
                model_df = pd.concat(results_list, ignore_index=True)
                # Ensure essential columns exist, add 'Model' column
                if not model_df.empty:
                    model_df['Model'] = model_name
                    all_results_list.append(model_df)
            except Exception as e:
                print(f"Error concatenating results for {model_name}: {e}")

    if not all_results_list:
        print("Error: No valid result lists provided or concatenation failed.")
        return None

    combined_df = pd.concat(all_results_list, ignore_index=True)

    # --- 2. Filter and Clean ---
    filtered_df = combined_df[combined_df['Split'].isin(splits_to_compare)].copy()

    metrics_to_check = ['AUROC', 'AP', 'Best F1'] # Core metrics for filtering failed runs
    cols_exist = all(m in filtered_df.columns for m in metrics_to_check)
    if cols_exist:
        successful_df = filtered_df.dropna(subset=metrics_to_check, how='any')
    else:
        print("Warning: Core metric columns missing. Filtering might be incomplete.")
        successful_df = filtered_df

    if successful_df.empty:
        print(f"Warning: No successful runs found for models in specified splits: {splits_to_compare}.")
        return pd.DataFrame() # Return empty DataFrame

    print(f"Aggregating results from {successful_df['seed'].nunique()} successful runs across models...")

    # --- 3. Calculate Aggregates (Mean, Std, Count) ---
    # Define all metrics we might want to aggregate
    all_metrics_possible = metrics_to_compare + \
                           [f'{p}@{k}' for k in k_list_eval for p in ['Precision', 'Recall']]
    # Filter list to only those metrics actually present in the successful data
    metrics_present = [m for m in all_metrics_possible if m in successful_df.columns]

    if not metrics_present:
         print("Error: None of the specified metrics_to_compare or P@K/R@K metrics found in the data.")
         return None

    try:
        # Group by Split, Element, and Model then aggregate
        grouped = successful_df.groupby(['Split', 'Element', 'Model'])
        aggregated_means = grouped[metrics_present].mean()
        aggregated_stds = grouped[metrics_present].std()
        aggregated_counts = grouped['seed'].nunique().rename('n_runs') # Count distinct runs per group

        # Combine mean, std, and counts
        aggregated_df = aggregated_means.join(aggregated_stds, lsuffix='_mean', rsuffix='_std')
        aggregated_df = aggregated_df.join(aggregated_counts)

    except Exception as e:
        print(f"Error during aggregation: {e}")
        return None

    # --- 4. Format the Comparison Table ---
    final_df_list = []
    # Use index from aggregated_df to handle multi-level index easily
    for idx, row in aggregated_df.iterrows():
        split, element, model = idx # Unpack the index levels
        formatted_row = {'Split': split, 'Element': element, 'Model': model, 'Runs': int(row['n_runs'])}
        for metric in metrics_present:
            mean_val = row.get(f'{metric}_mean', np.nan)
            std_val = row.get(f'{metric}_std', np.nan)

            if pd.notna(mean_val) and pd.notna(std_val):
                formatted_row[metric] = f"{mean_val:.4f} ± {std_val:.4f}"
            elif pd.notna(mean_val):
                formatted_row[metric] = f"{mean_val:.4f} ± nan"
            else:
                formatted_row[metric] = "N/A" # Or np.nan if preferred
        final_df_list.append(formatted_row)

    if not final_df_list:
        print("Error: No data after formatting.")
        return pd.DataFrame()

    final_comparison_df_long = pd.DataFrame(final_df_list)

    # Pivot to get models as columns
    try:
        # Define the order of metrics for columns
        column_order = ['Runs'] + metrics_present
        pivot_comparison_df = final_comparison_df_long.pivot_table(
            index=['Split', 'Element'],
            columns='Model',
            values=column_order # Specify columns to pivot as values
        )
        # Reorder model columns alphabetically for consistency
        pivot_comparison_df = pivot_comparison_df.reindex(columns=sorted(pivot_comparison_df.columns.levels[1]), level='Model')
        # Reorder metric level for readability
        pivot_comparison_df = pivot_comparison_df.reindex(columns=column_order, level=0)


    except Exception as e:
        print(f"Error pivoting final table: {e}")
        print("Returning table in long format instead.")
        # Set index for long format return
        final_comparison_df_long = final_comparison_df_long.set_index(['Split', 'Element', 'Model'])
        return final_comparison_df_long # Return long format as fallback


    # --- Optional: Sort index by custom element order ---
    if sort_elements_by and isinstance(pivot_comparison_df.index, pd.MultiIndex):
         try:
             # Sort by Split first, then by custom Element order
             pivot_comparison_df = pivot_comparison_df.sort_index(
                 level='Element',
                 key=lambda index: index.map({elem: i for i, elem in enumerate(sort_elements_by)}),
                 sort_remaining=True # Sorts by Split automatically after sorting by Element key
             )
         except Exception as e:
              print(f"Warning: Could not sort by custom element order: {e}")


    print("\n--- Final Model Comparison (Mean ± Std Dev) ---")
    # Configure pandas display options for better table printing
    pd.set_option('display.max_rows', 500)
    pd.set_option('display.max_columns', len(pivot_comparison_df.columns) + 1) # Adjust based on columns
    pd.set_option('display.width', 200) # Adjust width as needed
    pd.set_option('display.precision', 4) # Set default float precision

    print(pivot_comparison_df.to_string())

    # Reset display options if desired
    # pd.reset_option('all')

    return pivot_comparison_df

# --- Example Usage ---
# Assume all_summary_dfs, all_if_summaries, all_oddball_summaries, all_dominant_summaries
# are lists of DataFrames populated from the previous iteration loops.

# Example element order (adjust based on your actual 'Element' column values)
# element_order = [
#     'Node (provider)',
#     'Node (member)',
#     "Edge ('provider', 'to', 'member')" # Example edge element name
# ]

# comparison_table = aggregate_and_compare_model_results(
#     bgae_results_list = all_summary_dfs,
#     if_results_list = all_if_summaries,
#     oddball_results_list = all_oddball_summaries,
#     dominant_results_list = all_dominant_summaries,
#     metrics_to_compare = ['AUROC', 'AP', 'Best F1'], # Primary metrics
#     k_list_eval = [50, 100, 200, 500], # K values used
#     splits_to_compare = ['test'], # Focus on test set
#     sort_elements_by = element_order # Optional sorting
# )

# if comparison_table is not None:
#      print("\nComparison Table Created:")
#      # print(comparison_table.to_markdown()) # For easy viewing if needed

#      # Optional: Save to CSV
#      # save_dir = "final_evaluation_results_iter"
#      # if save_dir:
#      #     os.makedirs(save_dir, exist_ok=True)
#      #     comparison_table.to_csv(os.path.join(save_dir, "model_comparison_summary.csv"))

comparison_table = aggregate_and_compare_model_results(
     bgae_results_list = all_summary_dfs,
     if_results_list = all_if_summaries,
     oddball_results_list = all_oddball_summaries,
     dominant_results_list = all_dominant_summaries,
     metrics_to_compare = ['AUROC', 'AP', 'Best F1'], # Choose primary metrics
     k_list_eval = [50, 100, 200, 500], # K values used
     splits_to_compare = ['test'] # Focus on test set
 )

# if comparison_table is not None:
#      print("\nComparison Table Created:")
#      print(comparison_table.to_markdown()) # Print markdown for easy viewing

#      # Optional: Save to CSV
#      # save_dir = "final_evaluation_results_iter"
#      # if save_dir:
#      #     os.makedirs(save_dir, exist_ok=True)
#      #     comparison_table.to_csv(os.path.join(save_dir, "model_comparison_summary.csv"))