In [1]:
import os

os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Generate a synthetic dataset with 10000 graphs, ranging from 10 to 500 nodes.
dataset = dgl.data.GINDataset("PROTEINS", self_loop=True) 
#NOTE: self_loop指可以添加自环，即节点到自身的边

In [3]:
print("Node feature dimensionality:", dataset.dim_nfeats)
print("Number of graph categories:", dataset.gclasses)


from dgl.dataloading import GraphDataLoader

Node feature dimensionality: 3
Number of graph categories: 2


## 图分类任务
- 对于一个给定的图，预测一个图的属性（或标签）
- 数据集往往是许多的图，每个图都有不同的属性
- 需要一个操作来把各节点上的特征聚合为一个图整体特征 (**readout**)

类似torch中的DataLoader，dgl提供了GraphDataLoader来加载图数据集

In [None]:
from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(dataset)
num_train = int(num_examples * 0.8)

'''
训练集与测试集8：2
'''

# 8000个训练集，2000个测试集
train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

train_dataloader = GraphDataLoader(
    dataset, sampler=train_sampler, batch_size=5, drop_last=False
) #NOTE: 支持用torch中的sampler
test_dataloader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=5, drop_last=False
)

In [None]:
it = iter(train_dataloader)
batch = next(it)
batched_graph, labels = batch
#NOTE: 返回的是一个batch_graph数据集，数据集包含两个部分，第一个是图对象，第二个是标签

'''
把一个batch的图合并为一个大图dgl.graph对象
'''

print('大图对象：',batched_graph)
print('标签Batch：',labels)


'''
下面来展示每个batch的大图对象确实包含了5个图
'''
print(
    "该batch中各图的节点数:",
    batched_graph.batch_num_nodes()
    )

print(
    "该batch中各图的边数:",
    batched_graph.batch_num_edges()
    )

'''
用unbatch方法将大图对象还原为原始的5个图对象
'''
graphs = dgl.unbatch(batched_graph)
print("The original graphs in the minibatch:")
print(graphs)

大图对象： Graph(num_nodes=146, num_edges=740,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={})
标签Batch： tensor([0, 0, 0, 0, 0])
该batch中各图的节点数: tensor([18, 20, 47, 34, 27])
该batch中各图的边数: tensor([ 96, 104, 227, 166, 147])
The original graphs in the minibatch:
[Graph(num_nodes=18, num_edges=96,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=20, num_edges=104,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=47, num_edges=227,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=34, num_edges=166,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,)

In [22]:
'''
定义一个最终用mean_nodes方法来readout的GCN模型
'''

from dgl.nn import GraphConv


class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata["h"] = h
        return dgl.mean_nodes(g, "h")

In [None]:
# Create the model with given dimensions
model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(20):
    # NOTE: data是有两个元素，所以迭代时要用两个变量接收
    for batched_graph, labels in train_dataloader:
        pred = model(batched_graph, batched_graph.ndata["attr"].float()) # 返回一个graph的embedding
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred = model(batched_graph, batched_graph.ndata["attr"].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_tests += len(labels)

print("Test accuracy:", num_correct / num_tests)

Test accuracy: 0.1031390134529148
