In [1]:
import torch
from torch_geometric.data import Data
from torch_geometric.utils import k_hop_subgraph, to_undirected
import numpy as np

# ===== 1. 原始图与节点特征 =====
edge_index = torch.tensor([
    [0, 0, 1, 2, 3, 4],
    [1, 2, 2, 3, 4, 5]
], dtype=torch.long)
edge_index  = to_undirected(edge_index)
num_nodes = 6

# 随机初始化节点特征(比如embedding)，这里假设每个节点是4维embedding
embedding_dim = 4
torch.manual_seed(42)
node_emb = torch.randn(num_nodes, embedding_dim)

data = Data(x=node_emb, edge_index=edge_index)
print(data)

# ===== 2. k-hop子图采样 =====
src, dst = 2, 3
num_hops = 1
sub_node_index, sub_edge_index, mapping, _ = k_hop_subgraph(
    [src, dst], num_hops, data.edge_index, relabel_nodes=True
)

print("原图节点编号:", list(range(num_nodes)))
print("子图内节点对应原图编号:", sub_node_index.tolist())
print("src, dst 在子图的编号:", mapping.tolist())

# ===== 3. 子图节点 embedding =====
sub_x = data.x[sub_node_index]  # shape: [子图节点数, embedding_dim]
sub_src, sub_dst = mapping.tolist()

# ===== 4. 构建候选节点列表（不含src, dst）=====
candidates = [i for i in range(len(sub_node_index)) if i != sub_src and i != sub_dst]
print("候选节点（子图编号）:", candidates)

# ===== 5. 假定分数: 用 embedding 随机打分（可自定义复杂逻辑）=====
# 这里简单模拟，直接用 embedding 第一维度+少量噪声
np.random.seed(0)
scores = sub_x[candidates, 0] + torch.tensor(np.random.rand(len(candidates)))
print("候选节点分数:", scores.tolist())

# ===== 6. 选 top-k =====
top_k = 2
topk_indices = torch.topk(scores, top_k).indices.tolist()
topk_neighbors = [candidates[i] for i in topk_indices]
print("Top-k选中节点（子图编号）:", topk_neighbors)

# ===== 7. 汇总最终节点，重新编号 =====
final_nodes = [sub_src, sub_dst] + topk_neighbors
final_nodes = list(set(final_nodes))
final_nodes.sort()
print("最终节点（子图编号）:", final_nodes)

# 旧编号到新编号映射
node_id_map = {old: new for new, old in enumerate(final_nodes)}

# ===== 8. 得到最终 embedding =====
final_x = sub_x[final_nodes]
print("最终节点 embedding:")
print(final_x)

# ===== 9. 只保留最终节点间的边，重编号 =====
mask = torch.isin(sub_edge_index[0], torch.tensor(final_nodes)) & \
       torch.isin(sub_edge_index[1], torch.tensor(final_nodes))
final_edge_index = sub_edge_index[:, mask]
final_edge_index = torch.stack([
    torch.tensor([node_id_map[int(i)] for i in final_edge_index[0]]),
    torch.tensor([node_id_map[int(i)] for i in final_edge_index[1]])
], dim=0)
print("最终子图边:", final_edge_index.tolist())

src_new = node_id_map[sub_src]
dst_new = node_id_map[sub_dst]
print(src_new, dst_new)

# ===== 10. 输出最终 Data 对象（如需要可直接传入GNN）=====
final_data = Data(x=final_x, edge_index=final_edge_index)
print(final_data)


Data(x=[6, 4], edge_index=[2, 12])
原图节点编号: [0, 1, 2, 3, 4, 5]
子图内节点对应原图编号: [0, 1, 2, 3, 4]
src, dst 在子图的编号: [2, 3]
候选节点（子图编号）: [0, 1, 4]
候选节点分数: [2.4757287918988213, 1.393607764275862, 0.2930366013791085]
Top-k选中节点（子图编号）: [0, 1]
最终节点（子图编号）: [0, 1, 2, 3]
最终节点 embedding:
tensor([[ 1.9269,  1.4873,  0.9007, -2.1055],
        [ 0.6784, -1.2345, -0.0431, -1.6047],
        [ 0.3559, -0.6866, -0.4934,  0.2415],
        [-1.1109,  0.0915, -2.3169, -0.2168]])
最终子图边: [[0, 0, 1, 1, 2, 2, 2, 3], [1, 2, 0, 2, 0, 1, 3, 2]]
2 3
Data(x=[4, 4], edge_index=[2, 8])
