In [1]:
import os, time
import numpy as np
import gc, random
import torch
import networkx as nx
import logging
from tqdm import tqdm
import glob
import psutil
import out_manager as om
import torch.nn.functional as F
from itertools import chain
from config import Config
from torch_geometric.data import Data
from torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix
from model.score_gnn import scoregnn_dict
from scipy.sparse.csgraph import shortest_path

In [2]:
config = Config()
ModelClass = scoregnn_dict[config.scoregnn.gnn_type]
out_dir = om.get_existing_out_dir(config)
om.setup_logging(os.path.join(out_dir, "sample_log.txt"))
seed = config.seed
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
device = config.device

In [3]:
class SubgraphBatchSampler:
    def __init__(self, model, predictor, k_min, num_hops, save_dir, alpha = 40, beta = 20, gamma = 2, device = device):
        self.model = model.eval()
        self.predictor = predictor.eval()
        self.k_min = k_min
        self.num_hops = num_hops
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.save_dir = save_dir
        self.device = device
        super().__init__()

        # 注册所有可选打分函数
        self.score_fn_dict = {
            "gnn": self.get_subgraph_scores_gnn,
            "pagerank": self.get_subgraph_scores_pagerank,
            "adamic-adar": self.get_subgraph_scores_adamicadar,
        }

    # def get_max_z(self, data, edge_label_index, y, batch_size=1000):
    #     random.seed(2025)
    #     num_samples = edge_label_index.size(1)
    #     # 使用tqdm添加进度条
    #     for i in tqdm(range(0, num_samples, batch_size), desc="扫描 max_z", unit="batch"):
    #         batch_idx = edge_label_index[:, i:i+batch_size]
    #         batch_data_list = self.sample_all_edges(data, batch_idx, y)
    #         for batch_data in batch_data_list:
    #             zmax = batch_data.z.max().item()
    #             if zmax > self._max_z:
    #                 print(f"⚠️ 更新 _max_z: {self._max_z} -> {zmax} (batch {i//batch_size}, y={y})")
    #                 self._max_z = zmax
    #         del batch_data_list
    #         gc.collect()


    def get_max_z(self, data, edge_label_index, y, batch_size=1000):
        num_samples = edge_label_index.size(1)
        # 使用tqdm添加进度条
        for i in tqdm(range(0, num_samples, batch_size), desc="扫描 max_z", unit="batch"):
            batch_idx = edge_label_index[:, i:i+batch_size]
            batch_data_list = self.sample_all_edges(data, batch_idx, y)
            del batch_data_list
            gc.collect()
    
    def save_batches(self, data, edge_label_index, y, out_prefix, max_z, batch_size=100):
        random.seed(2025)
        os.makedirs(os.path.dirname(out_prefix), exist_ok=True)
        num_samples = edge_label_index.size(1)
        idx = 0
        # 使用tqdm添加进度条
        for i in tqdm(range(0, num_samples, batch_size), desc=f"保存 {out_prefix} 分批文件", unit="batch"):
            batch_idx = edge_label_index[:, i:i+batch_size]
            batch_data_list = self.sample_all_edges(data, batch_idx, y)
            for batch_data in batch_data_list:
                batch_data.x = F.one_hot(batch_data.z, max_z + 1).to(torch.float)
                torch.save(batch_data_list, f"{out_prefix}_batch{idx}.pt")
            del batch_data_list
            gc.collect()
            idx += 1

    def merge_batches(self, batch_prefix, out_file):
        batch_files = sorted(glob.glob(f"{batch_prefix}_batch*.pt"),
                            key=lambda x: int(x.split('_batch')[-1].split('.pt')[0]))
        all_data = []
        for batch_file in batch_files:
            data_list = torch.load(batch_file, map_location='cpu')  # 👈 强制放到 CPU
            all_data.extend(data_list)
            print(f"合并了 {batch_file}，当前总量：{len(all_data)}")
            del data_list
            gc.collect()
            # CPU 上不用显存释放了
        torch.save(all_data, out_file)
        print(f"保存到 {out_file}，总计 {len(all_data)} 条数据")
        del all_data
        gc.collect()

    def merge_pos_neg(self, pos_file, neg_file, out_file):
        pos_data = torch.load(pos_file, map_location='cpu')  # 👈 放在 CPU
        neg_data = torch.load(neg_file, map_location='cpu')  # 👈 放在 CPU
        all_data = pos_data + neg_data
        torch.save(all_data, out_file)
        print(f"最终合并 {out_file}，总计 {len(all_data)} 条（正例 {len(pos_data)}，负例 {len(neg_data)}）")
        del pos_data, neg_data, all_data
        gc.collect()

    
    def process(self):

        seed = 2025
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        train_data = torch.load(f'./data/{config.dataset}/split/train_data.pt')
        val_data = torch.load(f'./data/{config.dataset}/split/val_data.pt')
        test_data = torch.load(f'./data/{config.dataset}/split/test_data.pt')

        train_data = train_data.to(self.device)
        val_data = val_data.to(self.device)
        test_data = test_data.to(self.device)

        #第一次扫描统计maxz
        self._max_z = 0
        self.get_max_z(train_data, train_data.pos_edge_label_index, 1),
        print(self._max_z)
        self.get_max_z(train_data, train_data.neg_edge_label_index, 0)
        print(self._max_z)
        self.get_max_z(val_data, val_data.pos_edge_label_index, 1),
        print(self._max_z)
        self.get_max_z(val_data, val_data.neg_edge_label_index, 0)
        self.get_max_z(test_data, test_data.pos_edge_label_index, 1),
        self.get_max_z(test_data, test_data.neg_edge_label_index, 0)
        print(self._max_z)

        # 2. 分批次one-hot和保存，绝不汇总到内存
        save_dir = self.save_dir
        train_pos_path = os.path.join(save_dir, "SSSEAL_train_pos")
        train_neg_path = os.path.join(save_dir, "SSSEAL_train_neg")
        val_pos_path = os.path.join(save_dir, "SSSEAL_val_pos")
        val_neg_path = os.path.join(save_dir, "SSSEAL_val_neg")
        test_pos_path = os.path.join(save_dir, "SSSEAL_test_pos")
        test_neg_path = os.path.join(save_dir, "SSSEAL_test_neg")

        print("保存 train 分批文件")
        self.save_batches(train_data, train_data.pos_edge_label_index, 1, train_pos_path, self._max_z)
        self.save_batches(train_data, train_data.neg_edge_label_index, 0, train_neg_path, self._max_z)

        print("保存 val 分批文件")
        self.save_batches(val_data, val_data.pos_edge_label_index, 1, val_pos_path, self._max_z)
        self.save_batches(val_data, val_data.neg_edge_label_index, 0, val_neg_path, self._max_z)

        print("保存 test 分批文件")
        self.save_batches(test_data, test_data.pos_edge_label_index, 1, test_pos_path, self._max_z)
        self.save_batches(test_data, test_data.neg_edge_label_index, 0, test_neg_path, self._max_z)

        print("所有分批处理和保存已完成！🚀")

        del train_data, val_data, test_data
        gc.collect()
    
    def cancat_pos(self):
        split_dir = self.save_dir
        for prefix in ["train", "val", "test"]:
            pos_prefix = os.path.join(split_dir, f"SSSEAL_{prefix}_pos")
            pos_data_list = pos_prefix + "_data_list.pt"
            merged_data_list = pos_data_list  # 直接用 pos_data_list 作为结果

            print(f"\n--- 合并 {prefix} pos batch ---")
            self.merge_batches(pos_prefix, merged_data_list)
            gc.collect()
            torch.cuda.empty_cache()

    def cancat_neg(self):
        split_dir = self.save_dir
        for prefix in ["train", "val", "test"]:
            neg_prefix = os.path.join(split_dir, f"SSSEAL_{prefix}_neg")
            neg_data_list = neg_prefix + "_data_list.pt"
            merged_data_list = neg_data_list  # 直接用 neg_data_list 作为结果

            print(f"\n--- 合并 {prefix} neg batch ---")
            self.merge_batches(neg_prefix, merged_data_list)
            gc.collect()
            torch.cuda.empty_cache()

    def cancat_pos_neg(self):
        split_dir = self.save_dir
        for prefix in ["train", "val", "test"]:
            pos_data_list = os.path.join(split_dir, f"SSSEAL_{prefix}_pos_data_list.pt")
            neg_data_list = os.path.join(split_dir, f"SSSEAL_{prefix}_neg_data_list.pt")
            merged_data_list = os.path.join(f"./data/{config.dataset}/split/ssseal_{prefix}_data_k{self.k_min}_h{self.num_hops}_{config.version}.pt")

            print(f"\n--- 合并 {prefix} pos+neg 为总 data_list ---")
            self.merge_pos_neg(pos_data_list, neg_data_list, merged_data_list)
            gc.collect()
            torch.cuda.empty_cache()

        print("所有 pos+neg 合并已完成！🚀")

    def cancat(self):
        split_dir = self.save_dir
        for prefix in ["train", "val", "test"]:
            pos_prefix = os.path.join(split_dir, f"SSSEAL_{prefix}_pos")
            neg_prefix = os.path.join(split_dir, f"SSSEAL_{prefix}_neg")
            pos_data_list = pos_prefix + "_data_list.pt"
            neg_data_list = neg_prefix + "_data_list.pt"
            merged_data_list = os.path.join(split_dir, f"ssseal_{prefix}_data_k{self.k_min}_h{self.num_hops}_{config.version}.pt")

            print(f"\n--- 合并 {prefix} pos batch ---")
            self.merge_batches(pos_prefix, pos_data_list)
            gc.collect()
            torch.cuda.empty_cache()  # 👈 pos 合并后清理显存

            print(f"--- 合并 {prefix} neg batch ---")
            self.merge_batches(neg_prefix, neg_data_list)
            gc.collect()
            torch.cuda.empty_cache()  # 👈 neg 合并后清理显存

            print(f"--- 合并 {prefix} pos+neg 为总 data_list ---")
            self.merge_pos_neg(pos_data_list, neg_data_list, merged_data_list)
            gc.collect()
            torch.cuda.empty_cache()  # 👈 pos+neg 合并完再清理

        print("所有分批处理、合并已完成！🚀")

        # --------- 自动删除所有 batch 文件 ----------
        pattern = os.path.join(self.save_dir, "SSSEAL_*_batch*.pt")
        batch_files = glob.glob(pattern)
        for file in batch_files:
            try:
                os.remove(file)
                print(f"已删除 {file}")
            except Exception as e:
                print(f"删除 {file} 失败：{e}")
        
        target_files = [
            "SSSEAL_test_neg_data_list.pt",
            "SSSEAL_test_pos_data_list.pt",
            "SSSEAL_val_neg_data_list.pt",
            "SSSEAL_val_pos_data_list.pt",
            "SSSEAL_train_neg_data_list.pt",
            "SSSEAL_train_pos_data_list.pt"
        ]

        for filename in target_files:
            file_path = os.path.join(self.save_dir, filename)
            if os.path.exists(file_path):
                try:
                    os.remove(file_path)
                    print(f"已删除 {file_path}")
                except Exception as e:
                    print(f"删除 {file_path} 失败：{e}")

        print('所有数据已保存并清理临时 batch 文件')

    
    def sample_all_edges(self, data, edge_label_index, y):
        data_list = []
        for src, dst in edge_label_index.t().tolist():
            data_list.append(self.sample_subgraph(src, dst, data, y))
        return data_list
    
    def sample_all_edges_in_batches(self, data, edge_label_index, y, batch_size=256):
        data_list = []
        num_edges = edge_label_index.size(1)
        for start in range(0, num_edges, batch_size):
            end = min(start + batch_size, num_edges)
            batch_edges = edge_label_index[:, start:end]
            for src, dst in batch_edges.t().tolist():
                subgraph = self.sample_subgraph(src, dst, data, y)
                data_list.append(subgraph.cpu())
                del subgraph
            gc.collect() 
            torch.cuda.empty_cache()  # 清理缓存
        return data_list

    
    def sample_subgraph(self, src, dst, data, y):
         
        # # 采k-hop子图，得到子图节点的新编号、子图内边、mapping
        sub_node_index, sub_edge_index, mapping, _ = k_hop_subgraph(
            [src, dst], self.num_hops, data.edge_index, relabel_nodes=True)
        
        if len(sub_node_index) <= self.k_min:
            sub_node_index, sub_edge_index, mapping, _ = k_hop_subgraph(
            [src, dst], self.num_hops + self.gamma, data.edge_index, relabel_nodes=True)

            sub_x = data.x[sub_node_index]
            sub_src, sub_dst = mapping.tolist()

            final_nodes = list(range(len(sub_x)))

        else:
            #子图全部节点初始特征向量(sub.num_of_node, data.x.size(1))
            # sub_x = data.x[sub_node_index].to(self.device)
            sub_edge_index = sub_edge_index.to(self.device)
            sub_x = data.x[sub_node_index]
            sub_src, sub_dst = mapping.tolist()
            
            #构建子图的data
            sub_data = Data(x = sub_x,edge_index = sub_edge_index).to(self.device)
            #获取子图的所有节点分数字典（不包含src和dst）
            # scores_dist, sub_node_emb = self.get_subgraph_scores(sub_src, sub_dst, sub_data)
            
            # ====== 改进点：用torch.topk替代heapq.nlargest，速度更快 =======

            # # 分数从高到低取前top_k
            # scores_dist = self.get_subgraph_scores(sub_src, sub_dst, sub_data)
            # topk_neighbors = heapq.nlargest(self.k_min, scores_dist, key=scores_dist.get)

            candidates_tensor, scores = self.get_subgraph_scores(sub_src, sub_dst, sub_data)

            k = max(1, int(len(scores) * self.alpha // 100))
            _, topk_indices = torch.topk(scores, min(k, scores.size(0)))
            topk_neighbors = candidates_tensor[topk_indices].tolist()
            # ===========================================================
            # 从剩下的候选节点中随机选择20%的节点
            remaining_candidates = [i for i in candidates_tensor.tolist() if i not in topk_neighbors]
            num_random_select = min(int(len(scores) * self.beta // 100), len(remaining_candidates))
            random_neighbors = random.sample(remaining_candidates, num_random_select)

            # 合并前40%和随机选择的节点，得到最终的topk_neighbors
            final_neighbors = topk_neighbors + random_neighbors
            # 源点和目标点在子图的编号
            final_nodes = [sub_src, sub_dst] + final_neighbors
        
        final_nodes = list(set(final_nodes))  # 防止重复
        final_nodes.sort()  # 方便后面重新映射

        # 旧编号到新编号的映射
        node_id_map = {old: new for new, old in enumerate(final_nodes)}

        # 新的x
        final_x = sub_x[final_nodes]

        # mask边：只保留两个端点都在final_nodes内的边
        final_nodes_tensor = torch.tensor(final_nodes, device=self.device)
        mask = torch.isin(sub_edge_index[0], final_nodes_tensor) & \
            torch.isin(sub_edge_index[1], final_nodes_tensor)
        final_edge_index = sub_edge_index[:, mask]

        # 重新编号edge_index
        final_edge_index = torch.stack([
            torch.tensor([node_id_map[int(i)] for i in final_edge_index[0].tolist()], device=self.device),
            torch.tensor([node_id_map[int(i)] for i in final_edge_index[1].tolist()], device=self.device)
        ], dim=0)

        #去除 src-dst 之间的边（无向图记得两个方向都删！）
        src_new = node_id_map[sub_src]
        dst_new = node_id_map[sub_dst]
        mask1 = (final_edge_index[0] != src_new) | (final_edge_index[1] != dst_new)
        mask2 = (final_edge_index[0] != dst_new) | (final_edge_index[1] != src_new)
        mask = mask1 & mask2
        final_edge_index = final_edge_index[:, mask]

        z = self.drnl_node_labeling(final_edge_index, src_new, dst_new, num_nodes = len(final_nodes))

        final_sub_data = Data(x = final_x, z = z, edge_index = final_edge_index, y = y)
        final_sub_data = final_sub_data.to(next(self.model.parameters()).device)
        return final_sub_data
    
    def get_subgraph_scores(self, src, dst, data):
        fn = self.score_fn_dict.get(config.scoresampler.score_fn, self.get_subgraph_scores_gnn)  # 默认GNN
        return fn(src, dst, data)
    
    def get_subgraph_scores_gnn(self, src, dst, data):
        with torch.no_grad():
            device = next(self.model.parameters()).device
            data = data.to(device)
            node_emb =self.model(data.x, data.edge_index)

            #构建所有src和dst分别到子图所有节点的组合(不包含互相)
            candidates = [i for i in range(data.num_nodes) if i != src and i != dst]
            candidates_tensor = torch.tensor(candidates, device=self.device, dtype=torch.long)

            src_1 = torch.tensor([src] * len(candidates), dtype=torch.long, device=self.device)
            dst_1 = torch.tensor(candidates, dtype=torch.long, device=self.device)
            src_2 = torch.tensor([dst] * len(candidates), dtype=torch.long, device=self.device)
            dst_2 = torch.tensor(candidates, dtype=torch.long, device=self.device)
            
            edge_label_index_1 = torch.stack([src_1, dst_1], dim=0)
            edge_label_index_2 = torch.stack([src_2, dst_2], dim=0)

            scores_1 = self.predictor(node_emb, edge_label_index_1)
            scores_2 = self.predictor(node_emb, edge_label_index_2)
            scores = (scores_1 + scores_2) / 2
            # scores_dist = {i: float(score) for i, score in zip(candidates, scores)}
        return candidates_tensor, scores
    
    def get_subgraph_scores_adamicadar(self, src, dst, data):
        edge_index = data.edge_index.cpu().numpy()
        G = nx.Graph()
        G.add_edges_from(edge_index.T.tolist())
        G.add_nodes_from(range(data.num_nodes))  # 保证节点都在

        # 只考虑有边的节点作为候选
        candidates = [i for i in range(data.num_nodes) if i != src and i != dst and G.degree(i) > 0]
        candidates_tensor = torch.tensor(candidates,device=self.device, dtype=torch.long)

        # 如果src或dst本身也是孤立节点，也跳过/直接返回空
        if G.degree(src) == 0 or G.degree(dst) == 0:
            return candidates_tensor, torch.zeros_like(candidates_tensor, dtype=torch.float)

        aa_src = {(u, v): s for u, v, s in nx.adamic_adar_index(G, [(src, i) for i in candidates])}
        aa_dst = {(u, v): s for u, v, s in nx.adamic_adar_index(G, [(dst, i) for i in candidates])}

        scores = []
        for i in candidates:
            s1 = aa_src.get((src, i), 0.0)
            s2 = aa_dst.get((dst, i), 0.0)
            s = (s1 + s2) / 2
            scores.append(s)

        scores = torch.tensor(scores, device=self.device, dtype=torch.float)
        return candidates_tensor, scores
        
    def get_subgraph_scores_pagerank(self, src, dst, data):
        # 1. edge_index转成networkx图，节点编号是局部编号
        edge_index = data.edge_index.cpu().numpy()
        G = nx.Graph()
        G.add_edges_from(edge_index.T.tolist())
        G.add_nodes_from(range(data.num_nodes))  # 确保所有节点都在G中

        # 2. 只考虑有边的节点
        candidates = [i for i in range(data.num_nodes) if i != src and i != dst and G.degree(i) > 0]
        candidates_tensor = torch.tensor(candidates, device=self.device, dtype=torch.long)

        # 如果src或dst本身是孤立节点，直接返回零分
        if G.degree(src) == 0 or G.degree(dst) == 0:
            return candidates_tensor, torch.zeros(len(candidates), device=self.device, dtype=torch.float)

        # 3. Personalized PageRank（以src和dst为个性化起点，各算一次）
        personalization_src = {n: 0 for n in G.nodes}
        personalization_src[src] = 1
        pr_src = nx.pagerank(G, personalization=personalization_src)

        personalization_dst = {n: 0 for n in G.nodes}
        personalization_dst[dst] = 1
        pr_dst = nx.pagerank(G, personalization=personalization_dst)

        # 4. 对每个候选节点，分别查src和dst个性化pagerank的分数，做平均
        scores = []
        for i in candidates:
            s = (pr_src.get(i, 0.0) + pr_dst.get(i, 0.0)) / 2
            scores.append(s)

        # 转成torch张量
        scores = torch.tensor(scores, device=self.device, dtype=torch.float)
        return candidates_tensor, scores

    def drnl_node_labeling(self, edge_index, src, dst, num_nodes=None):
        # Double-radius node labeling (DRNL).
        src, dst = (dst, src) if src > dst else (src, dst)
        adj = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr()

        idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
        adj_wo_src = adj[idx, :][:, idx]

        idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
        adj_wo_dst = adj[idx, :][:, idx]

        dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True,
                                 indices=src)
        dist2src = np.insert(dist2src, dst, 0, axis=0)
        dist2src = torch.from_numpy(dist2src)

        dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True,
                                 indices=dst - 1)
        dist2dst = np.insert(dist2dst, src, 0, axis=0)
        dist2dst = torch.from_numpy(dist2dst)

        dist = dist2src + dist2dst
        dist_over_2, dist_mod_2 = dist // 2, dist % 2

        z = 1 + torch.min(dist2src, dist2dst)
        z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
        z[src] = 1.
        z[dst] = 1.
        z[torch.isnan(z)] = 0.

        self._max_z = max(int(z.max()), self._max_z)

        return z.to(torch.long)

In [4]:
class SubgraphSampler:
    def __init__(self, model, predictor, k_min, num_hops, alpha = 40, beta = 20, gamma = 2, device = device):
        self.model = model.eval()
        self.predictor = predictor.eval()
        self.k_min = k_min
        self.num_hops = num_hops
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.device = device

        # 注册所有可选打分函数
        self.score_fn_dict = {
            "gnn": self.get_subgraph_scores_gnn,
            "pagerank": self.get_subgraph_scores_pagerank,
            "adamic-adar": self.get_subgraph_scores_adamicadar,
        }

    
    def process(self):
        train_data = torch.load(f'./data/{config.dataset}/split/train_data.pt')
        val_data = torch.load(f'./data/{config.dataset}/split/val_data.pt')
        test_data = torch.load(f'./data/{config.dataset}/split/test_data.pt')

        train_data = train_data.to(self.device)
        val_data = val_data.to(self.device)
        test_data = test_data.to(self.device)

        self._max_z = 0

        train_pos_data_list = self.sample_all_edges_in_batches(
        train_data, train_data.pos_edge_label_index, 1)
        train_neg_data_list = self.sample_all_edges_in_batches(
        train_data, train_data.neg_edge_label_index, 0)

        val_pos_data_list = self.sample_all_edges_in_batches(
        val_data, val_data.pos_edge_label_index, 1)
        val_neg_data_list = self.sample_all_edges_in_batches(
        val_data, val_data.neg_edge_label_index, 0)

        test_pos_data_list = self.sample_all_edges_in_batches(
        test_data, test_data.pos_edge_label_index, 1)
        test_neg_data_list = self.sample_all_edges_in_batches(
        test_data, test_data.neg_edge_label_index, 0)

        for data in chain(train_pos_data_list, train_neg_data_list,
                          val_pos_data_list, val_neg_data_list,
                          test_pos_data_list, test_neg_data_list):
            # We solely learn links from structure, dropping any node features:
            data.x = F.one_hot(data.z, self._max_z + 1).to(torch.float)

        train_data_list = train_pos_data_list + train_neg_data_list
        val_data_list = val_pos_data_list + val_neg_data_list
        test_data_list = test_pos_data_list + test_neg_data_list

        torch.save(train_data_list, f'./data/{config.dataset}/split/ssseal_train_data_k{self.k_min}_h{self.num_hops}_{config.version}.pt')
        torch.save(val_data_list, f'./data/{config.dataset}/split/ssseal_val_data_k{self.k_min}_h{self.num_hops}_{config.version}.pt')
        torch.save(test_data_list, f'./data/{config.dataset}/split/ssseal_test_data_k{self.k_min}_h{self.num_hops}_{config.version}.pt')
        print("All processed data have been saved.")
    
    def sample_all_edges(self, data, edge_label_index, y):
        data_list = []
        for src, dst in edge_label_index.t().tolist():
            data_list.append(self.sample_subgraph(src, dst, data, y))
        return data_list
    
    def sample_all_edges_in_batches(self, data, edge_label_index, y, batch_size=256):
        data_list = []
        num_edges = edge_label_index.size(1)
        for start in range(0, num_edges, batch_size):
            end = min(start + batch_size, num_edges)
            batch_edges = edge_label_index[:, start:end]
            for src, dst in batch_edges.t().tolist():
                subgraph = self.sample_subgraph(src, dst, data, y)
                data_list.append(subgraph.cpu())
                del subgraph
            gc.collect() 
            torch.cuda.empty_cache()  # 清理缓存
        return data_list

    
    def sample_subgraph(self, src, dst, data, y):
        
        # # 采k-hop子图，得到子图节点的新编号、子图内边、mapping
        sub_node_index, sub_edge_index, mapping, _ = k_hop_subgraph(
            [src, dst], self.num_hops, data.edge_index, relabel_nodes=True)
        
        if len(sub_node_index) <= self.k_min:
            sub_node_index, sub_edge_index, mapping, _ = k_hop_subgraph(
            [src, dst], self.num_hops + self.gamma, data.edge_index, relabel_nodes=True)

            sub_x = data.x[sub_node_index]
            sub_src, sub_dst = mapping.tolist()

            final_nodes = list(range(len(sub_x)))

        else:
            sub_edge_index = sub_edge_index.to(self.device)
            #子图全部节点初始特征向量(sub.num_of_node, data.x.size(1))
            # sub_x = data.x[sub_node_index].to(self.device)
            sub_x = data.x[sub_node_index]
            sub_src, sub_dst = mapping.tolist()
            
            #构建子图的data
            sub_data = Data(x = sub_x,edge_index = sub_edge_index).to(self.device)
            #获取子图的所有节点分数字典（不包含src和dst）
            # scores_dist, sub_node_emb = self.get_subgraph_scores(sub_src, sub_dst, sub_data)
            
            # ====== 改进点：用torch.topk替代heapq.nlargest，速度更快 =======

            # # 分数从高到低取前top_k
            # scores_dist = self.get_subgraph_scores(sub_src, sub_dst, sub_data)
            # topk_neighbors = heapq.nlargest(self.k_min, scores_dist, key=scores_dist.get)

            candidates_tensor, scores = self.get_subgraph_scores(sub_src, sub_dst, sub_data)

            k = max(1, int(len(scores) * self.alpha // 100))
            _, topk_indices = torch.topk(scores, min(k, scores.size(0)))
            topk_neighbors = candidates_tensor[topk_indices].tolist()
            # ===========================================================
            # 从剩下的候选节点中随机选择20%的节点
            remaining_candidates = [i for i in candidates_tensor.tolist() if i not in topk_neighbors]
            num_random_select = min(int(len(scores) * self.beta // 100), len(remaining_candidates))
            random_neighbors = random.sample(remaining_candidates, num_random_select)

            # 合并前40%和随机选择的节点，得到最终的topk_neighbors
            final_neighbors = topk_neighbors + random_neighbors
            # 源点和目标点在子图的编号
            final_nodes = [sub_src, sub_dst] + final_neighbors
        
        final_nodes = list(set(final_nodes))  # 防止重复
        final_nodes.sort()  # 方便后面重新映射

        # 旧编号到新编号的映射
        node_id_map = {old: new for new, old in enumerate(final_nodes)}

        # 新的x
        final_x = sub_x[final_nodes]

        # mask边：只保留两个端点都在final_nodes内的边
        final_nodes_tensor = torch.tensor(final_nodes, device=self.device)
        mask = torch.isin(sub_edge_index[0], final_nodes_tensor) & \
            torch.isin(sub_edge_index[1], final_nodes_tensor)
        final_edge_index = sub_edge_index[:, mask]

        # 重新编号edge_index
        final_edge_index = torch.stack([
            torch.tensor([node_id_map[int(i)] for i in final_edge_index[0].tolist()], device=self.device),
            torch.tensor([node_id_map[int(i)] for i in final_edge_index[1].tolist()], device=self.device)
        ], dim=0)

        #去除 src-dst 之间的边（无向图记得两个方向都删！）
        src_new = node_id_map[sub_src]
        dst_new = node_id_map[sub_dst]
        mask1 = (final_edge_index[0] != src_new) | (final_edge_index[1] != dst_new)
        mask2 = (final_edge_index[0] != dst_new) | (final_edge_index[1] != src_new)
        mask = mask1 & mask2
        final_edge_index = final_edge_index[:, mask]

        z = self.drnl_node_labeling(final_edge_index, src_new, dst_new, num_nodes = len(final_nodes))

        final_sub_data = Data(x = final_x, z = z, edge_index = final_edge_index, y = y)
        return final_sub_data
    
    def get_subgraph_scores(self, src, dst, data):
        fn = self.score_fn_dict.get(config.scoresampler.score_fn, self.get_subgraph_scores_gnn)  # 默认GNN
        return fn(src, dst, data)
    
    def get_subgraph_scores_gnn(self, src, dst, data):
        with torch.no_grad():
            node_emb =self.model(data.x, data.edge_index)

            #构建所有src和dst分别到子图所有节点的组合(不包含互相)
            candidates = [i for i in range(data.num_nodes) if i != src and i != dst]
            candidates_tensor = torch.tensor(candidates, device=self.device, dtype=torch.long)

            src_1 = torch.tensor([src] * len(candidates), dtype=torch.long, device=self.device)
            dst_1 = torch.tensor(candidates, dtype=torch.long, device=self.device)
            src_2 = torch.tensor([dst] * len(candidates), dtype=torch.long, device=self.device)
            dst_2 = torch.tensor(candidates, dtype=torch.long, device=self.device)
            
            edge_label_index_1 = torch.stack([src_1, dst_1], dim=0)
            edge_label_index_2 = torch.stack([src_2, dst_2], dim=0)

            scores_1 = self.predictor(node_emb, edge_label_index_1)
            scores_2 = self.predictor(node_emb, edge_label_index_2)
            scores = (scores_1 + scores_2) / 2
            # scores_dist = {i: float(score) for i, score in zip(candidates, scores)}
        return candidates_tensor, scores
    
    def get_subgraph_scores_adamicadar(self, src, dst, data):
        edge_index = data.edge_index.cpu().numpy()
        G = nx.Graph()
        G.add_edges_from(edge_index.T.tolist())
        G.add_nodes_from(range(data.num_nodes))  # 保证节点都在

        # 只考虑有边的节点作为候选
        candidates = [i for i in range(data.num_nodes) if i != src and i != dst and G.degree(i) > 0]
        candidates_tensor = torch.tensor(candidates, device=self.device, dtype=torch.long)

        # 如果src或dst本身也是孤立节点，也跳过/直接返回空
        if G.degree(src) == 0 or G.degree(dst) == 0:
            return candidates_tensor, torch.zeros_like(candidates_tensor, dtype=torch.float, device=self.device)

        aa_src = {(u, v): s for u, v, s in nx.adamic_adar_index(G, [(src, i) for i in candidates])}
        aa_dst = {(u, v): s for u, v, s in nx.adamic_adar_index(G, [(dst, i) for i in candidates])}

        scores = []
        for i in candidates:
            s1 = aa_src.get((src, i), 0.0)
            s2 = aa_dst.get((dst, i), 0.0)
            s = (s1 + s2) / 2
            scores.append(s)

        scores = torch.tensor(scores, device=self.device, dtype=torch.float)
        return candidates_tensor, scores
        
    def get_subgraph_scores_pagerank(self, src, dst, data):
        # 1. edge_index转成networkx图，节点编号是局部编号
        edge_index = data.edge_index.cpu().numpy()
        G = nx.Graph()
        G.add_edges_from(edge_index.T.tolist())
        G.add_nodes_from(range(data.num_nodes))  # 确保所有节点都在G中

        # 2. 只考虑有边的节点
        candidates = [i for i in range(data.num_nodes) if i != src and i != dst and G.degree(i) > 0]
        candidates_tensor = torch.tensor(candidates, device=self.device, dtype=torch.long)

        # 如果src或dst本身是孤立节点，直接返回零分
        if G.degree(src) == 0 or G.degree(dst) == 0:
            return candidates_tensor, torch.zeros(len(candidates), device=self.device, dtype=torch.float)

        # 3. Personalized PageRank（以src和dst为个性化起点，各算一次）
        personalization_src = {n: 0 for n in G.nodes}
        personalization_src[src] = 1
        pr_src = nx.pagerank(G, personalization=personalization_src)

        personalization_dst = {n: 0 for n in G.nodes}
        personalization_dst[dst] = 1
        pr_dst = nx.pagerank(G, personalization=personalization_dst)

        # 4. 对每个候选节点，分别查src和dst个性化pagerank的分数，做平均
        scores = []
        for i in candidates:
            s = (pr_src.get(i, 0.0) + pr_dst.get(i, 0.0)) / 2
            scores.append(s)

        # 转成torch张量
        scores = torch.tensor(scores, device=self.device, dtype=torch.float)
        return candidates_tensor, scores

    def drnl_node_labeling(self, edge_index, src, dst, num_nodes=None):
        # Double-radius node labeling (DRNL).
        src, dst = (dst, src) if src > dst else (src, dst)
        adj = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr()

        idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
        adj_wo_src = adj[idx, :][:, idx]

        idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
        adj_wo_dst = adj[idx, :][:, idx]

        dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True,
                                 indices=src)
        dist2src = np.insert(dist2src, dst, 0, axis=0)
        dist2src = torch.from_numpy(dist2src)

        dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True,
                                 indices=dst - 1)
        dist2dst = np.insert(dist2dst, src, 0, axis=0)
        dist2dst = torch.from_numpy(dist2dst)

        dist = dist2src + dist2dst
        dist_over_2, dist_mod_2 = dist // 2, dist % 2

        z = 1 + torch.min(dist2src, dist2dst)
        z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
        z[src] = 1.
        z[dst] = 1.
        z[torch.isnan(z)] = 0.

        self._max_z = max(int(z.max()), self._max_z)

        return z.to(torch.long)

In [5]:
model = ModelClass(config.data_init_num_features, hidden_dim = config.scoregnn.hidden_dim, 
                 output_dim = config.scoregnn.output_dim , num_layers = config.scoregnn.num_layers, 
                 dropout = config.scoregnn.dropout).to(device)
predictor = config.scoregnn.predictor.to(device)

# 加载参数（假设你的文件结构是这样保存的）
checkpoint = torch.load('./model/scoregnn.pth', map_location=device)
model.load_state_dict(checkpoint['model'])
predictor.load_state_dict(checkpoint['predictor'])

model.eval()
predictor.eval()
print("Model and predictor loaded successfully.")

Model and predictor loaded successfully.


In [6]:
# model = ModelClass(config.data_init_num_features, hidden_dim = config.scoregnn.hidden_dim, 
#                  output_dim = config.scoregnn.output_dim , num_layers = config.scoregnn.num_layers, 
#                  dropout = config.scoregnn.dropout).to(device)
# predictor = config.scoregnn.predictor

# # 加载参数（假设你的文件结构是这样保存的）
# checkpoint = torch.load('./model/scoregnn.pth', map_location=device)
# model.load_state_dict(checkpoint['model'])
# predictor.load_state_dict(checkpoint['predictor'])

# model.eval()
# predictor.eval()
# print("Model and predictor loaded successfully.")

In [7]:
start_time = time.time()
sampler = SubgraphBatchSampler(model = model, predictor = predictor, k_min = config.scoresampler.k_min, 
                          num_hops = config.scoresampler.num_hops,save_dir = f'./data/{config.dataset}/split', alpha = config.scoresampler.alpha, 
                          beta = config.scoresampler.beta, gamma = config.scoresampler.gamma)
sampler.process()
end_time = time.time()
logging.info(f'Sample time: {end_time - start_time}')

扫描 max_z: 100%|██████████| 16/16 [19:47<00:00, 74.24s/batch]


72


扫描 max_z: 100%|██████████| 16/16 [14:02<00:00, 52.65s/batch]


90


扫描 max_z: 100%|██████████| 1/1 [01:08<00:00, 68.34s/batch]


90


扫描 max_z: 100%|██████████| 1/1 [00:49<00:00, 49.99s/batch]
扫描 max_z: 100%|██████████| 2/2 [02:33<00:00, 76.74s/batch]
扫描 max_z: 100%|██████████| 2/2 [01:42<00:00, 51.11s/batch]


90
保存 train 分批文件


保存 ./data/Github/split\SSSEAL_train_pos 分批文件: 100%|██████████| 158/158 [1:51:38<00:00, 42.39s/batch]
保存 ./data/Github/split\SSSEAL_train_neg 分批文件:  30%|██▉       | 47/158 [25:41<1:00:41, 32.81s/batch]


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [6]:
sampler = SubgraphBatchSampler(model = model, predictor = predictor, k_min = config.scoresampler.k_min, 
                          num_hops = config.scoresampler.num_hops,save_dir = f'./data/{config.dataset}/split', alpha = config.scoresampler.alpha, 
                          beta = config.scoresampler.beta, gamma = config.scoresampler.gamma)

In [9]:
# sampler.cancat_pos_neg()

In [10]:
# sampler.cancat()

In [None]:
# start_time = time.time()
# sampler = SubgraphSampler(model = model, predictor = predictor, k_min = config.scoresampler.k_min, 
#                           num_hops = config.scoresampler.num_hops, alpha = config.scoresampler.alpha, 
#                           beta = config.scoresampler.beta, gamma = config.scoresampler.gamma)
# sampler.process()
# end_time = time.time()
# logging.info(f'Sample time: {end_time - start_time}')