https://github.com/datawhalechina/team-learning-nlp/blob/master/GNN/Markdown%E7%89%88%E6%9C%AC/6-2-%E8%8A%82%E7%82%B9%E9%A2%84%E6%B5%8B%E4%B8%8E%E8%BE%B9%E9%A2%84%E6%B5%8B%E4%BB%BB%E5%8A%A1%E5%AE%9E%E8%B7%B5.md

In [1]:
import os.path as osp

from torch_geometric.utils import negative_sampling
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import train_test_split_edges
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score



In [None]:
dataset = Planetoid('dataset', 'Cora', transform=T.NormalizeFeatures())
data = dataset[0]
data.train_mask = data.val_mask = data.test_mask = data.y = None # 不再有用

print(data.edge_index.shape)
# torch.Size([2, 10556])

data = train_test_split_edges(data)
# data.sort(sort_by_row=False)

for key in data.keys():
    print(key, getattr(data, key).shape)

# x torch.Size([2708, 1433])
# val_pos_edge_index torch.Size([2, 263])
# test_pos_edge_index torch.Size([2, 527])
# train_pos_edge_index torch.Size([2, 8976])
# train_neg_adj_mask torch.Size([2708, 2708])
# val_neg_edge_index torch.Size([2, 263])
# test_neg_edge_index torch.Size([2, 527])
# 263 + 527 + 8976 = 9766 != 10556
# 263 + 527 + 8976/2 = 5278 = 10556/2

In [2]:
import torch
from torch_geometric.nn import GCNConv, SAGEConv, GATConv

class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Net, self).__init__()
        self.conv1 = GCNConv(in_channels, 128) # SAGEConv(in_channels, 128, "max")
        self.conv2 = GCNConv(128, out_channels) # SAGEConv(128, out_channels, "max")

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        return self.conv2(x, edge_index)

    def decode(self, z, pos_edge_index, neg_edge_index):
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
        return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()


In [3]:
def get_link_labels(pos_edge_index, neg_edge_index):
    num_links = pos_edge_index.size(1) + neg_edge_index.size(1)
    link_labels = torch.zeros(num_links, dtype=torch.float)
    link_labels[:pos_edge_index.size(1)] = 1.
    return link_labels

def train(data, model, optimizer):
    model.train()

    neg_edge_index = negative_sampling(
        edge_index=data.train_pos_edge_index,
        num_nodes=data.num_nodes,
        num_neg_samples=data.train_pos_edge_index.size(1))

    optimizer.zero_grad()
    z = model.encode(data.x, data.train_pos_edge_index)
    link_logits = model.decode(z, data.train_pos_edge_index, neg_edge_index)
    link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index).to(data.x.device)
    loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
    loss.backward()
    optimizer.step()

    return loss


In [4]:
@torch.no_grad()
def test(data, model):
    model.eval()

    z = model.encode(data.x, data.train_pos_edge_index)

    results = []
    for prefix in ['val', 'test']:
        pos_edge_index = data[f'{prefix}_pos_edge_index']
        neg_edge_index = data[f'{prefix}_neg_edge_index']
        link_logits = model.decode(z, pos_edge_index, neg_edge_index)
        link_probs = link_logits.sigmoid()
        link_labels = get_link_labels(pos_edge_index, neg_edge_index)
        results.append(roc_auc_score(link_labels.cpu(), link_probs.cpu()))
    return results


In [5]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = 'Cora'
path = osp.join('..', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]
ground_truth_edge_index = data.edge_index.to(device)
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)
data = data.to(device)

model = Net(dataset.num_features, 64).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)

best_val_auc = test_auc = 0
for epoch in range(1, 101):
    loss = train(data, model, optimizer)
    val_auc, tmp_test_auc = test(data, model)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        test_auc = tmp_test_auc
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
          f'Test: {test_auc:.4f}')

z = model.encode(data.x, data.train_pos_edge_index)
final_edge_index = model.decode_all(z)




Epoch: 001, Loss: 0.6930, Val: 0.6575, Test: 0.7050
Epoch: 002, Loss: 0.6808, Val: 0.6504, Test: 0.7050
Epoch: 003, Loss: 0.7211, Val: 0.6616, Test: 0.7016
Epoch: 004, Loss: 0.6766, Val: 0.6911, Test: 0.7085
Epoch: 005, Loss: 0.6851, Val: 0.7499, Test: 0.7423
Epoch: 006, Loss: 0.6896, Val: 0.7637, Test: 0.7746
Epoch: 007, Loss: 0.6910, Val: 0.6941, Test: 0.7746
Epoch: 008, Loss: 0.6913, Val: 0.6535, Test: 0.7746
Epoch: 009, Loss: 0.6906, Val: 0.6425, Test: 0.7746
Epoch: 010, Loss: 0.6888, Val: 0.6426, Test: 0.7746
Epoch: 011, Loss: 0.6853, Val: 0.6483, Test: 0.7746
Epoch: 012, Loss: 0.6812, Val: 0.6528, Test: 0.7746
Epoch: 013, Loss: 0.6809, Val: 0.6634, Test: 0.7746
Epoch: 014, Loss: 0.6800, Val: 0.6766, Test: 0.7746
Epoch: 015, Loss: 0.6750, Val: 0.6889, Test: 0.7746
Epoch: 016, Loss: 0.6705, Val: 0.6964, Test: 0.7746
Epoch: 017, Loss: 0.6675, Val: 0.6997, Test: 0.7746
Epoch: 018, Loss: 0.6635, Val: 0.6978, Test: 0.7746
Epoch: 019, Loss: 0.6582, Val: 0.6944, Test: 0.7746
Epoch: 020, 

In [6]:
# Epoch: 100, Loss: 0.4344, Val: 0.9281, Test: 0.8863 GraphSAGE mean
# Epoch: 100, Loss: 0.4129, Val: 0.8898, Test: 0.8938 GraphSAGE max
# Epoch: 100, Loss: 0.4421, Val: 0.9054, Test: 0.9134 GCN
# Epoch: 100, Loss: 0.4439, Val: 0.9164, Test: 0.9056 GAT