# Deep Learning on Graphs with Message Passing Neural Networks

At this point in the Straight Dope, we've seen a wide variety of different types of data fed as input to our models.

We started with linear regression models and MLPs, which take simple, 1-dimensional vectors of real numbers as input.  Then we met CNNs, which take images represented as 3-dimensional tensors as input.  Next we saw how RNNs can take sequence data, like time-series or natural-language sentences, or really anything we can represent as a sequence of tensors, as input.  And we even saw how to consume tree-structured data, like a parse tree of a natural-language sentence, using a Tree LSTM.

In this chapter we'll see how to build models to handle yet another type of data: graph-structured data. We'll learn how to build Message Passing Neural Networks (MPNNs), which are a class of deep model that can take arbitrary graphs as input.

**Wait, "graphs"?**

When I say "graph", I mean that word the way a mathematician means it.  [Wikipedia explains the concept well][1], if you're not familiar.  Going forward I'll assume we're familiar with graph-lingo like "directed edge" and "adjacency matrix", so take a gander at that link if you need to.

**So what exactly does "taking graphs as input" mean?**

Good question!  Reading papers or blogs about this topic can be confusing, since (at least) two distinct learning scenarios both go by the name "learning on graphs":
1. *We're trying to learn a model whose inputs are arbitrary graphs.*  Our dataset consists of (graph, label) pairs.  E.g. predicting the pharmacological activity of a molecule based on how its atoms are connected.
    
2. *We're trying to learn a model whose inputs are vertices in some graph.*  Our dataset is one big graph whose vertices are datapoints with edges between them, some labeled, some unlabeled.  E.g. predicting the impact factor of an article given a bag-of-words representation of the article and edges connecting it to its references.

In this chapter, we're focusing only on scenario 1, but MPNNs can be used for scenario 2 as well.

**Aren't sequences and trees just special types of graphs?  We already know models that handle those. (RNNs and Tree-RNNs.)**

Yes they are!  In fact you can (and people do) even think of images as graphs where each pixel is a vertex with edges to all its adjacent pixels.  But MPNNs can operate on *any* type of graph: directed or not, cyclic or not, etc.  Be careful though: MPNNs likely won't perform as well on sequences, trees, or images as models designed specifically for these data types will.

**But can't you basically represent anything as a graph if you try hard enough?**

Yeah, that's partly why graphs are ubiquitous in math and computer science.  They're a super general concept.  

This generality should make us veeeery suspicious that deep learning on graphs won't work as consistently well as, say, deep learning on real-world images does.  If it did, we could use deep networks to reason about nearly anything, and that would smell like a free lunch.

But MPNNs are still worth learning about.  They're the best tool we have at the moment for understanding graph-structured data, and they're a hot area of research.

[1]:https://en.wikipedia.org/wiki/Graph_(discrete_mathematics)

## Message Passing Neural Networks

Message Passing Neural Networks were introduced in [this paper](https://arxiv.org/pdf/1704.01212.pdf).  MPNNs are actually a family of models rather than a specific implementation, like how RNNs are a general model family, one implementation of which is an LSTM.  We'll first go over the general MPNN idea and then build a specific implementation.

### The Setup

We've got a dataset of `(graph, label)` pairs.  In each graph, each vertex $v$ has associated features $x_v$, and each edge has features $e_{vw}$.  For simplicity of explanation we'll assume each graph is undirected, but once you understand MPNNs it's easy to see how to extend them to directed graphs or multigraphs.

### The Model

The goal of an MPNN is to take in a `graph` and output the correct `label`.  They do this by the following procedure:
1. Initialize a "hidden state" $h_v^0$ for each vertex $v$ in the graph as a function of the vertex's features: $$h_v^0 = \text{init_hidden}(x_v).$$
2. For each round $t$ out of $T$ total rounds:
    3. Each vertex $v$ receives a "message" $m_v^{t+1}$, which is the sum of messages passed by $v$'s neighbors as functions of their current hidden states and the edge features: $$m_v^{t+1} = \sum_{w \in \text{neighbors of }v} M_t(h_v^t, h_w^t, e_{vw}).$$
    4. Each vertex $v$ updates its hidden state as a function of the message it received: $$h_v^{t+1} = U_t(h_v^t, m_v^{t+1}).$$
5. The output is computed as the "readout" function of all the hidden states: $$\hat{y} = R_t(\{h_v^T \vert v \text{ is in the graph} \}).$$

Here's an base class for any type of MPNN that encapsulates this procedure:

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import sklearn.metrics as metrics
import numpy as np
import scipy as sp
import math
np.random.seed(1)
torch.manual_seed(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class MPNN(nn.Module):
    '''
    General base class for all varieties of Message Passing Neural Network.
    '''
    def __init__(self, n_msg_pass_iters, *args, **kwargs):
        super(MPNN, self).__init__()
        self.n_msg_pass_iters = n_msg_pass_iters
    
    def init_hidden_states_and_edges(self, graph):
        # Performs "init_hidden" from above and prepares adjacency information from the graph
        # (This function is here so the model can be flexible about what format the graph is given to us in.)
        raise NotImplementedError()
    
    def compute_messages(self, hidden_states, edges, t):
        # Computes M_t from above and sums the messages
        raise NotImplementedError()
    
    def update_hidden_states(self, hidden_states, messages, t):
        # Performs U_t from above
        raise NotImplementedError()
    
    def readout(self, hidden_states):
        # Performs R_t from above
        raise NotImplementedError()
        
    def forward(self, graph):
        hidden_states, edges = self.init_hidden_states_and_edges(graph)
        for t in range(self.n_msg_pass_iters):
            messages = self.compute_messages(hidden_states, edges, t)
            hidden_states = self.update_hidden_states(hidden_states, messages, t)
            
        return self.readout(hidden_states)

Using device: cuda


Different flavors of MPNN use different functions for $\text{init_hidden}$, $M_t$, $U_t$, and $R_t$, and more often than not these functions are simpler than the fully general versions described above.  For example, in the GGSNN version of MPNN we'll discuss below, $M_t$ is the same function for each $t$, and it doesn't depend on the neighboring vertex's hidden state or any edge features.

## Gated Graph Sequence Neural Networks

Now that we've got the MPNN framework down, let's grab some real data and implement a particular type of MPNN, called a Gated Graph Sequence Neural Network (GGSNN), to learn on it.

### An actual dataset

As a demonstration task, we'll use the [Tox21 dataset][1].  The objective of this dataset is to take in the [chemical structure of a molecule][2], represented as an undirected graph with atoms as vertices and bonds as edges, and predict the toxicity of the molecule.  In particular, we'll try to predict whether a molecule might [activate a particular cellular response to pollutants in your body][3].

We'll create our own implementation to load and process the Tox21 dataset using RDKit for molecular feature extraction, eliminating the need for DeepChem.

[1]:https://tripod.nih.gov/tox21/challenge/
[2]:https://en.wikipedia.org/wiki/Structural_formula
[3]:https://pubchem.ncbi.nlm.nih.gov/bioassay/743122#section=Top

Now we'll load the data and convert it to graph format using RDKit for molecular processing.  If you're not fluent in chemistry, don't worry about the details of the following preprocessing.  We're just transforming the data from a molecular format into the format we're used to seeing from above.

What we'll end up with is a dataset of `(graph, label)` tuples where each `label` is a binary label (toxic or not), and each `graph` is an undirected graph represented as a vector of features for each vertex and an adjacency matrix.

Our implementation will:
1. Download the Tox21 dataset from a public source
2. Convert SMILES strings to molecular graphs using RDKit
3. Extract atom features and adjacency matrices
4. Create train/validation/test splits

In [None]:
# Install RDKit for molecular processing
!pip install rdkit

Collecting deepchem
  Downloading deepchem-2.5.0-py3-none-any.whl.metadata (1.1 kB)
  Downloading deepchem-2.5.0-py3-none-any.whl.metadata (1.1 kB)
Downloading deepchem-2.5.0-py3-none-any.whl (552 kB)
[?25l   [38;5;237m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/552.4 kB[0m [31m?[0m eta [36m-:--:--[0mDownloading deepchem-2.5.0-py3-none-any.whl (552 kB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m552.4/552.4 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m552.4/552.4 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: deepchem
Installing collected packages: deepchem
Successfully installed deepchem-2.5.0
Successfully installed deepchem-2.5.0
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [None]:
from rdkit import Chem
from rdkit.Chem import Descriptors, rdMolDescriptors
import pandas as pd
from sklearn.model_selection import train_test_split
import os
import urllib.request

def tox21_data():
    """Tox21 dataset from public source"""
    filename = "xAI_Drug/tox21.csv"

    return pd.read_csv(filename)

def get_atom_features(atom):
    """Extract atom features for graph neural network"""
    features = [
        atom.GetAtomicNum(),
        atom.GetDegree(),
        atom.GetTotalDegree(),
        atom.GetFormalCharge(),
        int(atom.GetHybridization()),
        int(atom.GetIsAromatic()),
        int(atom.IsInRing()),
        int(atom.IsInRingSize(3)),
        int(atom.IsInRingSize(4)),
        int(atom.IsInRingSize(5)),
        int(atom.IsInRingSize(6)),
        int(atom.IsInRingSize(7)),
        int(atom.IsInRingSize(8)),
    ]
    
    # One-hot encode atomic number for common elements
    atomic_nums = [6, 7, 8, 9, 15, 16, 17, 35, 53]  # C, N, O, F, P, S, Cl, Br, I
    for atomic_num in atomic_nums:
        features.append(int(atom.GetAtomicNum() == atomic_num))
    
    # Pad to make feature vector length consistent
    while len(features) < 75:
        features.append(0)
    
    return features[:75]  # Ensure exactly 75 features

def smiles_to_graph(smiles):
    """Convert SMILES string to graph representation"""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    
    # Get atom features
    atom_features = []
    for atom in mol.GetAtoms():
        atom_features.append(get_atom_features(atom))
    
    # Get adjacency matrix
    adj_mat = Chem.GetAdjacencyMatrix(mol)
    
    # Convert to sparse matrix
    adj_sparse = sp.sparse.csr_matrix(adj_mat.astype(np.float32))
    
    return {
        'vertex_features': np.array(atom_features, dtype=np.float32),
        'adj_mat': adj_sparse
    }

def create_tox21_dataset():
    """Tox21 dataset from SMILES strings"""
    # Download the data
    df = tox21_data()
    
    # Focus on NR-AhR assay (nuclear receptor signaling bioassays)
    target_column = 'NR-AhR'
    
    # Check if the column exists, if not use a different one
    toxicity_columns = [col for col in df.columns if col.startswith('NR-') or col.startswith('SR-')]
    if target_column not in df.columns and toxicity_columns:
        target_column = toxicity_columns[0]
        print(f"Using {target_column} as target column")
    
    
    # Filter out rows with missing SMILES or target values
    df = df.dropna(subset=['smiles', target_column])
    
    # Convert SMILES to graph representations
    dataset = []
    for idx, row in df.iterrows():
        smiles = row['smiles']
        label = int(row[target_column]) if not pd.isna(row[target_column]) else 0
        
        graph = smiles_to_graph(smiles)
        if graph is not None:
            graph['label'] = label
            dataset.append(graph)
    
    print(f"Created dataset with {len(dataset)} molecules")
    return dataset

# Create the dataset
full_dataset = create_tox21_dataset()

# Split into train/validation/test sets
train_dataset, temp_dataset = train_test_split(full_dataset, test_size=0.3, random_state=42)
valid_dataset, test_dataset = train_test_split(temp_dataset, test_size=0.5, random_state=42)

print(f"Train set: {len(train_dataset)} molecules")
print(f"Validation set: {len(valid_dataset)} molecules")
print(f"Test set: {len(test_dataset)} molecules")

# Check class balance
train_labels = [mol['label'] for mol in train_dataset]
print(f"Training set class balance: {sum(train_labels)}/{len(train_labels)} positive samples")



Created dataset with 6549 molecules
Train set: 4584 molecules
Validation set: 982 molecules
Test set: 983 molecules
Training set class balance: 523/4584 positive samples


### The GGSNN model

Now we'll implement a Gated Graph Sequence Neural Network, introduced in [this paper][1], on this dataset.

A GGSNN is an MPNN with the following customizations:
1. Hidden states $h_v^0$ for each vertex are initialized with a single-layer MLP.
2. The messages passed to $v$ by its neighbors are a simple matrix multiplication of each neighbor's hidden state: $$m_v^{t+1} = \sum_{w \in \text{neighbors of }v} W_{\texttt{msg_fxn}}h_w^t.$$
3. Each vertex $v$ updates its hidden state to be the output of a [GRU cell](https://pytorch.org/docs/stable/generated/torch.nn.GRUCell.html) (a type of RNN cell) whose hidden state is the vertex's hidden state and whose input is the message the vertex received: $$h_v^{t+1} = \text{GRU}(m_v^{t+1}, h_v^t).$$
4. The "readout" function is this funny little beast: $$\hat{y} = \text{softmax}\left(f_{\text{out}}\left(\sum_{v} \sigma\left(f_1([h_v^T, h_v^0])\right) \odot f_2(h_v^T)\right)\right),$$ where the $f$s are MLPs, $\sigma$ is the sigmoid function, and $\odot$ is elementwise multiplication.  This acts like a sort of attention mechanism that depends on how much each vertex's hidden state changed during message passing.

Here's an implementation of GGSNN that fills out the details of the MPNN base class from above:

> *A key implementation note about what follows:* You'll notice below that the GGSNN is coded as though it takes in a single graph, rather than a minibatch of graphs as you might expect.  This is intentional.  We want to reserve the 0th/batch dimension of the tensors in our implementation to index over the vertices of the graph.  This makes the implementation more elegant, since PyTorch operations are built to handle inputs that vary in size along the 0th dimension, and the number of vertices in each graph is usually different.

> But of course, we DO want to process minibatches of data.  To do this, combine a minibatch of graphs into a single, large, disconnected graph, do all the message passing on this graph (no messages will get passed between minibatch elements, because their graphs are disconnected), and use the `batch_sizes` list to produce separate outputs for each graph in the minibatch in the `readout` step.

[1]:https://arxiv.org/pdf/1511.05493.pdf

In [8]:
class GGSNN(MPNN):
    '''
    GGSNN model for operating on the Tox21 dataset
    '''
    def __init__(self, vertex_feature_size, hidden_size, output_size, **kwargs):
        super(GGSNN, self).__init__(**kwargs)
        
        # Initializing model components
        self.vertex_init = nn.Linear(vertex_feature_size, hidden_size)
        self.message_fxn = nn.Linear(hidden_size, hidden_size, bias=False)
        self.gru = nn.GRUCell(hidden_size, hidden_size)
        
        # Readout networks
        self.readout_1 = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size * 2),
            nn.Tanh(),
            nn.Linear(hidden_size * 2, hidden_size)
        )
        self.readout_2 = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size)
        )
        self.readout_final = nn.Linear(hidden_size, output_size)
                
    def init_hidden_states_and_edges(self, graph):
        # vertex_features are a (num_vertices x num_features) tensor 
        # edges is a (num_vertices x num_vertices) sparse tensor
        # batch_sizes is a list of the sizes of the graphs in the batch that were combined into the graph
        vertex_features, edges, batch_sizes = graph
        init_hidden_states = torch.tanh(self.vertex_init(vertex_features))
        # Saving these for use in the readout function later - not every MPNN requires this, but GGSNNs do
        self.init_hidden_states = init_hidden_states.clone()
        self.batch_sizes = batch_sizes.copy()
        return init_hidden_states, edges
    
    def compute_messages(self, hidden_states, edges, t):
        passed_msgs = self.message_fxn(hidden_states)
        # For sparse matrix multiplication in PyTorch
        summed_msgs = torch.sparse.mm(edges, passed_msgs)
        return summed_msgs
    
    def update_hidden_states(self, hidden_states, messages, t):
        hidden_states = self.gru(messages, hidden_states)
        return hidden_states
    
    def readout(self, hidden_states):
        readout_in_1 = torch.cat([hidden_states, self.init_hidden_states], dim=1)
        readout_hid_1 = torch.sigmoid(self.readout_1(readout_in_1))
        readout_hid_2 = self.readout_2(hidden_states)
        readout_hid = readout_hid_1 * readout_hid_2
        readout_attention = []
        i = j = 0
        while self.batch_sizes:
            i = j
            j += self.batch_sizes.pop(0)
            readout_attention.append(torch.sum(readout_hid[i:j], dim=0, keepdim=True))
        readout_attention = torch.cat(readout_attention, dim=0)
        return self.readout_final(readout_attention)

### Let's train!

We'll create a new GGSNN instance and initialize our model parameters, loss function, and optimizer as usual:

In [9]:
model = GGSNN(vertex_feature_size=75, hidden_size=100, output_size=2, n_msg_pass_iters=6)
model.to(device)
# Initialize weights
for param in model.parameters():
    if param.dim() > 1:
        nn.init.normal_(param, std=0.01)
    else:
        nn.init.constant_(param, 0)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0002)

Then we'll add a few helper functions to keep the training loop code clean:

In [10]:
def batchify_graphs(graphs):
    '''
    Args:
        batch: List of graphs in {vertex_feature, adjacency_matrix, label} format
        
    Returns:
        The combination of the input graphs into a big disconnected graph
        The labels of each of the input graphs
    '''
    vertex_features = np.concatenate([g['vertex_features'] for g in graphs])
    vertex_features = torch.tensor(vertex_features, dtype=torch.float32, device=device)
    
    # Create block diagonal sparse matrix for adjacency
    adj_mat = sp.sparse.block_diag([g['adj_mat'] for g in graphs]).tocoo()
    indices = torch.tensor(np.vstack([adj_mat.row, adj_mat.col]), dtype=torch.long, device=device)
    values = torch.tensor(adj_mat.data, dtype=torch.float32, device=device)
    adj_mat = torch.sparse_coo_tensor(indices, values, adj_mat.shape).to_sparse_csr()
    
    batch_sizes = [g['vertex_features'].shape[0] for g in graphs]
    labels = torch.tensor([g['label'] for g in graphs], dtype=torch.long, device=device)
    return (vertex_features, adj_mat, batch_sizes), labels

In [11]:
def evaluate_accuracy(dataset, model, n_batch):
    '''
    Measures the accuracy of the model on the provided dataset, in batches
    '''
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for i in range(0, math.ceil(len(dataset)/n_batch)):
            data = dataset[n_batch*i:n_batch*(i+1)]
            graph, label = batchify_graphs(data)
            output = model(graph)
            _, predicted = torch.max(output.data, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()
    
    return correct / total

In [12]:
def evaluate_roc_score(dataset, model, n_batch):
    '''
    Measures the area under the ROC curve of the model on the provided dataset, in batches
    '''
    model.eval()
    pos_probs = []
    labels = []
    
    with torch.no_grad():
        for i in range(0, math.ceil(len(dataset)/n_batch)):
            data = dataset[n_batch*i:n_batch*(i+1)]
            graph, label = batchify_graphs(data)
            output = model(graph)
            probs = F.softmax(output, dim=1)
            pos_probs.append(probs[:, 1])
            labels.append(label)
    
    labels = torch.cat(labels, dim=0).cpu().numpy()
    pos_probs = torch.cat(pos_probs, dim=0).cpu().numpy()
    return metrics.roc_auc_score(labels, pos_probs)

Notice that the class balance in the dataset is heavily skewed toward the "not toxic" label:

In [18]:
labels = np.array([i['label'] for i in train_dataset])
print('Percentage of "not toxic" labels in training data = {}'.format(sum(labels == 0)/len(labels)))

Percentage of "not toxic" labels in training data = 0.8859075043630017


This is why in the training loop below we're measuring the [ROC AUC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic), in addition to just the accuracy.

Now let's train!

In [19]:
n_epochs = 30
n_batch = 128

for e in range(n_epochs):
    model.train()
    cumulative_loss = 0
    
    for i in range(0, math.ceil(len(train_dataset)/n_batch)):
        data = train_dataset[n_batch*i:n_batch*(i+1)]
        graph, label = batchify_graphs(data)
        
        optimizer.zero_grad()
        output = model(graph)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        
        cumulative_loss += loss.item() * len(data)
    
    valid_accuracy = evaluate_accuracy(valid_dataset, model, n_batch)
    train_accuracy = evaluate_accuracy(train_dataset, model, n_batch)
    valid_roc = evaluate_roc_score(valid_dataset, model, n_batch)
    train_roc = evaluate_roc_score(train_dataset, model, n_batch)
    
    print('Epoch {}. Loss: {:.4f}, \n\tTrain_acc {:.4f}, Valid_acc {:.4f}\n\tTrain_roc_auc {:.4f}, Valid_roc_auc {:.4f}'.format(
            e, cumulative_loss/len(train_dataset), train_accuracy, valid_accuracy, train_roc, valid_roc))
    
print('Test Accuracy: {:.4f}'.format(evaluate_accuracy(test_dataset, model, n_batch)))

  adj_mat = torch.sparse_coo_tensor(indices, values, adj_mat.shape).to_sparse_csr()


Epoch 0. Loss: 0.6012, 
	Train_acc 0.8859, Valid_acc 0.8768
	Train_roc_auc 0.3968, Valid_roc_auc 0.4038
Epoch 1. Loss: 0.4381, 
	Train_acc 0.8859, Valid_acc 0.8768
	Train_roc_auc 0.4104, Valid_roc_auc 0.4195
Epoch 1. Loss: 0.4381, 
	Train_acc 0.8859, Valid_acc 0.8768
	Train_roc_auc 0.4104, Valid_roc_auc 0.4195
Epoch 2. Loss: 0.4287, 
	Train_acc 0.8859, Valid_acc 0.8768
	Train_roc_auc 0.4228, Valid_roc_auc 0.4339
Epoch 2. Loss: 0.4287, 
	Train_acc 0.8859, Valid_acc 0.8768
	Train_roc_auc 0.4228, Valid_roc_auc 0.4339
Epoch 3. Loss: 0.4212, 
	Train_acc 0.8859, Valid_acc 0.8768
	Train_roc_auc 0.4385, Valid_roc_auc 0.4515
Epoch 3. Loss: 0.4212, 
	Train_acc 0.8859, Valid_acc 0.8768
	Train_roc_auc 0.4385, Valid_roc_auc 0.4515
Epoch 4. Loss: 0.4132, 
	Train_acc 0.8859, Valid_acc 0.8768
	Train_roc_auc 0.4618, Valid_roc_auc 0.4782
Epoch 4. Loss: 0.4132, 
	Train_acc 0.8859, Valid_acc 0.8768
	Train_roc_auc 0.4618, Valid_roc_auc 0.4782
Epoch 5. Loss: 0.4024, 
	Train_acc 0.8859, Valid_acc 0.8768
	Tra

Alright!  Unsurprisingly, given the class imbalance, our accuracy didn't improve much; but our ROC score got much better, in line with the current state of the art on this dataset: see the physiology section [here](http://moleculenet.ai/latest-results).

Now go forth and invent your own types of MPNNs!