In [1]:
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import negative_sampling
from torch_geometric.data import InMemoryDataset, download_url
from torch_geometric.io import read_planetoid_data
from torch_geometric.nn import GATConv, Sequential, GCNConv
from torch_geometric.transforms import NormalizeFeatures, RandomLinkSplit
import os
from sklearn.metrics import roc_auc_score
from typing import Any, Callable, List, Optional, Tuple, Union

## 一个简化的 InMemory 数据集类
以公开数据集`PubMed`为例。`PubMed`数据集存储的是文章引用网络，文章对应图的结点，如果两篇文章存在引用关系（无论引用与被引用），则这两篇文章对应的结点之间存在边。

In [2]:
class PlanetoidPubMed(InMemoryDataset):
    r"""The citation network datasets "PubMed" from the
    `"Revisiting Semi-Supervised Learning with Graph Embeddings"
    <https://arxiv.org/abs/1603.08861>`_ paper.
    Nodes represent documents and edges represent citation links.
    Training, validation and test splits are given by binary masks.

    Args:
        root (string): Root directory where the dataset should be saved.
        split (string): The type of dataset split
            (:obj:`"public"`, :obj:`"full"`, :obj:`"random"`).
            If set to :obj:`"public"`, the split will be the public fixed split
            from the
            `"Revisiting Semi-Supervised Learning with Graph Embeddings"
            <https://arxiv.org/abs/1603.08861>`_ paper.
            If set to :obj:`"full"`, all nodes except those in the validation
            and test sets will be used for training (as in the
            `"FastGCN: Fast Learning with Graph Convolutional Networks via
            Importance Sampling" <https://arxiv.org/abs/1801.10247>`_ paper).
            If set to :obj:`"random"`, train, validation, and test sets will be
            randomly generated, according to :obj:`num_train_per_class`,
            :obj:`num_val` and :obj:`num_test`. (default: :obj:`"public"`)
        num_train_per_class (int, optional): The number of training samples
            per class in case of :obj:`"random"` split. (default: :obj:`20`)
        num_val (int, optional): The number of validation samples in case of
            :obj:`"random"` split. (default: :obj:`500`)
        num_test (int, optional): The number of test samples in case of
            :obj:`"random"` split. (default: :obj:`1000`)
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
    """
    url = 'https://github.com/kimiyoung/planetoid/raw/master/data'

    def __init__(self, root, split="public", num_train_per_class=20,
                 num_val=500, num_test=1000, transform=None,
                 pre_transform=None):

        super(PlanetoidPubMed, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

        self.split = split
        assert self.split in ['public', 'full', 'random']

        if split == 'full':
            data = self.get(0)
            data.train_mask.fill_(True)
            data.train_mask[data.val_mask | data.test_mask] = False
            self.data, self.slices = self.collate([data])

        elif split == 'random':
            data = self.get(0)
            data.train_mask.fill_(False)
            for c in range(self.num_classes):
                idx = (data.y == c).nonzero(as_tuple=False).view(-1)
                idx = idx[torch.randperm(idx.size(0))[:num_train_per_class]]
                data.train_mask[idx] = True

            remaining = (~data.train_mask).nonzero(as_tuple=False).view(-1)
            remaining = remaining[torch.randperm(remaining.size(0))]

            data.val_mask.fill_(False)
            data.val_mask[remaining[:num_val]] = True

            data.test_mask.fill_(False)
            data.test_mask[remaining[num_val:num_val + num_test]] = True

            self.data, self.slices = self.collate([data])

    @property
    def raw_dir(self):
        return os.path.join(self.root, 'raw')

    @property
    def processed_dir(self):
        return os.path.join(self.root, 'processed')

    @property
    def raw_file_names(self):
        names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']
        return ['ind.pubmed.{}'.format(name) for name in names]

    @property
    def processed_file_names(self):
        return 'data.pt'

    def download(self):
        for name in self.raw_file_names:
            download_url('{}/{}'.format(self.url, name), self.raw_dir)

    def process(self):
        data = read_planetoid_data(self.raw_dir, 'pubmed')
        data = data if self.pre_transform is None else self.pre_transform(data)
        torch.save(self.collate([data]), self.processed_paths[0])

    def __repr__(self):
        return '{}()'.format(self.name)

In [3]:
dataset = PlanetoidPubMed('../datasets/PlanetoidPubMed', transform=NormalizeFeatures())
print(dataset.num_classes)
print(dataset[0].num_edges)
print(dataset[0].num_nodes)
print(dataset[0].num_features)

3
88648
19717
500


## 节点预测
### 定义网络

In [4]:
class GAT(torch.nn.Module):
    def __init__(self, num_features, hidden_channels_list, num_classes, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        hns = [num_features] + hidden_channels_list
        conv_list = []
        for idx in range(len(hidden_channels_list)):
            conv_list.append((GATConv(hns[idx], hns[idx + 1]), 'x, edge_index -> x'))
            conv_list.append(torch.nn.ReLU(inplace=True), )
        self.convseq = Sequential('x, edge_index', conv_list)
        self.linear = torch.nn.Linear(hidden_channels_list[-1], num_classes)
    
    def forward(self, x, edge_index):
        x = self.convseq(x, edge_index)
        x = torch.nn.functional.dropout(x, p=0.5, training=self.training)
        return self.linear(x)

### 训练函数

In [5]:
def train(model, data, optimizer, criterion):
    model.train()
    optimizer.zero_grad()  # Clear gradients.
    out = model(data.x, data.edge_index)  # Perform a single forward pass.
    # Compute the loss solely based on the training nodes.
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss

def test(model, data):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)  # Use the class with highest probability.
    test_correct = pred[data.test_mask] == data.y[data.test_mask]  # Check against ground-truth labels.
    test_acc = int(test_correct.sum()) / int(data.test_mask.sum())  # Derive ratio of correct predictions.
    return test_acc

### 训练和测试

In [6]:
data = dataset[0].cuda()
model = GAT(num_features=dataset.num_features, hidden_channels_list=[200, 100], num_classes=dataset.num_classes).cuda()
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(200):
    loss = train(model, data, optimizer, criterion)
    print(f'Epoch {epoch + 1:03d}, Loss {loss:.4f}')

test_acc = test(model, data)
print(f'Test Accuracy: {test_acc:.4f}')

GAT(
  (convseq): Sequential(
    (0) - GATConv(500, 200, heads=1): x, edge_index -> x
    (1) - ReLU(inplace=True): x -> x
    (2) - GATConv(200, 100, heads=1): x, edge_index -> x
    (3) - ReLU(inplace=True): x -> x
  )
  (linear): Linear(in_features=100, out_features=3, bias=True)
)
Epoch 001, Loss 1.0988
Epoch 002, Loss 1.0957
Epoch 003, Loss 1.0931
Epoch 004, Loss 1.0832
Epoch 005, Loss 1.0769
Epoch 006, Loss 1.0673
Epoch 007, Loss 1.0215
Epoch 008, Loss 0.9741
Epoch 009, Loss 0.9371
Epoch 010, Loss 0.8782
Epoch 011, Loss 0.7742
Epoch 012, Loss 0.7022
Epoch 013, Loss 0.5933
Epoch 014, Loss 0.4859
Epoch 015, Loss 0.4228
Epoch 016, Loss 0.3232
Epoch 017, Loss 0.2407
Epoch 018, Loss 0.2863
Epoch 019, Loss 0.1684
Epoch 020, Loss 0.1706
Epoch 021, Loss 0.0979
Epoch 022, Loss 0.1455
Epoch 023, Loss 0.0624
Epoch 024, Loss 0.0477
Epoch 025, Loss 0.0444
Epoch 026, Loss 0.0518
Epoch 027, Loss 0.0475
Epoch 028, Loss 0.0157
Epoch 029, Loss 0.0292
Epoch 030, Loss 0.0103
Epoch 031, Loss 0.0258


## 边预测
边预测任务，目标是预测两个节点之间是否存在边。拿到一个图数据集，我们有节点属性`x`，边端点`edge_index`。`edge_index`存储的便是正样本。为了构建边预测任务，我们需要生成一些负样本，即采样一些不存在边的节点对作为负样本边，正负样本数量应平衡。此外要将样本分为训练集、验证集和测试集三个集合。
PyG中为我们提供了现成的采样负样本边的方法，`RandomLinkSplit`
该函数将自动地采样得到负样本，并将正负样本分成训练集、验证集和测试集三个集合。
`edge_index` 在 `RandomLinkSplit` 之后表示剩余的、用于训练模型的边
`edge_label_index` 则用于评估模型在链接预测任务上的性能，包括模型对存在和不存在的边的预测能力。
### 数据

In [7]:
dataset = Planetoid('../datasets/Planetoid/', 'Cora', transform=NormalizeFeatures())
data = dataset[0]
data.train_mask = data.val_mask = data.test_mask = data.y = None
print(data.edge_index.shape)
random_split = RandomLinkSplit(is_undirected=True)
train_data, val_data, test_data = random_split(data)
for name, d in zip(['train', 'val', 'test'], [train_data, val_data, test_data]):
    print(f'{name}: {d}')  

torch.Size([2, 10556])
train: Data(x=[2708, 1433], edge_index=[2, 7392], edge_label=[7392], edge_label_index=[2, 7392])
val: Data(x=[2708, 1433], edge_index=[2, 7392], edge_label=[1054], edge_label_index=[2, 1054])
test: Data(x=[2708, 1433], edge_index=[2, 8446], edge_label=[2110], edge_label_index=[2, 2110])


### 定义网络

In [8]:
class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv1 = GCNConv(in_channels, 128)
        self.conv2 = GCNConv(128, out_channels)
    
    def encode(self, x, edge_index):
        x = torch.nn.functional.relu(self.conv1(x, edge_index))
        return self.conv2(x, edge_index)
    
    def decode(self, z, edge_index):
        """
        根据边两端节点的表征生成边为真的几率
        """
        return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
    
    def decode_all(self, z):
        """
        用于推理阶段，对所有的节点对预测存在边的几率
        """
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

### 训练函数

In [9]:
def train(model, data, optimizer):
    data = data.to(next(model.parameters()).device)
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x, data.edge_label_index)
    link_logits = model.decode(z, data.edge_label_index)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(link_logits, data.edge_label)
    loss.backward()
    optimizer.step()
    return loss

def test(model, data):
    data = data.to(next(model.parameters()).device)
    model.eval()
    with torch.no_grad():
        z = model.encode(data.x, data.edge_label_index)
        link_logits = model.decode(z, data.edge_label_index)
    link_probs = torch.sigmoid(link_logits)
    return roc_auc_score(data.edge_label.cpu(), link_probs.cpu())

### 训练代码

In [11]:
model = Net(dataset.num_features, 64).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

best_val_auc = best_test_auc = 0
for epoch in range(100):
    loss = train(model, train_data, optimizer)
    val_auc = test(model, val_data)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
    test_auc = test(model, test_data)
    if test_acc > best_test_auc:
        best_test_auc = test_acc
    print(f'Epoch: {epoch + 1:03d}, Loss: {loss:.4f}, Val: {best_val_auc:.4f}, Test: {best_test_auc:.4f}')

Epoch: 001, Loss: 0.6931, Val: 0.6211, Test: 0.7780
Epoch: 002, Loss: 0.6881, Val: 0.6211, Test: 0.7780
Epoch: 003, Loss: 0.6893, Val: 0.6211, Test: 0.7780
Epoch: 004, Loss: 0.6826, Val: 0.6211, Test: 0.7780
Epoch: 005, Loss: 0.6949, Val: 0.6211, Test: 0.7780
Epoch: 006, Loss: 0.6775, Val: 0.6211, Test: 0.7780
Epoch: 007, Loss: 0.6838, Val: 0.6211, Test: 0.7780
Epoch: 008, Loss: 0.6867, Val: 0.6211, Test: 0.7780
Epoch: 009, Loss: 0.6866, Val: 0.6211, Test: 0.7780
Epoch: 010, Loss: 0.6836, Val: 0.6211, Test: 0.7780
Epoch: 011, Loss: 0.6769, Val: 0.6211, Test: 0.7780
Epoch: 012, Loss: 0.6705, Val: 0.6211, Test: 0.7780
Epoch: 013, Loss: 0.6770, Val: 0.6211, Test: 0.7780
Epoch: 014, Loss: 0.6693, Val: 0.6211, Test: 0.7780
Epoch: 015, Loss: 0.6648, Val: 0.6211, Test: 0.7780
Epoch: 016, Loss: 0.6658, Val: 0.6211, Test: 0.7780
Epoch: 017, Loss: 0.6647, Val: 0.6211, Test: 0.7780
Epoch: 018, Loss: 0.6603, Val: 0.6211, Test: 0.7780
Epoch: 019, Loss: 0.6576, Val: 0.6211, Test: 0.7780
Epoch: 020, 