In [None]:
# 1. Data Intake
# GNN Data Intake — Integrity Check & Schema Normalization
#
# This script will:
# 1) Load your provided Excel files (nodes/edges)
# 2) Display them as interactive tables
# 3) Auto-detect identifier columns (node id, edge src/dst)
# 4) Validate referential integrity, nulls, dtypes
# 5) Compile a concise validation report
# 6) Produce standardized CSVs for downstream GNN code if possible
#
# Files:
# assumed_node_features.xlsx
# assumed_edge_features.xlsx
#
# Outputs:
# validation_report.txt
# nodes_standardized.csv
# edges_standardized.csv

In [1]:
from pathlib import Path
import pandas as pd
import numpy as np
import re
from typing import Optional, Tuple

In [17]:
NODE_XLSX = Path("assumed_node_features.xlsx")
EDGE_XLSX = Path("assumed_edge_features.xlsx")

# ---------- 1) Load ----------
nodes = pd.read_excel(NODE_XLSX)
edges = pd.read_excel(EDGE_XLSX)
nodes.head(20)
edges.head(20)

Unnamed: 0,mol_id,bond_idx,a1,a2,elem1,elem2,pair,distance,bo_mayer_abs,bo_wiberg,bo_mull
0,C23,0,0,1,C,C,C-C,1.521615,1.019884,1.120686,1.110974
1,C23,1,0,5,C,C,C-C,1.529253,0.894315,1.078091,1.040698
2,C23,2,0,23,C,H,C-H,1.096959,0.769069,0.934254,0.697315
3,C23,3,0,24,C,H,C-H,1.094622,0.944571,1.120111,0.382928
4,C23,4,1,2,C,C,C-C,1.548193,1.06256,1.26406,0.421025
5,C23,5,1,25,C,H,C-H,1.098821,0.755087,1.255285,0.930093
6,C23,6,1,26,C,H,C-H,1.097661,0.963163,0.788721,0.551803
7,C23,7,2,3,C,C,C-C,1.569927,0.926406,1.031078,0.71702
8,C23,8,2,21,C,C,C-C,1.539477,1.115778,1.071979,0.729561
9,C23,9,2,22,C,C,C-C,1.55145,0.858996,1.130484,1.142553


In [19]:
# ---------- 2) Display (raw previews) ----------
NODE_ID_LEGACY = "atom_idx"   # -> node_id
EDGE_SRC_LEGACY = "a1"        # -> src
EDGE_DST_LEGACY = "a2"        # -> dst

REMOVE_SELF_LOOPS = False

---

In [None]:
# 标准化表示“节点 ID”和“边缘端点”的列，以跨越不一致的命名
# ---------- Helpers ----------
## 1.保留原始列名一致
## 一个基于正则表达式的选择器 ，它扫描现有 columns 并返回与任何 candidates 模式匹配的列，保留顺序并避免重复
def pick_col(patterns: List[str], columns: pd.Index) -> List[str]:
    """
    Order-preserving regex selector. Returns columns matching any pattern,
    preserving pattern priority and avoiding duplicates.
    """
    hits, seen = [], set()
    for pat in patterns:
        for c in columns:
            if re.search(pat, c, flags=re.I) and c not in seen:
                hits.append(c); seen.add(c)
    return hits

def ensure_graph_alias(df: pd.DataFrame) -> Tuple[pd.DataFrame, str, str]:
    """
    Ensure DataFrame has a canonical 'graph_alias' column.
    Returns (df_with_alias, alias_col_name, note)
    """
    candidates = pick_col(
        [r"^graph_?alias$", r"^graph_?id$", r"^mol_?(id|name)$", r"^molecule$", r"^name$"],
        df.columns
    )
    note = ""
    if candidates:
        alias = candidates[0]
        if alias != "graph_alias":
            df = df.rename(columns={alias: "graph_alias"})
    else:
        df = df.copy()
        df["graph_alias"] = "MOL_0"
        note = "No alias column found; created default graph_alias='MOL_0'."
    return df, "graph_alias", note

def ensure_zero_based(nodes_df: pd.DataFrame, edges_df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame, List[str]]:
    """
    Ensure node_id/src/dst are 0-based per graph_alias. If min(node_id)==1 per graph,
    shift nodes and corresponding edges by -1. Returns potentially modified frames and a log list.
    """
    log = []
    for alias, g_nodes in nodes_df.groupby("graph_alias"):
        if g_nodes["node_id"].notna().any():
            min_id = g_nodes["node_id"].min()
            if pd.notna(min_id) and int(min_id) == 1:
                nodes_df.loc[nodes_df["graph_alias"] == alias, "node_id"] -= 1
                m = edges_df["graph_alias"] == alias
                edges_df.loc[m, "src"] -= 1
                edges_df.loc[m, "dst"] -= 1
                log.append(f"- Shifted {alias} from 1-based → 0-based.")
    return nodes_df, edges_df, log

def quality_table(df: pd.DataFrame) -> pd.DataFrame:
    n_rows = len(df)
    nulls = df.isna().sum()
    null_pct = (nulls / n_rows * 100).round(2) if n_rows > 0 else 0.0
    dtypes = df.dtypes.astype(str)

    is_numlike = df.apply(lambda s: pd.api.types.is_numeric_dtype(s))
    nonfinite = pd.Series(0, index=df.columns, dtype="int64")
    if n_rows > 0:
        for c in df.columns[is_numlike]:
            s = pd.to_numeric(df[c], errors="coerce")
            nonfinite[c] = int(s.isna().sum() + np.isinf(s.replace({np.nan: 0})).sum())

    nunique_nonnull = df.nunique(dropna=True)
    is_constant = (nunique_nonnull <= 1)

    q = pd.DataFrame({
        "dtype": dtypes,
        "nulls": nulls,
        "null_%": null_pct,
        "nonfinite_count": nonfinite,
        "unique_nonnull": nunique_nonnull,
        "is_constant": is_constant
    }).sort_values(["null_%","nonfinite_count"], ascending=[False, False])
    return q

# ---------------------- Main ----------------------
def main():
    report = []
    add = report.append

    # 0) Load
    if not NODE_XLSX.exists():
        die(f"Missing input file: {NODE_XLSX}")
    if not EDGE_XLSX.exists():
        die(f"Missing input file: {EDGE_XLSX}")

    nodes = pd.read_excel(NODE_XLSX)
    edges = pd.read_excel(EDGE_XLSX)

    # 1) Ensure alias column
    nodes_std, alias_nodes, note_nodes = ensure_graph_alias(nodes)
    edges_std, alias_edges, note_edges = ensure_graph_alias(edges)
    add("# GNN Intake Validation Report (alias, undirected, NaN-kept)")
    add("## Column mappings")
    if note_nodes: add(f"- Nodes: {note_nodes}")
    if note_edges: add(f"- Edges: {note_edges}")
    add("- graph key: graph_alias")
    add(f"- node id: node_id ← {NODE_ID_LEGACY}")
    add(f"- edge ends: src ← {EDGE_SRC_LEGACY}, dst ← {EDGE_DST_LEGACY}")

    # 2) Canonical IDs
    if NODE_ID_LEGACY not in nodes_std.columns:
        die(f"Expected node id column '{NODE_ID_LEGACY}' not found in nodes.")
    if EDGE_SRC_LEGACY not in edges_std.columns or EDGE_DST_LEGACY not in edges_std.columns:
        die(f"Expected edge columns '{EDGE_SRC_LEGACY}'/'{EDGE_DST_LEGACY}' not found in edges.")

    nodes_std["node_id"] = pd.to_numeric(nodes_std[NODE_ID_LEGACY], errors="coerce").astype("Int64")
    edges_std["src"]     = pd.to_numeric(edges_std[EDGE_SRC_LEGACY], errors="coerce").astype("Int64")
    edges_std["dst"]     = pd.to_numeric(edges_std[EDGE_DST_LEGACY], errors="coerce").astype("Int64")

    # 3) Quick node-id health
    dup_count  = nodes_std[["graph_alias","node_id"]].duplicated().sum()
    null_count = nodes_std["node_id"].isna().sum()
    add("## Node-ID health")
    add(f"- node_id nulls: {int(null_count)}; duplicates (table): {int(dup_count)}")

    # 4) Enforce 0-based per graph
    nodes_std, edges_std, zero_logs = ensure_zero_based(nodes_std, edges_std)
    if zero_logs:
        add("## Index normalization")
        add("\n".join(zero_logs))

    # 5) Referential integrity (vectorized, per graph)
    # 每个边端点必须指向同一分子中的一个有效节点
    add("## Referential integrity (per graph_alias)")
    rows = []
    bad_records = []

    aliases_nodes = set(nodes_std["graph_alias"].dropna().unique())
    aliases_edges = set(edges_std["graph_alias"].dropna().unique())
    missing_in_nodes = sorted(aliases_edges - aliases_nodes)
    missing_in_edges = sorted(aliases_nodes - aliases_edges)
    if missing_in_nodes:
        add(f"- ⚠️ Edges contain graph_alias with no nodes: {missing_in_nodes[:10]}{' ...' if len(missing_in_nodes)>10 else ''}")
    if missing_in_edges:
        add(f"- ℹ️ Nodes contain graph_alias with no edges: {missing_in_edges[:10]}{' ...' if len(missing_in_edges)>10 else ''}")

    for alias in sorted(aliases_edges | aliases_nodes):
        g_nodes = nodes_std.loc[nodes_std["graph_alias"] == alias, ["node_id"]].dropna()
        g_edges = edges_std.loc[edges_std["graph_alias"] == alias, ["graph_alias","src","dst"]]
        V = set(g_nodes["node_id"].astype("Int64").dropna().astype(int).tolist())

        if len(g_edges) == 0 and len(g_nodes) == 0:
            continue

        if len(V) == 0 and len(g_edges) > 0:
            rows.append((alias, len(g_edges), len(g_edges), 0, len(g_edges)))
            bad = g_edges.copy(); bad["reason"] = "no_nodes_for_alias"; bad_records.append(bad); continue

        # Mark invalids
        sentinel = -10**9
        src_invalid_mask = ~g_edges["src"].astype("Int64").fillna(sentinel).astype(int).isin(V)
        dst_invalid_mask = ~g_edges["dst"].astype("Int64").fillna(sentinel).astype(int).isin(V)

        bad_src = int(src_invalid_mask.sum())
        bad_dst = int(dst_invalid_mask.sum())
        rows.append((alias, bad_src, bad_dst, len(V), len(g_edges)))

        if bad_src:
            b = g_edges[src_invalid_mask].copy(); b["reason"] = "src_not_in_nodes"; bad_records.append(b)
        if bad_dst:
            b = g_edges[dst_invalid_mask].copy(); b["reason"] = "dst_not_in_nodes"; bad_records.append(b)

    add("| graph_alias | bad_src | bad_dst | |V| nodes | |E| edges |")
    add("|---|---:|---:|---:|---:|")
    for alias, bsrc, bdst, nV, nE in rows:
        add(f"| {alias} | {bsrc} | {bdst} | {nV} | {nE} |")

    if bad_records:
        bad_df = pd.concat(bad_records, ignore_index=True)
        bad_path = Path("invalid_edge_references.csv")
        bad_df.to_csv(bad_path, index=False)
        add(f"- 🔍 Detailed invalid edges saved to: {bad_path.as_posix()}")
    else:
        add("- ✅ No invalid edge references found.")

    # 6) Undirected canonicalization & dedup
    if REMOVE_SELF_LOOPS:
        before = len(edges_std)
        edges_std = edges_std[edges_std["src"] != edges_std["dst"]].copy()
        add(f"- Self-loops removed: {before - len(edges_std)}")

    pre_n = len(edges_std)
    edges_std["src_fix"] = edges_std[["src","dst"]].min(axis=1).astype("Int64")
    edges_std["dst_fix"] = edges_std[["src","dst"]].max(axis=1).astype("Int64")
    edges_std = edges_std.drop_duplicates(subset=["graph_alias","src_fix","dst_fix"])
    edges_std[["src","dst"]] = edges_std[["src_fix","dst_fix"]]
    edges_std = edges_std.drop(columns=["src_fix","dst_fix"])
    post_n = len(edges_std)
    add("## Undirected normalization")
    add(f"- canonical src<=dst; duplicates removed: {pre_n - post_n}")

    # 7) Data quality overview (no imputation; just reporting)
    nodes_q = quality_table(nodes_std)
    edges_q = quality_table(edges_std)
    # (You can also write nodes_q/edges_q to CSV if you’d like—kept in report only.)

    add("## Data quality (nulls/dtypes/non-finite)")
    # Log a few top offenders by null%
    top_nodes = nodes_q.sort_values(["null_%","nonfinite_count"], ascending=[False, False]).head(10)
    top_edges = edges_q.sort_values(["null_%","nonfinite_count"], ascending=[False, False]).head(10)
    add("- Top nodes columns by null%:")
    add(top_nodes.to_string())
    add("- Top edges columns by null%:")
    add(top_edges.to_string())

    add("## NaN policy")
    add("- NaNs retained at intake; no imputation at this stage.")
    add("- Later: impute + add missingness masks; fit scalers on observed values only.")

    # 8) Feature discovery
    def numeric_feature_cols(df: pd.DataFrame, exclude: set) -> list:
        cols = []
        for c in df.columns:
            if c in exclude: 
                continue
            if pd.api.types.is_numeric_dtype(df[c]):
                cols.append(c)
            else:
                coerced = pd.to_numeric(df[c], errors="coerce")
                if len(df)>0 and coerced.notna().mean() > 0.9:
                    cols.append(c)
        return cols

    exclude_nodes = {"graph_alias", "node_id"}
    exclude_edges = {"graph_alias", "src", "dst"}
    node_feat_cols = numeric_feature_cols(nodes_std, exclude_nodes)
    edge_feat_cols = numeric_feature_cols(edges_std, exclude_edges)

    add("## Detected numeric feature columns")
    add(f"- Node features (count={len(node_feat_cols)}): {node_feat_cols[:24]}{' ...' if len(node_feat_cols)>24 else ''}")
    add(f"- Edge features (count={len(edge_feat_cols)}): {edge_feat_cols[:24]}{' ...' if len(edge_feat_cols)>24 else ''}")

    # 9) Save standardized outputs (current directory)
    saved_nodes = saved_edges = False
    if "node_id" in nodes_std.columns:
        keep = ["graph_alias","node_id"] + node_feat_cols
        nodes_std[keep].to_csv("nodes_standardized.csv", index=False)
        saved_nodes = True
    if {"src","dst"}.issubset(edges_std.columns):
        keep = ["graph_alias","src","dst"] + edge_feat_cols
        edges_std[keep].to_csv("edges_standardized.csv", index=False)
        saved_edges = True

    if saved_nodes:
        add("✅ Saved standardized nodes → nodes_standardized.csv")
    else:
        add("❌ Nodes could not be standardized (missing node_id).")
    if saved_edges:
        add("✅ Saved standardized edges → edges_standardized.csv")
    else:
        add("❌ Edges could not be standardized (missing src/dst).")

    # 10) Write report
    Path("validation_report.txt").write_text("\n".join(report), encoding="utf-8")
    print("\n".join(report))

if __name__ == "__main__":
    main()


# GNN Intake Validation Report (alias, undirected, NaN-kept)
## Column mappings
- graph key: graph_alias
- node id: node_id ← atom_idx
- edge ends: src ← a1, dst ← a2
## Node-ID health
- node_id nulls: 0; duplicates (table): 0
## Referential integrity (per graph_alias)
- ℹ️ Nodes contain graph_alias with no edges: ['C24']
| graph_alias | bad_src | bad_dst | |V| nodes | |E| edges |
|---|---:|---:|---:|---:|
| C23 | 0 | 0 | 65 | 67 |
| C24 | 0 | 0 | 68 | 0 |
| C25 | 0 | 0 | 71 | 73 |
| C26 | 0 | 0 | 74 | 76 |
| C2822R | 0 | 0 | 80 | 82 |
| C2822S | 0 | 0 | 74 | 76 |
| C29 | 0 | 0 | 79 | 83 |
| C2922R | 0 | 0 | 83 | 85 |
| C2922S | 0 | 0 | 83 | 85 |
| C30 | 0 | 0 | 82 | 86 |
| H31R | 0 | 0 | 85 | 89 |
| H31S | 0 | 0 | 85 | 89 |
| H32R | 0 | 0 | 88 | 92 |
| H32S | 0 | 0 | 88 | 92 |
| H33R | 0 | 0 | 91 | 95 |
| H33S | 0 | 0 | 91 | 95 |
| H34R | 0 | 0 | 94 | 98 |
| H34S | 0 | 0 | 94 | 98 |
| Tm | 0 | 0 | 73 | 77 |
| Ts | 0 | 0 | 73 | 77 |
- ✅ No invalid edge references found.
## Undirected n

---

In [None]:
# ---------- Optional: enforce 0-based indexing for atom indices ----------
# RDKit uses 0-based atom indices, but some quantum chemistry outputs use 1-based.
# This block automatically checks and shifts if necessary.

def ensure_zero_based(nodes_df, edges_df):
    """Ensure node_id, src, dst start at 0 for each graph_alias."""
    for alias, g_nodes in nodes_df.groupby("graph_alias"):
        min_id = g_nodes["node_id"].min()
        if pd.notna(min_id) and min_id == 1:
            print(f"🔧 Detected 1-based indexing in nodes for graph {alias} — shifting to 0-based.")
            # shift nodes
            nodes_df.loc[nodes_df["graph_alias"] == alias, "node_id"] -= 1
            # shift corresponding edges
            mask = edges_df["graph_alias"] == alias
            edges_df.loc[mask, "src"] -= 1
            edges_df.loc[mask, "dst"] -= 1
    return nodes_df, edges_df

# Apply the correction
nodes_std, edges_std = ensure_zero_based(nodes_std, edges_std)


In [None]:
# 4. 每个边端点必须指向同一分子中的一个有效节点
# ---------- 4) Per-graph referential integrity (vectorized) ----------

rows = []

# 4.1 detect graph_alias sets present in nodes vs edges
aliases_nodes = set(nodes_std["graph_alias"].dropna().unique())
aliases_edges = set(edges_std["graph_alias"].dropna().unique())

missing_in_nodes = sorted(aliases_edges - aliases_nodes)   # edges have graphs that nodes don't
missing_in_edges = sorted(aliases_nodes - aliases_edges)   # nodes have graphs that edges don't

if missing_in_nodes:
    add(f"- ⚠️ Edges contain graph_alias with **no nodes**: {missing_in_nodes[:10]}{' ...' if len(missing_in_nodes)>10 else ''}")
if missing_in_edges:
    add(f"- ℹ️ Nodes contain graph_alias with **no edges**: {missing_in_edges[:10]}{' ...' if len(missing_in_edges)>10 else ''}")

# 4.2 summary per graph_alias
bad_records = []  # collect bad rows for optional CSV

for alias in sorted(aliases_edges | aliases_nodes):
    g_nodes = nodes_std.loc[nodes_std["graph_alias"] == alias, ["node_id"]].dropna()
    g_edges = edges_std.loc[edges_std["graph_alias"] == alias, ["graph_alias","src","dst"]]

    V = set(g_nodes["node_id"].astype("Int64").dropna().astype(int).tolist())

    if len(g_edges) == 0 and len(g_nodes) == 0:
        continue  # nothing to check

    if len(V) == 0 and len(g_edges) > 0:
        # all edges are invalid because the node set is empty
        rows.append((alias, len(g_edges), len(g_edges), 0, len(g_edges)))
        # stash all bad rows
        bad = g_edges.copy()
        bad["reason"] = "no_nodes_for_alias"
        bad_records.append(bad)
        continue

    # Vectorized membership tests (nullable ints -> coerce, fill with sentinel)
    src_ok = g_edges["src"].astype("Int64").dropna().astype(int).isin(V)
    dst_ok = g_edges["dst"].astype("Int64").dropna().astype(int).isin(V)

    # Align back to original index to mark invalids
    src_invalid_mask = ~g_edges["src"].astype("Int64").astype("Int64").fillna(-10**9).astype(int).isin(V)
    dst_invalid_mask = ~g_edges["dst"].astype("Int64").astype("Int64").fillna(-10**9).astype(int).isin(V)

    bad_src = int(src_invalid_mask.sum())
    bad_dst = int(dst_invalid_mask.sum())
    rows.append((alias, bad_src, bad_dst, len(V), len(g_edges)))

    # stash bad rows with reasons for debugging
    if bad_src:
        b = g_edges[src_invalid_mask].copy(); b["reason"] = "src_not_in_nodes"; bad_records.append(b)
    if bad_dst:
        b = g_edges[dst_invalid_mask].copy(); b["reason"] = "dst_not_in_nodes"; bad_records.append(b)


In [None]:
# 5.无向图边检验
if {"src","dst"}.issubset(edges_std.columns):
    edges_std["src_fix"] = edges_std[["src","dst"]].min(axis=1).astype("Int64")
    edges_std["dst_fix"] = edges_std[["src","dst"]].max(axis=1).astype("Int64")
    pre_n = len(edges_std)
    edges_std = edges_std.drop_duplicates(subset=["graph_alias","src_fix","dst_fix"])
    post_n = len(edges_std)
    add("")
    add("## Undirected canonicalization")
    add(f"- Canonical order `src<=dst` enforced. Duplicates removed: {pre_n - post_n}.")
    edges_std[["src","dst"]] = edges_std[["src_fix","dst_fix"]]
    edges_std = edges_std.drop(columns=["src_fix","dst_fix"])


In [None]:
# 6.NAN CHECK
def nulls_table(df: pd.DataFrame) -> pd.DataFrame:
    n = df.isna().sum()
    pct = (n / len(df) * 100).round(2) if len(df)>0 else 0.0
    dtypes = df.dtypes.astype(str)
    return pd.DataFrame({"nulls": n, "null_%": pct, "dtype": dtypes}).sort_values("null_%", ascending=False)

nodes_nulls = nulls_table(nodes_std)
edges_nulls = nulls_table(edges_std)
display_dataframe_to_user("Nodes — nulls & dtypes (post-normalization)", nodes_nulls)
display_dataframe_to_user("Edges — nulls & dtypes (post-normalization)", edges_nulls)

add("")
add("## NaN policy (current phase)")
add("- NaNs retained as-is. No imputation at intake.")
add("- TODO (later): impute + add missingness masks; fit scalers on observed values only.")


---