
# Graph Classification with DGL

Here we demonstrate how to use DGL to finish graph classification tasks. The dataset we use here is Tox21, a public database measuring toxicity of compounds.  The dataset contains qualitative toxicity measurements for 8014 compounds on 12 different targets, including nuclear receptors and stress response pathways. 

In [51]:
import dgl
from dgl.data import TUDataset
from dgl.data.utils import split_dataset

In [52]:
import numpy as np

In [53]:
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
from dgl.nn.pytorch import conv

In [54]:
dgl.__version__

'0.4'

## Load Dataset

<img src="./asset/enzymes.png" width="500"/>

In [55]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [56]:
dataset = TUDataset("ENZYMES")

In [57]:
dataset[0][0]

DGLGraph(num_nodes=37, num_edges=168,
         ndata_schemes={'node_labels': Scheme(shape=(1,), dtype=torch.int64), 'node_attr': Scheme(shape=(18,), dtype=torch.int64)}
         edata_schemes={})

In [58]:
dataset.graph_labels=torch.tensor(dataset.graph_labels)

In [59]:
for i in range(len(dataset)):
    dataset[i][0].ndata['node_attr']= (dataset[i][0].ndata['node_attr']).float()

### Split dataset into train and val

In [60]:
trainset, valset = split_dataset(dataset, [0.8, 0.2], shuffle=True, random_state=42)

In [61]:
graph, label= dataset[0]
print(graph)
print(label)

DGLGraph(num_nodes=37, num_edges=168,
         ndata_schemes={'node_labels': Scheme(shape=(1,), dtype=torch.int64), 'node_attr': Scheme(shape=(18,), dtype=torch.float32)}
         edata_schemes={})
tensor([5])


## Prepare Dataloader

DGL could batch multiple small graphs together to accelerate the computation. Detail of batching can be found [here](https://docs.dgl.ai/tutorials/basics/4_batch.html).

<img src="https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/batch/batch.png" width="500"/>

In [62]:
def collate_molgraphs_for_classification(data):
    """Batching a list of datapoints for dataloader in classification tasks."""
    graphs, labels = map(list, zip(*data))
    bg = dgl.batch(graphs)
    labels = torch.stack(labels, dim=0)
    return bg, labels

train_loader = DataLoader(trainset, batch_size=512,
                          collate_fn=collate_molgraphs_for_classification)
val_loader = DataLoader(valset, batch_size=512,
                        collate_fn=collate_molgraphs_for_classification)

## Prepare Model and Optimizer

Here we use a two layer Graph Convolutional Network to classify the graphs. Detailed source code can be found [here](https://github.com/dmlc/dgl/blob/master/python/dgl/model_zoo/chem/classifiers.py#L111).

In [63]:
class GCNModel(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 out_feats):
        super().__init__()
        self.layers = nn.ModuleList([
            conv.GraphConv(in_feats, n_hidden, activation=F.relu),
            conv.GraphConv(n_hidden, n_hidden, activation=F.relu),
            conv.GraphConv(n_hidden, n_hidden, activation=F.relu)
        ])
        
        self.classifier = nn.Linear(n_hidden, out_feats)

    def forward(self, g, features):
        h = features
        for layer in self.layers:
            h = layer(g, h)
        with g.local_scope():
            g.ndata['feat'] = h
            h_g = dgl.sum_nodes(g, 'feat')
        return self.classifier(h_g)

In [64]:
model = GCNModel(in_feats=18, n_hidden=64, out_feats=6).to(device)
loss_criterion = torch.nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())
print(model)

GCNModel(
  (layers): ModuleList(
    (0): GraphConv(in=18, out=64, normalization=True, activation=<function relu at 0x7f59a57ac0d0>)
    (1): GraphConv(in=64, out=64, normalization=True, activation=<function relu at 0x7f59a57ac0d0>)
    (2): GraphConv(in=64, out=64, normalization=True, activation=<function relu at 0x7f59a57ac0d0>)
  )
  (classifier): Linear(in_features=64, out_features=6, bias=True)
)


### Training

In [65]:
logits.shape

torch.Size([120, 6])

In [66]:
labels.shape

torch.Size([120])

In [67]:
epochs = 500
model.train()
for i in range(epochs):
    loss_list = []
    true_samples = 0
    num_samples = 0
    for batch_id, batch_data in enumerate(train_loader):
        bg, labels = batch_data
        atom_feats = bg.ndata.pop('node_attr').float()
        atom_feats, labels = atom_feats.to(device), \
                                   labels.to(device).squeeze(-1)
        logits = model(bg, atom_feats)
        loss = loss_criterion(logits, labels)
        true_samples += (logits.argmax(1)==labels.long()).float().sum().item()
        num_samples += len(labels)
        loss_list.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("Epoch {:05d} | Loss: {:.4f} | Accuracy: {:.4f}".format(i, np.mean(loss_list), true_samples/num_samples))

Epoch 00000 | Loss: 128.8334 | Accuracy: 0.1729
Epoch 00001 | Loss: 97.0797 | Accuracy: 0.1729
Epoch 00002 | Loss: 67.1622 | Accuracy: 0.1729
Epoch 00003 | Loss: 53.3588 | Accuracy: 0.1750
Epoch 00004 | Loss: 43.3624 | Accuracy: 0.1750
Epoch 00005 | Loss: 29.1658 | Accuracy: 0.2062
Epoch 00006 | Loss: 14.4205 | Accuracy: 0.1562
Epoch 00007 | Loss: 13.5954 | Accuracy: 0.2104
Epoch 00008 | Loss: 14.7965 | Accuracy: 0.1958
Epoch 00009 | Loss: 18.5140 | Accuracy: 0.2250
Epoch 00010 | Loss: 20.7691 | Accuracy: 0.1688
Epoch 00011 | Loss: 21.9511 | Accuracy: 0.1625
Epoch 00012 | Loss: 21.5492 | Accuracy: 0.1771
Epoch 00013 | Loss: 19.2594 | Accuracy: 0.1708
Epoch 00014 | Loss: 17.2495 | Accuracy: 0.2167
Epoch 00015 | Loss: 14.5592 | Accuracy: 0.2125
Epoch 00016 | Loss: 14.0101 | Accuracy: 0.1938
Epoch 00017 | Loss: 14.0666 | Accuracy: 0.1917
Epoch 00018 | Loss: 12.1471 | Accuracy: 0.1875
Epoch 00019 | Loss: 13.1340 | Accuracy: 0.1979
Epoch 00020 | Loss: 13.1660 | Accuracy: 0.1792
Epoch 00021 

Epoch 00180 | Loss: 1.5261 | Accuracy: 0.3688
Epoch 00181 | Loss: 1.5203 | Accuracy: 0.3812
Epoch 00182 | Loss: 1.5212 | Accuracy: 0.4000
Epoch 00183 | Loss: 1.5202 | Accuracy: 0.3729
Epoch 00184 | Loss: 1.5155 | Accuracy: 0.3604
Epoch 00185 | Loss: 1.5127 | Accuracy: 0.3896
Epoch 00186 | Loss: 1.5100 | Accuracy: 0.3688
Epoch 00187 | Loss: 1.5126 | Accuracy: 0.3771
Epoch 00188 | Loss: 1.5066 | Accuracy: 0.3896
Epoch 00189 | Loss: 1.5051 | Accuracy: 0.3812
Epoch 00190 | Loss: 1.5020 | Accuracy: 0.4000
Epoch 00191 | Loss: 1.4996 | Accuracy: 0.3896
Epoch 00192 | Loss: 1.5005 | Accuracy: 0.3812
Epoch 00193 | Loss: 1.4912 | Accuracy: 0.4062
Epoch 00194 | Loss: 1.4918 | Accuracy: 0.4042
Epoch 00195 | Loss: 1.4924 | Accuracy: 0.3979
Epoch 00196 | Loss: 1.4865 | Accuracy: 0.4167
Epoch 00197 | Loss: 1.4818 | Accuracy: 0.4229
Epoch 00198 | Loss: 1.4849 | Accuracy: 0.4125
Epoch 00199 | Loss: 1.4817 | Accuracy: 0.4188
Epoch 00200 | Loss: 1.4752 | Accuracy: 0.4146
Epoch 00201 | Loss: 1.4743 | Accur

Epoch 00360 | Loss: 1.3458 | Accuracy: 0.5042
Epoch 00361 | Loss: 1.3451 | Accuracy: 0.5083
Epoch 00362 | Loss: 1.3443 | Accuracy: 0.5062
Epoch 00363 | Loss: 1.3436 | Accuracy: 0.5083
Epoch 00364 | Loss: 1.3429 | Accuracy: 0.5042
Epoch 00365 | Loss: 1.3422 | Accuracy: 0.5021
Epoch 00366 | Loss: 1.3414 | Accuracy: 0.5021
Epoch 00367 | Loss: 1.3407 | Accuracy: 0.5021
Epoch 00368 | Loss: 1.3400 | Accuracy: 0.5062
Epoch 00369 | Loss: 1.3393 | Accuracy: 0.5042
Epoch 00370 | Loss: 1.3385 | Accuracy: 0.5083
Epoch 00371 | Loss: 1.3378 | Accuracy: 0.5104
Epoch 00372 | Loss: 1.3371 | Accuracy: 0.5104
Epoch 00373 | Loss: 1.3363 | Accuracy: 0.5083
Epoch 00374 | Loss: 1.3356 | Accuracy: 0.5083
Epoch 00375 | Loss: 1.3349 | Accuracy: 0.5083
Epoch 00376 | Loss: 1.3341 | Accuracy: 0.5083
Epoch 00377 | Loss: 1.3334 | Accuracy: 0.5083
Epoch 00378 | Loss: 1.3326 | Accuracy: 0.5104
Epoch 00379 | Loss: 1.3319 | Accuracy: 0.5125
Epoch 00380 | Loss: 1.3312 | Accuracy: 0.5125
Epoch 00381 | Loss: 1.3304 | Accur

### Inference

In [68]:
model.eval()
true_samples = 0
num_samples = 0
with torch.no_grad():
    for batch_id, batch_data in enumerate(val_loader):
        bg, labels = batch_data
        atom_feats = bg.ndata.pop('node_attr')
        atom_feats, labels = atom_feats.to(device), \
                                   labels.to(device).squeeze(-1)
        logits = model(bg, atom_feats)
        logits.argmax()
        num_samples += len(labels)
        true_samples += (logits.argmax(1)==labels.long()).float().sum().item()
print(true_samples/num_samples)

0.48333333333333334
