# Comprehensive Particle Physics Analysis Notebook
## Multi-Model Performance & Cluster Visualization
**Files Used:**
`comprehensive_query_table_{model_name}_model.parquet`
`{model_name}_model_raw_pairs.parquet` 
`{model_name}_model_T_pairs.npy`

In [None]:
import warnings
from typing import List, Tuple, Optional, Dict

import numpy as np
import polars as pl
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import (
    roc_curve, auc, precision_recall_curve,
    confusion_matrix, classification_report
)
from sklearn.preprocessing import label_binarize
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.lines import Line2D
import matplotlib.cm as cm

warnings.filterwarnings('ignore')

# Plotting Style
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (10, 6)

# Central definitions
class_names = ['Lone-Lone', 'True-True', 'Cluster-Lone', 'Lone-Cluster', 'Cluster-Cluster']
class_colors = ['blue', 'orange', 'green', 'red', 'purple']
DEFAULT_COLUMNS = [
    'event_id', 'edge_id', 'source_id', 'target_id',
    'true_label', 'pred_label', 'is_correct', 'confidence',
    'confidence_class_0', 'confidence_class_1', 'confidence_class_2',
    'confidence_class_3', 'confidence_class_4',
    'snr_source', 'eta_source', 'phi_source',
    'snr_target', 'eta_target', 'phi_target',
    'delta_eta', 'delta_phi', 'spatial_distance',
    'avg_snr', 'event_size'
]

## Data Loading Functions

In [None]:
def load_comprehensive_data(model_name: str, file_pwd: str,
                            columns: Optional[List[str]] = None,
                            use_polars: bool = False) -> pd.DataFrame:
    """
    Load the comprehensive query table for a specific model.
    By default uses pandas + pyarrow and projects columns to reduce IO.
    Set use_polars=True to use Polars (faster on large datasets if installed).
    """
    cols = columns or DEFAULT_COLUMNS
    file_path = f"{file_pwd}/comprehensive_query_table_{model_name}_model.parquet"
    print(f"-> Loading: {file_path}")

    if use_polars:
        try:
            df_pl = pl.read_parquet(file_path, columns=cols)
            df = df_pl.to_pandas()  # convert to pandas for rest of pipeline
        except Exception as e:
            raise RuntimeError("Polars requested but not installed or failed to load file.") from e
    else:
        # use pyarrow engine and column projection
        df = pd.read_parquet(file_path, engine='pyarrow', columns=cols)

    print(f"Loaded {len(df):,} rows from {file_path}")
    return df


def load_raw_pairs_data(model_name: str, file_pwd: str) -> pd.DataFrame:
    file_path = f"{file_pwd}/{model_name}_model_raw_pairs.parquet"
    df = pd.read_parquet(file_path, engine='pyarrow')
    print(f"Loaded raw pairs: {len(df):,} rows")
    return df


def load_tensor_data(model_name: str, file_pwd: str) -> np.ndarray:
    file_path = f"{file_pwd}/{model_name}_model_T_pairs.npy"
    tensor = np.load(file_path, mmap_mode='r')  # memory-map for large files
    print(f"Loaded tensor: {tensor.shape}")
    return tensor

## Downcasting Function

In [None]:
def downcast_dtypes(df: pd.DataFrame, float_cols: Optional[List[str]] = None) -> pd.DataFrame:
    """
    Downcast commonly used columns to memory-efficient dtypes.
    """
    # integer columns
    for c in ['true_label', 'pred_label']:
        if c in df.columns:
            df[c] = df[c].astype('int8')
    if 'event_id' in df.columns:
        try:
            df['event_id'] = df['event_id'].astype('int32')
        except Exception:
            pass
    if 'is_correct' in df.columns:
        df['is_correct'] = df['is_correct'].astype('bool')

    # float columns: convert confidence_class_* and confidence to float32
    for c in df.columns:
        if c.startswith('confidence_class_') or c in ('confidence', 'spatial_distance', 'avg_snr', 'snr_source', 'snr_target', 'eta_source', 'eta_target', 'phi_source', 'phi_target'):
            df[c] = df[c].astype('float32')

    return df

## Analysis Core Functions

In [None]:
def prepare_roc_data_from_df(df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
    """Return (y_true, y_score) as numpy arrays (float32)"""
    # Guarantee columns exist
    required_conf_cols = [f'confidence_class_{i}' for i in range(len(class_names))]
    missing = [c for c in required_conf_cols if c not in df.columns]
    if missing:
        raise KeyError(f"Missing confidence columns: {missing}")

    y_true = df['true_label'].to_numpy(dtype=np.int8)
    # Stack confidence columns into a (N, num_classes) float32 array efficiently
    y_score = np.vstack([df[c].to_numpy(dtype=np.float32) for c in required_conf_cols]).T
    return y_true, y_score

def compute_tt_threshold(y_true: np.ndarray, y_score: np.ndarray, tt_tpr_target: float = 0.99) -> float:
    """
    Compute the threshold for True-True class to reach tt_tpr_target (if possible).
    y_true: 1D array of integer class labels
    y_score: (N, C) float array
    """
    tt_idx = class_names.index("True-True")
    y_true_bin = label_binarize(y_true, classes=np.arange(len(class_names)))
    fpr, tpr, thresholds = roc_curve(y_true_bin[:, tt_idx], y_score[:, tt_idx])
    if tpr.max() >= tt_tpr_target:
        tt_threshold = thresholds[np.where(tpr >= tt_tpr_target)[0][0]]
    else:
        tt_threshold = thresholds[np.argmax(tpr)]
    return float(tt_threshold)

def apply_tt_only_threshold_fast(df: pd.DataFrame, y_true: np.ndarray, y_score: np.ndarray,
                                 tt_tpr_target: float = 0.99) -> Tuple[pd.DataFrame, float]:
    """
    Vectorized: apply True-True threshold override only.
    Returns modified df (copy) and the tt_threshold.
    """
    df = df.copy()
    tt_idx = class_names.index("True-True")
    tt_threshold = compute_tt_threshold(y_true, y_score, tt_tpr_target)
    tt_scores = y_score[:, tt_idx]  # (N,)
    tt_override_mask = tt_scores >= tt_threshold  # boolean mask (N,)

    # Use index alignment to set values in df
    idxs = df.index.to_numpy()
    override_idxs = idxs[tt_override_mask]
    # Apply vectorized assignment
    if len(override_idxs) > 0:
        df.loc[override_idxs, 'pred_label'] = np.int8(tt_idx)
        df.loc[override_idxs, 'confidence'] = tt_scores[tt_override_mask].astype(np.float32)

    # Recompute is_correct vectorized
    df['is_correct'] = (df['pred_label'].to_numpy(dtype=np.int8) == df['true_label'].to_numpy(dtype=np.int8))
    return df, tt_threshold

## Performance Visualization Functions

In [None]:
def downsample_indices(n: int, max_rows: int = 2_000_000, seed: int = 42) -> Optional[np.ndarray]:
    """
    Return a list of indices to sample from a dataset of size n.
    If n <= max_rows, returns None (no downsampling).
    """
    if n <= max_rows:
        return None
    rng = np.random.default_rng(seed)
    return rng.choice(n, size=max_rows, replace=False)

def plot_loss_and_accuracy(metrics, save_path=None):
    """
    Plot training/test loss and accuracy curves from your metrics dictionary.
    """

    train_loss = np.array(metrics.get("train_loss", []))
    test_loss  = np.array(metrics.get("test_loss", []))
    train_acc  = np.array(metrics.get("train_acc", []))
    test_acc   = np.array(metrics.get("test_acc", []))

    epochs = np.arange(1, len(train_loss) + 1)

    fig, axs = plt.subplots(1, 2, figsize=(14, 5))

    axs[0].plot(epochs, train_loss, label="Train Loss", linewidth=2)
    axs[0].plot(epochs, test_loss, label="Test Loss", linewidth=2)
    axs[0].set_title("Loss Over Epochs", fontsize=14)
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].legend()
    axs[0].grid(True, alpha=0.3)

    axs[1].plot(epochs, train_acc, label="Train Accuracy", linewidth=2)
    axs[1].plot(epochs, test_acc, label="Test Accuracy", linewidth=2)
    axs[1].set_title("Accuracy Over Epochs", fontsize=14)
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    axs[1].set_ylim(0, 1.0)
    axs[1].legend()
    axs[1].grid(True, alpha=0.3)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150)
        print(f"Saved loss/accuracy curves to {save_path}")
    else:
        plt.show()


def plot_roc_curves(y_true: np.ndarray, y_score: np.ndarray,
                    sample_idx: Optional[np.ndarray] = None,
                    title_suffix: str = ""):
    """Plot ROC for each class using provided arrays (fast)."""
    if sample_idx is not None:
        y_true = y_true[sample_idx]
        y_score = y_score[sample_idx, :]

    plt.figure(figsize=(10, 8))
    for i, cname in enumerate(class_names):
        fpr, tpr, _ = roc_curve(label_binarize(y_true, classes=np.arange(len(class_names)))[:, i], y_score[:, i])
        plt.plot(fpr, tpr, lw=2, label=f"{cname} (AUC = {auc(fpr, tpr):.3f})", color=class_colors[i])
    plt.plot([0,1],[0,1],'k--',lw=2,label="Random classifier")
    plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curves - Multi-class{title_suffix}")
    plt.legend(loc="lower right"); plt.grid(alpha=0.3)
    plt.show()
    plt.close()

def plot_precision_recall_curves(y_true: np.ndarray, y_score: np.ndarray,
                                 sample_idx: Optional[np.ndarray] = None,
                                 title_suffix: str = ""):
    if sample_idx is not None:
        y_true = y_true[sample_idx]
        y_score = y_score[sample_idx, :]

    plt.figure(figsize=(10, 8))
    for i, cname in enumerate(class_names):
        precision, recall, _ = precision_recall_curve(label_binarize(y_true, classes=np.arange(len(class_names)))[:, i], y_score[:, i])
        plt.plot(recall, precision, lw=2, label=f"{cname} (AP = {auc(recall, precision):.3f})", color=class_colors[i])
    plt.xlabel("Recall"); plt.ylabel("Precision")
    plt.title(f"Precision-Recall Curves{title_suffix}")
    plt.legend(loc="lower left"); plt.grid(alpha=0.3)
    plt.show()
    plt.close()

def plot_confusion_matrix_from_df(df: pd.DataFrame, title_suffix: str = "", normalize: bool = True, sample_idx: Optional[np.ndarray] = None):
    if sample_idx is not None:
        df_plot = df.iloc[sample_idx]
    else:
        df_plot = df

    y_true = df_plot['true_label'].to_numpy(dtype=np.int8)
    y_pred = df_plot['pred_label'].to_numpy(dtype=np.int8)
    cm = confusion_matrix(y_true, y_pred, labels=np.arange(len(class_names)))
    if normalize:
        cm = cm.astype(float) / cm.sum(axis=1)[:, None]
        fmt, title = '.2f', f'Normalized Confusion Matrix{title_suffix}'
    else:
        fmt, title = 'd', f'Confusion Matrix{title_suffix}'

    plt.figure(figsize=(10,8))
    sns.heatmap(cm, annot=True, fmt=fmt, cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted"); plt.ylabel("True"); plt.title(title)
    plt.xticks(rotation=45); plt.yticks(rotation=0)
    plt.tight_layout(); plt.show()
    plt.close()

    print("Classification Report:")
    print(classification_report(y_true, y_pred, target_names=class_names, digits=3))

def plot_classification_confidence_analysis(df: pd.DataFrame, sample_idx: Optional[np.ndarray] = None, title_suffix: str = ""):
    if sample_idx is not None:
        df_plot = df.iloc[sample_idx]
    else:
        df_plot = df

    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    conf = df_plot['confidence'].to_numpy(dtype=np.float32)
    is_correct = df_plot['is_correct'].to_numpy(dtype=np.bool_)
    y_true = df_plot['true_label'].to_numpy(dtype=np.int8)

    axes[0, 0].hist(conf[is_correct], bins=30, alpha=0.7, label='Correct')
    axes[0, 0].hist(conf[~is_correct], bins=30, alpha=0.7, label='Incorrect')
    axes[0, 0].set_xlabel('Confidence'); axes[0, 0].set_ylabel('Count'); axes[0, 0].set_title('Confidence Distribution: Correct vs Incorrect')
    axes[0, 0].legend(); axes[0, 0].grid(alpha=0.3)

    for cls in range(len(class_names)):
        cls_mask_correct = (y_true == cls) & is_correct
        cls_mask_incorrect = (y_true == cls) & ~is_correct
        if cls_mask_correct.any():
            axes[0,1].hist(conf[cls_mask_correct], bins=20, alpha=0.6, color=class_colors[cls], label=class_names[cls])
        if cls_mask_incorrect.any():
            axes[1,0].hist(conf[cls_mask_incorrect], bins=20, alpha=0.6, color=class_colors[cls], label=class_names[cls])

    axes[0,1].set_xlabel('Confidence'); axes[0,1].set_ylabel('Count'); axes[0,1].set_title('Confidence Distribution by Class (Correct Predictions)')
    axes[0,1].legend(); axes[0,1].grid(alpha=0.3)
    axes[1,0].set_xlabel('Confidence'); axes[1,0].set_ylabel('Count'); axes[1,0].set_title('Confidence Distribution by Class (Incorrect Predictions)')
    axes[1,0].legend(); axes[1,0].grid(alpha=0.3)

    bins = np.linspace(0,1,11)
    accuracy_by_bin = np.array([is_correct[(conf >= bins[i]) & (conf < bins[i+1])].mean() if np.any((conf >= bins[i]) & (conf < bins[i+1])) else 0 for i in range(len(bins)-1)])
    axes[1,1].plot(bins[:-1]+0.05, accuracy_by_bin, 'o-', linewidth=2, markersize=8)
    axes[1,1].set_xlabel('Confidence Bin'); axes[1,1].set_ylabel('Accuracy'); axes[1,1].set_title('Accuracy vs Confidence'); axes[1,1].grid(alpha=0.3)
    axes[1,1].set_ylim(0,1)

    plt.tight_layout(); plt.show(); plt.close()

# The more heavy distribution plots can accept y_true/y_score or df and accept sample_idx to be efficient
def plot_class_wise_distributions(df: pd.DataFrame, y_true: np.ndarray, y_score: np.ndarray,
                                  tt_threshold: float, sample_idx: Optional[np.ndarray] = None, title_suffix: str = ""):
    if sample_idx is not None:
        y_true_plot = y_true[sample_idx]
        y_score_plot = y_score[sample_idx, :]
    else:
        y_true_plot = y_true
        y_score_plot = y_score

    global_min = float(np.nanmin(y_score_plot))
    global_max = float(np.nanmax(y_score_plot))
    num_bins = 50
    common_bins = np.linspace(global_min, global_max, num_bins + 1)

    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    title_suffix_full = f" ({', '.join([f'{class_names[i]}: {(y_true_plot==i).sum():,}' for i in range(len(class_names))])})"
    fig.suptitle(f"Class-wise Score Distributions with True-True Threshold{title_suffix}{title_suffix_full}", fontsize=14)

    for idx, class_name in enumerate(class_names):
        ax = axes[idx]
        ax.set_title(class_name)
        ax.set_xlabel('Score'); ax.set_ylabel('Density')
        optimal_thresh = tt_threshold if class_name == 'True-True' else 0.5

        for truth_type in sorted(np.unique(y_true_plot)):
            scores = y_score_plot[y_true_plot == truth_type, idx]
            fraction_above = np.mean(scores > optimal_thresh) if scores.size > 0 else 0.0
            ax.hist(scores, bins=common_bins, density=True, alpha=0.6, label=f'Truth {truth_type} (> {fraction_above:.3%})', color=class_colors[truth_type])

        ax.axvline(optimal_thresh, color='black', linestyle='dashed', linewidth=2, label=f'Thresh: {optimal_thresh:.4f}')
        ax.legend()
    if len(class_names) < len(axes):
        axes[len(class_names)].set_visible(False)
    plt.tight_layout(rect=[0, 0, 1, 0.95]); plt.show(); plt.close()

def plot_truth_type_distributions(y_true: np.ndarray, y_score: np.ndarray,
                                  tt_threshold: float, sample_idx: Optional[np.ndarray] = None, title_suffix: str = ""):
    if sample_idx is not None:
        y_true_plot = y_true[sample_idx]
        y_score_plot = y_score[sample_idx, :]
    else:
        y_true_plot = y_true
        y_score_plot = y_score

    global_min = float(np.nanmin(y_score_plot))
    global_max = float(np.nanmax(y_score_plot))
    num_bins = 50
    common_bins = np.linspace(global_min, global_max, num_bins + 1)

    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    title_suffix_full = f" ({', '.join([f'{class_names[i]}: {(y_true_plot==i).sum():,}' for i in range(len(class_names))])})"
    fig.suptitle(f"Truth Type-wise Score Distributions for All Output Classes{title_suffix}{title_suffix_full}", fontsize=14)

    unique_truth_types = sorted(np.unique(y_true_plot))
    for idx, truth_type in enumerate(unique_truth_types):
        if idx >= len(axes):
            break
        ax = axes[idx]
        ax.set_title(f'Truth Type {truth_type} ({class_names[truth_type]})')
        ax.set_xlabel('Score'); ax.set_ylabel('Density')
        associated_class_idx = truth_type
        optimal_threshold = tt_threshold if class_names[associated_class_idx] == 'True-True' else 0.5

        for class_idx, class_name in enumerate(class_names):
            scores = y_score_plot[y_true_plot == truth_type, class_idx]
            fraction_above = np.mean(scores > optimal_threshold) * 100 if scores.size > 0 else 0.0
            ax.hist(scores, bins=common_bins, density=True, alpha=0.6, label=f'{class_name} ({fraction_above:.1f}%)', color=class_colors[class_idx])

        ax.axvline(optimal_threshold, color='black', linestyle='dashed', linewidth=2, label=f'Thresh: {optimal_threshold:.4f}')
        ax.legend()
    for idx in range(len(unique_truth_types), len(axes)):
        axes[idx].set_visible(False)
    plt.tight_layout(rect=[0, 0, 1, 0.95]); plt.show(); plt.close()

def plot_loss_and_accuracy(metrics, save_path=None):
    """
    Plot training/test loss and accuracy curves from your metrics dictionary.
    """

    train_loss = np.array(metrics.get("train_loss", []))
    test_loss  = np.array(metrics.get("test_loss", []))
    train_acc  = np.array(metrics.get("train_acc", []))
    test_acc   = np.array(metrics.get("test_acc", []))

    epochs = np.arange(1, len(train_loss) + 1)

    fig, axs = plt.subplots(1, 2, figsize=(14, 5))

    axs[0].plot(epochs, train_loss, label="Train Loss", linewidth=2)
    axs[0].plot(epochs, test_loss, label="Test Loss", linewidth=2)
    axs[0].set_title("Loss Over Epochs", fontsize=14)
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].legend()
    axs[0].grid(True, alpha=0.3)

    axs[1].plot(epochs, train_acc, label="Train Accuracy", linewidth=2)
    axs[1].plot(epochs, test_acc, label="Test Accuracy", linewidth=2)
    axs[1].set_title("Accuracy Over Epochs", fontsize=14)
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    axs[1].set_ylim(0, 1.0)
    axs[1].legend()
    axs[1].grid(True, alpha=0.3)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150)
        print(f"Saved loss/accuracy curves to {save_path}")
    else:
        plt.show()

def load_metrics(model_name: str, file_pwd: str) -> Optional[Dict]:
    """
    Load training metrics from a file.
    Adjust this based on how you save your metrics.
    """
    import pickle
    import json
    
    # Try different file formats
    possible_paths = [
        f"{file_pwd}/{model_name}_metrics.pkl",
        f"{file_pwd}/{model_name}_metrics.json",
        f"{file_pwd}/metrics_{model_name}.pkl",
        f"{file_pwd}/metrics_{model_name}.json",
    ]
    
    for path in possible_paths:
        try:
            if path.endswith('.pkl'):
                with open(path, 'rb') as f:
                    metrics = pickle.load(f)
                print(f"Loaded metrics from {path}")
                return metrics
            elif path.endswith('.json'):
                with open(path, 'r') as f:
                    metrics = json.load(f)
                print(f"Loaded metrics from {path}")
                return metrics
        except FileNotFoundError:
            continue
        except Exception as e:
            print(f"Error loading {path}: {e}")
            continue
    
    print("No metrics file found")
    return None

## Event Feature Analysis

In [None]:
def plot_event_features_comprehensive(df: pd.DataFrame, num_events: int = 5,
                                      features: List[str] = ['snr', 'eta', 'phi'],
                                      log: bool = False):
    unique_events = df['event_id'].unique()
    events_to_plot = unique_events[:min(num_events, len(unique_events))]
    class_names_dict = {0: 'Lone-Lone', 1: 'True-True', 2: 'Cluster-Lone', 3: 'Lone-Cluster', 4: 'Cluster-Cluster'}
    # Pre-group once
    grouped = df.groupby('event_id', sort=False)
    for event_id in events_to_plot:
        if event_id not in grouped.groups:
            continue
        event_data = grouped.get_group(event_id)
        n_edges = len(event_data)
        if n_edges == 0:
            continue
        print(f"\nEvent {event_id}: {n_edges} edges")
        class_counts = event_data['true_label'].value_counts().to_dict()
        for cls, count in sorted(class_counts.items()):
            print(f"  Class {cls} ({class_names_dict[cls]}): {count} edges")
        n_features = len(features)
        fig, axes = plt.subplots(n_features, 2, figsize=(15, 5 * n_features))
        if n_features == 1:
            axes = axes.reshape(1, -1)
        true_labels = event_data['true_label'].to_numpy()
        class_masks = {cls: (true_labels == cls) for cls in range(5)}
        for feature_idx, feature in enumerate(features):
            s_vals = event_data[f'{feature}_source'].to_numpy(dtype=np.float32)
            t_vals = event_data[f'{feature}_target'].to_numpy(dtype=np.float32)
            all_vals = np.concatenate((s_vals, t_vals), dtype=np.float32)
            use_log = log and feature.lower() == 'snr'
            x_label = f"Log {feature.upper()}" if use_log else feature.upper()
            ax_left = axes[feature_idx, 0]
            ax_left.hist(all_vals, bins=100, alpha=0.7, log=use_log)
            ax_left.set_title(f'Event {event_id}: Overall {x_label} Distribution\n({len(all_vals)} values, range: {np.nanmin(all_vals):.2f}‚Äì{np.nanmax(all_vals):.2f})')
            ax_left.set_xlabel(x_label); ax_left.set_ylabel('Count'); ax_left.grid(True, alpha=0.3)
            ax_right = axes[feature_idx, 1]
            for cls in range(5):
                mask = class_masks[cls]
                if not np.any(mask):
                    continue
                cls_vals = np.concatenate((s_vals[mask], t_vals[mask]), dtype=np.float32)
                ax_right.hist(cls_vals, bins=100, alpha=0.5, color=class_colors[cls], label=f'Class {cls} ({class_names_dict[cls]})', log=use_log)
            ax_right.set_title(f'Event {event_id}: {x_label} by True Class'); ax_right.set_xlabel(x_label); ax_right.set_ylabel('Count'); ax_right.legend(); ax_right.grid(True, alpha=0.3)
        plt.tight_layout(); plt.show(); plt.close()
        # Summary stats
        print(f"\nEvent {event_id} Feature Statistics:")
        for feature in features:
            s = event_data[f'{feature}_source'].to_numpy(dtype=np.float32)
            t = event_data[f'{feature}_target'].to_numpy(dtype=np.float32)
            corr = np.corrcoef(s, t)[0, 1] if len(s) > 1 else np.nan
            print(f"  {feature.upper()}: Source mean={s.mean():.3f}, Target mean={t.mean():.3f}, Correlation={corr:.3f}")

def plot_error_analysis_by_features(df: pd.DataFrame, title_suffix: str = ""):
    errors = df[df['is_correct'] == False]
    if len(errors) == 0:
        print("No errors found for analysis"); return
    fig, axes = plt.subplots(2, 2, figsize=(15,12))
    spatial = df['spatial_distance'].to_numpy(dtype=np.float32)
    avg_snr = df['avg_snr'].to_numpy(dtype=np.float32)
    is_correct = df['is_correct'].to_numpy(dtype=np.bool_)
    # Error rate vs spatial
    bins = np.linspace(0, np.nanmax(spatial), 10)
    error_rates = [np.mean(~is_correct[(spatial >= bins[i]) & (spatial < bins[i+1])]) for i in range(len(bins)-1)]
    axes[0,0].plot(bins[:-1]+np.diff(bins)[0]/2, error_rates, 'o-', linewidth=2)
    axes[0,0].set_xlabel('Spatial Distance'); axes[0,0].set_ylabel('Error Rate'); axes[0,0].set_title('Error Rate vs Spatial Distance'); axes[0,0].grid(alpha=0.3)
    # Error rate vs avg SNR
    bins = np.linspace(0, np.nanmax(avg_snr), 10)
    error_rates_snr = [np.mean(~is_correct[(avg_snr >= bins[i]) & (avg_snr < bins[i+1])]) for i in range(len(bins)-1)]
    axes[0,1].plot(bins[:-1]+np.diff(bins)[0]/2, error_rates_snr, 'o-', linewidth=2, color='orange')
    axes[0,1].set_xlabel('Average SNR'); axes[0,1].set_ylabel('Error Rate'); axes[0,1].set_title('Error Rate vs Average SNR'); axes[0,1].grid(alpha=0.3)
    # Error type heatmap
    error_types = errors.groupby(['true_label','pred_label']).size().unstack(fill_value=0)
    im = axes[1,0].imshow(error_types.values, cmap='Reds', aspect='auto')
    axes[1,0].set_xlabel('Predicted Label'); axes[1,0].set_ylabel('True Label'); axes[1,0].set_title('Error Type Heatmap')
    axes[1,0].set_xticks(range(len(class_names))); axes[1,0].set_xticklabels(class_names, rotation=45)
    axes[1,0].set_yticks(range(len(class_names))); axes[1,0].set_yticklabels(class_names)
    plt.colorbar(im, ax=axes[1,0])
    # confidence by error type
    for true_label in np.unique(errors['true_label']):
        err_mask = (errors['true_label'] == true_label)
        for pred_label in np.unique(errors['pred_label'][err_mask]):
            conf_vals = errors.loc[err_mask & (errors['pred_label'] == pred_label), 'confidence'].to_numpy()
            if conf_vals.size > 0:
                axes[1,1].hist(conf_vals, bins=20, alpha=0.6, label=f'{class_names[int(true_label)]}‚Üí{class_names[int(pred_label)]}')
    axes[1,1].set_xlabel('Confidence'); axes[1,1].set_ylabel('Count'); axes[1,1].set_title('Confidence Distribution by Error Type'); axes[1,1].legend(); axes[1,1].grid(alpha=0.3)
    plt.tight_layout(); plt.show(); plt.close()

## Multi-Model Comparison

In [None]:
def compare_multiple_models(model_names=["1_layer", "2_layer", "3_layer", "6_layer", "9_layer", "12_layer"]):
    """Compare performance across multiple models"""
    comparison_data = []
    
    for model_name in model_names:
        try:
            df = load_comprehensive_data(model_name, file_pwd)
            
            # Basic metrics
            accuracy = df['is_correct'].mean()
            total_samples = len(df)
            
            # Per-class accuracy
            class_accuracies = []
            for cls in range(5):
                cls_mask = df['true_label'] == cls
                if cls_mask.any():
                    cls_acc = df[cls_mask]['is_correct'].mean()
                    class_accuracies.append(cls_acc)
                else:
                    class_accuracies.append(0)
            
            comparison_data.append({
                'model': model_name,
                'layers': int(model_name.split('_')[0]),
                'accuracy': accuracy,
                'total_samples': total_samples,
                'class_0_acc': class_accuracies[0],
                'class_1_acc': class_accuracies[1],
                'class_2_acc': class_accuracies[2],
                'class_3_acc': class_accuracies[3],
                'class_4_acc': class_accuracies[4]
            })
            
            print(f"Processed {model_name}: accuracy = {accuracy:.3%}")
            
        except FileNotFoundError:
            print(f"File not found for {model_name}, skipping...")
    
    return pd.DataFrame(comparison_data)

def plot_model_comparison(comparison_df):
    """Plot comparison across different models"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Overall accuracy by number of layers
    axes[0, 0].plot(comparison_df['layers'], comparison_df['accuracy'], 'o-', linewidth=2, markersize=8)
    axes[0, 0].set_xlabel('Number of Layers')
    axes[0, 0].set_ylabel('Overall Accuracy')
    axes[0, 0].set_title('Model Accuracy vs Number of Layers')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Per-class accuracy comparison
    class_acc_cols = ['class_0_acc', 'class_1_acc', 'class_2_acc', 'class_3_acc', 'class_4_acc']
    x_pos = np.arange(len(comparison_df))
    width = 0.15
    
    for i, (cls_name, color) in enumerate(zip(class_names, class_colors)):
        axes[0, 1].bar(x_pos + i*width, comparison_df[f'class_{i}_acc'], 
                       width, label=cls_name, color=color, alpha=0.7)
    
    axes[0, 1].set_xlabel('Model')
    axes[0, 1].set_ylabel('Class Accuracy')
    axes[0, 1].set_title('Per-Class Accuracy by Model')
    axes[0, 1].set_xticks(x_pos + 2*width)
    axes[0, 1].set_xticklabels(comparison_df['model'], rotation=45)
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Confidence distributions across models
    for model_name in comparison_df['model']:
        try:
            df = load_comprehensive_data(model_name, file_pwd)
            axes[1, 0].hist(df['confidence'], bins=30, alpha=0.6, 
                           label=f"{model_name}", density=True)
        except:
            continue
    
    axes[1, 0].set_xlabel('Confidence')
    axes[1, 0].set_ylabel('Density')
    axes[1, 0].set_title('Confidence Distribution by Model')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Error rate by spatial distance across models
    for model_name in comparison_df['model']:
        try:
            df = load_comprehensive_data(model_name, file_pwd)
            spatial_bins = np.linspace(0, df['spatial_distance'].max(), 10)
            error_rates = []
            for i in range(len(spatial_bins) - 1):
                bin_mask = (df['spatial_distance'] >= spatial_bins[i]) & (df['spatial_distance'] < spatial_bins[i + 1])
                bin_data = df[bin_mask]
                if len(bin_data) > 0:
                    error_rate = (bin_data['is_correct'] == False).mean()
                    error_rates.append(error_rate)
                else:
                    error_rates.append(0)
            
            axes[1, 1].plot(spatial_bins[:-1] + np.diff(spatial_bins)[0]/2, 
                           error_rates, 'o-', label=model_name, alpha=0.7)
        except:
            continue
    
    axes[1, 1].set_xlabel('Spatial Distance')
    axes[1, 1].set_ylabel('Error Rate')
    axes[1, 1].set_title('Error Rate vs Spatial Distance by Model')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## Main Analysis Workflow

In [None]:
def run_comprehensive_model_analysis(df: pd.DataFrame, model_name: str = "",
                                     fast_mode: bool = True, max_plot_rows: int = 2_000_000,
                                     metrics: Optional[Dict] = None):  
    """
    Main analysis runner. If fast_mode=True:
      - skip heavy plotting or subsample for plotting
    """
    title_suffix = f" - {model_name}" if model_name else ""
    print(f"=== COMPREHENSIVE MODEL ANALYSIS{title_suffix} ===")

    # Downcast dtypes for memory & speed
    df = downcast_dtypes(df)

    # Prepare roc arrays once (cached)
    y_true, y_score = prepare_roc_data_from_df(df)

    # Apply True-True only threshold (vectorized)
    df_optimized, tt_threshold = apply_tt_only_threshold_fast(df, y_true, y_score, tt_tpr_target=0.99)
    print(f"‚úÖ True-True threshold for 99% TPR: {tt_threshold:.4f}")

    print(f"Total samples: {len(df_optimized):,}")
    print(f"Overall accuracy: {df_optimized['is_correct'].mean():.3%}")

    print("Class distribution:")
    for cls_idx, cname in enumerate(class_names):
        count = int((df_optimized['true_label'] == cls_idx).sum())
        acc = df_optimized[df_optimized['true_label'] == cls_idx]['is_correct'].mean() if count else 0
        true_pos = int(((df_optimized['true_label'] == cls_idx) & (df_optimized['pred_label'] == cls_idx)).sum())
        recall = true_pos / count if count > 0 else 0
        print(f"  {cname}: {count:,} samples, accuracy: {acc:.3%}, recall: {recall:.3%}")

    tt_true_count = int((df_optimized['true_label'] == class_names.index("True-True")).sum())
    tt_recall = ((df_optimized['true_label'] == class_names.index("True-True")) & (df_optimized['pred_label'] == class_names.index("True-True"))).sum() / tt_true_count
    print(f"\n‚úÖ Verified True-True Recall: {tt_recall:.3%}")

    # ADD THIS SECTION - Plot training curves (always show this - it's lightweight)
    if metrics is not None:
        print("\nüìä Plotting training history...")
        plot_loss_and_accuracy(metrics)
    else:
        print("\n‚ö†Ô∏è  No metrics provided - skipping training history plots")

    # Decide sampling for plotting
    n = len(df_optimized)
    sample_idx = downsample_indices(n, max_rows=max_plot_rows)
    if fast_mode:
        print("‚ö° Fast mode enabled: plotting will use subsampling or be skipped for heavy plots.")
    else:
        print("‚ÑπÔ∏è Full plotting mode: generating all plots (this may take time).")

    # Plotting: pass y_true, y_score to avoid repeated prepare calls
    if not fast_mode:
        # full resolution (but still safe-guard with sampling for extremely large n)
        # REMOVED plot_loss_and_accuracy from here since we moved it above
        plot_roc_curves(y_true, y_score, sample_idx=None if sample_idx is None else sample_idx, title_suffix=title_suffix)
        plot_precision_recall_curves(y_true, y_score, sample_idx=None if sample_idx is None else sample_idx, title_suffix=title_suffix)
        plot_confusion_matrix_from_df(df_optimized, title_suffix=title_suffix, normalize=True, sample_idx=None if sample_idx is None else sample_idx)
        plot_confusion_matrix_from_df(df_optimized, title_suffix=title_suffix, normalize=False, sample_idx=None if sample_idx is None else sample_idx)
        plot_classification_confidence_analysis(df_optimized, sample_idx=None if sample_idx is None else sample_idx, title_suffix=title_suffix)
        plot_error_analysis_by_features(df_optimized, title_suffix=title_suffix)
        plot_class_wise_distributions(df_optimized, y_true, y_score, tt_threshold, sample_idx=None if sample_idx is None else sample_idx, title_suffix=title_suffix)
        plot_truth_type_distributions(y_true, y_score, tt_threshold, sample_idx=None if sample_idx is None else sample_idx, title_suffix=title_suffix)
        plot_event_features_comprehensive(df_optimized, num_events=5)
    else:
        # fast mode: subsampled plots + skip some heavy ones
        print("Generating key diagnostic plots (subsampled)...")
        # REMOVED plot_loss_and_accuracy from here since we moved it above
        plot_roc_curves(y_true, y_score, sample_idx=sample_idx, title_suffix=title_suffix + " (subsampled)")
        plot_precision_recall_curves(y_true, y_score, sample_idx=sample_idx, title_suffix=title_suffix + " (subsampled)")
        plot_confusion_matrix_from_df(df_optimized, title_suffix=title_suffix + " (subsampled)", normalize=True, sample_idx=sample_idx)
        plot_classification_confidence_analysis(df_optimized, sample_idx=sample_idx, title_suffix=title_suffix)
        # heavy distribution plots (subsampled)
        plot_class_wise_distributions(df_optimized, y_true, y_score, tt_threshold, sample_idx=sample_idx, title_suffix=title_suffix + " (subsampled)")
        plot_truth_type_distributions(y_true, y_score, tt_threshold, sample_idx=sample_idx, title_suffix=title_suffix + " (subsampled)")
        # skip per-event comprehensive plots in fast mode (they can be enabled when needed)
        print("Note: event-level comprehensive plots skipped in fast mode to save time. Re-run with fast_mode=False to generate them.")

    return df_optimized, tt_threshold

## Cluster Visualization Functions

In [None]:
def plot_3d_snr_distributions(df):
    """Create 3D SNR distribution plots"""
    # Your existing code - works perfectly!
    eta  = df['eta_source']
    phi  = df['phi_source']
    snr  = df['snr_source']

    n_eta, n_phi = 30, 30
    hist_snr, eta_edges, phi_edges = np.histogram2d(
        eta, phi, bins=[n_eta, n_phi], weights=snr
    )
    counts, _, _ = np.histogram2d(eta, phi, bins=[n_eta, n_phi])
    mean_snr = np.divide(hist_snr, counts, out=np.zeros_like(hist_snr), where=counts>0)

    eta_centers = 0.5*(eta_edges[:-1] + eta_edges[1:])
    phi_centers = 0.5*(phi_edges[:-1] + phi_edges[1:])
    eta_grid, phi_grid = np.meshgrid(eta_centers, phi_centers, indexing='ij')

    # Plot 1: Total SNR Distribution
    fig = plt.figure(figsize=(10,7))
    ax = fig.add_subplot(111, projection='3d')

    dx = dy = (eta_edges[1]-eta_edges[0]) * 0.8
    dz = hist_snr.ravel()

    # Normalize dz for color mapping
    norm = plt.Normalize(dz.min(), dz.max())
    colors = cm.viridis(norm(dz))

    ax.bar3d(
        eta_grid.ravel(),
        phi_grid.ravel(),
        np.zeros_like(dz),
        dx, dy, dz,
        color=colors,
        shade=True,
    )

    ax.set_xlabel('Œ∑')
    ax.set_ylabel('œÜ')
    ax.set_zlabel('Total SNR')
    ax.set_title('3D SNR Distribution - Source Nodes')

    mappable = cm.ScalarMappable(norm=norm, cmap='viridis')
    fig.colorbar(mappable, ax=ax, shrink=0.6, label='SNR')
    plt.tight_layout()
    plt.show()

    # Plot 2: Mean SNR Distribution
    fig = plt.figure(figsize=(10,7))
    ax = fig.add_subplot(111, projection='3d')

    dz = mean_snr.ravel()

    # Normalize dz for color mapping
    norm = plt.Normalize(dz.min(), dz.max())
    colors = cm.viridis(norm(dz))

    ax.bar3d(
        eta_grid.ravel(),
        phi_grid.ravel(),
        np.zeros_like(dz),
        dx, dy, dz,
        color=colors,
        shade=True,
    )

    ax.set_xlabel('Œ∑')
    ax.set_ylabel('œÜ')
    ax.set_zlabel('Mean SNR')
    ax.set_title('3D Mean SNR Distribution - Source Nodes')

    mappable = cm.ScalarMappable(norm=norm, cmap='viridis')
    fig.colorbar(mappable, ax=ax, shrink=0.6, label='Mean SNR')
    plt.tight_layout()
    plt.show()

def plot_2d_class_distributions(df):
    """Create 2D scatter plots by class"""
    # Select only classes 1‚Äì4 (or use 0-4 for all classes)
    mask = df['true_label'].isin([1,2,3,4])
    subset = df[mask]

    # Custom color palette matching your class definitions
    class_colors_dict = {
        1: 'orange',  # True-True
        2: 'green',   # Cluster-Lone  
        3: 'red',     # Lone-Cluster
        4: 'purple'   # Cluster-Cluster
    }

    plt.figure(figsize=(10,8))
    sc = plt.scatter(
        subset['eta_source'],
        subset['phi_source'],
        c=subset['true_label'].map(class_colors_dict),
        s=15,
        alpha=0.7
    )

    # Create custom legend
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='orange', markersize=8, label='True-True'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='green', markersize=8, label='Cluster-Lone'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='red', markersize=8, label='Lone-Cluster'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='purple', markersize=8, label='Cluster-Cluster')
    ]

    plt.legend(handles=legend_elements, loc='upper right')
    plt.xlabel('Œ∑ (Eta)')
    plt.ylabel('œÜ (Phi)')
    plt.title('Eta‚ÄìPhi Distribution by True Class (Source Nodes)')
    plt.grid(True, alpha=0.3)
    plt.show()

def plot_comparison_visualizations(df):
    """Create comparison visualizations"""
    # Source vs Target comparison
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Source nodes
    sc1 = ax1.scatter(df['eta_source'], df['phi_source'], 
                     c=df['snr_source'], cmap='viridis', s=10, alpha=0.6)
    ax1.set_xlabel('Œ∑')
    ax1.set_ylabel('œÜ')
    ax1.set_title('Source Nodes - Colored by SNR')
    plt.colorbar(sc1, ax=ax1, label='SNR')

    # Target nodes  
    sc2 = ax2.scatter(df['eta_target'], df['phi_target'],
                     c=df['snr_target'], cmap='viridis', s=10, alpha=0.6)
    ax2.set_xlabel('Œ∑')
    ax2.set_ylabel('œÜ')
    ax2.set_title('Target Nodes - Colored by SNR')
    plt.colorbar(sc2, ax=ax2, label='SNR')

    plt.tight_layout()
    plt.show()

    # Correct vs Incorrect predictions
    plt.figure(figsize=(10,8))

    correct = df[df['is_correct'] == True]
    incorrect = df[df['is_correct'] == False]

    plt.scatter(correct['eta_source'], correct['phi_source'], 
               c='green', s=10, alpha=0.6, label='Correct Predictions')
    plt.scatter(incorrect['eta_source'], incorrect['phi_source'], 
               c='red', s=10, alpha=0.6, label='Incorrect Predictions')

    plt.xlabel('Œ∑ (Eta)')
    plt.ylabel('œÜ (Phi)')
    plt.title('Prediction Accuracy in Eta-Phi Space')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

## Execution: Single Model Analysis

In [None]:
if __name__ == "__main__":
    # === USER CONFIGURE HERE ===
    model_name = "fixed_generator_bs1"   # change as needed
    file_pwd = "/storage/mxg1065/fixed_batch_size_models"
    use_polars = True            # set True if you want Polars (and have it installed)
    fast_mode = True              # True = fast (subsampled plots, skip heavy event-level plots)
    max_plot_rows = 2_000_000     # cap for plotting; adjust if you have more memory
    # ============================

    print(f"üöÄ Starting analysis for {model_name} model...")
    
    metrics = load_metrics(model_name, file_pwd)
    
    df = load_comprehensive_data(model_name, file_pwd, columns=DEFAULT_COLUMNS, use_polars=use_polars)

    print("üìà Basic Statistics (pre-optimization):")
    print(f"Total samples: {len(df):,}")
    if 'is_correct' in df.columns:
        print(f"Overall accuracy: {df['is_correct'].mean():.3%}")
    print("\nClass Distribution (pre-optimization):")
    for cls_idx, cls_name in enumerate(class_names):
        if cls_idx in df['true_label'].unique():
            count = int((df['true_label'] == cls_idx).sum())
            acc = df[df['true_label'] == cls_idx]['is_correct'].mean() if count > 0 else 0
            print(f"  {cls_name}: {count:,} samples, accuracy: {acc:.3%}")
        else:
            print(f"  {cls_name}: 0 samples")

    # Run the fast/optimized analysis
    df_optimized, tt_threshold = run_comprehensive_model_analysis(df, model_name=model_name, 
                                                                  fast_mode=fast_mode, 
                                                                  max_plot_rows=max_plot_rows,
                                                                  metrics=metrics) 

    print("\nAnalysis complete.")
    print(f"True-True threshold: {tt_threshold:.4f}")

## Execution: Cluster Visualizations

In [None]:
print("üåå Creating cluster visualizations...")
plot_3d_snr_distributions(df)
plot_2d_class_distributions(df)
plot_comparison_visualizations(df)

## Execution: Multi-Model Comparison

In [None]:
# print("üîÑ Comparing multiple models...")
# comparison_df = compare_multiple_models(["1_layer", "2_layer", "3_layer", "6_layer", "9_layer", "12_layer"])
# plot_model_comparison(comparison_df)

## Event-Level Deep Dive

In [None]:
print("üîç Analyzing specific events...")
interesting_events = df['event_id'].value_counts().head(5).index
for event_id in interesting_events[:2]:
    event_data = df[df['event_id'] == event_id]
    print(f"\nüîç Event {event_id} Deep Dive:")
    print(f"   Total edges: {len(event_data)}")
    print(f"   Accuracy: {event_data['is_correct'].mean():.3%}")
    
    # Plot feature distributions for this event
    plot_event_features_comprehensive(df[df['event_id'] == event_id], 
                                    num_events=1, 
                                    features=['snr', 'eta', 'phi'])

## Summary

In [None]:
print("üéâ Analysis completed successfully!")
print("\nüìä Generated Analysis:")
print("‚úÖ ROC Curves & Precision-Recall Curves")
print("‚úÖ Confusion Matrices & Classification Reports") 
print("‚úÖ Confidence Analysis & Error Analysis")
print("‚úÖ 3D SNR Distributions & Cluster Visualizations")
print("‚úÖ Multi-Model Comparisons")
print("‚úÖ Event-Level Deep Dives")