# 02456 Molecular Property Prediction

Basic example of how to train the PaiNN model to predict the QM9 property
"internal energy at 0K". This property (and the majority of the other QM9
properties) is computed as a sum of atomic contributions.

In [64]:
%%capture

!pip install pytorch_lightning
!pip install torch_geometric

import torch
import torch.nn as nn
import math
import argparse
import torch.nn.functional as F
import numpy as np
import pytorch_lightning as pl
from tqdm import trange
from pytorch_lightning import seed_everything
from torch_geometric.data import Data
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.transforms import BaseTransform
from typing import Optional, List, Union, Tuple

## QM9 Datamodule

In [2]:
class GetTarget(BaseTransform):

    def __init__(self, target: Optional[int] = None) -> None:
        self.target = [target]


    def forward(self, data: Data) -> Data:
        if self.target is not None:
            data.y = data.y[:, self.target]
        return data


class QM9DataModule(pl.LightningDataModule):

    target_types = ['atomwise' for _ in range(19)]
    target_types[0] = 'dipole_moment'
    target_types[5] = 'electronic_spatial_extent'

    # Specify unit conversions (eV to meV).
    unit_conversion = {
        i: (lambda t: 1000*t) if i not in [0, 1, 5, 11, 16, 17, 18]
        else (lambda t: t)
        for i in range(19)
    }


    def __init__(
        self,
        target: int = 7,
        data_dir: str = 'data/',
        batch_size_train: int = 100,
        batch_size_inference: int = 1000,
        num_workers: int = 0,
        splits: Union[List[int], List[float]] = [110000, 10000, 10831],
        seed: int = 0,
        subset_size: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.target = target
        self.data_dir = data_dir
        self.batch_size_train = batch_size_train
        self.batch_size_inference = batch_size_inference
        self.num_workers = num_workers
        self.splits = splits
        self.seed = seed
        self.subset_size = subset_size

        self.data_train = None
        self.data_val = None
        self.data_test = None


    def prepare_data(self) -> None:
        # Download data
        QM9(root=self.data_dir)


    def setup(self, stage: Optional[str] = None) -> None:
        dataset = QM9(root=self.data_dir, transform=GetTarget(self.target))

        # Shuffle dataset
        rng = np.random.default_rng(seed=self.seed)
        dataset = dataset[rng.permutation(len(dataset))]

        # Subset dataset
        if self.subset_size is not None:
            dataset = dataset[:self.subset_size]

        # Split dataset
        if all([type(split) == int for split in self.splits]):
            split_sizes = self.splits
        elif all([type(split) == float for split in self.splits]):
            split_sizes = [int(len(dataset) * prop) for prop in self.splits]

        split_idx = np.cumsum(split_sizes)
        self.data_train = dataset[:split_idx[0]]
        self.data_val = dataset[split_idx[0]:split_idx[1]]
        self.data_test = dataset[split_idx[1]:]


    def get_target_stats(
        self,
        remove_atom_refs: bool = True,
        divide_by_atoms: bool = True
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        atom_refs = self.data_train.atomref(self.target)                        # Atom reference energy

        ys = list()
        for batch in self.train_dataloader(shuffle=False):
            y = batch.y.clone()
            if remove_atom_refs and atom_refs is not None:
                y.index_add_(
                    dim=0, index=batch.batch, source=-atom_refs[batch.z]
                )
            if divide_by_atoms:                                                 # Normalize internal energy by the number of atoms
                _, num_atoms  = torch.unique(batch.batch, return_counts=True)
                y = y / num_atoms.unsqueeze(-1)
            ys.append(y)

        y = torch.cat(ys, dim=0)
        return y.mean(), y.std(), atom_refs


    def train_dataloader(self, shuffle: bool = True) -> DataLoader:
        return DataLoader(
            self.data_train,
            batch_size=self.batch_size_train,
            num_workers=self.num_workers,
            shuffle=shuffle,
            pin_memory=True,
        )


    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_val,
            batch_size=self.batch_size_inference,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )


    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_test,
            batch_size=self.batch_size_inference,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )

## Post-processing module

In [3]:
class AtomwisePostProcessing(nn.Module):
    """
    Post-processing for (QM9) properties that are predicted as sums of atomic
    contributions.
    """
    def __init__(
        self,
        num_outputs: int,
        mean: torch.FloatTensor,
        std: torch.FloatTensor,
        atom_refs: torch.FloatTensor,
    ) -> None:
        """
        Args:
            num_outputs: Integer with the number of model outputs. In most
                cases 1.
            mean: torch.FloatTensor with mean value to shift atomwise
                contributions by.
            std: torch.FloatTensor with standard deviation to scale atomwise
                contributions by.
            atom_refs: torch.FloatTensor of size [num_atom_types, 1] with
                atomic reference values.
        """
        super().__init__()
        self.num_outputs = num_outputs
        self.register_buffer('scale', std)
        self.register_buffer('shift', mean)
        self.atom_refs = nn.Embedding.from_pretrained(atom_refs, freeze=True)


    def forward(
        self,
        atomic_contributions: torch.FloatTensor,
        atoms: torch.LongTensor,
        graph_indexes: torch.LongTensor,
    ) -> torch.FloatTensor:
        """
        Atomwise post-processing operations and atomic sum.

        Args:
            atomic_contributions: torch.FloatTensor of size [num_nodes,
                num_outputs] with each node's contribution to the overall graph
                prediction, i.e., each atom's contribution to the overall
                molecular property prediction.
            atoms: torch.LongTensor of size [num_nodes] with atom type of each
                node in the graph.
            graph_indexes: torch.LongTensor of size [num_nodes] with the graph
                index each node belongs to.

        Returns:
            A torch.FLoatTensor of size [num_graphs, num_outputs] with
            predictions for each graph (molecule).
        """
        num_graphs = torch.unique(graph_indexes).shape[0]

        atomic_contributions = atomic_contributions*self.scale + self.shift
        atomic_contributions = atomic_contributions + self.atom_refs(atoms)

        # Sum contributions for each graph
        output_per_graph = torch.zeros(
            (num_graphs, self.num_outputs),
            device=atomic_contributions.device,
        )
        output_per_graph.index_add_(
            dim=0,
            index=graph_indexes,
            source=atomic_contributions,
        )

        return output_per_graph

## PaiNN

In [46]:
def LocalEdges(atom_positions,
               graph_indexes,
               cutoff_dist):

    # The number of atoms in the batch.
    num_atoms = graph_indexes.size(0)

    # Pairing of the atoms across all molecules in the batch.
    pos_i = atom_positions.unsqueeze(0).repeat(num_atoms, 1, 1)
    pos_j = atom_positions.unsqueeze(1).repeat(1, num_atoms, 1)

    # Compute all r_ij vectors and their norms (distances).
    r_ij = pos_j - pos_i                                          # Pairwise vector differences.
    r_ij_norm = torch.norm(r_ij, dim=2)                           # Pairwise distances.

    # We will not consider the distance between an atom and itself (i == j).
    # We only consider atoms within the same molecule (graph_indexes[i] == graph_indexes[j]).
    # We only want the pairs of close atoms specified by the cutoff.
    # Thus, we create masks to filter pairs.
    same_graph_mask     = graph_indexes.unsqueeze(0) == graph_indexes.unsqueeze(1)
    different_atom_mask = torch.arange(num_atoms).unsqueeze(1) != torch.arange(num_atoms).unsqueeze(0)
    within_cutoff_mask  = r_ij_norm <= cutoff_dist

    # Combine masks: same graph, different atoms, within cutoff.
    valid_pairs_mask = same_graph_mask & different_atom_mask & within_cutoff_mask

    # Filter indices and values based on the mask.
    edge_indexes = valid_pairs_mask.nonzero(as_tuple=False).t()   # Edge indexes, shape: (2, num_edges) - nonzero returns the indices of the elements that are non-zero (False is interpreted as 0).
    edge_vector = r_ij[valid_pairs_mask]                          # Edge vectors, shape: (num_edges, 3)
    edge_distance = r_ij_norm[valid_pairs_mask]                   # Edge distances, shape: (num_edges,)

    return edge_indexes, edge_vector, edge_distance

def RadialBasis(edge_distance,
                num_rbf_features,
                cutoff_dist):

    # Number of local edges.
    num_egdes = edge_distance.size()[0]

    # Generate n values evenly spaced between 1 and 20.
    n_values = torch.linspace(1, 20, num_rbf_features)

    # Expand the n_values to match the shape of edge_distance.
    n_values_expanded = n_values.unsqueeze(0).expand(num_egdes, num_rbf_features)
    edge_distance_expanded = edge_distance.unsqueeze(1).expand(num_egdes, num_rbf_features)

    # Compute the RBF for each pair of (r_ij, n).
    edge_rbf = torch.sin(n_values_expanded * torch.pi * edge_distance_expanded / cutoff_dist) / edge_distance_expanded

    return edge_rbf

def CosineCutoff(edge_distance,
                 cutoff_dist):

    # Compute values of cutoff function.
    fcut = 0.5 * (torch.cos(edge_distance * math.pi / cutoff_dist) + 1.0)

    # Remove contributions beyond the cutoff radius.
    fcut *= (fcut < cutoff_dist).float()

    return fcut

In [75]:
class MessageBlock(nn.Module):
  def __init__(self,
               num_features,
               num_rbf_features):
    super().__init__()

    self.num_features = num_features

    self.num_rbf_features = num_rbf_features

    self.linear_s = nn.Sequential(
        nn.Linear(num_features, num_features),
        nn.SiLU(),
        nn.Linear(num_features, num_features * 3),
        )

    self.linear_rbf = nn.Linear(num_rbf_features, num_features * 3)


  def forward(self,
              s,
              vec,
              edge_indexes,
              edge_vector,
              edge_distance,
              cutoff_dist):

    # Atomwise layers.
    phi = self.linear_s(s)

    # Compute radial basis functions.
    edge_rbf = RadialBasis(edge_distance,
                           self.num_features,
                           cutoff_dist)

    # Linear combination of the radial basis functions.
    edge_rbf_linear = self.linear_rbf(edge_rbf)

    # Cosine cutoff.
    fcut = CosineCutoff(edge_distance,
                        cutoff_dist)

    W = edge_rbf_linear * fcut[..., None]

    # Split of W.
    vec_Ws, vec_Wvv, vec_Wvs = torch.split(phi * W, self.num_features, -1)

    # Aggregate contributions from neighboring atoms
    ds = vec_Ws

    vec_n = edge_vector / edge_distance[..., None]

    dvec = vec * Wvv.unsqueeze(1) + Wvs * vec_n.unsqueeze(2)

    # Aggregate contributions from neighboring atoms

    return ds, dvec

class UpdateBlock(nn.Module):
    def __init__(self,
                 num_features):
        super().__init__()

        self.num_features = num_features

        self.linear_vec = nn.Linear(num_features, num_features * 2, bias=False)

        self.linear_svec = nn.Sequential(
            nn.Linear(num_features * 2, num_features),
            nn.SiLU(),
            nn.Linear(num_features, num_features * 3),
        )

    def forward(self,
                s,
                vec):

        vec_U, vec_V = torch.split(self.linear_vec(vec), self.num_features, dim = -1)

        vec_dot = (vec_U * vec_V).sum(dim=1) #* self.inv_sqrt_h

        vec_Vn = torch.sqrt(torch.sum(vec_V**2, dim = -2) + 1e-8)      # Add an epsilon offset to make sure sqrt is always positive.

        vec_W = self.linear_svec(torch.cat([s, vec_Vn], dim = -1))

        a_vv, a_sv, a_ss = torch.split(vec_W, self.num_features, dim = -1)

        ds = a_ss + a_sv * vec_dot    # * self.inv_sqrt_2

        dvec = a_vv.unsqueeze(1) * vec_U

        return ds, dvec

class PaiNN(nn.Module):
    """
    Polarizable Atom Interaction Neural Network with PyTorch.
    """
    def __init__(
        self,
        num_message_passing_layers: int = 3,
        num_features: int = 128,
        num_outputs: int = 1,
        num_rbf_features: int = 20,
        num_unique_atoms: int = 100,
        cutoff_dist: float = 5.0,
    ) -> None:
        """
        Args:
            num_message_passing_layers: Number of message passing layers in
                the PaiNN model.
            num_features: Size of the node embeddings (scalar features) and
                vector features.
            num_outputs: Number of model outputs. In most cases 1.
            num_rbf_features: Number of radial basis functions to represent
                distances.
            num_unique_atoms: Number of unique atoms in the data that we want
                to learn embeddings for.
            cutoff_dist: Euclidean distance threshold for determining whether
                two nodes (atoms) are neighbours.
        """
        super().__init__()

        self.num_message_passing_layers = num_message_passing_layers
        self.num_features = num_features
        self.num_outputs = num_outputs
        self.num_rbf_features = num_rbf_features
        self.num_unique_atoms = num_unique_atoms
        self.cutoff_dist = cutoff_dist
        self.device = device

        self.embedding_s = nn.Embedding(num_unique_atoms, num_features)
        self.embedding_v = nn.Embedding(3, num_features, sparse=True)

        self.message_1 = MessageBlock() #INSERT ARGUEMENTS
        self.message_2 = MessageBlock() #INSERT ARGUEMENTS
        self.message_3 = MessageBlock() #INSERT ARGUEMENTS

        self.update_1 = UpdateBlock()
        self.update_2 = UpdateBlock()
        self.update_3 = UpdateBlock()

        self.output = nn.Sequential(
            nn.Linear(num_features, num_features // 2),
            nn.SiLU(),
            nn.Linear(num_features // 2, 1),
        )

    def forward(
        self,
        atoms: torch.LongTensor,
        atom_positions: torch.FloatTensor,
        graph_indexes: torch.LongTensor,
    ) -> torch.FloatTensor:
        """
        Forward pass of PaiNN. Includes the readout network highlighted in blue
        in Figure 2 in (Schütt et al., 2021) with normal linear layers which is
        used for predicting properties as sums of atomic contributions. The
        post-processing and final sum is perfomed with
        src.models.AtomwisePostProcessing.

        Args:
            atoms: torch.LongTensor of size [num_nodes] with atom type of each
                node in the graph.
            atom_positions: torch.FloatTensor of size [num_nodes, 3] with
                euclidean coordinates of each node / atom.
            graph_indexes: torch.LongTensor of size [num_nodes] with the graph
                index each node belongs to.

        Returns:
            A torch.FloatTensor of size [num_nodes, num_outputs] with atomic
            contributions to the overall molecular property prediction.
        """
        # ----------------------------------------------------------------------
        # EMBEDDING
        # We initialize learnable embeddings for the atomtype.
        # The directions v_i are embedded by a null vector.

        s = self.embedding_s(atoms)
        v = self.embedding_v(atoms)

        # ----------------------------------------------------------------------
        # LOCAL NEIGHBORHOOD
        # We create edges by the relative position of nodes from a specified
        # cutoff within the same molecule (local interactions)

        edge_indexes, edge_vector, edge_distance = LocalEdges(atom_positions,
                                                              graph_indexes,
                                                              self.cutoff_dist)

        # ----------------------------------------------------------------------
        # RADIAL BASIS

        edge_rbf = RadialBasis(edge_distance,
                               num_rbf_features,
                               cutoff_dist)

        # MESSAGE AND UPDATE

        # Message and update 1
        ds, dv = self.message_1()
        s = s + ds
        v = v + dv

        ds, dv = self.update_1()
        s = s + ds
        v = v + dv

        # Message and update 2
        ds, dv = self.message_2()
        s = s + ds
        v = v + dv

        ds, dv = self.update_2()
        s = s + ds
        v = v + dv

        # Message and update 3
        ds, dv = self.message_3()
        s = s + ds
        v = v + dv

        ds, dv = self.update_3()
        s = s + ds
        v = v + dv

        # ----------------------------------------------------------------------
        # ATOMIC CONTRIBUTIONS

        atomic_contributions = self.output(s)

        # ----------------------------------------------------------------------

        # Final output
        return atomic_contributions

## Hyperparameters

In [5]:
def cli(args: list = []):
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', default=0)

    # Data
    parser.add_argument('--target', default=7, type=int) # 7 => Internal energy at 0K
    parser.add_argument('--data_dir', default='data/', type=str)
    parser.add_argument('--batch_size_train', default=100, type=int)
    parser.add_argument('--batch_size_inference', default=1000, type=int)
    parser.add_argument('--num_workers', default=0, type=int)
    parser.add_argument('--splits', nargs=3, default=[110000, 10000, 10831], type=int) # [num_train, num_val, num_test]
    parser.add_argument('--subset_size', default=None, type=int)

    # Model
    parser.add_argument('--num_message_passing_layers', default=3, type=int)
    parser.add_argument('--num_features', default=128, type=int)
    parser.add_argument('--num_outputs', default=1, type=int)
    parser.add_argument('--num_rbf_features', default=20, type=int)
    parser.add_argument('--num_unique_atoms', default=100, type=int)
    parser.add_argument('--cutoff_dist', default=5.0, type=float)

    # Training
    parser.add_argument('--lr', default=5e-4, type=float)
    parser.add_argument('--weight_decay', default=0.01, type=float)
    parser.add_argument('--num_epochs', default=1000, type=int)

    args = parser.parse_args(args=args)
    return args

## Training and testing

In [6]:
# Specify non-default arguments in this list.
args = []
args = cli(args)
seed_everything(args.seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load and prepare data from the QM9 data set.
dm = QM9DataModule(
    target=args.target,
    data_dir=args.data_dir,
    batch_size_train=args.batch_size_train,
    batch_size_inference=args.batch_size_inference,
    num_workers=args.num_workers,
    splits=args.splits,
    seed=args.seed,
    subset_size=args.subset_size,
)
dm.prepare_data()
dm.setup()

# Calculate target statistics.
y_mean, y_std, atom_refs = dm.get_target_stats(
    remove_atom_refs=True, divide_by_atoms=True
)

INFO:lightning_fabric.utilities.seed:Seed set to 0
Downloading https://data.pyg.org/datasets/qm9_v3.zip
Extracting data/raw/qm9_v3.zip
Processing...
Using a pre-processed version of the dataset. Please install 'rdkit' to alternatively process the raw data.
Done!


In [None]:

# Initialize the model.
painn = PaiNN(
    num_message_passing_layers=args.num_message_passing_layers,     # 3
    num_features=args.num_features,                                 # 128
    num_outputs=args.num_outputs,                                   # 1
    num_rbf_features=args.num_rbf_features,
    num_unique_atoms=args.num_unique_atoms,
    cutoff_dist=args.cutoff_dist,                                   # 5
)

post_processing = AtomwisePostProcessing(
    args.num_outputs, y_mean, y_std, atom_refs
)

painn.to(device)
post_processing.to(device)

# Define optimizer.
optimizer = torch.optim.AdamW(
    painn.parameters(),
    lr=args.lr,
    weight_decay=args.weight_decay,
)

# Train the model.
painn.train()
pbar = trange(args.num_epochs)
for epoch in pbar:

    loss_epoch = 0.
    for batch in dm.train_dataloader():
        batch = batch.to(device)

        atomic_contributions = painn(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch
        )
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )
        loss_step = F.mse_loss(preds, batch.y, reduction='sum')

        loss = loss_step / len(batch.y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_epoch += loss_step.detach().item()
    loss_epoch /= len(dm.data_train)
    pbar.set_postfix_str(f'Train loss: {loss_epoch:.3e}')

mae = 0
painn.eval()
with torch.no_grad():
    for batch in dm.test_dataloader():
        batch = batch.to(device)

        atomic_contributions = painn(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch,
        )
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )
        mae += F.l1_loss(preds, batch.y, reduction='sum')

mae /= len(dm.data_test)
unit_conversion = dm.unit_conversion[args.target]
print(f'Test MAE: {unit_conversion(mae):.3f}')

INFO:lightning_fabric.utilities.seed:Seed set to 0
  0%|          | 0/1000 [00:00<?, ?it/s]


TypeError: MessagePassingLayer.forward() missing 1 required positional argument: 'edge_index'

In [None]:
# Check sizes of train, validation, and test splits:
print("Training set size:", len(dm.data_train))
print("Validation set size:", len(dm.data_val))
print("Test set size:", len(dm.data_test))

# View the first sample in the training dataset:
sample = dm.data_train[0]
print("Sample features:", sample)

# Access individual attributes of the sample which we will use:
print("Atom type (z):", sample.z)                       # Atom type for each node in the graph
print("Atom position (pos):", sample.pos)               # Atom position for each node in the graph
print("Edge indices (edge_index):", sample.edge_index)  # Connectivity info between atoms
print("Target properties (y):", sample.y)               # Target property (energy)

# Print the mean and standard deviation for the target property
print("Target mean:", y_mean)
print("Target standard deviation:", y_std)

# Print atom reference values (standardized contributions of individual atoms to internal energy)
#print("Atom reference values:", atom_refs)

Training set size: 110000
Validation set size: 10000
Test set size: 10831
Sample features: Data(x=[13, 11], edge_index=[2, 26], edge_attr=[26, 4], y=[1, 1], pos=[13, 3], idx=[1], name='gdb_2329', z=[13])
Atom type (z): tensor([8, 6, 6, 6, 8, 6, 8, 1, 1, 1, 1, 1, 1])
Atom position (pos): tensor([[-0.4970,  1.2608, -0.4083],
        [-0.2214, -0.0731, -0.1197],
        [-0.3092, -0.6333,  1.2750],
        [-1.3466, -1.0139,  0.2482],
        [-2.6112, -0.4155,  0.4511],
        [-3.2778, -0.0487, -0.6749],
        [-4.3477,  0.4762, -0.6321],
        [-1.0897,  1.6111,  0.2655],
        [ 0.5573, -0.4431, -0.7776],
        [ 0.4061, -1.3830,  1.5939],
        [-0.6500,  0.0318,  2.0625],
        [-1.3897, -2.0286, -0.1362],
        [-2.7250, -0.2967, -1.5982]])
Edge indices (edge_index): tensor([[ 0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3,  4,  4,  5,  5,
          5,  6,  7,  8,  9, 10, 11, 12],
        [ 1,  7,  0,  2,  3,  8,  1,  3,  9, 10,  1,  2,  4, 11,  3,  5,  4,  6,