
# Graph Classification with DGL

Here we demonstrate how to use DGL to finish graph classification tasks. The dataset we use here is Tox21, a public database measuring toxicity of compounds.  The dataset contains qualitative toxicity measurements for 8014 compounds on 12 different targets, including nuclear receptors and stress response pathways. 

In [None]:
import dgl
from dgl.data import TUDataset
from dgl.data.utils import split_dataset
from dgl.nn.pytorch import conv

In [None]:
import numpy as np
import networkx as nx

import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
import torch.nn.functional as F

## Load Dataset

<img src="./asset/enzymes.png" width="500"/>

Here we use an enzymes dataset. It constructs graphs from the enzymes based on group functions. Nodes means functional group and edges means the connection between them. Each graph has a label from 0-5, which means the type of the enzymes.

In [None]:
dataset = TUDataset("ENZYMES")

dataset.graph_labels=torch.tensor(dataset.graph_labels)
for i in range(len(dataset)):
    dataset[i][0].ndata['node_attr']=(dataset[i][0].ndata['node_attr']).float()

In [None]:
graph, label= dataset[0]
print(graph)
print(label)

In [None]:
nx.draw_spring(graph.to_networkx())

### Split dataset into train and val

In [None]:
trainset, valset = split_dataset(dataset, [0.8, 0.2], shuffle=True, random_state=42)

## Prepare Dataloader

DGL could batch multiple small graphs together to accelerate the computation. Detail of batching can be found [here](https://docs.dgl.ai/tutorials/basics/4_batch.html).

<img src="https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/batch/batch.png" width="500"/>

In [None]:
def collate_molgraphs_for_classification(data):
    """Batching a list of datapoints for dataloader in classification tasks."""
    graphs, labels = map(list, zip(*data))
    bg = dgl.batch(graphs)
    labels = torch.stack(labels, dim=0)
    return bg, labels

train_loader = DataLoader(trainset, batch_size=512,
                          collate_fn=collate_molgraphs_for_classification)
val_loader = DataLoader(valset, batch_size=512,
                        collate_fn=collate_molgraphs_for_classification)

## Prepare Model and Optimizer

Here we use a two layer Graph Convolutional Network to classify the graphs. Detailed source code can be found [here](https://github.com/dmlc/dgl/blob/master/python/dgl/model_zoo/chem/classifiers.py#L111).

In [None]:
class GCNModel(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 out_feats):
        super().__init__()
        self.layers = nn.ModuleList([
            conv.GraphConv(in_feats, n_hidden, activation=F.relu),
            conv.GraphConv(n_hidden, n_hidden, activation=F.relu),
            conv.GraphConv(n_hidden, n_hidden, activation=F.relu)
        ])
        
        self.classifier = nn.Linear(n_hidden, out_feats)

    def forward(self, g, features):
        h = features
        for layer in self.layers:
            h = layer(g, h)
        with g.local_scope():
            g.ndata['feat'] = h
            h_g = dgl.sum_nodes(g, 'feat')
        return self.classifier(h_g)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GCNModel(in_feats=18, n_hidden=64, out_feats=6).to(device)
loss_criterion = torch.nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())
print(device)
print(model)

## Training

In [None]:
epochs = 500 if torch.cuda.is_available() else 50
model.train()
for i in range(epochs):
    loss_list = []
    true_samples = 0
    num_samples = 0
    for batch_id, batch_data in enumerate(train_loader):
        bg, labels = batch_data
        atom_feats = bg.ndata.pop('node_attr').float()
        atom_feats, labels = atom_feats.to(device), \
                                   labels.to(device).squeeze(-1)
        logits = model(bg, atom_feats)
        loss = loss_criterion(logits, labels)
        true_samples += (logits.argmax(1)==labels.long()).float().sum().item()
        num_samples += len(labels)
        loss_list.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("Epoch {:05d} | Loss: {:.4f} | Accuracy: {:.4f}".format(i, np.mean(loss_list), true_samples/num_samples))

## Validation

In [None]:
model.eval()
true_samples = 0
num_samples = 0
with torch.no_grad():
    for batch_id, batch_data in enumerate(val_loader):
        bg, labels = batch_data
        atom_feats = bg.ndata.pop('node_attr')
        atom_feats, labels = atom_feats.to(device), \
                                   labels.to(device).squeeze(-1)
        logits = model(bg, atom_feats)
        logits.argmax()
        num_samples += len(labels)
        true_samples += (logits.argmax(1)==labels.long()).float().sum().item()
print("Validation Accuracy: {:.4f}".format(true_samples/num_samples))