https://medium.com/@pytorch_geometric/link-prediction-on-heterogeneous-graphs-with-pyg-6d5c29677c70

In [1]:
import os.path as osp
import torch
from torch_geometric.utils import negative_sampling
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import train_test_split_edges
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score



In [2]:
dataset = Planetoid('dataset', 'Cora', transform=T.NormalizeFeatures())
data = dataset[0]
data.train_mask = data.val_mask = data.test_mask = data.y = None 

In [3]:
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.3,
    neg_sampling_ratio=2.0,
    add_negative_train_samples=False,    
)
train_data, val_data, test_data = transform(data)

In [4]:
from torch_geometric.loader import LinkNeighborLoader
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[10, 5],
    neg_sampling_ratio=2.0,
    batch_size=128,
    shuffle=True,
)

In [5]:
from torch_geometric.nn import SAGEConv
import torch.nn.functional as F
class GNN(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, 128)
        self.conv2 = SAGEConv(128, out_channels)
    def forward(self, x, edge_index) :
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

class Classifier(torch.nn.Module):
    def forward(self, x_from, x_to,):
        return (x_from * x_to).sum(dim=-1)

class Model(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.gnn = GNN(in_channels, out_channels)
        self.classifier = Classifier()
    def forward(self, data):
        x_out = self.gnn(data.x, data.edge_label_index)
        pred = self.classifier(
            x_out[data.edge_label_index[0]], ## 边的起始点。
            x_out[data.edge_label_index[-1]] ## 边的终结点。
        )
        return pred
        
model = Model(in_channels=data.num_features, out_channels=64)

In [6]:
# !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+cu111.html
# import torch_geometric
# torch_geometric.typing.WITH_TORCH_SPARSE

In [7]:
import tqdm
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(1, 6):
    total_loss = total_examples = 0
    for sampled_data in tqdm.tqdm(train_loader):
        optimizer.zero_grad()
        sampled_data.to(device)
        pred = model(sampled_data)
        ground_truth = sampled_data.edge_label
        loss = F.binary_cross_entropy_with_logits(pred, ground_truth)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * pred.numel()
        total_examples += pred.numel()
    print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")

Device: 'cpu'


100%|███████████████████████████████████████████| 47/47 [00:00<00:00, 78.34it/s]


Epoch: 001, Loss: 0.6944


100%|███████████████████████████████████████████| 47/47 [00:00<00:00, 83.20it/s]


Epoch: 002, Loss: 0.6513


100%|███████████████████████████████████████████| 47/47 [00:00<00:00, 82.26it/s]


Epoch: 003, Loss: 0.6312


100%|███████████████████████████████████████████| 47/47 [00:00<00:00, 82.46it/s]


Epoch: 004, Loss: 0.6139


100%|███████████████████████████████████████████| 47/47 [00:00<00:00, 82.40it/s]

Epoch: 005, Loss: 0.5881





In [8]:
test_loader = LinkNeighborLoader(
    data=val_data,
    num_neighbors=[10, 5],
    neg_sampling_ratio=2.0,
    batch_size=128,
    shuffle=True,
)

In [9]:
from sklearn.metrics import roc_auc_score
preds = []
ground_truths = []
for sampled_data in tqdm.tqdm(test_loader):
    with torch.no_grad():
        sampled_data.to(device)
        preds.append(model(sampled_data))
        ground_truths.append(sampled_data.edge_label)
pred = torch.cat(preds, dim=0).cpu().numpy()
ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()
auc = roc_auc_score(ground_truth, pred)
print()
print(f"Validation AUC: {auc:.4f}")

100%|██████████████████████████████████████████| 66/66 [00:00<00:00, 123.04it/s]


Validation AUC: 0.6984





In [10]:
pred.shape

(25338,)

In [13]:
data

Data(x=[2708, 1433], edge_index=[2, 10556])

In [15]:
data

Data(x=[2708, 1433], edge_index=[2, 10556])

In [14]:
model(data)

AttributeError: 'GlobalStorage' object has no attribute 'edge_label_index'