In [1]:
import dgl

In [5]:
from dgl.contrib.data import load_data
import numpy as np
data = load_data(dataset='aifb')

Loading dataset aifb
Number of nodes:  8285
Number of edges:  66371
Number of relations:  91
Number of classes:  4
removing nodes that are more than 3 hops away


In [4]:
"""This is a demo for graph classification with dgl where we have a
synthetic dataset consisting of cycle_graphs and star_graphs and
we want to perform a binary classification."""
import dgl
import dgl.function as fn
import networkx as nx
import random
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Generate a synthetic dataset
ds = []
for i in range(5):
    num_nodes = 5
    cycle = dgl.DGLGraph(nx.cycle_graph(num_nodes))
    cycle.ndata['h'] = th.ones(num_nodes, 1)
    ds.append((cycle, 0))

    star = dgl.DGLGraph(nx.star_graph(num_nodes - 1))
    star.ndata['h'] = th.ones(num_nodes, 1)
    ds.append((star, 1))

random.shuffle(ds)
g_list = [data[0] for data in ds]
labels = th.tensor([data[1] for data in ds]).float().view(-1, 1)

# Model and optimizer
weight1 = th.randn((1, 16), requires_grad=True)
weight2 = th.randn((16, 1), requires_grad=True)
optimizer = optim.Adam([weight1, weight2], lr=0.01)
loss_func = nn.BCELoss()

# Configure message passing. With the message func and reduce func defined
# below, the updated node feature will simply be node degree.
msg_func = fn.copy_src(src='h', out='m')
reduce_func = fn.sum(msg='m', out='h')

# Training
for i in range(100):
    bg = dgl.batch(g_list)
    # Perform message passing
    bg.update_all(msg_func, reduce_func)
    # Readout and get graph features for the 10 graphs.
    bg_h = dgl.mean_nodes(bg, 'h')
    logits = F.relu(th.mm(bg_h, weight1))
    logits = th.mm(logits, weight2)
    prediction = th.sigmoid(logits)

    optimizer.zero_grad()
    loss = loss_func(prediction, labels)
    print('The prediction loss is {}'.format(loss))
    loss.backward()
    optimizer.step()

The prediction loss is 1.656454086303711
The prediction loss is 1.520641803741455
The prediction loss is 1.3898983001708984
The prediction loss is 1.2655490636825562
The prediction loss is 1.1491563320159912
The prediction loss is 1.0424907207489014
The prediction loss is 0.9474363327026367
The prediction loss is 0.8658297657966614
The prediction loss is 0.7992105484008789
The prediction loss is 0.7485160827636719
The prediction loss is 0.7137940526008606
The prediction loss is 0.6940209865570068
The prediction loss is 0.6871294379234314
The prediction loss is 0.6902592182159424
The prediction loss is 0.7001688480377197
The prediction loss is 0.7136803269386292
The prediction loss is 0.7280322909355164
The prediction loss is 0.7410855293273926
The prediction loss is 0.7513815760612488
The prediction loss is 0.758097767829895
The prediction loss is 0.7609473466873169
The prediction loss is 0.760063111782074
The prediction loss is 0.7558855414390564
The prediction loss is 0.7490631937980