In [1]:
import os
import re
import json
import numpy as np
import pandas as pd

def graph_based_summary(data_path, dataset_dir):
    ds_path = os.path.join(data_path, dataset_dir)
    if not os.path.isdir(ds_path):
        raise FileNotFoundError(f"'{dataset_dir}' directory not found under {data_path}")

    model_scores = {}

    for run_dir in os.listdir(ds_path):
        run_path = os.path.join(ds_path, run_dir)
        if not os.path.isdir(run_path):
            continue

        name = run_dir  # keep original case
        lname = name.lower()

        # include ONLY 'graph_based' and EXCLUDE any 'graph_based+'
        # if 'graph_based+' in lname:
        #     continue
        # if 'graph_based' not in lname:
        #     continue
        if 'graph_based+' not in lname:
            continue
        # (optional extra safety if names can be tricky with suffixes)
        # if not re.search(r'(^|[^a-z0-9])graph_based([^a-z0-9]|$)', lname):
        #     continue

        res_path = os.path.join(run_path, "results.json")
        if not os.path.isfile(res_path):
            continue
        with open(res_path, "r") as f:
            res = json.load(f)

        model_name = res.get("model_name", run_dir)
        model_scores.setdefault(model_name, []).append(res["f1_macro"])

    if not model_scores:
        raise RuntimeError(f"No 'graph_based' runs found in {dataset_dir}.")

    if "GL_LR" in model_scores:
        print("GL_LR scores:", model_scores["GL_LR"])

    for m, scores in model_scores.items():
        if len(scores) != 10:
            raise ValueError(f"Model '{m}' expected 10 runs, found {len(scores)}.")

    rows = [
        {"Model": m, "Mean": float(np.mean(s)), "Std": float(np.std(s))}
        for m, s in model_scores.items()
    ]

    return (
        pd.DataFrame(rows)
        .sort_values("Mean", ascending=False)
        .reset_index(drop=True)
    )


In [8]:
df = graph_based_summary("/disk/10tb/home/shmelev/GENLINK/downstream_tasks/runs/real_data_no_mask/", "Western-Europe") # Western-Europe, CR, NC_graph_rel_eng, Scandinavia, Volga
df

Unnamed: 0,Model,Mean,Std
0,GL_SAGEConv_3l_128h,0.947606,0.011066
1,GL_MLP_3l_512h,0.947249,0.009418
2,GL_MLP_9l_128h,0.947027,0.008729
3,GL_MLP_9l_512h,0.945824,0.007775
4,GL_MLP_3l_128h,0.944334,0.009005
5,GL_SAGEConv_3l_512h,0.943992,0.010663
6,GL_TAGConv_3l_512h_w_k3_gnorm,0.942541,0.010734
7,GL_TAGConv_3l_512h_w_k3,0.942443,0.010329
8,GL_TAGConv_3l_512h_w_k3_gnorm_relu,0.942358,0.010002
9,GL_TAGConv_3l_512h_w_k3_gnorm_leaky_relu,0.941835,0.009046


In [3]:
# ---------- Example notebook cell ----------
data_root = "/disk/10tb/home/shmelev/GENLINK/downstream_tasks/runs/real_data_no_mask/"
summary_df = cr_graph_based_plus_summary(data_root)
summary_df


Unnamed: 0,Model,Mean,Std
0,GL_TAGConv_3l_512h_w_k3_gnorm,0.614577,0.016623
1,GL_TAGConv_3l_512h_nw_k3_gnorm_leaky_relu,0.612924,0.015017
2,GL_TAGConv_3l_512h_w_k3,0.612449,0.016688
3,GL_TAGConv_3l_512h_nw_k3_gnorm_relu,0.612396,0.013504
4,GL_TAGConv_3l_512h_w_k3_gnorm_leaky_relu,0.611929,0.018301
5,GL_TAGConv_3l_512h_w_k3_gnorm_relu,0.611912,0.015999
6,GL_MLP_9l_512h,0.610913,0.022341
7,GL_TAGConv_3l_512h_nw_k3_gnorm_gelu,0.610308,0.014395
8,GL_TAGConv_3l_512h_nw_k3_gnorm,0.610303,0.015755
9,GL_TAGConv_3l_512h_w_k3_gnorm_gelu,0.610284,0.01769


# GNN statistics

In [1]:
# JUPYTER CELL â€” analyze all torch geometric models in a .py file

import importlib.util
import inspect
import sys
import ast
import textwrap
from collections import Counter, defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F

import pandas as pd

# ---- torch_geometric basics
from torch_geometric.data import Data
from torch_geometric.nn.conv import MessagePassing  # most PyG convs derive from this

# ======================
# 1) CONFIG: your models file
# ======================
# ðŸ”§ EDIT THIS to your actual file path (relative or absolute)
MODELS_FILE = "/disk/10tb/home/shmelev/GENLINK/utils/models.py"  # e.g., "models/gnn_zoo.py"

# ======================
# 2) Dynamic import
# ======================
spec = importlib.util.spec_from_file_location("user_models", MODELS_FILE)
module = importlib.util.module_from_spec(spec)
sys.modules["user_models"] = module
spec.loader.exec_module(module)

# ======================
# 3) Build a placeholder Data object
# ======================
def build_dummy_data(num_nodes=8, num_features=4, num_classes=4):
    # simple undirected ring graph with weights
    src = torch.arange(num_nodes)
    dst = (src + 1) % num_nodes
    edge_index = torch.stack([torch.cat([src, dst]), torch.cat([dst, src])], dim=0)

    x = torch.randn(num_nodes, num_features)
    edge_weight = torch.ones(edge_index.shape[1], dtype=torch.float)
    edge_attr = torch.ones(edge_index.shape[1], 1, dtype=torch.float)  # for convs that want edge_attr

    data = Data(x=x, edge_index=edge_index)
    data.weight = edge_weight           # some of your models use data.weight as edge weights
    data.edge_weight = edge_weight      # common PyG name
    data.edge_attr = edge_attr          # for NNConv/TransformerConv variants that expect edge features
    data.batch = torch.zeros(num_nodes, dtype=torch.long)  # for pooling ops if any

    # Some of your __init__ read these:
    data.num_classes = num_classes      # not a built-in Data property; we define it here
    # num_features is auto-derived from x.shape[-1] via Data.num_features property, so OK.

    return data

dummy_data = build_dummy_data(num_nodes=8, num_features=4, num_classes=4)

# ======================
# 4) Helpers
# ======================

# Which modules count as "learnable layers" (for depth)?
LEARNABLE_LAYER_TYPES = (nn.Linear, MessagePassing)

# Known activation module classes (nn.*)
_ACTIVATION_MODULE_NAMES = {
    "ReLU","ReLU6","LeakyReLU","ELU","SELU","CELU","GELU","SiLU","Swish",
    "Tanh","Sigmoid","Softplus","Softsign","Hardtanh","Hardsigmoid","Hardswish",
    "LogSoftmax","Softmax","PReLU","RReLU","Tanhshrink","Mish","Hardshrink"
}

def is_activation_module(m: nn.Module) -> bool:
    return type(m).__name__ in _ACTIVATION_MODULE_NAMES

# Extract activations used via torch.nn.functional.* / F.*
_ACTIVATION_F_NAMES = {
    "relu","relu6","leaky_relu","elu","selu","celu","gelu","silu",
    "tanh","sigmoid","softplus","softsign","hardtanh","hardsigmoid","hardswish",
    "log_softmax","softmax","prelu","rrelu","tanhshrink","mish","hardshrink"
}

def dotted_name_from_ast(node):
    """Reconstruct dotted name from an ast.Attribute/Name chain."""
    parts = []
    while isinstance(node, ast.Attribute):
        parts.append(node.attr)
        node = node.value
    if isinstance(node, ast.Name):
        parts.append(node.id)
    parts.reverse()
    return ".".join(parts)

def extract_functional_activations_from_forward(cls):
    acts = set()
    try:
        src = textwrap.dedent(inspect.getsource(cls.forward))
    except OSError:
        return acts  # source not available
    try:
        tree = ast.parse(src)
    except SyntaxError:
        return acts

    for node in ast.walk(tree):
        if isinstance(node, ast.Call):
            func = node.func
            if isinstance(func, ast.Attribute) or isinstance(func, ast.Name):
                dotted = dotted_name_from_ast(func)
                # We consider any call whose last piece is an activation name, and the prefix looks like F or functional
                last = dotted.split(".")[-1]
                prefix = ".".join(dotted.split(".")[:-1])
                if last.lower() in _ACTIVATION_F_NAMES and (
                    prefix in {"F", "torch.nn.functional", "nn.functional", "torch.nn.functional.functional"}
                    or prefix.endswith(".functional")
                    or dotted.startswith("F.")
                ):
                    acts.add(last.lower())
    return acts

def count_params(model: nn.Module):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    return trainable, non_trainable

def run_with_hooks_and_capture_widths(model: nn.Module, data: Data, num_classes: int = 4):
    """
    Attach forward hooks to Linear and MessagePassing modules.
    Return (depth, widths_list) where widths_list are output feature sizes for each layer,
    and we will later derive a representative 'hidden width' from these.
    """
    widths = []
    handles = []

    def hook(_module, _inp, out):
        # out shape could be [N, C] (nodes), or a tuple for some layers. We try to get the feature size.
        # Prefer second dim if tensor; if tuple, grab first tensor-like element.
        t = out
        if isinstance(t, (tuple, list)):
            t = next((o for o in t if torch.is_tensor(o)), None)
        if torch.is_tensor(t):
            if t.dim() >= 2:
                widths.append(int(t.shape[-1]))
            elif t.dim() == 1:
                widths.append(int(t.shape[0]))  # fallback; unusual
        # else ignore

    for m in model.modules():
        if isinstance(m, LEARNABLE_LAYER_TYPES):
            handles.append(m.register_forward_hook(hook))

    model.eval()
    with torch.no_grad():
        try:
            _ = model(dummy_data)
        except Exception:
            # It's OK if forward fails for some model; we just won't have widths/depth for it.
            pass
    for h in handles:
        h.remove()

    # depth = how many learnable layers successfully produced an output
    depth = len(widths)

    # Compute a representative hidden width: consider widths that are neither input dim nor num_classes
    hidden_candidates = [w for w in widths if w != num_classes]
    # If everything equals num_classes (edge case), keep them anyway
    if not hidden_candidates:
        hidden_candidates = widths[:]
    # Pick the mode (most common). If tie, pick the max.
    if hidden_candidates:
        freq = Counter(hidden_candidates)
        max_count = max(freq.values())
        top = [w for w, c in freq.items() if c == max_count]
        rep_width = max(top)
    else:
        rep_width = None

    return depth, widths, rep_width

# ======================
# 5) Collect candidate classes
# ======================
all_classes = []
for name, obj in inspect.getmembers(module, inspect.isclass):
    # Only classes defined in this module (avoid imported PyTorch/PyG classes)
    if obj.__module__ == module.__name__ and issubclass(obj, nn.Module):
        all_classes.append((name, obj))

# ======================
# 6) Instantiate and analyze
# ======================
rows = []
errors = {}

for cls_name, cls in all_classes:
    info = {
        "model_name": cls_name,
        "depth": None,
        "width": None,
        "layer_widths_observed": None,
        "trainable_params": None,
        "non_trainable_params": None,
        "activations": None,
        "error": None,
    }
    try:
        # Attempt instantiation with the pattern you've shown: __init__(self, data)
        model = cls(dummy_data)
    except Exception as e:
        info["error"] = f"init: {type(e).__name__}: {e}"
        rows.append(info)
        errors[cls_name] = info["error"]
        continue

    # Count params
    tr, ntr = count_params(model)
    info["trainable_params"] = tr
    info["non_trainable_params"] = ntr

    # Depth + widths via forward hooks
    try:
        depth, widths, rep_width = run_with_hooks_and_capture_widths(model, dummy_data, num_classes=4)
        info["depth"] = depth
        info["width"] = rep_width
        info["layer_widths_observed"] = widths
    except Exception as e:
        info["error"] = f"hooks/forward: {type(e).__name__}: {e}"

    # Activations: both module-based and F.* functional calls
    acts_module = {type(m).__name__ for m in model.modules() if is_activation_module(m)}
    acts_f = extract_functional_activations_from_forward(cls)
    # Normalize names: module names already like 'ELU', function names are lower-case.
    acts_clean = set()
    acts_clean.update(acts_module)
    acts_clean.update(a.upper() for a in acts_f)
    info["activations"] = ", ".join(sorted(acts_clean)) if acts_clean else ""

    rows.append(info)

# ======================
# 7) Present results
# ======================
df = pd.DataFrame(rows)
# Order columns nicely
df = df[
    ["model_name", "depth", "width", "layer_widths_observed",
     "trainable_params", "non_trainable_params", "activations", "error"]
].sort_values("model_name").reset_index(drop=True)

# Pretty print
pd.set_option("display.max_colwidth", 200)
df


Unnamed: 0,model_name,depth,width,layer_widths_observed,trainable_params,non_trainable_params,activations,error
0,GL_ChebConv_3l_128h_w_k3,3.0,128.0,"[128, 128, 4]",52484.0,0.0,ELU,
1,GL_GATConv_3l_128h,4.0,256.0,"[256, 256, 256, 4]",269060.0,0.0,ELU,
2,GL_GATConv_3l_512h,4.0,1024.0,"[1024, 1024, 1024, 4]",4221956.0,0.0,ELU,
3,GL_GATConv_9l_128h,10.0,256.0,"[256, 256, 256, 256, 256, 256, 256, 256, 256, 4]",1063172.0,0.0,ELU,
4,GL_GATConv_9l_512h,10.0,1024.0,"[1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 4]",16835588.0,0.0,ELU,
5,GL_GCNConv_3l_128h_nw,3.0,128.0,"[128, 128, 4]",17668.0,0.0,ELU,
6,GL_GCNConv_3l_128h_w,3.0,128.0,"[128, 128, 4]",17668.0,0.0,ELU,
7,GL_GCNConv_3l_32h_nw,3.0,32.0,"[32, 32, 4]",1348.0,0.0,ELU,
8,GL_GCNConv_3l_512h_nw,3.0,512.0,"[512, 512, 4]",267268.0,0.0,ELU,
9,GL_GCNConv_3l_512h_w,3.0,512.0,"[512, 512, 4]",267268.0,0.0,ELU,


In [2]:
# Filter, format, and emit LaTeX (Booktabs) for your model stats DataFrame `df`

import pandas as pd
import math
import re

# --------- CONFIGURE caption/label if you like ---------
CAPTION = "Architecture statistics for selected widths (128/256/512/1024)."
LABEL = "tab:model_arch_stats"

# --------- Helper: LaTeX-escape cell contents ---------
_LATEX_SUBS = [
    ("\\", r"\textbackslash{}"),
    ("{", r"\{"),
    ("}", r"\}"),
    ("$", r"\$"),
    ("&", r"\&"),
    ("#", r"\#"),
    ("_", r"\_"),
    ("%", r"\%"),
    ("~", r"\textasciitilde{}"),
    ("^", r"\textasciicircum{}"),
]

def latex_escape(s):
    if s is None:
        return ""
    s = str(s)
    for a, b in _LATEX_SUBS:
        s = s.replace(a, b)
    # collapse excessive whitespace
    return re.sub(r"\s+", " ", s).strip()

# --------- 1) Filter by width ---------
keep_widths = {128, 256, 512, 1024}
df2 = df.copy()

# Ensure numeric width (coerce bad values to NaN, then filter)
df2["width"] = pd.to_numeric(df2["width"], errors="coerce")
df2 = df2[df2["width"].isin(keep_widths)].copy()

# --------- 2) Drop requested columns (if present) ---------
for col in ["error", "layer_widths_observed"]:
    if col in df2.columns:
        df2.drop(columns=col, inplace=True)

# --------- 3) Select & order columns for the table ---------
wanted_cols = ["model_name", "depth", "width", "trainable_params", "non_trainable_params", "activations"]
existing_cols = [c for c in wanted_cols if c in df2.columns]
df2 = df2[existing_cols]

# Sort for readability
sort_cols = [c for c in ["width", "depth", "model_name"] if c in df2.columns]
if sort_cols:
    df2 = df2.sort_values(sort_cols, kind="mergesort")

# --------- 4) Nicely format numbers ---------
def fmt_int(x):
    if x is None or (isinstance(x, float) and math.isnan(x)):
        return ""
    try:
        return f"{int(x):,}"
    except Exception:
        return str(x)

if "trainable_params" in df2.columns:
    df2["trainable_params"] = df2["trainable_params"].apply(fmt_int)
if "non_trainable_params" in df2.columns:
    df2["non_trainable_params"] = df2["non_trainable_params"].apply(fmt_int)

if "depth" in df2.columns:
    df2["depth"] = df2["depth"].apply(lambda v: "" if pd.isna(v) else f"{int(v)}")
if "width" in df2.columns:
    df2["width"] = df2["width"].apply(lambda v: "" if pd.isna(v) else f"{int(v)}")

# --------- 5) Build LaTeX (Booktabs) table ---------
# Column header names (pretty)
header_names = {
    "model_name": r"\textbf{Model}",
    "depth": r"\textbf{Depth}",
    "width": r"\textbf{Width}",
    "trainable_params": r"\textbf{Trainable}",
    "non_trainable_params": r"\textbf{Non-train}",
    "activations": r"\textbf{Activations}",
}
# Column alignment: l c c r r l
align_map = {
    "model_name": "l",
    "depth": "c",
    "width": "c",
    "trainable_params": "r",
    "non_trainable_params": "r",
    "activations": "l",
}
align_spec = "@{}" + "".join(align_map[c] for c in existing_cols) + "@{}"

# Header row
header_row = " & ".join(header_names[c] for c in existing_cols) + r" \\"

# Data rows
rows = []
for _, r in df2.iterrows():
    cells = []
    for c in existing_cols:
        cells.append(latex_escape(r[c]))
    rows.append(" & ".join(cells) + r" \\")
body_rows = "\n    ".join(rows) if rows else r"\multicolumn{6}{c}{\emph{No rows after filtering by width.}} \\"

latex_table = fr"""\begin{{table}}[h]
  \centering
  \caption{{{latex_escape(CAPTION)}}}
  \label{{{latex_escape(LABEL)}}}
  \begin{{tabular}}{{{align_spec}}}
    \toprule
    {header_row}
    \midrule
    {body_rows}
    \bottomrule
  \end{{tabular}}
\end{{table}}"""

print(latex_table)


\begin{table}[h]
  \centering
  \caption{Architecture statistics for selected widths (128/256/512/1024).}
  \label{tab:model\_arch\_stats}
  \begin{tabular}{@{}lccrrl@{}}
    \toprule
    \textbf{Model} & \textbf{Depth} & \textbf{Width} & \textbf{Trainable} & \textbf{Non-train} & \textbf{Activations} \\
    \midrule
    GL\_ChebConv\_3l\_128h\_w\_k3 & 3 & 128 & 52,484 & 0 & ELU \\
    GL\_GCNConv\_3l\_128h\_nw & 3 & 128 & 17,668 & 0 & ELU \\
    GL\_GCNConv\_3l\_128h\_w & 3 & 128 & 17,668 & 0 & ELU \\
    GL\_MLP\_3l\_128h & 3 & 128 & 17,668 & 0 & ELU \\
    GL\_SAGEConv\_3l\_128h & 3 & 128 & 35,076 & 0 & ELU \\
    GL\_SSGConv\_3l\_128h\_w\_a05\_k1 & 3 & 128 & 17,668 & 0 & ELU \\
    GL\_SSGConv\_3l\_128h\_w\_a09\_k1 & 3 & 128 & 17,668 & 0 & ELU \\
    GL\_TAGConv\_3l\_128h\_nw\_k3 & 3 & 128 & 69,892 & 0 & ELU \\
    GL\_TAGConv\_3l\_128h\_w\_k3 & 3 & 128 & 69,892 & 0 & ELU \\
    GL\_GCNConv\_9l\_128h\_nw & 9 & 128 & 116,740 & 0 & ELU \\
    GL\_GCNConv\_9l\_128h\_w & 9 & 128 & 116,7