## Compare pooled vs class-conditional topology scoring

This notebook quantifies the impact of enabling **class-conditional** topology scoring (`cfg.detector.topo_class_conditional=True`) versus the current **pooled** (all-class) scoring (`False`).

- **Dataset**: `synthetic_shapes_3class` from `DATASET_REGISTRY` (small, local-friendly; avoids MNIST)
- **Pipeline**: uses `src.api.run_pipeline()` (same wiring as the rest of the repo)
- **Outputs**: adversarial detection metrics + wall-clock runtime

### Setup

We keep the experiment small for local runtime by:
- using `n_samples=600`
- limiting `max_points_for_scoring` (PH is the expensive part)
- using FGSM (fast)
- setting a modest training epoch count

In [4]:
import time
from dataclasses import replace
import numpy as np

import pathlib
import sys 

repo_root = pathlib.Path("..").resolve()
if str(repo_root) not in sys.path:
    sys.path.append(str(repo_root))
from src.api import run_pipeline
from src.utils import ExperimentConfig, DataConfig, ModelConfig, AttackConfig, GraphConfig, DetectorConfig

def _summarize(res, *, wall_s: float) -> dict:
    # res.eval.metrics already contains the repo's standard metrics.
    m = dict(res.eval.metrics)
    out = {
        'wall_s': float(wall_s),
        'threshold': float(getattr(res.detector, 'threshold', np.nan)),
        'roc_auc': m.get('roc_auc'),
        'pr_auc': m.get('pr_auc'),
        'fpr_at_tpr95': m.get('fpr_at_tpr95'),
        'accuracy': m.get('accuracy'),
        'precision': m.get('precision'),
        'recall': m.get('recall'),
        'f1': m.get('f1'),
    }
    return out


def run_one(cfg: ExperimentConfig, *, label: str, max_points_for_scoring: int = 200) -> dict:
    t0 = time.perf_counter()
    res = run_pipeline(
        dataset_name='synthetic_shapes_3class',
        model_name='CNN',
        cfg=cfg,
        # Keep runtime manageable
        make_plots=False,
        run_ood=False,
        eval_only_successful_attacks=True,
        max_points_for_scoring=int(max_points_for_scoring),
    )
    wall = time.perf_counter() - t0
    out = _summarize(res, wall_s=wall)
    out['label'] = str(label)
    out['topo_class_conditional'] = bool(getattr(cfg.detector, 'topo_class_conditional', False))
    out['topo_class_scoring_mode'] = str(getattr(cfg.detector, 'topo_class_scoring_mode', 'min_over_classes'))
    return out

### Base config (kept identical across runs)

We only change `cfg.detector.topo_class_conditional` between the two runs.

In [5]:
# Base experiment config (local-friendly)
cfg_base = ExperimentConfig(
    seed=42,
    device='cpu',
    data=DataConfig(n_samples=600, noise=0.1, train_ratio=0.6, val_ratio=0.2, test_ratio=0.2),
    model=ModelConfig(epochs=20, batch_size=64, learning_rate=1e-3, weight_decay=1e-4),
    attack=AttackConfig(attack_type='fgsm', epsilon=0.05),
    graph=GraphConfig(
        space='feature',
        feature_layer='penultimate',
        # Keep topology on; reduce neighborhood size for speed
        use_topology=True,
        topo_k=30,
        topo_preprocess='pca',
        topo_pca_dim=10,
        topo_maxdim=1,
        # Optional speed knobs
        k=10,
        use_tangent=True,
        tangent_k=20,
    ),
    detector=DetectorConfig(
        detector_type='topology_score',
        topo_percentile=95.0,
        topo_cov_shrinkage=1e-3,
        # We'll toggle these per run
        topo_class_conditional=False,
        topo_class_scoring_mode='min_over_classes',
        topo_min_clean_per_class=5,
    ),
)

cfg_base

ExperimentConfig(data=DataConfig(n_samples=600, noise=0.1, random_state=42, train_ratio=0.6, val_ratio=0.2, test_ratio=0.2, root='./data', download=False), model=ModelConfig(input_dim=2, hidden_dims=[64, 32], output_dim=2, activation='relu', learning_rate=0.001, epochs=20, batch_size=64, weight_decay=0.0001, random_state=42), attack=AttackConfig(attack_type='fgsm', epsilon=0.05, num_steps=10, step_size=0.01, random_start=True), ood=OODConfig(enabled=False, method='feature_shuffle', severity=1.0, seed=None, batch_size=128, patch_size=4, blur_kernel_size=5, blur_sigma=1.0, saltpepper_p=0.05), graph=GraphConfig(k=10, sigma=None, space='feature', feature_layer='penultimate', normalized_laplacian=True, use_diffusion=False, diffusion_components=10, use_tangent=True, tangent_k=20, tangent_dim=None, tangent_var_threshold=0.9, tangent_dim_min=2, tangent_dim_max=None, use_topology=True, topo_k=30, topo_maxdim=1, topo_metric='euclidean', topo_thresh=None, topo_min_persistence=1e-06, topo_preproce

### Run 1: pooled topology scoring (current default)

In [6]:
cfg_pooled = replace(cfg_base, detector=replace(cfg_base.detector, topo_class_conditional=False))
res_pooled = run_one(cfg_pooled, label='pooled', max_points_for_scoring=200)
res_pooled

Epoch [10/20] Train Loss: 0.6179, Train Acc: 83.33%, Val Loss: 0.5454, Val Acc: 91.67%
Epoch [20/20] Train Loss: 0.0086, Train Acc: 100.00%, Val Loss: 0.0166, Val Acc: 100.00%


{'wall_s': 14.791697791000047,
 'threshold': 3.1898625580667903,
 'roc_auc': np.float64(0.8090277777777779),
 'pr_auc': np.float64(0.8569394799631662),
 'fpr_at_tpr95': np.float64(0.9833333333333333),
 'accuracy': 0.8541666666666666,
 'precision': np.float64(0.8235294117647058),
 'recall': np.float64(0.7777777777777778),
 'f1': np.float64(0.7999999999999999),
 'label': 'pooled',
 'topo_class_conditional': False,
 'topo_class_scoring_mode': 'min_over_classes'}

### Run 2: class-conditional topology scoring

This fits one Gaussian per class on clean topology features and scores by the min distance over classes (default).

In [7]:
cfg_cc = replace(
    cfg_base,
    detector=replace(cfg_base.detector, topo_class_conditional=True, topo_class_scoring_mode='min_over_classes'),
)
res_cc = run_one(cfg_cc, label='class_conditional', max_points_for_scoring=200)
res_cc

Epoch [10/20] Train Loss: 0.6179, Train Acc: 83.33%, Val Loss: 0.5454, Val Acc: 91.67%
Epoch [20/20] Train Loss: 0.0086, Train Acc: 100.00%, Val Loss: 0.0166, Val Acc: 100.00%


{'wall_s': 13.754407624999999,
 'threshold': 3.2552150433325644,
 'roc_auc': np.float64(0.8241898148148149),
 'pr_auc': np.float64(0.8679843621741854),
 'fpr_at_tpr95': np.float64(0.9583333333333334),
 'accuracy': 0.8697916666666666,
 'precision': np.float64(0.8615384615384616),
 'recall': np.float64(0.7777777777777778),
 'f1': np.float64(0.8175182481751826),
 'label': 'class_conditional',
 'topo_class_conditional': True,
 'topo_class_scoring_mode': 'min_over_classes'}

### Summary + deltas

In [8]:
results = [res_pooled, res_cc]
for r in results:
    print(r['label'])
    for k in ['roc_auc', 'pr_auc', 'fpr_at_tpr95', 'accuracy', 'f1', 'threshold', 'wall_s']:
        print(f"  {k:>14}: {r.get(k)}")
    print()

def _delta(a, b, k):
    try:
        return float(b[k]) - float(a[k])
    except Exception:
        return None

print('Delta (class_conditional - pooled)')
for k in ['roc_auc', 'pr_auc', 'fpr_at_tpr95', 'accuracy', 'f1', 'wall_s']:
    print(f"  {k:>14}: {_delta(res_pooled, res_cc, k)}")

pooled
         roc_auc: 0.8090277777777779
          pr_auc: 0.8569394799631662
    fpr_at_tpr95: 0.9833333333333333
        accuracy: 0.8541666666666666
              f1: 0.7999999999999999
       threshold: 3.1898625580667903
          wall_s: 14.791697791000047

class_conditional
         roc_auc: 0.8241898148148149
          pr_auc: 0.8679843621741854
    fpr_at_tpr95: 0.9583333333333334
        accuracy: 0.8697916666666666
              f1: 0.8175182481751826
       threshold: 3.2552150433325644
          wall_s: 13.754407624999999

Delta (class_conditional - pooled)
         roc_auc: 0.015162037037037002
          pr_auc: 0.01104488221101918
    fpr_at_tpr95: -0.02499999999999991
        accuracy: 0.015625
              f1: 0.01751824817518266
          wall_s: -1.037290166000048


### Optional: predicted-class scoring mode

If you want to test the alternative scoring mode that conditions on the classifier prediction:
- set `cfg.detector.topo_class_scoring_mode='predicted_class'`
- keep `cfg.detector.topo_class_conditional=True`

Note: this mode can behave differently for targeted attacks that land on the target class manifold.

In [None]:
#Uncomment to run:
cfg_cc_pred = replace(
    cfg_base,
    detector=replace(cfg_base.detector, topo_class_conditional=True, topo_class_scoring_mode='predicted_class'),
)
res_cc_pred = run_one(cfg_cc_pred, label='class_conditional_predicted_class', max_points_for_scoring=200)
res_cc_pred