In [2]:
import numpy as np
%matplotlib inline

PML&DL Homework (optional)
=============================
Based on the Tutorial: https://docs.dgl.ai/en/0.4.x/tutorials/basics/4_batch.html 

Instructions for the assignment are given in tags: [PMLDL] ... [/PMLDL]

Graph classification is an important problem
with applications across many fields, such as bioinformatics, chemoinformatics, social
network analysis, urban computing, and cybersecurity. Applying graph neural
networks to this problem has been a popular approach recently. This can be seen in the following reserach references: 
`Ying et al., 2018 <https://arxiv.org/abs/1806.08804>`_,
`Cangea et al., 2018 <https://arxiv.org/abs/1811.01287>`_,
`Knyazev et al., 2018 <https://arxiv.org/abs/1811.09595>`_,
`Bianchi et al., 2019 <https://arxiv.org/abs/1901.01343>`_,
`Liao et al., 2019 <https://arxiv.org/abs/1901.01484>`_,
`Gao et al., 2019 <https://openreview.net/forum?id=HJePRoAct7>`_).


Simple graph classification task
--------------------------------
In this tutorial, you learn how to perform batched graph classification
with DGL. The example task objective is to classify eight types of topologies shown here.

![](https://data.dgl.ai/tutorial/batch/dataset_overview.png)

    :align: center

Implement a synthetic dataset :class:`data.MiniGCDataset` in DGL. The dataset has eight 
different types of graphs and each class has the same number of graph samples.



In [4]:
from IPython.core.display import clear_output
from dgl import DGLHeteroGraph
# Used partially idea from: https://towardsdatascience.com/node2vec-embeddings-for-graph-data-32a866340fef
import dgl
from dgl.data import MiniGCDataset
from node2vec import Node2Vec

'''
[PMLDL]
Check the MiniGCDataset source code and implement a class that generated a vector of features for each node (e.g. using node2vec).
So, result would be a YourNameItGCDataset class. 
[/PMLDL]
'''


class RavidaItGCDAtaset( MiniGCDataset):
    def __init__(self, dataset):
        outs = []
        for pair in dataset:
            graph, label  = pair
            node2vec = Node2Vec(dgl.to_networkx(graph), dimensions=20, walk_length=16,
                                 num_walks=100, workers=2)
            model = node2vec.fit(window = 10, min_count = 1, batch_words = 4)
            emb = model.wv.vectors
            outs.append((graph,emb,label))
        self.out = outs

    def __getitem__(self, item):
        return self.out[item]


import matplotlib.pyplot as plt
import networkx as nx

# A dataset with 80 samples, each graph is
# of size [10, 20]
dataset = MiniGCDataset(80, 10, 20)
graph, label = dataset[0]

gr = RavidaItGCDAtaset(dataset)
clear_output()
'''
[PMLDL]

Here YourNameItGCDataset dataset[0] should return a three elements: 
graph, node_features, label = dataset[0]
node_features stores all node features of the graph

[/PMLDL]
'''

'\n[PMLDL]\n\nHere YourNameItGCDataset dataset[0] should return a three elements: \ngraph, node_features, label = dataset[0]\nnode_features stores all node features of the graph\n\n[/PMLDL]\n'

In [None]:
graphs, feat, labels = gr[0]
print(feat.shape)

Form a graph mini-batch
-----------------------
To train neural networks efficiently, a common practice is to batch
multiple samples together to form a mini-batch. Batching fixed-shaped tensor
inputs is common. For example, batching two images of size 28 x 28
gives a tensor of shape 2 x 28 x 28. By contrast, batching graph inputs
has two challenges:

* Graphs are sparse.
* Graphs can have various length. For example, number of nodes and edges.

To address this, DGL provides a :func:`dgl.batch` API. It leverages the idea that
a batch of graphs can be viewed as a large graph that has many disjointed 
connected components. Below is a visualization that gives the general idea.

![](https://data.dgl.ai/tutorial/batch/batch.png)

    :width: 400pt
    :align: center

Define the following ``collate`` function to form a mini-batch from a given
list of graph and label pairs.



In [7]:
import dgl
import torch
import numpy as np


def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)


def my_collate(samples):
    graphs, features, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)

    features = [item[1] for item in samples]
    batched_features = np.vstack(features)

    return batched_graph, torch.tensor(batched_features), torch.tensor(labels)



'''
[PMLDL]

Here you need to implement a new version of the collate(...) function.
New version should collate (stack) the graph, features and labels in a mini-batch.

[/PMLDL]
'''

'\n[PMLDL]\n\nHere you need to implement a new version of the collate(...) function.\nNew version should collate (stack) the graph, features and labels in a mini-batch.\n\n[/PMLDL]\n'

The return type of :func:`dgl.batch` is still a graph. In the same way, 
a batch of tensors is still a tensor. This means that any code that works
for one graph immediately works for a batch of graphs. More importantly,
because DGL processes messages on all nodes and edges in parallel, this greatly
improves efficiency.

Graph classifier
----------------
Graph classification proceeds as follows.

![](https://data.dgl.ai/tutorial/batch/graph_classifier.png)


From a batch of graphs, perform message passing and graph convolution
for nodes to communicate with others. After message passing, compute a
tensor for graph representation from node (and edge) attributes. This step might 
be called readout or aggregation. Finally, the graph 
representations are fed into a classifier $g$ to predict the graph labels.

Graph convolution layer can be found in the ``dgl.nn.<backend>`` submodule.



In [None]:
from dgl.nn.pytorch import GraphConv

Readout and classification
--------------------------
For this demonstration, consider initial node features to be their degrees.
After two rounds of graph convolution, perform a graph readout by averaging
over all node features for each graph in the batch.

\begin{align}h_g=\frac{1}{|\mathcal{V}|}\sum_{v\in\mathcal{V}}h_{v}\end{align}

In DGL, :func:`dgl.mean_nodes` handles this task for a batch of
graphs with variable size. You then feed the graph representations into a
classifier with one linear layer to obtain pre-softmax logits.



In [6]:
import torch.nn as nn
import torch.nn.functional as F


class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g,h):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.

        # h = g.in_degrees().view(-1, 1).float()

        '''
        [PMLDL]
        This call: h = g.in_degrees().view(-1, 1).float()
        feeds nodes degees list into input of the GCN (so, in_dim == 1).

        Here you need to change the code to feed features of nodes into input layer.
        Thus, in_dim will be the dimension of a feature-vector (say, 50).
        In the lecture slides such vector for a node v was called x_v

        [/PMLDL]
        '''

        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        g.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        hg = dgl.mean_nodes(g, 'h')
        return self.classify(hg)

Setup and training
------------------
Create a synthetic dataset of $400$ graphs with $10$ ~
$20$ nodes. $320$ graphs constitute a training set and
$80$ graphs constitute a test set.



In [8]:
import torch.optim as optim
from torch.utils.data import DataLoader

# Create training and test sets.
'''
[PMLDL]

Here calls to the YourNameItGCDataset class

[/PMLDL]
'''
trainset = MiniGCDataset(320, 10, 20)
testset = MiniGCDataset(80, 10, 20)

In [None]:
test = RavidaItGCDAtaset(testset)
train  = RavidaItGCDAtaset(trainset)

clear_output()

# Use PyTorch's DataLoader and the collate function
# defined before.
'''
[PMLDL]

Here pass new version of collate()

[/PMLDL]
'''
data_loader = DataLoader(train, batch_size=32, shuffle=True,
                         collate_fn=my_collate())

# Create model
'''
[PMLDL]

Here pass new dimension of input (say, 50) as the first parameter.

[/PMLDL]
'''

model = Classifier(1, 256, train.num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()

epoch_losses = []
for epoch in range(80):
    epoch_loss = 0
    for iter, (bg, label) in enumerate(data_loader):
        prediction = model(bg)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)

Computing transition probabilities:   0%|          | 0/15 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1258.15it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1269.20it/s]


Computing transition probabilities:   0%|          | 0/10 [00:00<?, ?it/s]

Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1507.61it/s]
Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1437.39it/s]


Computing transition probabilities:   0%|          | 0/13 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1145.01it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1089.03it/s]


Computing transition probabilities:   0%|          | 0/13 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1411.94it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1378.49it/s]


Computing transition probabilities:   0%|          | 0/17 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1058.95it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1063.99it/s]


Computing transition probabilities:   0%|          | 0/19 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 973.50it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 931.24it/s]


Computing transition probabilities:   0%|          | 0/13 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1397.14it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1395.06it/s]


Computing transition probabilities:   0%|          | 0/15 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1028.79it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1002.14it/s]


Computing transition probabilities:   0%|          | 0/12 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1255.37it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1127.43it/s]


Computing transition probabilities:   0%|          | 0/14 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1212.72it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1228.03it/s]


Computing transition probabilities:   0%|          | 0/17 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 861.49it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 864.28it/s]


Computing transition probabilities:   0%|          | 0/16 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1014.51it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 982.52it/s]


Computing transition probabilities:   0%|          | 0/18 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 872.55it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 848.13it/s]


Computing transition probabilities:   0%|          | 0/18 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 843.38it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 843.83it/s]


Computing transition probabilities:   0%|          | 0/11 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1491.74it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1470.39it/s]


Computing transition probabilities:   0%|          | 0/16 [00:00<?, ?it/s]

Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 961.77it/s]
Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 926.74it/s]


Computing transition probabilities:   0%|          | 0/17 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 600.37it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 605.25it/s]


Computing transition probabilities:   0%|          | 0/17 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 922.97it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 939.56it/s]


Computing transition probabilities:   0%|          | 0/18 [00:00<?, ?it/s]

Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 799.54it/s]
Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 762.34it/s]


Computing transition probabilities:   0%|          | 0/11 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1434.71it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1363.28it/s]


Computing transition probabilities:   0%|          | 0/15 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1097.92it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1087.80it/s]


Computing transition probabilities:   0%|          | 0/19 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 840.74it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 834.31it/s]


Computing transition probabilities:   0%|          | 0/18 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 878.70it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 877.43it/s]


Computing transition probabilities:   0%|          | 0/19 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 819.05it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 804.51it/s]


Computing transition probabilities:   0%|          | 0/14 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1110.47it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1034.79it/s]


Computing transition probabilities:   0%|          | 0/13 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1272.46it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1246.20it/s]


Computing transition probabilities:   0%|          | 0/10 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1558.54it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1626.88it/s]


Computing transition probabilities:   0%|          | 0/13 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1288.93it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1249.46it/s]


Computing transition probabilities:   0%|          | 0/15 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 750.71it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 750.26it/s]


Computing transition probabilities:   0%|          | 0/10 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1657.21it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1597.06it/s]


Computing transition probabilities:   0%|          | 0/12 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1414.40it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1358.42it/s]


Computing transition probabilities:   0%|          | 0/18 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 671.40it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 677.51it/s]


Computing transition probabilities:   0%|          | 0/13 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1146.84it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1133.26it/s]


Computing transition probabilities:   0%|          | 0/13 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1143.68it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1164.16it/s]


Computing transition probabilities:   0%|          | 0/17 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 745.60it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 746.40it/s]


Computing transition probabilities:   0%|          | 0/11 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1334.68it/s]



Computing transition probabilities:   0%|          | 0/19 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 446.79it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 453.17it/s]


Computing transition probabilities:   0%|          | 0/14 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1005.29it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1033.45it/s]


Computing transition probabilities:   0%|          | 0/12 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1383.09it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1419.91it/s]


Computing transition probabilities:   0%|          | 0/12 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1077.54it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1057.21it/s]


Computing transition probabilities:   0%|          | 0/8 [00:00<?, ?it/s]

Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1912.43it/s]
Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1198.73it/s]


Computing transition probabilities:   0%|          | 0/8 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 2033.15it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 2024.98it/s]


Computing transition probabilities:   0%|          | 0/8 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1920.64it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1878.83it/s]


Computing transition probabilities:   0%|          | 0/8 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 2077.09it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1995.82it/s]


Computing transition probabilities:   0%|          | 0/16 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 691.66it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 689.16it/s]


Computing transition probabilities:   0%|          | 0/16 [00:00<?, ?it/s]

Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 995.01it/s]
Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 945.30it/s]


Computing transition probabilities:   0%|          | 0/8 [00:00<?, ?it/s]

Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 2079.72it/s]
Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1891.15it/s]


Computing transition probabilities:   0%|          | 0/8 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 1995.56it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1889.79it/s]


Computing transition probabilities:   0%|          | 0/8 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 2094.66it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 1855.13it/s]


Computing transition probabilities:   0%|          | 0/16 [00:00<?, ?it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 867.35it/s]
Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 857.32it/s]


The learning curve of a run is presented below.



In [None]:
plt.title('cross entropy averaged over minibatches')
plt.plot(epoch_losses)
plt.show()

The trained model is evaluated on the test set created. To deploy
the tutorial, restrict the running time to get a higher
accuracy ($80$ % ~ $90$ %) than the ones printed below.



In [None]:
model.eval()
# Convert a list of tuples to two lists
'''
[PMLDL]

Slightly modify this code to evaluate results, as batch should contain three elements

[/PMLDL]
'''
test_X_graph, test_X_features, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X_graph)
test_feat = torch.tensor(np.vstack(test_X_features))
test_Y = torch.tensor(test_Y).float().view(-1, 1)

probs_Y = torch.softmax(model(test_bg, test_feat), 1)
sampled_Y = torch.multinomial(probs_Y, 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(
    (test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))



What's next?
------------
Graph classification with graph neural networks is still a new field.
It's waiting for people to bring more exciting discoveries. The work requires 
mapping different graphs to different embeddings, while preserving
their structural similarity in the embedding space. To learn more about it, see 
`How Powerful Are Graph Neural Networks? <https://arxiv.org/abs/1810.00826>`_ a research paper  
published for the International Conference on Learning Representations 2019.

For more examples about batched graph processing, see the following:

* Tutorials for `Tree LSTM <https://docs.dgl.ai/tutorials/models/2_small_graph/3_tree-lstm.html>`_ and `Deep Generative Models of Graphs <https://docs.dgl.ai/tutorials/models/3_generative_model/5_dgmg.html>`_
* An example implementation of `Junction Tree VAE <https://github.com/dmlc/dgl/tree/master/examples/pytorch/jtnn>`_

