In [1]:
# Purpose: import packages, set device, set a reproducible seed, print versions

import os, time, math, random
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F

import networkx as nx

from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_networkx
from torch_geometric.nn import GCNConv

# Reproducibility
def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"torch: {torch.__version__}")
try:
    import torch_geometric as tg
    print(f"pyg: {tg.__version__}")
except Exception as e:
    print("pyg import error:", e)

print("device:", device)


torch: 2.4.1
pyg: 2.6.1
device: cpu


In [2]:
# Purpose: download Cora via PyG and report core stats

DATA_ROOT = "./data"
dataset = Planetoid(root=DATA_ROOT, name="Cora")
data = dataset[0].to(device)

def report_data_summary(d):
    # undirected unique edges count
    ei = d.edge_index.cpu().numpy()
    undirected = set()
    for u, v in zip(ei[0], ei[1]):
        if u == v:
            continue
        a, b = (int(u), int(v)) if u < v else (int(v), int(u))
        undirected.add((a, b))

    print("Cora()")
    print(f"Nodes: {d.num_nodes}  | Edges (directed count): {d.edge_index.size(1)}  | Feats: {d.x.size(1)} | Classes: {dataset.num_classes}")
    print(f"Undirected unique edges: {len(undirected)}")
    print(f"Splits: train={int(d.train_mask.sum())}, val={int(d.val_mask.sum())}, test={int(d.test_mask.sum())}")

report_data_summary(data)


Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...


Cora()
Nodes: 2708  | Edges (directed count): 10556  | Feats: 1433 | Classes: 7
Undirected unique edges: 5278
Splits: train=140, val=500, test=1000


Done!


In [3]:
# Purpose: helper functions for Largest Connected Component percent and accuracy

def lcc_percent(d):
    G = to_networkx(d, to_undirected=True)
    comps = list(nx.connected_components(G))
    if not comps:
        return 0.0, 0
    lcc = max(comps, key=len)
    return 100.0 * len(lcc) / d.num_nodes, len(comps)

@torch.no_grad()
def accuracy(logits, y, mask):
    pred = logits.argmax(dim=-1)
    correct = (pred[mask] == y[mask]).float().mean().item() if mask.sum() > 0 else 0.0
    return correct


In [4]:
# Purpose: define a simple 2-layer GCN and a trainer that times epochs and returns metrics

class GCN(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim, dropout=0.5):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hid_dim)
        self.conv2 = GCNConv(hid_dim, out_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

def train_model(model, d, epochs=200, lr=0.01, wd=5e-4, log_every=20):
    model = model.to(device)
    d = d.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

    epoch_times = []
    best_val = -1.0
    best_snap = None
    t0 = time.time()

    for ep in range(1, epochs + 1):
        t_ep0 = time.time()
        model.train()
        opt.zero_grad()

        out = model(d.x, d.edge_index)
        loss = F.cross_entropy(out[d.train_mask], d.y[d.train_mask])
        loss.backward()
        opt.step()

        model.eval()
        with torch.no_grad():
            out = model(d.x, d.edge_index)
            acc_train = accuracy(out, d.y, d.train_mask)
            acc_val   = accuracy(out, d.y, d.val_mask)
            acc_test  = accuracy(out, d.y, d.test_mask)

        if acc_val > best_val:
            best_val = acc_val
            best_snap = dict(train=acc_train, val=acc_val, test=acc_test)

        t_ep = time.time() - t_ep0
        epoch_times.append(t_ep)

        if ep == 1 or ep % log_every == 0 or ep == epochs:
            print(f"Epoch {ep:03d} | Loss {loss.item():.3f} | Train {acc_train:.3f} Val {acc_val:.3f} Test {acc_test:.3f}")

    total_time = time.time() - t0
    avg_epoch_time = float(np.mean(epoch_times)) if epoch_times else 0.0

    result = dict(
        train=best_snap["train"],
        val=best_snap["val"],
        test=best_snap["test"],
        epochs=epochs,
        train_time=total_time,
        avg_epoch_time=avg_epoch_time,
        best_val=best_val
    )
    return result


In [5]:
# Purpose: utilities to work with undirected edges and build new edge_index

def unique_undirected(edge_index):
    ei = edge_index.detach().cpu().numpy()
    S = set()
    for u, v in zip(ei[0], ei[1]):
        if u == v:
            continue
        a, b = (int(u), int(v)) if u < v else (int(v), int(u))
        S.add((a, b))
    return list(S)

def build_edge_index_from_undirected(edge_pairs, device):
    e = []
    for u, v in edge_pairs:
        e.append([u, v])
        e.append([v, u])
    if len(e) == 0:
        # avoid empty tensor shape issues
        return torch.empty((2, 0), dtype=torch.long, device=device)
    return torch.tensor(e, dtype=torch.long, device=device).t().contiguous()


In [6]:
# Purpose: compute edge scores for trimming using either Jaccard or Cosine

def neighbor_sets(d):
    n = d.num_nodes
    N = [set() for _ in range(n)]
    ei = d.edge_index.detach().cpu().numpy()
    for u, v in zip(ei[0], ei[1]):
        if u == v:
            continue
        uu, vv = int(u), int(v)
        N[uu].add(vv)
        N[vv].add(uu)
    return N

def jaccard_scores(d, edges_u):
    N = neighbor_sets(d)
    out = []
    for u, v in edges_u:
        Nu, Nv = N[u], N[v]
        denom = len(Nu | Nv)
        s = 0.0 if denom == 0 else len(Nu & Nv) / denom
        out.append((s, u, v))
    out.sort(reverse=True, key=lambda x: x[0])
    return out

def cosine_scores(d, edges_u):
    X = d.x.detach().cpu().numpy().astype(np.float32)
    norms = np.linalg.norm(X, axis=1)
    out = []
    for u, v in edges_u:
        nu, nv = norms[u], norms[v]
        s = 0.0 if (nu == 0.0 or nv == 0.0) else float(np.dot(X[u], X[v]) / (nu * nv))
        out.append((s, u, v))
    out.sort(reverse=True, key=lambda x: x[0])
    return out


In [7]:
# Purpose: trim edges by top-K score, preserve hubs if needed, reconnect components once to the LCC

from collections import defaultdict

def degree_from_edges(edges_u, n):
    d = np.zeros(n, dtype=np.int64)
    for u, v in edges_u:
        d[u] += 1
        d[v] += 1
    return d

def trim_with_guards(
    d,
    method="cosine",            # 'cosine' or 'jaccard'
    keep_rate=0.80,             # fraction of undirected edges to keep
    hub_percent=0.10,           # fraction of top-degree nodes to protect
    reconnect_mode="bridge_once",  # currently only one reconnection per component
    effective_keep_cap=None     # optional cap on effective keep (e.g., 0.83)
):
    t0 = time.time()
    n = d.num_nodes
    edges_u = unique_undirected(d.edge_index)
    E = len(edges_u)
    target = max(1, int(E * keep_rate))

    # score
    if method == "cosine":
        scored = cosine_scores(d, edges_u)
    elif method == "jaccard":
        scored = jaccard_scores(d, edges_u)
    else:
        raise ValueError("Unknown method")

    kept = scored[:target]
    removed = scored[target:]

    kept_set = set((u, v) if u < v else (v, u) for _, u, v in kept)

    # hub preserve: ensure top-degree nodes are not left with degree 0 in kept_set
    deg0 = degree_from_edges(edges_u, n)
    H = max(0, int(n * hub_percent))
    hub_ids = np.argsort(-deg0)[:H] if H > 0 else []

    removed_by_node = defaultdict(list)
    for s, u, v in removed:
        a, b = (u, v) if u < v else (v, u)
        removed_by_node[u].append((s, a, b))
        removed_by_node[v].append((s, a, b))

    kept_deg = np.zeros(n, dtype=np.int64)
    for (u, v) in kept_set:
        kept_deg[u] += 1
        kept_deg[v] += 1

    for h in hub_ids:
        if kept_deg[h] == 0 and removed_by_node[h]:
            s, a, b = max(removed_by_node[h], key=lambda t: t[0])
            if (a, b) not in kept_set:
                kept_set.add((a, b))
                kept_deg[a] += 1
                kept_deg[b] += 1

    # reconnect each non-LCC component once if possible
    if reconnect_mode == "bridge_once":
        G = nx.Graph()
        G.add_nodes_from(range(n))
        G.add_edges_from(list(kept_set))
        comps = list(nx.connected_components(G))
        if len(comps) > 1:
            lcc = max(comps, key=len)
            L = set(lcc)
            for comp in comps:
                if comp is lcc:
                    continue
                best = None
                for x in comp:
                    for s, a, b in removed_by_node.get(x, []):
                        if ((a in comp and b in L) or (b in comp and a in L)) and ((a, b) not in kept_set):
                            if (best is None) or (s > best[0]):
                                best = (s, a, b)
                if best is not None:
                    _, a, b = best
                    kept_set.add((a, b))
                    G.add_edge(a, b)

    # optional cap to avoid too much inflation
    if effective_keep_cap is not None:
        cap_edges = int(E * effective_keep_cap)
        if len(kept_set) > cap_edges:
            # drop the lowest score non-bridge edges until meeting cap
            # build score map
            score_map = {}
            for s, u, v in scored:
                a, b = (u, v) if u < v else (v, u)
                if (a, b) not in score_map:
                    score_map[(a, b)] = s

            G = nx.Graph()
            G.add_nodes_from(range(n))
            G.add_edges_from(list(kept_set))

            # recompute bridges to avoid breaking connectivity further
            bridges = set(nx.bridges(G))

            # sort edges ascending by score
            candidates = sorted([e for e in kept_set if e not in bridges], key=lambda e: score_map.get(e, -1.0))
            for a, b in candidates:
                if len(kept_set) <= cap_edges:
                    break
                if G.has_edge(a, b):
                    G.remove_edge(a, b)
                    kept_set.remove((a, b))

    new_edge_index = build_edge_index_from_undirected(list(kept_set), d.edge_index.device)
    trim_time = time.time() - t0

    # make a Data just to compute LCC% and components
    tmp = type(data)(
        x=d.x, y=d.y, edge_index=new_edge_index,
        train_mask=d.train_mask, val_mask=d.val_mask, test_mask=d.test_mask
    ).cpu()
    lccp, comps = lcc_percent(tmp)

    stats = dict(
        orig_undirected=E,
        kept_undirected=len(kept_set),
        keep_rate_target=keep_rate,
        keep_rate_effective=len(kept_set) / E,
        LCC_percent=lccp,
        components=comps
    )
    return new_edge_index, trim_time, stats


In [8]:
# Purpose: define one place for all knobs used by the next cells

# Trimming method and guards
METHOD   = "cosine"       # 'cosine' or 'jaccard'
K_KEEP   = 0.80           # target keep-rate on undirected edges (e.g., 0.90, 0.80, 0.70)
HUB_P    = 0.10           # fraction of top-degree nodes to protect (0.05 or 0.10 worked well)
RECONNECT= "bridge_once"  # reconnection strategy
EFF_CAP  = None           # optional: e.g. 0.83 to cap effective keep if reconnect inflates too much

# Model and training
HIDDEN   = 64
DROPOUT  = 0.5
EPOCHS   = 200
LR       = 0.01
WD       = 5e-4

print("Params set. METHOD:", METHOD, "| K_KEEP:", K_KEEP, "| HUB_P:", HUB_P)


Params set. METHOD: cosine | K_KEEP: 0.8 | HUB_P: 0.1


In [9]:
# Purpose: train a baseline GCN on the original graph and record timings and accuracy

set_seed(0)
print("LCC% (original):", lcc_percent(data))

baseline_model = GCN(dataset.num_node_features, HIDDEN, dataset.num_classes, DROPOUT)
base = train_model(baseline_model, data, epochs=EPOCHS, lr=LR, wd=WD, log_every=20)

print("\nBASELINE (GCN on original)")
for k, v in base.items():
    if isinstance(v, float):
        print(f"{k}: {v:.4f}")
    else:
        print(f"{k}: {v}")


LCC% (original): (91.76514032496307, 78)
Epoch 001 | Loss 1.942 | Train 0.821 Val 0.516 Test 0.540
Epoch 020 | Loss 0.012 | Train 1.000 Val 0.768 Test 0.783
Epoch 040 | Loss 0.012 | Train 1.000 Val 0.768 Test 0.791
Epoch 060 | Loss 0.018 | Train 1.000 Val 0.770 Test 0.800
Epoch 080 | Loss 0.013 | Train 1.000 Val 0.752 Test 0.801
Epoch 100 | Loss 0.012 | Train 1.000 Val 0.768 Test 0.806
Epoch 120 | Loss 0.010 | Train 1.000 Val 0.760 Test 0.810
Epoch 140 | Loss 0.012 | Train 1.000 Val 0.764 Test 0.810
Epoch 160 | Loss 0.010 | Train 1.000 Val 0.770 Test 0.814
Epoch 180 | Loss 0.012 | Train 1.000 Val 0.770 Test 0.806
Epoch 200 | Loss 0.011 | Train 1.000 Val 0.764 Test 0.810

BASELINE (GCN on original)
train: 1.0000
val: 0.7920
test: 0.8080
epochs: 200
train_time: 2.6057
avg_epoch_time: 0.0130
best_val: 0.7920


In [10]:
# Purpose: trim edges using the chosen method and guards, then train the same model on the trimmed graph

set_seed(0)
new_edge_index, trim_time, stats = trim_with_guards(
    data,
    method=METHOD,
    keep_rate=K_KEEP,
    hub_percent=HUB_P,
    reconnect_mode=RECONNECT,
    effective_keep_cap=EFF_CAP
)
print("Trim stats:", stats, f"| trim_time: {trim_time:.2f}s")

# build a new Data with the trimmed edges
from torch_geometric.data import Data
trimmed = Data(
    x=data.x, y=data.y, edge_index=new_edge_index,
    train_mask=data.train_mask, val_mask=data.val_mask, test_mask=data.test_mask
).to(device)

set_seed(0)
trim_model = GCN(dataset.num_node_features, HIDDEN, dataset.num_classes, DROPOUT)
trim = train_model(trim_model, trimmed, epochs=EPOCHS, lr=LR, wd=WD, log_every=20)

print("\nTRIMMED (GCN on", METHOD, "kept graph + guards)")
for k, v in trim.items():
    if isinstance(v, float):
        print(f"{k}: {v:.4f}")
    else:
        print(f"{k}: {v}")
print(f"Total time (incl. trimming): {trim['train_time'] + trim_time:.2f}s")


Trim stats: {'orig_undirected': 5278, 'kept_undirected': 4358, 'keep_rate_target': 0.8, 'keep_rate_effective': 0.8256915498294809, 'LCC_percent': 90.76809453471196, 'components': 105} | trim_time: 0.04s
Epoch 001 | Loss 1.942 | Train 0.843 Val 0.534 Test 0.547
Epoch 020 | Loss 0.012 | Train 1.000 Val 0.740 Test 0.758
Epoch 040 | Loss 0.008 | Train 1.000 Val 0.744 Test 0.758
Epoch 060 | Loss 0.013 | Train 1.000 Val 0.750 Test 0.771
Epoch 080 | Loss 0.012 | Train 1.000 Val 0.732 Test 0.772
Epoch 100 | Loss 0.012 | Train 1.000 Val 0.736 Test 0.776
Epoch 120 | Loss 0.010 | Train 1.000 Val 0.744 Test 0.780
Epoch 140 | Loss 0.011 | Train 1.000 Val 0.736 Test 0.777
Epoch 160 | Loss 0.009 | Train 1.000 Val 0.732 Test 0.779
Epoch 180 | Loss 0.012 | Train 1.000 Val 0.742 Test 0.784
Epoch 200 | Loss 0.009 | Train 1.000 Val 0.734 Test 0.770

TRIMMED (GCN on cosine kept graph + guards)
train: 0.9929
val: 0.7700
test: 0.7720
epochs: 200
train_time: 2.4452
avg_epoch_time: 0.0122
best_val: 0.7700
Tota

In [11]:
# Purpose: print a clear summary comparing baseline and trimmed with percentages
# This cell assumes K_KEEP, HUB_P, base, trim, data, stats, trim_time all exist

def pct(a, b):
    return 0.0 if a == 0 else 100.0 * (b - a) / a

print("=== SUMMARY ===")
print(f"Keep-Rate target: {int(K_KEEP*100)}%  |  Hub-preserve: {int(HUB_P*100)}%")
print(f"LCC% original: {lcc_percent(data)[0]:.2f}%   LCC% trimmed: {stats['LCC_percent']:.2f}%")
print(f"Baseline Test Acc: {base['test']:.4f}   Trimmed Test Acc: {trim['test']:.4f}   Δ = {trim['test']-base['test']:.4f}")
print(
    f"Avg epoch time (baseline): {base['avg_epoch_time']:.4f}s   (trimmed): {trim['avg_epoch_time']:.4f}s   "
    f"~{pct(base['avg_epoch_time'], trim['avg_epoch_time']):+.1f}% change"
)
print(f"Total train time (baseline): {base['train_time']:.2f}s   (trimmed): {trim['train_time']:.2f}s")
print(f"Total incl. trimming: baseline={base['train_time']:.2f}s  trimmed={(trim['train_time']+trim_time):.2f}s")


=== SUMMARY ===
Keep-Rate target: 80%  |  Hub-preserve: 10%
LCC% original: 91.77%   LCC% trimmed: 90.77%
Baseline Test Acc: 0.8080   Trimmed Test Acc: 0.7720   Δ = -0.0360
Avg epoch time (baseline): 0.0130s   (trimmed): 0.0122s   ~-6.0% change
Total train time (baseline): 2.61s   (trimmed): 2.45s
Total incl. trimming: baseline=2.61s  trimmed=2.49s


In [12]:
# Purpose: quick reruns by changing METHOD or K_KEEP without touching earlier cells
# Example: switch to Jaccard at 0.90 keep

METHOD = "jaccard"
K_KEEP = 0.90
print("Re-running with METHOD:", METHOD, "K_KEEP:", K_KEEP)

set_seed(0)
new_edge_index, trim_time, stats = trim_with_guards(
    data,
    method=METHOD,
    keep_rate=K_KEEP,
    hub_percent=HUB_P,
    reconnect_mode=RECONNECT,
    effective_keep_cap=EFF_CAP
)
print("Trim stats:", stats, f"| trim_time: {trim_time:.2f}s")

trimmed = type(data)(
    x=data.x, y=data.y, edge_index=new_edge_index,
    train_mask=data.train_mask, val_mask=data.val_mask, test_mask=data.test_mask
).to(device)

set_seed(0)
trim_model = GCN(dataset.num_node_features, HIDDEN, dataset.num_classes, DROPOUT)
trim = train_model(trim_model, trimmed, epochs=EPOCHS, lr=LR, wd=WD, log_every=20)

print("\nTRIMMED (GCN on", METHOD, "kept graph + guards)")
for k, v in trim.items():
    if isinstance(v, float):
        print(f"{k}: {v:.4f}")
    else:
        print(f"{k}: {v}")
print(f"Total time (incl. trimming): {trim['train_time'] + trim_time:.2f}s")

# summary again
print("\n=== SUMMARY (rerun) ===")
print(f"Keep-Rate target: {int(K_KEEP*100)}%  |  Hub-preserve: {int(HUB_P*100)}%")
print(f"LCC% original: {lcc_percent(data)[0]:.2f}%   LCC% trimmed: {stats['LCC_percent']:.2f}%")
print(f"Baseline Test Acc: {base['test']:.4f}   Trimmed Test Acc: {trim['test']:.4f}   Δ = {trim['test']-base['test']:.4f}")
print(
    f"Avg epoch time (baseline): {base['avg_epoch_time']:.4f}s   (trimmed): {trim['avg_epoch_time']:.4f}s   "
    f"~{(100.0*(trim['avg_epoch_time']-base['avg_epoch_time'])/base['avg_epoch_time']):+.1f}% change"
)
print(f"Total train time (baseline): {base['train_time']:.2f}s   (trimmed): {trim['train_time']:.2f}s")
print(f"Total incl. trimming: baseline={base['train_time']:.2f}s  trimmed={(trim['train_time']+trim_time):.2f}s")


Re-running with METHOD: jaccard K_KEEP: 0.9
Trim stats: {'orig_undirected': 5278, 'kept_undirected': 4870, 'keep_rate_target': 0.9, 'keep_rate_effective': 0.9226979916635089, 'LCC_percent': 91.3589364844904, 'components': 104} | trim_time: 0.03s
Epoch 001 | Loss 1.940 | Train 0.807 Val 0.522 Test 0.526
Epoch 020 | Loss 0.012 | Train 1.000 Val 0.762 Test 0.770
Epoch 040 | Loss 0.013 | Train 1.000 Val 0.756 Test 0.781
Epoch 060 | Loss 0.017 | Train 1.000 Val 0.750 Test 0.789
Epoch 080 | Loss 0.013 | Train 1.000 Val 0.744 Test 0.788
Epoch 100 | Loss 0.013 | Train 1.000 Val 0.752 Test 0.803
Epoch 120 | Loss 0.010 | Train 1.000 Val 0.742 Test 0.800
Epoch 140 | Loss 0.012 | Train 1.000 Val 0.744 Test 0.793
Epoch 160 | Loss 0.011 | Train 1.000 Val 0.740 Test 0.796
Epoch 180 | Loss 0.012 | Train 1.000 Val 0.748 Test 0.800
Epoch 200 | Loss 0.010 | Train 1.000 Val 0.750 Test 0.790

TRIMMED (GCN on jaccard kept graph + guards)
train: 0.9929
val: 0.7740
test: 0.7900
epochs: 200
train_time: 2.5210
