# Predicting Molecular Properties with Graph Neural Networks

This notebook provides you with a complete code example to predict the properties of small molecules using graph convolutional layers with message passing.

## Understanding Graph Convolutions

Calculate the adjacency matrix of a simple cycle graph ...

In [1]:
import numpy as np

node_num = 5

A = np.zeros((node_num, node_num))
for i in range(node_num):
    for j in range(node_num):
        if j % node_num == (i + 1) % node_num:
            A[i, j] = 1

In [None]:
print(f"A: {A}")

... assign node features ...

In [3]:
x = np.zeros(node_num)
x[0] = 1

In [None]:
print("x:", x)

... implement a function to perform a graph convolution ...

In [5]:
def graph_convolution(A, x):
    """Calculate graph convolution."""
    conv = np.zeros(node_num)
    for i in range(node_num):
        for j in range(node_num):
            conv[j] = conv[j] + A[i, j] * x[i]
    return conv

... and apply the graph convolution.

In [None]:
for c in range(5):
    x = graph_convolution(A, x)
    print(f"Convolution {c + 1}: ", x)

## Predicting Molecular Properties with Graph Convolutions

### Implementing a Graph Convolution Layer

Define a class to implement a graph convolution ...

In [7]:
import torch.nn as nn

class GraphConvolution(nn.Module):
    """Graph convolution."""
    
    def forward(self, A, node_attr):
        """Compute the graph convolution."""
        return A @ node_attr

... and a class to implement the class convolution layer.

In [8]:
import deeplay as dl
import torch

class GCL(dl.DeeplayModule):
    """Graph convolution layer."""
    
    def __init__(self, in_feats, out_feats):
        """Initialize graph convolution layer."""
        super().__init__()
        self.transform = dl.Layer(nn.Linear, in_feats, out_feats)
        self.propagate = dl.Layer(GraphConvolution)
        self.update = dl.Layer(nn.ReLU)

    def diagonalize(self, A):
        """Add diagonal to adjacency matrix."""
        return A + torch.eye(A.size(0)).to(A.device)

    def normalize(seft, A):
        """Normalize ajacency matrix."""
        node_degrees = torch.sum(A, dim=1)
        inv_sqrt_node_degreed = node_degrees.pow(-0.5)
        inv_sqrt_node_degreed[inv_sqrt_node_degreed == float("inf")] = 0
        degree_matrix = torch.diag(inv_sqrt_node_degreed)
        return degree_matrix @ A @ degree_matrix
    
    def forward(self, A, node_attr):
        """Transform, propagate and update the node attributes."""
        A = self.normalize(self.diagonalize(A))
        transformed_node_attr = self.transform(node_attr)
        propagated_node_attr = self.propagate(A, transformed_node_attr)
        updated_node_attr = self.update(propagated_node_attr)
        return updated_node_attr

### Using the ZINC Dataset

Download the ZINC dataset ...

In [9]:
from torch_geometric.datasets import ZINC

train_set = ZINC(root="ZINC_dataset/", subset=True, split="train")
val_set = ZINC(root="ZINC_dataset/", subset=True, split="val")
test_set = ZINC(root="ZINC_dataset/", subset=True, split="test")

... implement a function to plot the adjacency matrix of a molecule ...

In [10]:
import matplotlib.pyplot as plt
from matplotlib import colormaps
from torch_geometric.utils import to_dense_adj

def plot_molecule(molecule):
    """Plot adjacency matrix of a molecule."""
    node_attr = molecule["x"].numpy().squeeze()  # Atom type numbers.
    A = to_dense_adj(molecule["edge_index"]).numpy().squeeze(0)
    logP = molecule["y"].item()
    
    plt.matshow(A, cmap=colormaps["gray"].reversed())
    plt.title(f"LogP={np.round(logP, 2)}", fontsize=24)
    plt.xlabel("Atom type", fontsize=16); plt.ylabel("Atom type", fontsize=16)
    plt.xticks(np.arange(len(node_attr)), node_attr, fontsize=12)
    plt.yticks(np.arange(len(node_attr)), node_attr, fontsize=12)
    ax = plt.gca(); ax.xaxis.set_ticks_position("bottom")
    plt.show()

... and use to visualize the structures of some molecules.

In [None]:
for molecule_index in [2, 1235, 9887]:
    plot_molecule(molecule=train_set[molecule_index])

### Implementing a Graph Convolutional Network

Create a class to implement a graph convolutional network ...

In [12]:
class GCN(dl.DeeplayModule):
    """Graph convolutional network."""
    
    def __init__(self, num_atoms, embed_dim, hidden_feats, out_feats):
        """Initialize graph convolutional network."""
        super().__init__()
        self.node_embedding = dl.Layer(nn.Embedding, num_atoms, embed_dim)
        self.blocks = dl.LayerList()
        for f_in, f_out in zip([embed_dim, *hidden_feats[:-1]], 
                               hidden_feats):
            self.blocks.append(GCL(in_feats=f_in, out_feats=f_out))
        self.dense_top = dl.Sequential(
            dl.Layer(nn.Linear, hidden_feats[-1], hidden_feats[-1] // 4),
            dl.Layer(nn.ReLU),
            dl.Layer(nn.Linear, hidden_feats[-1] // 4, out_feats),
        )

    def forward(self, G):
        """***Predict graph properties***."""
        G["node_attr"] = self.node_embedding(G["node_attr"])
        for block in self.blocks:
            G["node_attr"] = block(G["A"], G["node_attr"])
        
        num_graphs = torch.max(G["graph_ids"]) + 1
        aggregated_node_feats = torch.zeros(
            num_graphs, G["node_attr"].shape[1], device=G["node_attr"].device,
        )
        aggregated_node_attr = aggregated_node_attr.scatter_add(
            0, G["graph_ids"][:, None].expand_as(G["node_attr"]), 
            G["node_attr"],
        )
        node_counts = torch.bincount(G["graph_ids"])
        aggregated_node_attr = aggregated_node_attr / node_counts[:, None]
        
        return self.dense_top(aggregated_node_attr).squeeze()  # LogP.

... instantiate it ...

In [None]:
gcn_model = GCN(num_atoms=28, embed_dim=64, hidden_feats=[64,] * 4, 
                out_feats=1).create()

print(gcn_model)

... define the data loaders ...

In [14]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(dataset=train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=32, shuffle=False)
test_loader = DataLoader(dataset=test_set, batch_size=32, shuffle=False)

... implement a class to compile, train, and evaluate the graph convolutional network ...

In [15]:
class MolecularRegressor(dl.Regressor):
    """Regressor model for molecular property prediction."""

    def __init__(self, model, **kwargs):
        """Initialize molecular regressor."""
        super().__init__(model, **kwargs)
    
    def batch_preprocess(self, G):
        """Preprocess the graph batch for model input."""
        G["node_attr"] = G["x"].squeeze()
        G["A"] = to_dense_adj(G["edge_index"]).squeeze(0)
        G["graph_ids"] = G["batch"]
        return G.to(self.device)
    
    def forward(self, G):
        """Calculate model output for input graph batch."""
        return self.model(self.batch_preprocess(G))

... and train the graph convolutional network.

In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint
import os

gcn = MolecularRegressor(gcn_model, loss=nn.L1Loss(), 
                         optimizer=dl.Adam(lr=1e-3)).create()
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss", dirpath=os.path.join("models", "gcn"), 
    filename="ZINC-{epoch:02d}-{val_loss:.2f}", auto_insert_metric_name=False,
)
trainer = dl.Trainer(max_epochs=400, callbacks=[checkpoint_callback])
trainer.fit(gcn, train_loader, val_loader)

### Evaluating the Trained Graph Convolutional Network

Evaluate the performance of the trained graph convolutional network ...

In [None]:
import glob

model_paths = glob.glob(os.path.join("models", "gcn", "ZINC-*.ckpt"))
best_model_path = sorted(model_paths, key=os.path.getmtime)[-1]
gcn_best = MolecularRegressor.load_from_checkpoint(best_model_path, 
                                                   model=gcn_model)
test_results = trainer.test(gcn_best, test_loader)

... obtain the predicted and actual logP ...

In [18]:
logP_gts, logP_preds = [], []
for G in test_loader:
    logP_gts.append(G.pop("y"))
    logP_preds.append(gcn_best(G))
logP_gts = torch.cat(logP_gts).cpu().numpy()
logP_preds = torch.cat(logP_preds).detach().cpu().numpy()

... and plot the predicted logP versus their ground truth values.

In [None]:
heatmap, xedges, yedges = np.histogram2d(logP_preds, logP_gts, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

plt.plot([min(logP_gts), max(logP_gts)], [min(logP_gts), max(logP_gts)], "r--")
plt.imshow(heatmap.T, extent=extent, origin="lower")
plt.xlabel("Predicted LogP"); plt.ylabel("True LogP")
plt.show()

## Predicting Molecular Properties with Message Passing

### Implementing a Message Passing Layer

Implement a class to implement a layer to perform the transform step ...

In [20]:
class TransformLayer(nn.Module):
    """Transform layer."""
    
    def __init__(self, hidden_feats):
        """Initialize the transform layer."""
        super().__init__()
        self.linear = nn.LazyLinear(hidden_feats)
        self.activation = nn.ReLU()

    def forward(self, G):
        """Compute messages by transforming node and edge features."""
        src_node_attr = G["node_attr"][G["edge_index"][0]]
        tgt_node_attr = G["node_attr"][G["edge_index"][1]]
        edge_attr = G["edge_attr"]

        msg = torch.cat([src_node_attr, tgt_node_attr, edge_attr], dim=-1)
        G["msg"] = self.activation(self.linear(msg))
        return G

... a class to implement a layer to perform the propagate step ...

In [21]:
class PropagateLayer(nn.Module):
    """Propagate layer."""

    def __init__(self, hidden_feats):
        """Initialize propagate layer."""
        super().__init__()
        self.hidden_feats = hidden_feats

    def forward(self, G):
        """Aggregate messages from neighboring nodes."""
        num_nodes = G["node_attr"].size(0)
        aggregated_msg = torch.zeros(num_nodes, self.hidden_feats, 
                                     dtype=G["node_attr"].dtype,
                                     device=G["node_attr"].device)
        tgt_node_idxs = (G["edge_index"][1].unsqueeze(1).expand_as(G["msg"]))
        aggregated_msg = aggregated_msg.scatter_add(0, tgt_node_idxs, G["msg"])
        G["aggregated_msg"] = aggregated_msg
        return G

... a class to implement a layer to perform the update step ...

In [22]:
class UpdateLayer(nn.Module):
    """Update layer."""

    def __init__(self, hidden_feats):
        """Initialize update layer."""
        super().__init__()
        self.linear = nn.LazyLinear(hidden_feats)
        self.activation = nn.ReLU()

    def forward(self, G):
        """Update node attributes combining them with aggregated messages."""
        attr = torch.cat([G["node_attr"], G["aggregated_msg"]], dim=-1)
        G["node_attr"] = self.activation(self.linear(attr))
        G["edge_attr"] = G["msg"]
        return G

... and use them to define a message passing layer.

In [None]:
mpl = dl.Sequential(
    dl.Layer(TransformLayer, hidden_feats=64),
    dl.Layer(PropagateLayer, hidden_feats=64),
    dl.Layer(UpdateLayer, hidden_feats=64),
).create()

print(mpl)

### Implementing a Message Passing Network

Create a class to implement a message passing network ...

In [24]:
class MPN(dl.DeeplayModule):
    """Message passing network."""

    def __init__(self, num_atoms, num_edge_embed, embed_dim, hidden_feats, 
                 out_feats):
        """Initialize graph convolutional network."""
        super().__init__()
        
        self.node_embedding = dl.Layer(nn.Embedding, num_node_embed, embed_dim)
        self.edge_embedding = dl.Layer(nn.Embedding, num_edge_embed, embed_dim)

        self.blocks = dl.LayerList()
        for f_out in hidden_feats:
            mpl = dl.Sequential(
                dl.Layer(TransformLayer, f_out),
                dl.Layer(PropagateLayer, f_out),
                dl.Layer(UpdateLayer, f_out),
            )
            self.blocks.append(mpl)

        self.dense_top = dl.Sequential(
            dl.Layer(nn.Linear, hidden_feats[-1], hidden_feats[-1] // 4),
            dl.Layer(nn.ReLU),
            dl.Layer(nn.Linear, hidden_feats[-1] // 4, out_feats),
        )
        
    def forward(self, G):
        """Calculate forward pass."""
        G["node_attr"] = self.node_embedding(G["node_attr"])
        G["edge_attr"] = self.edge_embedding(G["edge_attr"])
        for block in self.blocks:
            G = block(G)
        
        batch_size = torch.max(G["graph_ids"]) + 1
        aggregated_node_feats = torch.zeros(batch_size, 
                                            G["node_attr"].shape[1], 
                                            device=G["node_attr"].device)
        aggregated_node_feats = aggregated_node_feats.scatter_add(
            0, G["graph_ids"][:, None].expand_as(G["node_attr"]), 
            G["node_attr"],
        )
        node_counts = torch.bincount(G["graph_ids"])
        aggregated_node_feats = aggregated_node_feats / node_counts[:, None]
        
        return self.dense_top(aggregated_node_feats).squeeze()  # LogP.
    
        '''batch_size = torch.max(G["graph_ids"]) + 1
        h = torch.zeros(batch_size, G["node_attr"].shape[1], device=G["node_attr"].device)
        h = h.scatter_add(0, G["graph_ids"][:, None].expand_as(G["node_attr"]), G["node_attr"])
        h = h / torch.bincount(G["graph_ids"])[:, None]
         
        return self.dense_top(h).squeeze()'''

... instantiate the message passing network ...

In [None]:
mpn_model = MPN(num_node_embed=28, num_edge_embed=4, embed_dim=64, 
                hidden_feats=[64,] * 4, out_feats=1).create()

print(mpn_model)

... train the message passing network ...

In [None]:
mpn = MolecularRegressor(
    mpn_model, loss=nn.L1Loss(), optimizer=dl.Adam(lr=1e-3),
).create()
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss", dirpath=os.path.join("models", "mpn"),
    filename="ZINC-{epoch:02d}-{val_loss:.2f}",auto_insert_metric_name=False,
)
trainer = dl.Trainer(max_epochs=400, callbacks=[checkpoint_callback])
trainer.fit(mpn, train_loader, val_loader)

... evaluate the performance of the trained message passing network ...

In [None]:
model_paths = glob.glob(os.path.join("models", "mpn", "ZINC-*.ckpt"))
best_model_path = sorted(model_paths, key=os.path.getmtime)[-1]
mpn_best = MolecularRegressor.load_from_checkpoint(best_model_path, 
                                                   model=mpn_model)
test_results = trainer.test(mpn_best, test_loader)

... obtaining the predicted and actual logP ...

In [28]:
logP_gts, logP_preds = [], []
for G in test_loader:
    logP_gts.append(G.pop("y"))
    logP_preds.append(mpn_best(G))
logP_gts = torch.cat(logP_gts).cpu().numpy()
logP_preds = torch.cat(logP_preds).detach().cpu().numpy()

... and plot the predicted logP versus their ground truth values.

In [None]:
heatmap, xedges, yedges = np.histogram2d(logP_preds, logP_gts, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

plt.plot([min(logP_gts), max(logP_gts)], [min(logP_gts), max(logP_gts)], "r--")
plt.imshow(heatmap.T, extent=extent, origin="lower")
plt.xlabel("Predicted LogP"); plt.ylabel("True LogP")
plt.show()