In [11]:
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.data import Data

# (Optional) safe one-time patch; or remove entirely
if getattr(torch.load, "__name__", "") != "_torch_load_allow_pickle":
    _torch_load_orig = torch.load
    def _torch_load_allow_pickle(*args, **kwargs):
        kwargs.setdefault("weights_only", False)
        return _torch_load_orig(*args, **kwargs)
    torch.load = _torch_load_allow_pickle

dataset = PygNodePropPredDataset(name="ogbn-proteins")
data: Data = dataset[0]
split_idx = dataset.get_idx_split()

def make_khop_subgraph(data, split_idx, target_nodes=50_000, max_k=4, seed_idx=None):
    train_idx = split_idx["train"]
    if seed_idx is None:
        seed_idx = train_idx[torch.randint(0, train_idx.numel(), (1,)).item()]
    # --- FIX: use node_idx ---
    node_idx = torch.tensor([seed_idx], dtype=torch.long)

    for k in range(1, max_k + 1):
        sub_nodes, sub_edge_index, _, edge_mask = k_hop_subgraph(
            node_idx=node_idx, num_hops=k, edge_index=data.edge_index, relabel_nodes=True
        )
        if sub_nodes.numel() >= target_nodes or k == max_k:
            break

    x_sub = getattr(data, "x", None)
    if x_sub is not None:
        x_sub = x_sub[sub_nodes]

    y_sub = getattr(data, "y", None)
    if y_sub is not None:
        y_sub = y_sub[sub_nodes]

    edge_attr = getattr(data, "edge_attr", None)
    if edge_attr is not None:
        edge_attr = edge_attr[edge_mask]

    sub_data = Data(
        x=x_sub,
        edge_index=sub_edge_index,
        edge_attr=edge_attr,
        y=y_sub,
        num_nodes=sub_nodes.numel(),
    )
    sub_data.original_n_id = sub_nodes

    # remap splits
    remap = -torch.ones(data.num_nodes, dtype=torch.long)
    remap[sub_nodes] = torch.arange(sub_nodes.numel())
    split_sub = {}
    for part in ["train", "valid", "test"]:
        idx = split_idx[part]
        kept = remap[idx]
        split_sub[part] = kept[kept >= 0]

    return sub_data, split_sub

sub_data, split_sub = make_khop_subgraph(data, split_idx, target_nodes=1000, max_k=400)

bundle = {"data": sub_data, "split_idx": split_sub,
          "meta": {"source": "ogbn-proteins",
                   "note": "Connected k-hop subgraph around a training node; indices remapped"}}
torch.save(bundle, "ogbn_proteins_subgraph.pt")
print("Saved:", sub_data.num_nodes, "nodes,", sub_data.edge_index.size(1), "edges")


Saved: 10824 nodes, 5754852 edges


In [12]:
import networkx as nx
from torch_geometric.utils import to_networkx

# Convert to NetworkX (undirected, if you want a simple graph view)
G = to_networkx(sub_data, to_undirected=True)

print(G)  # basic info
print(f'Number of connected components:{nx.number_connected_components(G)}')

Graph with 10824 nodes and 2877426 edges
Number of connected components:1
