In [9]:
import numpy as np
import heapq
import torch
from torch_geometric.data import Data
from torch_geometric.utils import k_hop_subgraph
from model.score_gnn import ScoreGNN, DotProductPredictor, HadamardMLPPredictor, ConcatMLPPredictor

seed = 2025
torch.manual_seed(seed)
np.random.seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [27]:
class SubgraphSampler:
    def __init__(self, model, predictor, k_top=20, num_hops = 5, device = device):
        self.model = model.eval()
        self.predictor = predictor.eval()
        self.k_top = k_top
        self.num_hops = num_hops
        self.device = device
    
    def process(self):
        train_data = torch.load('./data/Cora/split/train_data.pt')
        val_data = torch.load('./data/Cora/split/val_data.pt')
        test_data = torch.load('./data/Cora/split/test_data.pt')

        train_data = train_data.to(self.device)
        val_data = val_data.to(self.device)
        test_data = test_data.to(self.device)

        train_pos_data_list = self.sample_all_edges(
        train_data, train_data.pos_edge_label_index, 1)
        train_neg_data_list = self.sample_all_edges(
        train_data, train_data.neg_edge_label_index, 0)

        val_pos_data_list = self.sample_all_edges(
        val_data, val_data.pos_edge_label_index, 1)
        val_neg_data_list = self.sample_all_edges(
        val_data, val_data.neg_edge_label_index, 0)

        test_pos_data_list = self.sample_all_edges(
        test_data, test_data.pos_edge_label_index, 1)
        test_neg_data_list = self.sample_all_edges(
        test_data, test_data.neg_edge_label_index, 0)

        train_data_list = train_pos_data_list + train_neg_data_list
        val_data_list = val_pos_data_list + val_neg_data_list
        test_data_list = test_pos_data_list + test_neg_data_list

        torch.save(train_data_list, './data/Cora/split/ssseal_train_data.pt')
        torch.save(val_data_list, './data/Cora/split/ssseal_val_data.pt')
        torch.save(test_data_list, './data/Cora/split/ssseal_test_data.pt')
        print("All processed data have been saved.")
    
    def sample_all_edges(self, data, edge_label_index, y):
        data_list = []
        for src, dst in edge_label_index.t().tolist():
            data_list.append(self.sample_subgraph(src, dst, data, y))
        return data_list
    
    def sample_subgraph(self, src, dst, data, y):
        # 采k-hop子图，得到子图节点的新编号、子图内边、mapping
        sub_node_index, sub_edge_index, mapping, _ = k_hop_subgraph(
            [src, dst], self.num_hops, data.edge_index, relabel_nodes=True)
        sub_edge_index = sub_edge_index.to(self.device)

        #子图全部节点初始特征向量(sub.num_of_node, data.x.size(1))
        sub_x = data.x[sub_node_index].to(self.device)
        sub_src, sub_dst = mapping.tolist()
        
        #构建子图的data
        sub_data = Data(x = sub_x,edge_index = sub_edge_index).to(self.device)
        #获取子图的所有节点分数字典（不包含src和dst）
        scores_dist, sub_node_emb = self.get_subgraph_scores(sub_src, sub_dst, sub_data)
        # 分数从高到低取前top_k
        topk_neighbors = heapq.nlargest(self.k_top, scores_dist, key=scores_dist.get)

        # 源点和目标点在子图的编号
        final_nodes = [sub_src, sub_dst] + topk_neighbors
        final_nodes = list(set(final_nodes))  # 防止重复
        final_nodes.sort()  # 方便后面重新映射

        # 旧编号到新编号的映射
        node_id_map = {old: new for new, old in enumerate(final_nodes)}

        # 新的x
        final_x = sub_node_emb[final_nodes]

        # mask边：只保留两个端点都在final_nodes内的边
        final_nodes_tensor = torch.tensor(final_nodes, device=self.device)
        mask = torch.isin(sub_edge_index[0], final_nodes_tensor) & \
            torch.isin(sub_edge_index[1], final_nodes_tensor)
        final_edge_index = sub_edge_index[:, mask]

        # 重新编号edge_index
        final_edge_index = torch.stack([
            torch.tensor([node_id_map[int(i)] for i in final_edge_index[0].tolist()], device=self.device),
            torch.tensor([node_id_map[int(i)] for i in final_edge_index[1].tolist()], device=self.device)
        ], dim=0)

        #去除 src-dst 之间的边（无向图记得两个方向都删！）
        src_new = node_id_map[sub_src]
        dst_new = node_id_map[sub_dst]
        mask1 = (final_edge_index[0] != src_new) | (final_edge_index[1] != dst_new)
        mask2 = (final_edge_index[0] != dst_new) | (final_edge_index[1] != src_new)
        mask = mask1 & mask2
        final_edge_index = final_edge_index[:, mask]

        final_sub_data = Data(x=final_x, edge_index=final_edge_index, y = y)
        return final_sub_data
    
    def get_subgraph_scores(self, src, dst, data):
        with torch.no_grad():
            node_emb =self.model(data.x, data.edge_index)

            #构建所有src和dst分别到子图所有节点的组合(不包含互相)
            candidates = [i for i in range(data.num_nodes) if i != src and i != dst]
            src_1 = torch.tensor([src] * len(candidates), dtype=torch.long)
            dst_1 = torch.tensor(candidates, dtype=torch.long)
            src_2 = torch.tensor([dst] * len(candidates), dtype=torch.long)
            dst_2 = torch.tensor(candidates, dtype=torch.long)
            edge_label_index_1 = torch.stack([src_1, dst_1], dim=0)
            edge_label_index_2 = torch.stack([src_2, dst_2], dim=0)

            scores_1 = self.predictor(node_emb, edge_label_index_1)
            scores_2 = self.predictor(node_emb, edge_label_index_2)
            scores = (scores_1 + scores_2) / 2
            scores_dist = {i: float(score) for i, score in zip(candidates, scores)}
        return scores_dist, node_emb

In [28]:
args = {
    'hidden_dim': 256,
    'output_dim': 128,
    'num_layers': 3,
    'dropout': 0.5,
    'lr': 0.01,
    'epochs': 200,
    'predictor': HadamardMLPPredictor(input_dim=128).to(device), 
}

In [29]:
train_data = torch.load('./data/Cora/split/train_data.pt')
model = ScoreGNN(train_data.num_features, args['hidden_dim'], args['output_dim'], args['num_layers'], args['dropout']).to(device)
predictor = args['predictor']

# 加载参数（假设你的文件结构是这样保存的）
checkpoint = torch.load('./model/scoregnn.pth', map_location=device)
model.load_state_dict(checkpoint['model'])
predictor.load_state_dict(checkpoint['predictor'])

model.eval()
predictor.eval()
print("Model and predictor loaded successfully.")

Model and predictor loaded successfully.


In [30]:
sampler = SubgraphSampler(model = model, predictor = predictor)
sampler.process()

All processed data have been saved.
