# Graph Classification with DGL

Here we demonstrate how to use DGL to finish graph classification tasks. The dataset we use here is Tox21, a public database measuring toxicity of compounds.  The dataset contains qualitative toxicity measurements for 8014 compounds on 12 different targets, including nuclear receptors and stress response pathways. 

In [1]:
import dgl
from dgl.data import TUDataset
from dgl import model_zoo
from dgl.data.utils import split_dataset

In [2]:
import numpy as np
from rdkit import Chem
from rdkit.Chem import Draw,MolFromSmiles, MolToSmiles

In [3]:
import torch
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

## Load Dataset

This would take about one minute

In [4]:
dataset = TUDataset("ENZYMES")

In [5]:
dataset.graph_labels=torch.tensor(dataset.graph_labels)

In [6]:
for i in range(len(dataset)):
    dataset[i][0].ndata['feat']= torch.tensor(dataset[i][0].ndata['feat']).float()

### Split dataset into train and val

In [7]:
trainset, valset = split_dataset(dataset, [0.8, 0.2], shuffle=True, random_state=42)

In [8]:
graph, label= dataset[0]
print(graph)
print(label)

DGLGraph(num_nodes=37, num_edges=168,
         ndata_schemes={'feat': Scheme(shape=(18,), dtype=torch.float32)}
         edata_schemes={})
tensor(5)


## Prepare Dataloader

DGL could batch multiple small graphs together to accelerate the computation. Detail of batching can be found [here](https://docs.dgl.ai/tutorials/basics/4_batch.html).

<img src="https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/batch/batch.png" width="500"/>

In [9]:
def collate_molgraphs_for_classification(data):
    """Batching a list of datapoints for dataloader in classification tasks."""
    graphs, labels = map(list, zip(*data))
    bg = dgl.batch(graphs)
    labels = torch.stack(labels, dim=0)
    return bg, labels

train_loader = DataLoader(trainset, batch_size=256,
                          collate_fn=collate_molgraphs_for_classification)
val_loader = DataLoader(valset, batch_size=256,
                        collate_fn=collate_molgraphs_for_classification)

## Prepare Model and Optimizer

Here we use a two layer Graph Convolutional Network to classify the graphs. Detailed source code can be found [here](https://github.com/dmlc/dgl/blob/master/python/dgl/model_zoo/chem/classifiers.py#L111).

In [10]:
model = model_zoo.chem.GCNClassifier(in_feats=18, gcn_hidden_feats=[64, 64], n_tasks=6).cuda()

loss_criterion = torch.nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())
print(model)

GCNClassifier(
  (gnn_layers): ModuleList(
    (0): GCNLayer(
      (graph_conv): GraphConv(in=18, out=64, normalization=False, activation=<function relu at 0x7f7093160ae8>)
      (dropout): Dropout(p=0.0)
      (res_connection): Linear(in_features=18, out_features=64, bias=True)
      (bn_layer): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): GCNLayer(
      (graph_conv): GraphConv(in=64, out=64, normalization=False, activation=<function relu at 0x7f7093160ae8>)
      (dropout): Dropout(p=0.0)
      (res_connection): Linear(in_features=64, out_features=64, bias=True)
      (bn_layer): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (weighted_sum_readout): WeightAndSum(
    (atom_weighting): Sequential(
      (0): Linear(in_features=64, out_features=1, bias=True)
      (1): Sigmoid()
    )
  )
  (soft_classifier): MLPBinaryClassifier(
    (predict): Sequential(
      (0): Dropout(p=0.0)
      (1

### Training

In [11]:
epochs = 50
model.train()
for i in range(epochs):
    loss_list = []
    true_samples = 0
    num_samples = 0
    for batch_id, batch_data in enumerate(train_loader):
        bg, labels = batch_data
        atom_feats = bg.ndata.pop('feat')
        atom_feats, labels = atom_feats.to('cuda'), \
                                   labels.to('cuda')
        logits = model(bg, atom_feats)
        loss = loss_criterion(logits, labels)
        true_samples += (logits.argmax(1)==labels.long()).float().sum().item()
        num_samples += len(labels)
        loss_list.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("Epoch {:05d} | Loss: {:.4f} | Accuracy: {:.4f}".format(i, np.mean(loss_list), true_samples/num_samples))

Epoch 00000 | Loss: 1.8064 | Accuracy: 0.1750
Epoch 00001 | Loss: 1.7262 | Accuracy: 0.2562
Epoch 00002 | Loss: 1.6770 | Accuracy: 0.3271
Epoch 00003 | Loss: 1.6389 | Accuracy: 0.3688
Epoch 00004 | Loss: 1.6069 | Accuracy: 0.3937
Epoch 00005 | Loss: 1.5790 | Accuracy: 0.4042
Epoch 00006 | Loss: 1.5534 | Accuracy: 0.4167
Epoch 00007 | Loss: 1.5281 | Accuracy: 0.4104
Epoch 00008 | Loss: 1.5024 | Accuracy: 0.4292
Epoch 00009 | Loss: 1.4768 | Accuracy: 0.4500
Epoch 00010 | Loss: 1.4512 | Accuracy: 0.4646
Epoch 00011 | Loss: 1.4260 | Accuracy: 0.4688
Epoch 00012 | Loss: 1.3999 | Accuracy: 0.4875
Epoch 00013 | Loss: 1.3728 | Accuracy: 0.5042
Epoch 00014 | Loss: 1.3454 | Accuracy: 0.5188
Epoch 00015 | Loss: 1.3174 | Accuracy: 0.5354
Epoch 00016 | Loss: 1.2892 | Accuracy: 0.5500
Epoch 00017 | Loss: 1.2607 | Accuracy: 0.5708
Epoch 00018 | Loss: 1.2320 | Accuracy: 0.5813
Epoch 00019 | Loss: 1.2031 | Accuracy: 0.5979
Epoch 00020 | Loss: 1.1734 | Accuracy: 0.6250
Epoch 00021 | Loss: 1.1432 | Accur

### Inference

In [12]:

model.eval()
true_samples = 0
num_samples = 0
with torch.no_grad():
    for batch_id, batch_data in enumerate(val_loader):
        bg, labels = batch_data
        atom_feats = bg.ndata.pop('feat')
        atom_feats, labels = atom_feats.to('cuda'), \
                                   labels.to('cuda')
        logits = model(bg, atom_feats)
        logits.argmax()
        num_samples += len(labels)
        true_samples += (logits.argmax(1)==labels.long()).float().sum().item()
print(true_samples/num_samples)

0.6333333333333333


We acheived a pretty high accuracy. However, this dataset's labels are imbalanced, which means most labels could be negative. Therefore it would be unfair to evaluate result with accuracy score. A more detailed analysis of this task could be found at our [model zoo](https://github.com/dmlc/dgl/tree/master/examples/pytorch/model_zoo/chem/property_prediction).