In [1]:
import os

import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
dataset = dgl.data.GINDataset("PROTEINS", self_loop=True)

In [3]:
print("Node feature dimensionality:", dataset.dim_nfeats)
print("Number of graph categories:", dataset.gclasses)
print(len(dataset))
print(type(dataset))
for g, label in dataset:
    print(g, "\n", label.item(), "\n-------------------")

Node feature dimensionality: 3
Number of graph categories: 2
1113
<class 'dgl.data.gindt.GINDataset'>
Graph(num_nodes=42, num_edges=204,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}) 
 0 
-------------------
Graph(num_nodes=27, num_edges=119,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}) 
 0 
-------------------
Graph(num_nodes=10, num_edges=44,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}) 
 0 
-------------------
Graph(num_nodes=24, num_edges=116,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}) 
 0 
-------------------
Graph(num_nodes=11, num_edges=53,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64)

In [4]:
print(g.ndata["label"].shape)
print(g.ndata["attr"].shape)
g.ndata["label"]

torch.Size([40])
torch.Size([40, 3])


tensor([0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [5]:
from torch.utils.data.sampler import SubsetRandomSampler
from dgl.dataloading import GraphDataLoader

num_examples = len(dataset)
num_test = int(num_examples * 0.2)
num_train = num_examples - num_test

test_sampler = SubsetRandomSampler(torch.arange(num_test))
train_sampler = SubsetRandomSampler(torch.arange(num_test, num_examples))

print(f"test size: {len(test_sampler)}")
print(f"train size: {len(train_sampler)}")

test size: 222
train size: 891


In [145]:
# 与pytorch类似
train_dataloader = GraphDataLoader(dataset, sampler=train_sampler, batch_size=32)
test_dataloader = GraphDataLoader(dataset, sampler=test_sampler, batch_size=32)

In [7]:
batch = next(iter(train_dataloader))
print(batch)
# 直接next-iter得到的图是一坨batch
print("-----------------------------------------------")
batched_graph, labels = batch
print(batched_graph)
print("-----------------------------------------------")

# Recover the original graph elements from the minibatch
graphs = dgl.unbatch(batched_graph)
cnt_nodes = 0
cnt_edges = 0
for i, g in enumerate(graphs):
    print(g, "   label:", labels[i].item())
    cnt_nodes += g.num_nodes()
    cnt_edges += g.num_edges()

[Graph(num_nodes=515, num_edges=2383,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), tensor([1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0])]
-----------------------------------------------
Graph(num_nodes=515, num_edges=2383,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={})
-----------------------------------------------
Graph(num_nodes=51, num_edges=229,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={})    label: 1
Graph(num_nodes=46, num_edges=218,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={})    label: 1
Graph(num_nodes=15, num_edges=71,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(

In [8]:
print(cnt_nodes, cnt_edges)

515 2383


In [106]:
from dgl.nn import GraphConv


class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)

        g.ndata["out"] = h
        return dgl.mean_nodes(g, "out")

In [146]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [147]:
model = GCN(dataset.dim_nfeats, 16, dataset.num_classes).to(device)


In [157]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001,)


In [151]:
def train(model, epoch=50):
    model.train()
    for e in range(epoch):
        for batched_graph, labels in train_dataloader:
            batched_graph = batched_graph.to(device)
            labels = labels.to(device)
            logits = model(batched_graph, batched_graph.ndata["attr"])
            loss = F.cross_entropy(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        if e % 10 == 0:
            train_acc = compute_acc(model, train_dataloader)
            test_acc = compute_acc(model, test_dataloader)
            print(f"{e}: loss: {loss.item():.6f}  train acc: {train_acc:.4f},  test acc: {test_acc:.4f}")

In [152]:
def compute_acc(model, dataloader):
    model.eval()
    with torch.no_grad():
        num_correct = 0
        num_tests = 0
        for batched_graph, labels in dataloader:
            batched_graph = batched_graph.to(device)
            labels = labels.to(device)
            pred = model(batched_graph, batched_graph.ndata["attr"]).argmax(1)
            num_correct += (pred == labels).sum()
            num_tests += len(labels)
        return (num_correct / num_tests).item()

In [158]:
train(model, 100)

0: loss: 0.594593  train acc: 0.6577,  test acc: 0.7027
10: loss: 0.625737  train acc: 0.6566,  test acc: 0.7027
20: loss: 0.757888  train acc: 0.6554,  test acc: 0.6982
30: loss: 0.497171  train acc: 0.6554,  test acc: 0.6982
40: loss: 0.534115  train acc: 0.6554,  test acc: 0.7027
50: loss: 0.595600  train acc: 0.6566,  test acc: 0.7072
60: loss: 0.741253  train acc: 0.6566,  test acc: 0.6982
70: loss: 0.549053  train acc: 0.6577,  test acc: 0.7027
80: loss: 0.562436  train acc: 0.6577,  test acc: 0.7027
90: loss: 0.611415  train acc: 0.6554,  test acc: 0.7027
