In [3]:
#!/usr/bin/env python3
import os
import sys
import json
import math
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
from scipy.stats import pearsonr


def sigmoid(x: float) -> float:
    """Scalar sigmoid with basic numerical stability."""
    if x >= 0:
        z = math.exp(-x)
        return 1.0 / (1.0 + z)
    else:
        z = math.exp(x)
        return z / (1.0 + z)


def read_jsonl(path: Path):
    with path.open("r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            yield json.loads(line)


def arch_from_dir(model_dir_name: str) -> str:
    """
    Convert 'modelGIN' -> 'GIN', 'modelGAT' -> 'GAT', etc.
    Fallback: return original if it doesn't start with 'model'.
    """
    if model_dir_name.lower().startswith("model"):
        return model_dir_name[5:]
    return model_dir_name


def collect_impacts(impact_file: Path):
    """
    Collect motif impacts from masked-edge-impact.jsonl or masked-impact.jsonl.
    Returns:
      impacts[(split, motif_idx)] -> list of impact values
      graph_ids[(split, motif_idx)] -> set of 'split|graph_idx'
    """
    impacts = defaultdict(list)
    graph_ids = defaultdict(set)

    for rec in read_jsonl(impact_file):
        split = rec.get("split", "unknown")

        motif_idx = rec.get("motif_idx", rec.get("motif_index", None))
        if motif_idx is None or motif_idx < 0:
            # ignore UNK / background motif
            continue

        old_pred = rec["old_prediction"]
        new_pred = rec["new_prediction"]
        imp = abs(sigmoid(new_pred) - sigmoid(old_pred))

        key = (split, motif_idx)
        impacts[key].append(imp)

        graph_idx = rec.get("graph_idx", None)
        if graph_idx is not None:
            graph_ids[key].add(f"{split}|{graph_idx}")

    return impacts, graph_ids


def collect_node_scores(node_scores_file: Path):
    """
    Collect motif-level node scores from node_scores.jsonl.
    Returns:
      scores[(split, motif_index)] -> list of scores
      graph_ids[(split, motif_index)] -> set of 'split|smiles'
    """
    scores = defaultdict(list)
    graph_ids = defaultdict(set)

    for rec in read_jsonl(node_scores_file):
        split = rec.get("split", "unknown")

        motif_index = rec.get("motif_index", rec.get("motif_idx", None))
        if motif_index is None or motif_index < 0:
            continue

        score = rec["score"]
        smiles = rec.get("smiles", None)

        key = (split, motif_index)
        scores[key].append(score)

        if smiles is not None:
            graph_ids[key].add(f"{split}|{smiles}")

    return scores, graph_ids


def aggregate_mean(dict_of_lists):
    return {k: float(np.mean(v)) for k, v in dict_of_lists.items() if len(v) > 0}


def aggregate_counts(dict_of_sets):
    return {k: len(v) for k, v in dict_of_sets.items()}


def build_corr_df(
    avg_scores,
    avg_impacts,
    cnt_graphs_node,
    cnt_graphs_imp,
    impact_type: str,
) -> pd.DataFrame:
    """
    Build a DataFrame with:
      split, motif_idx, avg_score, avg_impact, n_graphs_node, n_graphs_imp,
      graph_count_match, graph_count_diff, impact_type
    Only motifs/splits that appear in both avg_scores and avg_impacts are kept.
    """
    rows = []
    common_keys = set(avg_scores.keys()) & set(avg_impacts.keys())

    for (split, motif_idx) in sorted(common_keys):
        s = avg_scores[(split, motif_idx)]
        imp = avg_impacts[(split, motif_idx)]
        n_node = cnt_graphs_node.get((split, motif_idx), 0)
        n_imp = cnt_graphs_imp.get((split, motif_idx), 0)
        rows.append(
            dict(
                split=split,
                motif_idx=motif_idx,
                avg_score=s,
                avg_impact=imp,
                n_graphs_node=n_node,
                n_graphs_imp=n_imp,
                graph_count_match=(n_node == n_imp),
                graph_count_diff=n_node - n_imp,
                impact_type=impact_type,
            )
        )

    return pd.DataFrame(rows)


def compute_corr(df: pd.DataFrame):
    """Compute Pearson correlation between avg_score and avg_impact."""
    if df.empty:
        return np.nan, np.nan, 0
    x = df["avg_score"].values
    y = df["avg_impact"].values
    # Pearson undefined if no variation or only one point
    if len(df) < 2 or np.allclose(x, x[0]) or np.allclose(y, y[0]):
        return np.nan, np.nan, len(df)
    r, p = pearsonr(x, y)
    return float(r), float(p), len(df)


def detail_rows_from_df(
    df: pd.DataFrame,
    experiment_name: str,
    dataset_name: str,
    arch_name: str,
    impact_type: str,
):
    rows = []
    for _, row in df.iterrows():
        rows.append(
            dict(
                Experiment=experiment_name,
                Dataset=dataset_name,
                Arch=arch_name,
                ImpactType=impact_type,
                Split=row["split"],
                MotifIdx=int(row["motif_idx"]),
                AvgScore=float(row["avg_score"]),
                AvgImpact=float(row["avg_impact"]),
                NumGraphsNode=int(row["n_graphs_node"]),
                NumGraphsImpact=int(row["n_graphs_imp"]),
                GraphCountMatch=bool(row["graph_count_match"]),
                GraphCountDiff=int(row["graph_count_diff"]),
            )
        )
    return rows


def main(root_dir: str):
    root = Path(root_dir).resolve()
    if not root.exists():
        print(f"Root directory {root} does not exist.", file=sys.stderr)
        sys.exit(1)

    # MotifLoss_* experiment dirs
    exp_dirs = sorted(
        [p for p in root.iterdir() if p.is_dir() and p.name.startswith("MotifLoss_")]
    )
    if not exp_dirs:
        print("No MotifLoss_* directories found under root.", file=sys.stderr)
        sys.exit(1)

    summary_rows = []
    motif_detail_rows = []

    for exp_dir in exp_dirs:
        experiment_name = exp_dir.name  # e.g. MotifLoss_0
        print(f"\n=== Processing experiment: {experiment_name} ===")

        # /MotifLoss_k/<dataset>/<modelARCH>/foldX/seedY/
        for dataset_dir in exp_dir.iterdir():
            if not dataset_dir.is_dir():
                continue
            dataset_name = dataset_dir.name  # e.g. esol

            for model_dir in dataset_dir.iterdir():
                if not model_dir.is_dir():
                    continue
                arch_name = arch_from_dir(model_dir.name)  # e.g. GIN

                print(f"  Dataset={dataset_name}, Arch={arch_name}")

                # Containers aggregated across folds/seeds
                edge_impacts_all = defaultdict(list)
                edge_graph_ids_all = defaultdict(set)

                graph_impacts_all = defaultdict(list)
                graph_graph_ids_all = defaultdict(set)

                scores_all = defaultdict(list)
                score_graph_ids_all = defaultdict(set)

                # Traverse folds and seeds
                for fold_dir in model_dir.glob("fold*"):
                    if not fold_dir.is_dir():
                        continue
                    for seed_dir in fold_dir.glob("seed*"):
                        if not seed_dir.is_dir():
                            continue

                        edge_impact_file = seed_dir / "masked-edge-impact.jsonl"
                        graph_impact_file = seed_dir / "masked-impact.jsonl"
                        node_scores_file = seed_dir / "node_scores.jsonl"

                        if not node_scores_file.exists():
                            print(f"    [WARN] Missing node_scores in {seed_dir}")
                            continue

                        # Node scores
                        s, s_graph_ids = collect_node_scores(node_scores_file)
                        for k, v in s.items():
                            scores_all[k].extend(v)
                        for k, v in s_graph_ids.items():
                            score_graph_ids_all[k].update(v)

                        # Edge-level impacts
                        if edge_impact_file.exists():
                            e_imp, e_graph = collect_impacts(edge_impact_file)
                            for k, v in e_imp.items():
                                edge_impacts_all[k].extend(v)
                            for k, v in e_graph.items():
                                edge_graph_ids_all[k].update(v)
                        else:
                            print(f"    [WARN] Missing masked-edge-impact in {seed_dir}")

                        # Graph-level impacts
                        if graph_impact_file.exists():
                            g_imp, g_graph = collect_impacts(graph_impact_file)
                            for k, v in g_imp.items():
                                graph_impacts_all[k].extend(v)
                            for k, v in g_graph.items():
                                graph_graph_ids_all[k].update(v)
                        else:
                            print(f"    [WARN] Missing masked-impact in {seed_dir}")

                if not scores_all:
                    print(f"    [INFO] No node scores found; skipping.")
                    continue

                # Aggregate means & counts
                avg_scores = aggregate_mean(scores_all)
                count_graphs_node = aggregate_counts(score_graph_ids_all)

                avg_edge_impacts = aggregate_mean(edge_impacts_all)
                count_graphs_edge = aggregate_counts(edge_graph_ids_all)

                avg_graph_impacts = aggregate_mean(graph_impacts_all)
                count_graphs_graph = aggregate_counts(graph_graph_ids_all)

                # ---- Edge impacts vs scores (masked-edge-impact) ----
                if avg_edge_impacts:
                    df_edge = build_corr_df(
                        avg_scores,
                        avg_edge_impacts,
                        count_graphs_node,
                        count_graphs_edge,
                        impact_type="masked-edge-impact",
                    )

                    # Compute correlation per split
                    for split in sorted(df_edge["split"].unique()):
                        df_split = df_edge[df_edge["split"] == split]
                        r_edge, p_edge, n_edge = compute_corr(df_split)

                        summary_rows.append(
                            dict(
                                Experiment=experiment_name,
                                Dataset=dataset_name,
                                Arch=arch_name,
                                ImpactType="masked-edge-impact",
                                Split=split,
                                NumMotifs=n_edge,
                                PearsonR=r_edge,
                                PValue=p_edge,
                            )
                        )

                    motif_detail_rows.extend(
                        detail_rows_from_df(
                            df_edge,
                            experiment_name,
                            dataset_name,
                            arch_name,
                            impact_type="masked-edge-impact",
                        )
                    )

                # ---- Graph impacts vs scores (masked-impact) ----
                if avg_graph_impacts:
                    df_graph = build_corr_df(
                        avg_scores,
                        avg_graph_impacts,
                        count_graphs_node,
                        count_graphs_graph,
                        impact_type="masked-impact",
                    )

                    for split in sorted(df_graph["split"].unique()):
                        df_split = df_graph[df_graph["split"] == split]
                        r_graph, p_graph, n_graph = compute_corr(df_split)

                        summary_rows.append(
                            dict(
                                Experiment=experiment_name,
                                Dataset=dataset_name,
                                Arch=arch_name,
                                ImpactType="masked-impact",
                                Split=split,
                                NumMotifs=n_graph,
                                PearsonR=r_graph,
                                PValue=p_graph,
                            )
                        )

                    motif_detail_rows.extend(
                        detail_rows_from_df(
                            df_graph,
                            experiment_name,
                            dataset_name,
                            arch_name,
                            impact_type="masked-impact",
                        )
                    )

    # -------------------------
    # Build DataFrames
    # -------------------------
    summary_df = pd.DataFrame(summary_rows)
    detail_df = pd.DataFrame(motif_detail_rows)

    # NEW: drop any exact duplicate summary rows
    if not summary_df.empty:
        summary_df = summary_df.drop_duplicates(
            subset=["Experiment", "Dataset", "Arch", "ImpactType", "Split"],
            keep="first",
        )

    # NEW: drop any exact duplicate motif-detail rows
    if not detail_df.empty:
        detail_df = detail_df.drop_duplicates()

    summary_out = root / "motif_correlation_summary_all_splits.csv"
    detail_out = root / "motif_correlation_details_all_splits.csv"

    summary_df.to_csv(summary_out, index=False)
    detail_df.to_csv(detail_out, index=False)

    print(f"\nSaved summary to: {summary_out}")
    print(f"Saved motif-level details to: {detail_out}")

    # ==============================
    # NEW: mean + std of PearsonR
    # ==============================
    if not summary_df.empty:
        # Group across splits for each (Experiment, Dataset, Arch, ImpactType)
        agg_df = (
            summary_df
            .groupby(["Experiment", "Dataset", "Arch", "ImpactType"], as_index=False)
            .agg(
                NumSplits=("PearsonR", "count"),
                MeanPearsonR=("PearsonR", "mean"),
                StdPearsonR=("PearsonR", "std"),
            )
        )
    else:
        agg_df = pd.DataFrame(
            columns=[
                "Experiment",
                "Dataset",
                "Arch",
                "ImpactType",
                "NumSplits",
                "MeanPearsonR",
                "StdPearsonR",
            ]
        )

    # NEW: safety â€” ensure one row per (Experiment, Dataset, Arch, ImpactType)
    if not agg_df.empty:
        agg_df = agg_df.drop_duplicates(
            subset=["Experiment", "Dataset", "Arch", "ImpactType"],
            keep="first",
        )

    agg_out = root / "motif_correlation_summary_agg.csv"
    agg_df.to_csv(agg_out, index=False)
    print(f"Saved aggregated mean/std of PearsonR to: {agg_out}")


main(".")



=== Processing experiment: MotifLoss_0 ===
  Dataset=Benzene, Arch=GCN
  Dataset=Benzene, Arch=GAT
  Dataset=Benzene, Arch=GIN
  Dataset=hERG, Arch=GCN
  Dataset=hERG, Arch=GAT
  Dataset=hERG, Arch=GIN
  Dataset=Mutagenicity, Arch=GCN
  Dataset=Mutagenicity, Arch=GAT
  Dataset=Mutagenicity, Arch=GIN
  Dataset=Fluoride_Carbonyl, Arch=GCN
  Dataset=Fluoride_Carbonyl, Arch=GAT
  Dataset=Fluoride_Carbonyl, Arch=GIN
  Dataset=Alkane_Carbonyl, Arch=GAT
  Dataset=Alkane_Carbonyl, Arch=GIN
  Dataset=Alkane_Carbonyl, Arch=GCN
  Dataset=BBBP, Arch=GCN
  Dataset=BBBP, Arch=GAT
  Dataset=BBBP, Arch=GIN
  Dataset=Lipophilicity, Arch=GCN
  Dataset=Lipophilicity, Arch=GAT
  Dataset=Lipophilicity, Arch=GIN
  Dataset=esol, Arch=GCN
  Dataset=esol, Arch=GAT
  Dataset=esol, Arch=GIN

=== Processing experiment: MotifLoss_0.0 ===
  Dataset=hERG, Arch=GCN
  Dataset=hERG, Arch=GAT
  Dataset=hERG, Arch=GIN
  Dataset=hERG, Arch=SAGE
  Dataset=Alkane_Carbonyl, Arch=SAGE
  Dataset=Alkane_Carbonyl, Arch=GIN
  Da

In [8]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

csv_path = "motif_correlation_summary_all_splits.csv"
df = pd.read_csv(csv_path)

out_dir = "significance_heatmaps"
os.makedirs(out_dir, exist_ok=True)

split_order = ["train", "valid", "test"]
df["Split"] = pd.Categorical(df["Split"], categories=split_order, ordered=True)

saved_files = []

for arch in df["Arch"].unique():
    for impact in df["ImpactType"].unique():
        sub = df[(df["Arch"] == arch) & (df["ImpactType"] == impact)]
        if sub.empty:
            continue
        
        sub = sub.copy()
        sub["neglogp"] = sub["PearsonR"]#-np.log10(sub["PearsonR"].replace(0, 1e-300))
        
        # Aggregate duplicates by mean
        g = sub.groupby(["Dataset", "Split"])["neglogp"].mean().reset_index()
        pivot = g.pivot(index="Dataset", columns="Split", values="neglogp")
        
        fig, ax = plt.subplots(figsize=(8, 6))
        im = ax.imshow(pivot, aspect='auto')
        
        ax.set_xticks(np.arange(len(pivot.columns)))
        ax.set_xticklabels(pivot.columns)
        ax.set_yticks(np.arange(len(pivot.index)))
        ax.set_yticklabels(pivot.index)
        
        for i in range(len(pivot.index)):
            for j in range(len(pivot.columns)):
                val = pivot.iloc[i, j]
                text = "NA" if pd.isna(val) else f"{val:.2f}"
                ax.text(j, i, text, ha="center", va="center", fontsize=8)
        
        ax.set_title(f"Significance Heatmap\nArch={arch}, Impact={impact}")
        fig.colorbar(im, ax=ax, label="-log10(p-value)")
        
        out_path = os.path.join(out_dir, f"heatmap_{arch}_{impact}.png")
        fig.tight_layout()
        fig.savefig(out_path, dpi=300)
        plt.close(fig)
        
        saved_files.append(out_path)
