# MUTAG Generator Evaluation

This notebook compiles a comprehensive evaluation of the initial class-conditional generators trained on the MUTAG dataset. It aggregates metrics from the explainee classifier, the SimGNN distance approximator, and the generators themselves to help diagnose explanation quality and training health.


## Notebook goals

* Load the MUTAG explainee, distance model, and per-class generators.
* Evaluate the explainee's predictive performance on a hold-out split.
* Measure SimGNN graph edit distance (GED) approximation quality.
* Score each generator with respect to its loss terms, fidelity, and diversity.
* Visualize generated graphs alongside real samples from the dataset.
* Summarize parameter counts and other diagnostics useful for debugging.

> **Tip:** Run the end-to-end training pipeline (`python -m src.pipeline.pipeline --config config/mutag.yaml`) beforehand so that checkpoints exist for all components.


## 1. Environment and configuration


In [None]:

import os
import sys
import math
import json
import yaml
import random
from pathlib import Path
from collections import defaultdict, Counter

import torch
import torch.nn.functional as F

try:
    import pandas as pd
except ImportError:  # fallback for minimal environments
    pd = None

try:
    import matplotlib.pyplot as plt
    import seaborn as sns
except ImportError:
    plt = None
    sns = None

try:
    import numpy as np
except ImportError:
    np = None

PROJECT_ROOT = Path("../").resolve()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

CONFIG_PATH = PROJECT_ROOT / "config" / "mutag.yaml"
CHECKPOINT_ROOT = PROJECT_ROOT / "checkpoints"
DATA_ROOT = PROJECT_ROOT / "data"

print(f"Project root: {PROJECT_ROOT}")
print(f"Using config: {CONFIG_PATH}")
print(f"Torch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")



In [None]:

with open(CONFIG_PATH, "r") as f:
    cfg = yaml.safe_load(f)

seed = cfg.get("experiment", {}).get("seed", 42)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

if torch.cuda.is_available() and cfg.get("experiment", {}).get("device", "cuda") == "cuda":
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Random seed set to {seed}")
print(f"Evaluation device: {device}")


## 2. Dataset utilities


In [None]:

from typing import Dict, List, Tuple

try:
    from torch_geometric.loader import DataLoader
    from torch_geometric.data import Data, Batch
    from torch_geometric.utils import to_networkx
except ImportError as exc:
    raise ModuleNotFoundError(
        "torch_geometric is required for this notebook. Install it before proceeding."
    ) from exc

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix

from src.datasets.mutag import load_mutag
from src.datasets.ged_dataset import GEDDataset, collate_pairs, ensure_ged_dataset
from src.models.explainee_gnn import ExplaineeGIN
from src.models.sim_gnn import SimGNN
from src.models.generator import GraphGenerator
from src.models.adapter import GeneratorAdapter
from src.models.losses import (
    SoftContrastiveEmbedLoss,
    PredictionConfidenceLoss,
    EdgePenalty,
    _build_data_from_gen_output,
)
from src.utils.embeddings import compute_classwise_means
from src.trainers.train_distance import evaluate as evaluate_distance_model


def prepare_mutag_splits(test_ratio: float = 0.2, batch_size: int = 32, shuffle: bool = True, seed: int = 42):
    dataset = load_mutag(root=str(DATA_ROOT / "MUTAG"))
    labels = [int(d.y) for d in dataset]
    train_idx, test_idx = train_test_split(
        list(range(len(dataset))),
        test_size=test_ratio,
        stratify=labels,
        random_state=seed,
    )
    train_data = [dataset[i] for i in train_idx]
    test_data = [dataset[i] for i in test_idx]
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=shuffle)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    return dataset, train_loader, test_loader


def dataset_summary(dataset) -> pd.DataFrame:
    rows = []
    for graph in dataset:
        num_nodes = graph.num_nodes
        num_edges = graph.edge_index.size(1)
        rows.append({
            "num_nodes": int(num_nodes),
            "num_edges": int(num_edges),
            "label": int(graph.y.item()) if hasattr(graph, "y") else None,
        })
    df = pd.DataFrame(rows) if pd is not None else rows
    return df


In [None]:

dataset, train_loader, test_loader = prepare_mutag_splits(
    test_ratio=0.2,
    batch_size=cfg["explainee"].get("batch_size", 32),
    seed=seed,
)

print(f"Loaded MUTAG dataset with {len(dataset)} graphs")
print(f"Train loader batches: {len(train_loader)} | Test loader batches: {len(test_loader)}")

mutag_stats = dataset_summary(dataset)
if pd is not None:
    display(mutag_stats.describe(include='all'))


## 3. Load explainee and evaluate performance


In [None]:

def build_explainee(dataset, cfg_explainee: Dict, device: torch.device) -> ExplaineeGIN:
    in_dim = dataset[0].x.size(1)
    model = ExplaineeGIN(
        in_dim=in_dim,
        hidden_dim=cfg_explainee.get("hidden_dim", 32),
        num_layers=cfg_explainee.get("num_layers", 2),
        dropout=cfg_explainee.get("dropout", 0.2),
        num_classes=cfg_explainee.get("num_classes", cfg["dataset"].get("num_classes", 2)),
    ).to(device)

    ckpt_path = PROJECT_ROOT / cfg_explainee.get("save_path", "models/explainees/gin_mutag.pt")
    if ckpt_path.exists():
        state = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(state)
        print(f"Loaded explainee checkpoint from {ckpt_path}")
    else:
        print(f"⚠️ Explainee checkpoint missing: {ckpt_path}. Using randomly initialised weights.")
    model.eval()
    return model


def evaluate_explainee(model: ExplaineeGIN, loader: DataLoader, device: torch.device) -> Dict[str, float]:
    model.eval()
    total_loss = 0.0
    total = 0
    all_logits, all_labels = [], []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            logits = model(batch)
            loss = F.cross_entropy(logits, batch.y, reduction='sum')
            total_loss += loss.item()
            total += batch.y.size(0)
            all_logits.append(logits.cpu())
            all_labels.append(batch.y.cpu())

    logits = torch.cat(all_logits)
    labels = torch.cat(all_labels)
    probs = torch.softmax(logits, dim=-1)
    preds = torch.argmax(probs, dim=-1)

    metrics = {
        "loss": total_loss / max(total, 1),
        "accuracy": accuracy_score(labels, preds),
        "precision": precision_score(labels, preds, average='weighted', zero_division=0),
        "recall": recall_score(labels, preds, average='weighted', zero_division=0),
        "f1": f1_score(labels, preds, average='weighted', zero_division=0),
    }

    cm = confusion_matrix(labels, preds)
    metrics_df = pd.DataFrame(metrics, index=["Explainee"]) if pd is not None else metrics
    if pd is not None:
        display(metrics_df)
        display(pd.DataFrame(cm, index=["true_0", "true_1"], columns=["pred_0", "pred_1"]))
    else:
        print(metrics)
        print("Confusion matrix:
", cm)

    return {
        **metrics,
        "logits": logits,
        "labels": labels,
        "probs": probs,
    }


In [None]:

explainee_cfg = cfg.get("explainee", {})
explainee_model = build_explainee(dataset, explainee_cfg, device)
explainee_eval = evaluate_explainee(explainee_model, test_loader, device)


## 4. SimGNN distance model evaluation


In [None]:

def load_distance_model(cfg_distance: Dict, dataset, device: torch.device):
    ged_path = cfg_distance.get("data_path", "data/mutag_ged.pt")
    ged_path = PROJECT_ROOT / ged_path
    if not ged_path.exists():
        print(f"Generating GED pairs at {ged_path} ...")
        ensure_ged_dataset(list(dataset), out_path=str(ged_path), alpha=cfg_distance.get("alpha", 0.5))

    ged_dataset = GEDDataset.load(ged_path)
    loader = DataLoader(ged_dataset, batch_size=cfg_distance.get("batch_size", 32), shuffle=False, collate_fn=collate_pairs)

    in_dim = ged_dataset[0][0].x.size(1)
    model = SimGNN(
        in_dim,
        hidden_dim=cfg_distance.get("hidden_dim", 64),
        tensor_channels=cfg_distance.get("tensor_channels", 8),
        use_tensor=cfg_distance.get("use_tensor", True),
    ).to(device)

    ckpt_path = PROJECT_ROOT / cfg_distance.get("save_path", "models/distances/simgnn_mutag.pt")
    if ckpt_path.exists():
        state = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(state)
        print(f"Loaded SimGNN checkpoint from {ckpt_path}")
    else:
        print(f"⚠️ Distance model checkpoint missing: {ckpt_path}. Using randomly initialised weights.")
    model.eval()
    return model, ged_dataset, loader


def summarize_distance_metrics(model: SimGNN, loader, device: torch.device) -> Dict[str, float]:
    metrics = evaluate_distance_model(model, loader, device)
    metrics_df = pd.DataFrame(metrics, index=["SimGNN"]) if pd is not None else metrics
    if pd is not None:
        display(metrics_df)
    else:
        print(metrics)
    return metrics


In [None]:

distance_cfg = cfg.get("distance_model", {})
simgnn_model, ged_dataset, ged_loader = load_distance_model(distance_cfg, dataset, device)
distance_metrics = summarize_distance_metrics(simgnn_model, ged_loader, device)


## 5. Generator loading and helper utilities


In [None]:
from dataclasses import dataclass

@dataclass
class GeneratorArtifacts:
    model: GraphGenerator
    state_path: Path
    target_class: int


def infer_generator_spec(dataset) -> Dict:
    adapter = GeneratorAdapter(dataset)
    spec = {
        "max_nodes": adapter.max_nodes,
        "num_cont_node_feats": adapter.num_cont_node_feats,
        "dis_node_blocks": adapter.dis_node_blocks,
        "num_cont_edge_feats": adapter.num_cont_edge_feats,
        "dis_edge_blocks": adapter.dis_edge_blocks,
    }
    print("Generator spec:", spec)
    return spec


def load_generators(cfg_gen: Dict, dataset, device: torch.device):
    spec = infer_generator_spec(dataset)
    save_template = cfg_gen.get("save_path", "checkpoints/generators/generator_classX.pt")
    fallback_template = None
    if "generator-class" in save_template:
        fallback_template = save_template.replace("generator-class", "generator_class")
    artifacts = []
    num_classes = cfg["dataset"].get("num_classes", 2)

    for class_idx in range(num_classes):
        template = save_template.replace("X", str(class_idx))
        path = PROJECT_ROOT / template
        resolved_path = path if path.exists() else None
        checked_paths = [path]
        if resolved_path is None and fallback_template is not None:
            alt_template = fallback_template.replace("X", str(class_idx))
            alt_path = PROJECT_ROOT / alt_template
            checked_paths.append(alt_path)
            if alt_path.exists():
                print(f"Fallback: resolved generator checkpoint for class {class_idx} to {alt_path}")
                resolved_path = alt_path
        gen = GraphGenerator(
            batch_size=cfg_gen.get("batch_size", 1),
            temperature=cfg_gen.get("temperature", 1.0),
            **spec,
        ).to(device)
        if resolved_path is not None and resolved_path.exists():
            state = torch.load(resolved_path, map_location=device)
            gen.load_state_dict(state)
            print(f"Loaded generator for class {class_idx} from {resolved_path}")
        else:
            missing_list = ", ".join(str(p) for p in checked_paths)
            print(f"⚠️ Generator checkpoint missing for class {class_idx}. Checked: {missing_list}")
        gen.eval()
        artifacts.append(
            GeneratorArtifacts(model=gen, state_path=resolved_path if resolved_path is not None else path, target_class=class_idx)
        )
    return artifacts


generator_cfg = cfg.get("generator", {})
generators = load_generators(generator_cfg, dataset, device)


### Class-wise embedding means

The embedding loss requires per-layer class means from the explainee. We recompute them using the training split to avoid leaking test labels.


In [None]:

classwise_means = compute_classwise_means(
    model=explainee_model,
    dataloader=train_loader,
    device=device,
    num_classes=cfg["dataset"].get("num_classes", 2),
)
print(f"Computed means for {len(classwise_means)} layers: {list(classwise_means.keys())}")


## 6. Generator metrics


### Metric glossary

The tables below use the following quantities:

* **Explainee metrics** – `accuracy`, `precision`, `recall`, and `f1` measure how well the explainee model predicts the held-out MUTAG labels. The `confusion_matrix` highlights per-class confusions.
* **SimGNN distance metrics** – `mae_norm`/`mse_norm` report absolute and squared errors on normalised GED targets, while `mae_raw`/`mse_raw` do the same on the unnormalised distances. `corr_norm`/`corr_raw` give the Pearson correlation between SimGNN predictions and the ground-truth GED scores.
* **Generator loss terms** – `total` aggregates the embedding (`pull`+`push`), prediction (`pred`), and structural (`edge`) penalties used during training. Lower values indicate the generator matches the explainee prototype with fewer regularisation violations.
* **Fidelity metrics** – `confidence_mean` is the explainee's average probability for the target class, `margin_mean` is the gap to the runner-up class, and `entropy_mean` captures output uncertainty. Higher confidence and margin with lower entropy reflect explanations aligned with the explainee.
* **Structural summaries** – `num_nodes`, `num_edges`, `avg_degree`, and `density` describe the generated graphs' size and connectivity, mirroring the statistics reported for the observed MUTAG graphs.
* **GED alignment** – `ged_min` is the closest SimGNN-estimated distance to any real graph of the class, while `ged_mean` is the average distance across the comparison set. Lower values imply better structural faithfulness.
* **Diversity** – the average pairwise SimGNN distance between generated samples, indicating how varied the generator outputs are for a class.
* **Observed graph statistics** – the report also lists dataset-wide counts and summary stats for real MUTAG graphs to contextualise the generator outputs.


In [None]:

from itertools import combinations


def generator_forward_samples(gen: GraphGenerator, num_samples: int = 16):
    outputs = []
    with torch.no_grad():
        for _ in range(num_samples):
            out = gen()
            outputs.append({k: v.detach().cpu() for k, v in out.items()})
    return outputs


def logits_from_gen_output(gen_out, explainee: torch.nn.Module, thresh: float = 0.5) -> torch.Tensor:
    batch = _build_data_from_gen_output({k: v.to(device) for k, v in gen_out.items()}, thresh=thresh)
    logits = explainee(batch)
    return logits.detach().cpu()


def build_data_from_output(gen_out, thresh: float = 0.5):
    batch = _build_data_from_gen_output({k: v.to(device) for k, v in gen_out.items()}, thresh=thresh)
    data_list = batch.to_data_list()
    assert len(data_list) == 1, "Generator batch size > 1 not supported yet"
    data = data_list[0].cpu()
    data.edge_index = data.edge_index.long()
    return data


def compute_loss_terms(gen_out, target_class: int, lambda_pred: float, lambda_edge: float):
    emb_loss = SoftContrastiveEmbedLoss(
        explainee=explainee_model,
        classwise_means=classwise_means,
        target_class=target_class,
        layer_names=classwise_means.keys(),
    )
    pred_loss = PredictionConfidenceLoss(explainee_model, target_class=target_class)
    edge_pen = EdgePenalty()

    out_device = {k: v.to(device) for k, v in gen_out.items()}
    comps = emb_loss.forward_components(out_device)
    l_pred = pred_loss(out_device)
    l_edge = edge_pen(out_device["adj"])
    total = (comps["pull"] + comps["push"]) + lambda_pred * l_pred + lambda_edge * l_edge
    return {
        "total": float(total.cpu()),
        "pull": float(comps["pull"].cpu()),
        "push": float(comps["push"].cpu()),
        "pred": float(l_pred.cpu()),
        "edge": float(l_edge.cpu()),
    }


def graph_stats(data: Data):
    num_nodes = int(data.num_nodes)
    num_edges = int(data.edge_index.size(1) // 2) if data.edge_index.size(1) > 0 else 0
    deg = torch.bincount(data.edge_index[0], minlength=num_nodes).float() if data.edge_index.numel() > 0 else torch.zeros(num_nodes)
    return {
        "num_nodes": num_nodes,
        "num_edges": num_edges,
        "avg_degree": float(deg.mean().item()) if deg.numel() > 0 else 0.0,
        "density": float(num_edges) / max(num_nodes * (num_nodes - 1) / 2, 1),
    }


def fidelity_metrics(logits: torch.Tensor, target_class: int):
    probs = torch.softmax(logits, dim=-1)
    target_probs = probs[:, target_class]
    top2 = probs.topk(2, dim=1).values
    margins = target_probs - top2[:, 1]
    entropy = -(probs * probs.clamp_min(1e-9).log()).sum(dim=1)
    return {
        "confidence_mean": float(target_probs.mean()),
        "confidence_std": float(target_probs.std(unbiased=False)) if target_probs.numel() > 1 else 0.0,
        "margin_mean": float(margins.mean()),
        "entropy_mean": float(entropy.mean()),
    }


def distance_to_real_samples(data: Data, real_graphs: List[Data], distance_model: SimGNN, device: torch.device):
    if not real_graphs:
        return {"ged_min": float("nan"), "ged_mean": float("nan")}
    scores = []
    with torch.no_grad():
        batch_fake = Batch.from_data_list([data]).to(device)
        for real in real_graphs:
            batch_real = Batch.from_data_list([real]).to(device)
            pred = distance_model(batch_fake, batch_real)
            scores.append(float(pred.cpu()))
    return {
        "ged_min": min(scores) if scores else float("nan"),
        "ged_mean": float(sum(scores) / len(scores)) if scores else float("nan"),
    }


def pairwise_diversity(datas: List[Data], distance_model: SimGNN, device: torch.device) -> float:
    if len(datas) < 2:
        return float("nan")
    scores = []
    with torch.no_grad():
        for i, j in combinations(range(len(datas)), 2):
            b_i = Batch.from_data_list([datas[i]]).to(device)
            b_j = Batch.from_data_list([datas[j]]).to(device)
            pred = distance_model(b_i, b_j)
            scores.append(float(pred.cpu()))
    return float(sum(scores) / len(scores)) if scores else float("nan")


def evaluate_generator(artifact, cfg_gen: Dict, real_dataset, distance_model: SimGNN, device: torch.device, num_samples: int = 16):
    outputs = generator_forward_samples(artifact.model, num_samples=num_samples)
    lambda_pred = cfg_gen.get("lambda_pred", 1.0)
    lambda_edge = cfg_gen.get("lambda_edge", 0.1)

    real_subset = [g for g in real_dataset if int(g.y.item()) == artifact.target_class]
    real_subset = random.sample(real_subset, min(len(real_subset), 32)) if len(real_subset) > 32 else real_subset

    rows = []
    datas = []
    for idx, out in enumerate(outputs):
        logits = logits_from_gen_output(out, explainee_model)
        data = build_data_from_output(out)
        datas.append(data)

        loss_terms = compute_loss_terms(out, artifact.target_class, lambda_pred=lambda_pred, lambda_edge=lambda_edge)
        fid = fidelity_metrics(logits, artifact.target_class)
        stats = graph_stats(data)
        ged = distance_to_real_samples(data, real_subset, distance_model, device)

        row = {
            "sample": idx,
            **loss_terms,
            **fid,
            **stats,
            **ged,
        }
        rows.append(row)

    diversity = pairwise_diversity(datas, distance_model, device)
    summary = pd.DataFrame(rows) if pd is not None else rows
    if pd is not None:
        summary.attrs["diversity"] = diversity
    return summary, diversity


In [None]:

generator_results = {}
for artifact in generators:
    summary_df, diversity = evaluate_generator(artifact, generator_cfg, dataset, simgnn_model, device, num_samples=16)
    generator_results[artifact.target_class] = summary_df
    if pd is not None and isinstance(summary_df, pd.DataFrame) and not summary_df.empty:
        display(summary_df.describe())
        print(f"Pairwise diversity (SimGNN GED) for class {artifact.target_class}: {diversity}")


## 7. Visual inspection of generated graphs


In [None]:

def plot_graph_pair(gen_data: Data, real_data: Data, title: str):
    if plt is None:
        print("matplotlib is not available; skipping visualization.")
        return
    import networkx as nx

    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    G_gen = to_networkx(gen_data, to_undirected=True)
    G_real = to_networkx(real_data, to_undirected=True)
    pos_gen = nx.spring_layout(G_gen, seed=seed)
    pos_real = nx.spring_layout(G_real, seed=seed)

    axes[0].set_title(f"Generated ({title})")
    nx.draw_networkx(G_gen, pos=pos_gen, ax=axes[0], with_labels=False, node_size=150)

    axes[1].set_title("Nearest real")
    nx.draw_networkx(G_real, pos=pos_real, ax=axes[1], with_labels=False, node_size=150, node_color='lightgrey')

    for ax in axes:
        ax.axis('off')
    plt.tight_layout()
    return fig


for cls, df in generator_results.items():
    if pd is None or not isinstance(df, pd.DataFrame) or df.empty:
        continue
    best_idx = df["ged_min"].idxmin()
    samples, _ = evaluate_generator(generators[cls], generator_cfg, dataset, simgnn_model, device, num_samples=best_idx + 1)
    if pd is not None and isinstance(samples, pd.DataFrame):
        out = generator_forward_samples(generators[cls].model, num_samples=best_idx + 1)[best_idx]
    else:
        out = generator_forward_samples(generators[cls].model, num_samples=best_idx + 1)[best_idx]
    gen_data = build_data_from_output(out)
    target_graphs = [g for g in dataset if int(g.y.item()) == cls]
    real_subset = random.sample(target_graphs, min(len(target_graphs), 32)) if len(target_graphs) > 32 else target_graphs
    best_real = None
    best_score = float("inf")
    with torch.no_grad():
        batch_fake = Batch.from_data_list([gen_data]).to(device)
        for real in real_subset:
            batch_real = Batch.from_data_list([real]).to(device)
            score = float(simgnn_model(batch_fake, batch_real).cpu())
            if score < best_score:
                best_score = score
                best_real = real
    if best_real is not None:
        fig = plot_graph_pair(gen_data, best_real, title=f"class {cls}")
        if plt is not None:
            plt.suptitle(f"Class {cls} | SimGNN GED ≈ {best_score:.3f}")
            plt.show()


## 8. Parameter statistics


In [None]:

def parameter_summary(model: torch.nn.Module):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {"total_params": total, "trainable_params": trainable}


param_rows = []
param_rows.append({"component": "Explainee", **parameter_summary(explainee_model)})
param_rows.append({"component": "SimGNN", **parameter_summary(simgnn_model)})
for art in generators:
    summary = parameter_summary(art.model)
    summary.update({
        "component": f"Generator_class_{art.target_class}",
        "state_path": str(art.state_path),
        "adj_min": float(art.model.adj_logits.min().item()),
        "adj_max": float(art.model.adj_logits.max().item()),
    })
    param_rows.append(summary)

param_df = pd.DataFrame(param_rows) if pd is not None else param_rows
if pd is not None:
    display(param_df)
else:
    print(param_rows)


## 9. Consolidated report


In [None]:
report = {
    "explainee": {k: float(v) for k, v in explainee_eval.items() if isinstance(v, (int, float))},
    "distance": {k: float(v) for k, v in distance_metrics.items()},
    "observed_graphs": {},
    "generators": {},
}

observed_stats = [graph_stats(graph) for graph in dataset]
node_counts = [stats["num_nodes"] for stats in observed_stats]
edge_counts = [stats["num_edges"] for stats in observed_stats]
avg_degrees = [stats["avg_degree"] for stats in observed_stats]
densities = [stats["density"] for stats in observed_stats]
labels = [int(graph.y.item()) if hasattr(graph, "y") else None for graph in dataset]
label_counter = Counter(label for label in labels if label is not None)

def summarise(values):
    if not values:
        return {"mean": float("nan"), "std": float("nan"), "min": float("nan"), "max": float("nan")}
    mean = sum(values) / len(values)
    variance = sum((val - mean) ** 2 for val in values) / len(values)
    return {
        "mean": float(mean),
        "std": float(variance ** 0.5),
        "min": float(min(values)),
        "max": float(max(values)),
    }

report["observed_graphs"] = {
    "count": len(dataset),
    "label_counts": {int(k): int(v) for k, v in label_counter.items()},
    "num_nodes": summarise(node_counts),
    "num_edges": summarise(edge_counts),
    "avg_degree": summarise(avg_degrees),
    "density": summarise(densities),
}

for cls, df in generator_results.items():
    if pd is None or not isinstance(df, pd.DataFrame) or df.empty:
        continue

    def floatify(mapping):
        return {k: float(v) if v == v else float("nan") for k, v in mapping.items()}

    report["generators"][f"class_{cls}"] = {
        "loss": floatify(df[["total", "pull", "push", "pred", "edge"]].mean().to_dict()),
        "fidelity": floatify(df[["confidence_mean", "margin_mean", "entropy_mean"]].mean().to_dict()),
        "structure": floatify(df[["num_nodes", "num_edges", "avg_degree", "density"]].mean().to_dict()),
        "ged": floatify(df[["ged_min", "ged_mean"]].mean().to_dict()),
        "diversity": float(df.attrs.get("diversity")) if df.attrs.get("diversity") is not None else float("nan"),
    }

print(json.dumps(report, indent=2, sort_keys=True))


## Next steps

* Inspect cases where generator confidence is low but pull/push losses are small — this may indicate insufficient prediction guidance.
* Compare GED statistics against real validation graphs to ensure the distance model generalises.
* Use the diversity metric to tune temperature or regularisation if the generator collapses to a single prototype.
