
# Graph Classification with DGL

Here we demonstrate how to use DGL to finish graph classification tasks. The dataset we use here is Tox21, a public database measuring toxicity of compounds.  The dataset contains qualitative toxicity measurements for 8014 compounds on 12 different targets, including nuclear receptors and stress response pathways. 

In [124]:
import dgl
from dgl.data import TUDataset
from dgl.data.utils import split_dataset

In [125]:
import numpy as np

In [126]:
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
from dgl.nn.pytorch import conv

## Load Dataset

<img src="./asset/enzymes.png" width="500"/>

In [121]:
device = "cpu"

In [122]:
dataset = TUDataset("ENZYMES")

In [123]:
dataset.graph_labels=torch.tensor(dataset.graph_labels)

In [93]:
for i in range(len(dataset)):
    dataset[i][0].ndata['feat']= torch.tensor(dataset[i][0].ndata['feat']).float()

### Split dataset into train and val

In [94]:
trainset, valset = split_dataset(dataset, [0.8, 0.2], shuffle=True, random_state=42)

In [95]:
graph, label= dataset[0]
print(graph)
print(label)

DGLGraph(num_nodes=37, num_edges=168,
         ndata_schemes={'feat': Scheme(shape=(18,), dtype=torch.float32)}
         edata_schemes={})
tensor(5)


## Prepare Dataloader

DGL could batch multiple small graphs together to accelerate the computation. Detail of batching can be found [here](https://docs.dgl.ai/tutorials/basics/4_batch.html).

<img src="https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/batch/batch.png" width="500"/>

In [97]:
def collate_molgraphs_for_classification(data):
    """Batching a list of datapoints for dataloader in classification tasks."""
    graphs, labels = map(list, zip(*data))
    bg = dgl.batch(graphs)
    labels = torch.stack(labels, dim=0)
    return bg, labels

train_loader = DataLoader(trainset, batch_size=512,
                          collate_fn=collate_molgraphs_for_classification)
val_loader = DataLoader(valset, batch_size=512,
                        collate_fn=collate_molgraphs_for_classification)

## Prepare Model and Optimizer

Here we use a two layer Graph Convolutional Network to classify the graphs. Detailed source code can be found [here](https://github.com/dmlc/dgl/blob/master/python/dgl/model_zoo/chem/classifiers.py#L111).

In [136]:
class GCNModel(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 out_feats):
        super().__init__()
        self.layers = nn.ModuleList([
            conv.GraphConv(in_feats, n_hidden, activation=F.relu),
            conv.GraphConv(n_hidden, n_hidden, activation=F.relu),
            conv.GraphConv(n_hidden, n_hidden, activation=F.relu)
        ])
        
        self.classifier = nn.Linear(n_hidden, out_feats)

    def forward(self, g, features):
        h = features
        for layer in self.layers:
            h = layer(g, h)
        with g.local_scope():
            g.ndata['feat'] = h
            h_g = dgl.sum_nodes(g, 'feat')
        return self.classifier(h_g)

In [137]:
model = GCNModel(in_feats=18, n_hidden=64, out_feats=6).to(device)
loss_criterion = torch.nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())
print(model)

GCNModel(
  (layers): ModuleList(
    (0): GraphConv(in=18, out=128, normalization=True, activation=<function relu at 0x116a00158>)
    (1): GraphConv(in=128, out=64, normalization=True, activation=<function relu at 0x116a00158>)
    (2): GraphConv(in=64, out=32, normalization=True, activation=<function relu at 0x116a00158>)
  )
  (classifier): Linear(in_features=32, out_features=6, bias=True)
)


### Training

In [138]:
epochs = 500
model.train()
for i in range(epochs):
    loss_list = []
    true_samples = 0
    num_samples = 0
    for batch_id, batch_data in enumerate(train_loader):
        bg, labels = batch_data
        atom_feats = bg.ndata.pop('feat')
        atom_feats, labels = atom_feats.to(device), \
                                   labels.to(device)
        logits = model(bg, atom_feats)
        loss = loss_criterion(logits, labels)
        true_samples += (logits.argmax(1)==labels.long()).float().sum().item()
        num_samples += len(labels)
        loss_list.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("Epoch {:05d} | Loss: {:.4f} | Accuracy: {:.4f}".format(i, np.mean(loss_list), true_samples/num_samples))

Epoch 00000 | Loss: 78.3070 | Accuracy: 0.1542
Epoch 00001 | Loss: 52.2662 | Accuracy: 0.1583
Epoch 00002 | Loss: 35.6786 | Accuracy: 0.1875
Epoch 00003 | Loss: 26.5348 | Accuracy: 0.1792
Epoch 00004 | Loss: 17.2299 | Accuracy: 0.1938
Epoch 00005 | Loss: 9.4017 | Accuracy: 0.2271
Epoch 00006 | Loss: 4.8281 | Accuracy: 0.2479
Epoch 00007 | Loss: 5.1498 | Accuracy: 0.2188
Epoch 00008 | Loss: 4.4797 | Accuracy: 0.2604
Epoch 00009 | Loss: 4.5233 | Accuracy: 0.2250
Epoch 00010 | Loss: 4.2406 | Accuracy: 0.1938
Epoch 00011 | Loss: 3.9234 | Accuracy: 0.2271
Epoch 00012 | Loss: 3.5610 | Accuracy: 0.2604
Epoch 00013 | Loss: 3.4678 | Accuracy: 0.2292
Epoch 00014 | Loss: 3.2297 | Accuracy: 0.2229
Epoch 00015 | Loss: 2.9896 | Accuracy: 0.2417
Epoch 00016 | Loss: 2.8344 | Accuracy: 0.2479
Epoch 00017 | Loss: 2.6362 | Accuracy: 0.2021
Epoch 00018 | Loss: 2.6142 | Accuracy: 0.1875
Epoch 00019 | Loss: 2.5773 | Accuracy: 0.1854
Epoch 00020 | Loss: 2.5366 | Accuracy: 0.1917
Epoch 00021 | Loss: 2.3635 | 

Epoch 00178 | Loss: 1.4552 | Accuracy: 0.4458
Epoch 00179 | Loss: 1.4684 | Accuracy: 0.4146
Epoch 00180 | Loss: 1.4756 | Accuracy: 0.4437
Epoch 00181 | Loss: 1.4737 | Accuracy: 0.4021
Epoch 00182 | Loss: 1.4531 | Accuracy: 0.4500
Epoch 00183 | Loss: 1.4363 | Accuracy: 0.4625
Epoch 00184 | Loss: 1.4299 | Accuracy: 0.4500
Epoch 00185 | Loss: 1.4318 | Accuracy: 0.4583
Epoch 00186 | Loss: 1.4404 | Accuracy: 0.4458
Epoch 00187 | Loss: 1.4523 | Accuracy: 0.4542
Epoch 00188 | Loss: 1.4641 | Accuracy: 0.4104
Epoch 00189 | Loss: 1.4581 | Accuracy: 0.4417
Epoch 00190 | Loss: 1.4414 | Accuracy: 0.4375
Epoch 00191 | Loss: 1.4250 | Accuracy: 0.4521
Epoch 00192 | Loss: 1.4215 | Accuracy: 0.4500
Epoch 00193 | Loss: 1.4281 | Accuracy: 0.4646
Epoch 00194 | Loss: 1.4425 | Accuracy: 0.4562
Epoch 00195 | Loss: 1.4567 | Accuracy: 0.4167
Epoch 00196 | Loss: 1.4484 | Accuracy: 0.4458
Epoch 00197 | Loss: 1.4305 | Accuracy: 0.4313
Epoch 00198 | Loss: 1.4150 | Accuracy: 0.4604
Epoch 00199 | Loss: 1.4132 | Accur

Epoch 00357 | Loss: 1.2349 | Accuracy: 0.5208
Epoch 00358 | Loss: 1.2323 | Accuracy: 0.5250
Epoch 00359 | Loss: 1.2307 | Accuracy: 0.5250
Epoch 00360 | Loss: 1.2307 | Accuracy: 0.5188
Epoch 00361 | Loss: 1.2321 | Accuracy: 0.5396
Epoch 00362 | Loss: 1.2347 | Accuracy: 0.5229
Epoch 00363 | Loss: 1.2392 | Accuracy: 0.5417
Epoch 00364 | Loss: 1.2479 | Accuracy: 0.5083
Epoch 00365 | Loss: 1.2632 | Accuracy: 0.5333
Epoch 00366 | Loss: 1.2692 | Accuracy: 0.4917
Epoch 00367 | Loss: 1.2585 | Accuracy: 0.5250
Epoch 00368 | Loss: 1.2328 | Accuracy: 0.5062
Epoch 00369 | Loss: 1.2207 | Accuracy: 0.5292
Epoch 00370 | Loss: 1.2248 | Accuracy: 0.5208
Epoch 00371 | Loss: 1.2375 | Accuracy: 0.5167
Epoch 00372 | Loss: 1.2563 | Accuracy: 0.5354
Epoch 00373 | Loss: 1.2693 | Accuracy: 0.5021
Epoch 00374 | Loss: 1.2571 | Accuracy: 0.5167
Epoch 00375 | Loss: 1.2338 | Accuracy: 0.5208
Epoch 00376 | Loss: 1.2170 | Accuracy: 0.5292
Epoch 00377 | Loss: 1.2264 | Accuracy: 0.5250
Epoch 00378 | Loss: 1.2603 | Accur

### Inference

In [139]:
model.eval()
true_samples = 0
num_samples = 0
with torch.no_grad():
    for batch_id, batch_data in enumerate(val_loader):
        bg, labels = batch_data
        atom_feats = bg.ndata.pop('feat')
        atom_feats, labels = atom_feats.to(device), \
                                   labels.to(device)
        logits = model(bg, atom_feats)
        logits.argmax()
        num_samples += len(labels)
        true_samples += (logits.argmax(1)==labels.long()).float().sum().item()
print(true_samples/num_samples)

0.4583333333333333


We acheived a pretty high accuracy. However, this dataset's labels are imbalanced, which means most labels could be negative. Therefore it would be unfair to evaluate result with accuracy score. A more detailed analysis of this task could be found at our [model zoo](https://github.com/dmlc/dgl/tree/master/examples/pytorch/model_zoo/chem/property_prediction).