In [150]:
import torch
from itertools import chain
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data 
from torch_geometric.utils import k_hop_subgraph, negative_sampling
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import Planetoid
import numpy as np
import random
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [151]:
class NeighborScorer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        return self.mlp(x).squeeze(-1)


In [152]:
def drnl_node_labeling(edge_index, src, dst, num_nodes=None):
    from scipy.sparse.csgraph import shortest_path
    from torch_geometric.utils import to_scipy_sparse_matrix

    src, dst = (dst, src) if src > dst else (src, dst)
    adj = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr()

    idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
    adj_wo_src = adj[idx, :][:, idx]

    idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
    adj_wo_dst = adj[idx, :][:, idx]

    dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src)
    dist2src = np.insert(dist2src, dst, 0, axis=0)
    dist2src = torch.from_numpy(dist2src)

    dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst - 1)
    dist2dst = np.insert(dist2dst, src, 0, axis=0)
    dist2dst = torch.from_numpy(dist2dst)

    dist = dist2src + dist2dst
    dist_over_2, dist_mod_2 = dist // 2, dist % 2

    z = 1 + torch.min(dist2src, dist2dst)
    z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
    z[src] = 1.
    z[dst] = 1.
    z[torch.isnan(z)] = 0.
    maxz = 0
    maxz = max(int(z.max()), maxz)
    return z.to(torch.long).to(device), maxz


In [153]:
def dynamic_prune_subgraph(data, y, num_hops, scorer, top_n):
    data_list = []
    maxz_total = 0
    edge_label_index = data.pos_edge_label_index if y == 1 else data.neg_edge_label_index

    for src, dst in edge_label_index.t().tolist():
        node_idx, edge_index, mapping, _ = k_hop_subgraph(
            [src, dst], num_hops, data.edge_index, relabel_nodes=True)
        src_sub, dst_sub = mapping.tolist()
        x_sub = data.x[node_idx.to(data.x.device)]
        all_indices = torch.arange(len(node_idx))
        mask_src_dst = torch.ones(len(node_idx), dtype=torch.bool)
        mask_src_dst[src_sub] = False
        mask_src_dst[dst_sub] = False

        x_sub_wo_srcdst = x_sub[mask_src_dst]
        with torch.no_grad():
            score = scorer(x_sub_wo_srcdst)
        n = min(top_n, x_sub_wo_srcdst.size(0))
        idx_top = score.topk(n).indices
        neighbors_idx = all_indices[mask_src_dst]
        topn_in_nodeidx = neighbors_idx[idx_top.to(neighbors_idx.device)]
        indices_final = torch.cat([torch.tensor([src_sub, dst_sub]), topn_in_nodeidx])
        x_sub_final = x_sub[indices_final]
        node_idx_final = node_idx[indices_final]

        final_set = set(indices_final.tolist())
        edge_mask = [(u in final_set) and (v in final_set) for u, v in edge_index.t().tolist()]
        edge_mask = torch.tensor(edge_mask, dtype=torch.bool)
        edge_index_filtered = edge_index[:, edge_mask]
        old2new = {old.item(): new for new, old in enumerate(indices_final)}
        edge_index_final = edge_index_filtered.clone()
        for i in range(edge_index_filtered.shape[1]):
            edge_index_final[0, i] = old2new[edge_index_filtered[0, i].item()]
            edge_index_final[1, i] = old2new[edge_index_filtered[1, i].item()]

        src_new, dst_new = 0, 1
        mask1 = ~((edge_index_final[0] == src_new) & (edge_index_final[1] == dst_new))
        mask2 = ~((edge_index_final[0] == dst_new) & (edge_index_final[1] == src_new))
        mask = mask1 & mask2

        edge_index_final_no_sd = edge_index_final[:, mask]

        z, maxz = drnl_node_labeling(edge_index_final_no_sd, src_new, dst_new, num_nodes=node_idx_final.size(0))
        maxz_total = max(maxz_total, maxz)
        data_graph = Data(x=x_sub_final, z=z, edge_index=edge_index_final_no_sd, y=y)
        data_list.append(data_graph)
    return data_list, maxz_total

In [154]:
# def dynamic_prune_subgraph(data, y, num_hops, scorer, top_n):
#     data_list = []
#     if y == 1:
#         for src, dst in data.pos_edge_label_index.t().tolist():
#             node_idx, edge_index, mapping, _ = k_hop_subgraph(
#                 [src, dst], num_hops, data.edge_index, relabel_nodes=True)
#             src, dst = mapping.tolist()
#             x_sub = data.x[node_idx]
#             all_indices = torch.arange(len(node_idx))
#             mask_src_dst = torch.ones(len(node_idx), dtype=torch.bool)
#             mask_src_dst[src] = False
#             mask_src_dst[dst] = False
#             x_sub_wo_srcdst = x_sub[mask_src_dst]
#             with torch.no_grad():
#                 score = scorer(x_sub_wo_srcdst)

#             n = min(top_n, x_sub_wo_srcdst.size(0))
#             idx_top = score.topk(n).indices
#             neighbors_idx = all_indices[mask_src_dst]
#             topn_in_nodeidx = neighbors_idx[idx_top]
#             indices_final = torch.cat([torch.tensor([src, dst]), topn_in_nodeidx])
#             x_sub_final = x_sub[indices_final]
#             node_idx_final = node_idx[indices_final]

#             # 3. 保留只与这些节点相关的边
#             # 假设你已经得到 indices_final
#             final_set = set(indices_final.tolist())
#             edge_mask = [(u in final_set) and (v in final_set) for u, v in edge_index.t().tolist()]
#             edge_mask = torch.tensor(edge_mask, dtype=torch.bool)
#             edge_index_filtered = edge_index[:, edge_mask]

#             # old -> new 编号
#             old2new = {old.item(): new for new, old in enumerate(indices_final)}

#             edge_index_final = edge_index_filtered.clone()
#             for i in range(edge_index_filtered.shape[1]):
#                 edge_index_final[0, i] = old2new[edge_index_filtered[0, i].item()]
#                 edge_index_final[1, i] = old2new[edge_index_filtered[1, i].item()]

#             # 删除 src-dst 连接
#             src_new, dst_new = 0, 1
#             mask1 = ~((edge_index_final[0] == src_new) & (edge_index_final[1] == dst_new))
#             mask2 = ~((edge_index_final[0] == dst_new) & (edge_index_final[1] == src_new))
#             mask = mask1 & mask2

#             edge_index_final_no_sd = edge_index_final[:, mask]

#             z, maxz = drnl_node_labeling(edge_index_final_no_sd, src_new, dst_new, num_nodes=node_idx_final.size(0))
#             data = Data(x = x_sub_final, z = z, edge_index = edge_index_final_no_sd, y = y)
#             data_list.append(data)
#     if y == 0:
#         for src, dst in data.neg_edge_label_index.t().tolist():
#             node_idx, edge_index, mapping, _ = k_hop_subgraph(
#                 [src, dst], num_hops, data.edge_index, relabel_nodes=True)
#             src, dst = mapping.tolist()
#             x_sub = data.x[node_idx]
#             all_indices = torch.arange(len(node_idx))
#             mask_src_dst = torch.ones(len(node_idx), dtype=torch.bool)
#             mask_src_dst[src] = False
#             mask_src_dst[dst] = False
#             x_sub_wo_srcdst = x_sub[mask_src_dst]
#             with torch.no_grad():
#                 score = scorer(x_sub_wo_srcdst)

#             n = min(top_n, x_sub_wo_srcdst.size(0))
#             idx_top = score.topk(n).indices
#             neighbors_idx = all_indices[mask_src_dst]
#             topn_in_nodeidx = neighbors_idx[idx_top]
#             indices_final = torch.cat([torch.tensor([src, dst]), topn_in_nodeidx])
#             x_sub_final = x_sub[indices_final]
#             node_idx_final = node_idx[indices_final]

#             # 3. 保留只与这些节点相关的边
#             # 假设你已经得到 indices_final
#             final_set = set(indices_final.tolist())
#             edge_mask = [(u in final_set) and (v in final_set) for u, v in edge_index.t().tolist()]
#             edge_mask = torch.tensor(edge_mask, dtype=torch.bool)
#             edge_index_filtered = edge_index[:, edge_mask]

#             # old -> new 编号
#             old2new = {old.item(): new for new, old in enumerate(indices_final)}

#             edge_index_final = edge_index_filtered.clone()
#             for i in range(edge_index_filtered.shape[1]):
#                 edge_index_final[0, i] = old2new[edge_index_filtered[0, i].item()]
#                 edge_index_final[1, i] = old2new[edge_index_filtered[1, i].item()]

#             # 删除 src-dst 连接
#             src_new, dst_new = 0, 1
#             mask1 = ~((edge_index_final[0] == src_new) & (edge_index_final[1] == dst_new))
#             mask2 = ~((edge_index_final[0] == dst_new) & (edge_index_final[1] == src_new))
#             mask = mask1 & mask2

#             edge_index_final_no_sd = edge_index_final[:, mask]

#             z, maxz = drnl_node_labeling(edge_index_final_no_sd, src_new, dst_new, num_nodes=node_idx_final.size(0))
#             data = Data(x = x_sub_final, z = z, edge_index = edge_index_final_no_sd, y = y)
#             data_list.append(data)
#     return data, maxz



In [155]:
torch.manual_seed(1)
dataset = Planetoid('./data/Planetoid', name='Cora')
data = dataset[0].to(device)
#随机初始化data.x
num_nodes = data.num_nodes
num_features = 256
data.x = torch.randn((num_nodes, num_features)).to(device)
data.x_z = None

# 用PyG的RandomLinkSplit划分数据集
from torch_geometric.transforms import RandomLinkSplit
transform = RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, split_labels=True)
train_data, val_data, test_data = transform(data)
train_data = train_data.to(device)
val_data = val_data.to(device)
test_data = test_data.to(device)

print(train_data)

# def make_edge_list(pos_edge_index, neg_edge_index):
#     pos_list = [(int(src), int(dst), 1) for src, dst in pos_edge_index.t()]
#     neg_list = [(int(src), int(dst), 0) for src, dst in neg_edge_index.t()]
#     return pos_list + neg_list

# train_edges = make_edge_list(train_data.pos_edge_label_index, train_data.neg_edge_label_index)
# val_edges = make_edge_list(val_data.pos_edge_label_index, val_data.neg_edge_label_index)
# test_edges = make_edge_list(test_data.pos_edge_label_index, test_data.neg_edge_label_index)

# random.shuffle(train_edges)


Data(x=[2708, 256], edge_index=[2, 8976], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], pos_edge_label=[4488], pos_edge_label_index=[2, 4488], neg_edge_label=[4488], neg_edge_label_index=[2, 4488])


In [156]:
class DGCNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, k=20):
        super().__init__()
        from torch_geometric.nn import GCNConv, global_sort_pool
        self.convs = nn.ModuleList([GCNConv(input_dim, hidden_dim)])
        for _ in range(num_layers-1):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
        self.pool = global_sort_pool
        self.k = k
        self.lin = nn.Linear(hidden_dim * self.k, 1)  # 修改这里！

    def forward(self, x, edge_index, batch):
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
        x = self.pool(x, batch, self.k)  # [batch_size, k * hidden_dim]
        return self.lin(x).squeeze(-1)


In [157]:
from sklearn.metrics import roc_auc_score, average_precision_score
import torch

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    ys, preds = [], []
    for data in loader:
        data = data.to(device)
        out = model(data.x_z, data.edge_index, data.batch)
        pred = torch.sigmoid(out).cpu().numpy()
        y = data.y.cpu().numpy()
        preds.append(pred)
        ys.append(y)
    ys = np.concatenate(ys)
    preds = np.concatenate(preds)
    auc = roc_auc_score(ys, preds)
    ap = average_precision_score(ys, preds)
    return auc, ap

In [158]:
# 第一步，预先统计所有子图的maxz，假设你已经有 global_maxz
global_maxz = 30  # 举例，你要实际算一遍所有训练/验证子图的maxz

scorer = NeighborScorer(input_dim=data.x.size(1), hidden_dim=64).to(device)
model = DGCNN(global_maxz+1, hidden_dim=32, num_layers=3).to(device)
optimizer = torch.optim.Adam(list(model.parameters()) + list(scorer.parameters()), lr=0.001)
criterion = nn.BCEWithLogitsLoss()
num_hops = 2
top_n = 10
num_epochs = 10
# val_pos_init, _ = dynamic_prune_subgraph(val_data, y=1, num_hops=num_hops, scorer=scorer, top_n=top_n)
# val_neg_init, _ = dynamic_prune_subgraph(val_data, y=0, num_hops=num_hops, scorer=scorer, top_n=top_n)
# val_loader_init = DataLoader(val_pos_init + val_neg_init, batch_size=32, shuffle=True)
# test_pos_init, _ = dynamic_prune_subgraph(test_data, y=1, num_hops=num_hops, scorer=scorer, top_n=top_n)
# test_neg_init, _ = dynamic_prune_subgraph(test_data, y=0, num_hops=num_hops, scorer=scorer, top_n=top_n)
# test_loader_init = DataLoader(test_pos_init + test_neg_init, batch_size=32, shuffle=True)
# auc_val_init, ap_val_init = evaluate(model, val_loader_init, device)
# auc_test_init, ap_test_init = evaluate(model, test_loader_init, device)
# print(f'训练前 val: AUC={auc_val_init:.4f}, AP={ap_val_init:.4f}')
# print(f'训练前 test: AUC={auc_test_init:.4f}, AP={ap_test_init:.4f}')



for epoch in range(num_epochs):
    model.train()
    scorer.train()
    total_loss = 0
    train_pos_data_list, maxz1 = dynamic_prune_subgraph(train_data, y=1, num_hops=num_hops, scorer=scorer, top_n=top_n)
    train_neg_data_list, maxz2 = dynamic_prune_subgraph(train_data, y=0, num_hops=num_hops, scorer=scorer, top_n=top_n)
    # maxz = max(maxz1, maxz2)  # 不需要每轮用，统一用 global_maxz
    for data in chain(train_pos_data_list, train_neg_data_list):
        data.x_z = F.one_hot(data.z, global_maxz + 1).to(torch.float)
    train_data_list = train_pos_data_list + train_neg_data_list
    train_loader = DataLoader(train_data_list, batch_size=32, shuffle=True)
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x_z, data.edge_index, data.batch)
        y = data.y.float().to(out.device)
        if y.shape != out.shape:
            y = y.view_as(out)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
    print(f'Epoch {epoch} | Train Loss: {total_loss / len(train_data_list):.4f}')


val_pos_final, _ = dynamic_prune_subgraph(val_data, y=1, num_hops=num_hops, scorer=scorer, top_n=top_n)
val_neg_final, _ = dynamic_prune_subgraph(val_data, y=0, num_hops=num_hops, scorer=scorer, top_n=top_n)
val_loader_final = DataLoader(val_pos_final + val_neg_final, batch_size=32, shuffle=False)

test_pos_final, _ = dynamic_prune_subgraph(test_data, y=1, num_hops=num_hops, scorer=scorer, top_n=top_n)
test_neg_final, _ = dynamic_prune_subgraph(test_data, y=0, num_hops=num_hops, scorer=scorer, top_n=top_n)
test_loader_final = DataLoader(test_pos_final + test_neg_final, batch_size=32, shuffle=False)
auc_val_final, ap_val_final = evaluate(model, val_loader_final, device)
auc_test_final, ap_test_final = evaluate(model, test_loader_final, device)
print(f'训练后 val: AUC={auc_val_final:.4f}, AP={ap_val_final:.4f}')
print(f'训练后 test: AUC={auc_test_final:.4f}, AP={ap_test_final:.4f}')






Epoch 0 | Train Loss: 0.6259
Epoch 1 | Train Loss: 0.5862
Epoch 2 | Train Loss: 0.5765
Epoch 3 | Train Loss: 0.5726
Epoch 4 | Train Loss: 0.5698
Epoch 5 | Train Loss: 0.5723
Epoch 6 | Train Loss: 0.5673
Epoch 7 | Train Loss: 0.5681
Epoch 8 | Train Loss: 0.5663
Epoch 9 | Train Loss: 0.5635


AttributeError: 'GlobalStorage' object has no attribute 'x_z'

In [None]:
# scorer = NeighborScorer(input_dim=data.x.size(1), hidden_dim=64)
# model = DGCNN(31, hidden_dim=32, num_layers=3)
# optimizer = torch.optim.Adam(list(model.parameters()) + list(scorer.parameters()), lr=0.001)
# criterion = nn.BCEWithLogitsLoss()
# num_hops = 2
# top_n = 10
# num_epochs = 10

# for epoch in range(num_epochs):
#     model.train()
#     scorer.train()
#     total_loss = 0
#     train_pos_data_list, maxz1 = dynamic_prune_subgraph(train_data,y =1, num_hops=num_hops, scorer = scorer, top_n = 10)
#     train_neg_data_list, maxz2 = dynamic_prune_subgraph(train_data,y =1, num_hops=num_hops, scorer = scorer, top_n = 10)
#     maxz = max(maxz1,maxz2)
#     print(maxz)
#     for data in chain(train_pos_data_list, train_neg_data_list,):
#         data.x = F.one_hot(data.z, maxz + 1).to(torch.float)
#     train_data_list = train_pos_data_list + train_neg_data_list
#     train_loader = DataLoader(train_data_list, batch_size=32, shuffle=True)
#     for data in train_loader:
#         optimizer.zero_grad()
#         out = model(data.x, data.edge_index, data.batch)
#         loss =loss = criterion(out.view(-1), data.y)
#         loss.backward()
#         optimizer()
#         total_loss+= float(loss) * data.num_graphs
#     print(f'Epoch {epoch} | Train Loss: {total_loss / len(train_data_list):.4f}')
