# Predicting Molecular Properties with Graph Neural Networks

This notebook provides you with a complete code example to predict the properties of small molecules using graph convolutional layers with message passing.

## Understanding Graph Convolutions

Calculate the adjacency matrix of a simple cycle graph ...

In [1]:
import numpy as np

node_num = 5

A = np.zeros((node_num, node_num))
for i in range(node_num):
    for j in range(node_num):
        if j % node_num == (i + 1) % node_num:
            A[i, j] = 1

In [None]:
print(f"A: {A}")

... assign node features ...

In [3]:
x = np.zeros(node_num)
x[0] = 1

In [None]:
print("x:", x)

... implement a function to perform a graph convolution ...

In [5]:
def graph_convolution(A, x):
    """Calculate graph convolution."""
    conv = np.zeros(node_num)
    for i in range(node_num):
        for j in range(node_num):
            conv[j] = conv[j] + A[i, j] * x[i]
    return conv

... and apply the graph convolution.

In [None]:
for c in range(5):
    x = graph_convolution(A, x)
    print(f"Convolution {c + 1}: ", x)

## Predicting Molecular Properties with Graph Convolutions

### Implementing a Graph Convolution Layer

Define a class to implement a graph convolution ...

In [7]:
import torch.nn as nn

class GraphConvolution(nn.Module):
    """Graph convolution."""
    
    def forward(self, A, node_feats):
        """Calculate forward pass."""
        return A @ node_feats

... and a class to implement the class convolution layer.

In [8]:
import deeplay as dl
import torch

class GCL(dl.DeeplayModule):
    """Graph convolution layer."""
    
    def __init__(self, in_feats, out_feats):
        """Initialize graph convolution layer."""
        super().__init__()
        self.transform = dl.Layer(nn.Linear, in_feats, out_feats)
        self.propagate = dl.Layer(GraphConvolution)
        self.update = dl.Layer(nn.ReLU)

    def diagonalize(self, A):
        """Add diagonal to adjacency matrix."""
        return A + torch.eye(A.size(0)).to(A.device)

    def normalize(seft, A):
        """Normalize ajacency matrix."""
        node_degrees = torch.sum(A, dim=1)
        inv_sqrt_node_degreed = node_degrees.pow(-0.5)
        inv_sqrt_node_degreed[inv_sqrt_node_degreed == float("inf")] = 0
        degree_matrix = torch.diag(inv_sqrt_node_degreed)
        return degree_matrix @ A @ degree_matrix
    
    def forward(self, A, node_feats):
        """Transform, propagate and update the node features."""
        A = self.normalize(self.diagonalize(A))
        transformed_node_feats = self.transform(node_feats)
        propagated_node_feats = self.propagate(A, transformed_node_feats)
        updated_node_feats = self.update(propagated_node_feats)
        return updated_node_feats

### Using the ZINC Dataset

Download the ZINC dataset ...

In [9]:
from torch_geometric.datasets import ZINC

train_set = ZINC(root="ZINC_dataset/", subset=True, split="train")
val_set = ZINC(root="ZINC_dataset/", subset=True, split="val")
test_set = ZINC(root="ZINC_dataset/", subset=True, split="test")

... implement a function to plot the adjacency matrix of a molecule ...

In [41]:
import matplotlib.pyplot as plt
from matplotlib import colormaps
from torch_geometric.utils import to_dense_adj

def plot_molecule(molecule):
    """Plot adjacency matrix of a molecule."""
    node_feats = molecule["x"].numpy().squeeze()  # Atom type numbers.
    A = to_dense_adj(molecule["edge_index"]).numpy().squeeze(0)
    logP = molecule["y"].item()
    
    plt.matshow(A, cmap=colormaps["gray"].reversed())
    indices = np.arange(len(node_feats))
    plt.title(f"LogP={np.round(logP, 2)}", fontsize=24)
    plt.xlabel("Atom type", fontsize=16); plt.ylabel("Atom type", fontsize=16)
    ax = plt.gca(); ax.xaxis.set_ticks_position('bottom')
    plt.xticks(indices, node_feats, fontsize=12)
    plt.yticks(indices, node_feats, fontsize=12)
    plt.show()

... and use to visualize the structures of some molecules.

In [None]:
for molecule_index in [2, 1235, 9887]:
    plot_molecule(molecule=train_dataset[molecule_index])

### Implementing a Graph Convolutional Network

Create a class to implement a graph convolutional network ...

In [13]:
class GCN(dl.DeeplayModule):
    """Graph convolutional network."""

    def __init__(self, num_node_embed, embed_dim, hidden_feats, 
                 out_feats):
        """Initialize graph convolutional network."""
        super().__init__()
        self.node_embedding = dl.Layer(nn.Embedding, num_node_embed, embed_dim)
        
        self.blocks = dl.LayerList()
        for f_in, f_out in zip([embed_dim, *hidden_feats[:-1]],
                               [*hidden_feats, out_feats]):
            self.blocks.append(GCL(in_feats=f_in, out_feats=f_out))
        
        self.dense_top = dl.Sequential(
            dl.Layer(nn.Linear, hidden_feats[-1], hidden_feats[-1] // 4),
            dl.Layer(nn.ReLU),
            dl.Layer(nn.Linear, hidden_feats[-1] // 4, out_feats),
        )
        
    def forward(self, G):
        """Calculate forward pass."""
        G["x"] = self.node_embedding(G["x"])
        for block in self.blocks:
            G["x"] = block(G["A"], G["x"])
        
        batch_size = torch.max(G["batch"]) + 1
        h = torch.zeros(batch_size, G["x"].shape[1], device=G["x"].device)
        h = h.scatter_add(0, G["batch"][:, None].expand_as(G["x"]), G["x"])
        h = h / torch.bincount(G["batch"])[:, None]
         
        return self.dense_top(h).squeeze()

... instantiate it ...

In [None]:
gcn_model = GCN(num_node_embed=28, embed_dim=64, hidden_feats=[64,] * 4, 
                out_feats=1).create()

print(gcn_model)

... define the data loaders ...

In [15]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(dataset=train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=32, shuffle=False)
test_loader = DataLoader(dataset=test_set, batch_size=32, shuffle=False)

... implement a class to compile, train, and evaluate the graph convolutional network ...

In [16]:
class MolecularRegressor(dl.Regressor):
    """Molecular regressor."""

    def __init__(self, model, **kwargs):
        """Initialize molecular regressor."""
        super().__init__(model, **kwargs)
    
    def batch_preprocess(self, batch):
        """Preprocess batch."""
        batch["x"] = batch["x"].squeeze()
        batch["A"] = to_dense_adj(batch["edge_index"]).squeeze(0)
        return batch.to(self.device)
    
    def forward(self, batch):
        """Calculate forward pass."""
        return self.model(self.batch_preprocess(batch))

... and train the graph convolutional network.

In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint
import os

gcn = MolecularRegressor(
    gcn_model, loss=nn.L1Loss(), optimizer=dl.Adam(lr=1e-3),
).create()
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss", dirpath=os.path.join("models", "gcn"), 
    filename="ZINC-{epoch:02d}-{val_loss:.2f}", auto_insert_metric_name=False,
)
trainer = dl.Trainer(max_epochs=400, callbacks=[checkpoint_callback])
trainer.fit(gcn, train_loader, val_loader)

### Evaluating the Trained Graph Convolutional Network

Evaluate the performance of the trained graph convolutional network ...

In [None]:
import glob

model_paths = glob.glob(os.path.join("models", "gcn", "ZINC-*.ckpt"))
best_model_path = sorted(model_paths, key=os.path.getmtime)[-1]
gcn_best = MolecularRegressor.load_from_checkpoint(
    best_model_path, model=gcn_model, loss=nn.L1Loss(), 
)
test_results = trainer.test(gcn_best, test_loader)

... obtain the predicted and actual logP ...

In [33]:
logP_gts, logP_preds = [], []
for G_batch in test_loader:
    logP_gts.append(G_batch.pop("y"))
    logP_preds.append(gcn_best(G_batch))
logP_gts = torch.cat(logP_gts).cpu().numpy()
logP_preds = torch.cat(logP_preds).detach().cpu().numpy()

... and plot the predicted logP versus their ground truth values.

In [None]:
heatmap, xedges, yedges = np.histogram2d(logP_preds, logP_gts, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

plt.clf()
plt.plot([min(logP_gts), max(logP_gts)], [min(logP_gts), max(logP_gts)], "r--")
plt.imshow(heatmap.T, extent=extent, origin="lower")
plt.xlabel("Predicted LogP")
plt.ylabel("True LogP")
plt.show()

## Predicting Molecular Properties with Message Passing

### Implementing a Message Passing Layer

Implement a class to implement a layer to perform the transform step ...

In [21]:
class TransformLayer(nn.Module):
    """Transform layer."""

    def __init__(self, hidden_feats):
        """Initialize the transform layer."""
        super().__init__()
        self.layer = nn.LazyLinear(hidden_feats)
        self.activation = nn.ReLU()

    def forward(self, G):
        """Calculate forward pass."""
        G["message"] = torch.cat([
            G["x"][G["edge_index"][0]],  # Source node features.
            G["x"][G["edge_index"][1]],  # Target node features.
            G["edge_attr"]  # Edge features.
            ], dim=-1)
        G["message"] = self.layer(G["message"])
        G["message"] = self.activation(G["message"])
        return G

... a class to implement a layer to perform the propagate step ...

In [22]:
class PropagateLayer(nn.Module):
    """PropagateLayer."""

    def __init__(self, hidden_feats):
        """Initialize propagate layer."""
        super().__init__()
        self.hidden_feats = hidden_feats

    def forward(self, G):
        """Calculate forward pass."""
        G["aggregate"] = torch.zeros(G["x"].size(0), self.hidden_feats)
        G["aggregate"] = G["aggregate"].type_as(G["x"])
        G["aggregate"] = G["aggregate"].to(G["x"].device)

        indices = G["edge_index"][1].unsqueeze(1).expand_as(G["message"]) 
        G["aggregate"] = G["aggregate"].scatter_add(0, indices, G["message"])
        return G

... a class to implement a layer to perform the update step ...

In [23]:
class UpdateLayer(nn.Module):
    """Update layer."""

    def __init__(self, hidden_feats):
        """Initialize update layer."""
        super().__init__()
        self.layer = nn.LazyLinear(hidden_feats)
        self.activation = nn.ReLU()

    def forward(self, G):
        """Calculate forward pass."""
        G["x"] = self.layer(torch.cat([G["x"], G["aggregate"]], dim=-1))
        G["x"] = self.activation(G["x"])
        G["edge_attr"] = G["message"]
        return G

... and use them to define a message passing layer.

In [None]:
mpl = dl.Sequential(
    dl.Layer(TransformLayer, hidden_feats=64),
    dl.Layer(PropagateLayer, hidden_feats=64),
    dl.Layer(UpdateLayer, hidden_feats=64),
).create()

print(mpl)

### Implementing a Message Passing Network

Create a class to implement a message passing network ...

In [25]:
class MPN(dl.DeeplayModule):
    """Message passing network."""

    def __init__(self, num_node_embed, num_edge_embed, embed_dim, 
                 hidden_feats, out_feats):
        """Initialize graph convolutional network."""
        super().__init__()
        
        self.node_embedding = dl.Layer(nn.Embedding, num_node_embed,
                                       embed_dim)
        self.edge_embedding = dl.Layer(nn.Embedding, num_edge_embed, 
                                       embed_dim)

        self.blocks = dl.LayerList()
        for f_out in hidden_feats:
            mpl = dl.Sequential(
                dl.Layer(TransformLayer, f_out),
                dl.Layer(PropagateLayer, f_out),
                dl.Layer(UpdateLayer, f_out),
            )
            self.blocks.append(mpl)

        self.dense_top = dl.Sequential(
            dl.Layer(nn.Linear, hidden_feats[-1], hidden_feats[-1] // 4),
            dl.Layer(nn.ReLU),
            dl.Layer(nn.Linear, hidden_feats[-1] // 4, out_feats),
        )
        
    def forward(self, G):
        """Calculate forward pass."""
        G["x"] = self.node_embedding(G["x"])
        G["edge_attr"] = self.edge_embedding(G["edge_attr"])
        for block in self.blocks:
            G = block(G)
        
        batch_size = torch.max(G["batch"]) + 1
        h = torch.zeros(batch_size, G["x"].shape[1], device=G["x"].device)
        h = h.scatter_add(0, G["batch"][:, None].expand_as(G["x"]), G["x"])
        h = h / torch.bincount(G["batch"])[:, None]
         
        return self.dense_top(h).squeeze()

... instantiate the message passing network ...

In [None]:
mpn_model = MPN(num_node_embed=28, num_edge_embed=4, embed_dim=64, 
                hidden_feats=[64,] * 4, out_feats=1).create()

print(mpn_model)

... train the message passing network ...

In [None]:
mpn = MolecularRegressor(
    mpn_model, loss=nn.L1Loss(), optimizer=dl.Adam(lr=1e-3),
).create()
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss", dirpath=os.path.join("models", "mpn"),
    filename="ZINC-{epoch:02d}-{val_loss:.2f}",auto_insert_metric_name=False,
)
trainer = dl.Trainer(max_epochs=400, callbacks=[checkpoint_callback])
trainer.fit(mpn, train_loader, val_loader)

... evaluate the performance of the trained message passing network ...

In [None]:
model_paths = glob.glob(os.path.join("models", "mpn", "ZINC-*.ckpt"))
best_model_path = sorted(model_paths, key=os.path.getmtime)[-1]
mpn_best = MolecularRegressor.load_from_checkpoint(
    best_model_path, model=mpn_model, loss=nn.L1Loss(),
)
test_results = trainer.test(mpn_best, test_loader)

... obtaining the predicted and actual logP ...

In [31]:
logP_gts, logP_preds = [], []
for G_batch in test_loader:
    y = G_batch.pop("y"); logP_gts.append(y)
    pred = mpn_best(G_batch); logP_preds.append(pred)
logP_gts = torch.cat(logP_gts).cpu().numpy()
logP_preds = torch.cat(logP_preds).detach().cpu().numpy()

... and plot the predicted logP versus their ground truth values.

In [None]:
heatmap, xedges, yedges = np.histogram2d(logP_preds, logP_gts, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

plt.clf()
plt.plot([min(logP_gts), max(logP_gts)], [min(logP_gts), max(logP_gts)], "r--")
plt.imshow(heatmap.T, extent=extent, origin="lower")
plt.xlabel("Predicted LogP")
plt.ylabel("True LogP")
plt.show()