# 02456 Molecular Property Prediction

Basic example of how to train the PaiNN model to predict the QM9 property
"internal energy at 0K". This property (and the majority of the other QM9
properties) is computed as a sum of atomic contributions.

In [2]:
%%capture

!pip install pytorch_lightning
!pip install torch_geometric

import torch
import argparse
import torch.nn.functional as F
import numpy as np
import pytorch_lightning as pl
from tqdm import trange
from pytorch_lightning import seed_everything
from torch_geometric.data import Data
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from typing import Optional, List, Union, Tuple
from torch_geometric.transforms import BaseTransform

## QM9 Datamodule

In [3]:
class GetTarget(BaseTransform):

    def __init__(self, target: Optional[int] = None) -> None:
        self.target = [target]


    def forward(self, data: Data) -> Data:
        if self.target is not None:
            data.y = data.y[:, self.target]
        return data


class QM9DataModule(pl.LightningDataModule):

    target_types = ['atomwise' for _ in range(19)]
    target_types[0] = 'dipole_moment'
    target_types[5] = 'electronic_spatial_extent'

    # Specify unit conversions (eV to meV).
    unit_conversion = {
        i: (lambda t: 1000*t) if i not in [0, 1, 5, 11, 16, 17, 18]
        else (lambda t: t)
        for i in range(19)
    }


    def __init__(
        self,
        target: int = 7,
        data_dir: str = 'data/',
        batch_size_train: int = 100,
        batch_size_inference: int = 1000,
        num_workers: int = 0,
        splits: Union[List[int], List[float]] = [110000, 10000, 10831],
        seed: int = 0,
        subset_size: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.target = target
        self.data_dir = data_dir
        self.batch_size_train = batch_size_train
        self.batch_size_inference = batch_size_inference
        self.num_workers = num_workers
        self.splits = splits
        self.seed = seed
        self.subset_size = subset_size

        self.data_train = None
        self.data_val = None
        self.data_test = None


    def prepare_data(self) -> None:
        # Download data
        QM9(root=self.data_dir)


    def setup(self, stage: Optional[str] = None) -> None:
        dataset = QM9(root=self.data_dir, transform=GetTarget(self.target))

        # Shuffle dataset
        rng = np.random.default_rng(seed=self.seed)
        dataset = dataset[rng.permutation(len(dataset))]

        # Subset dataset
        if self.subset_size is not None:
            dataset = dataset[:self.subset_size]

        # Split dataset
        if all([type(split) == int for split in self.splits]):
            split_sizes = self.splits
        elif all([type(split) == float for split in self.splits]):
            split_sizes = [int(len(dataset) * prop) for prop in self.splits]

        split_idx = np.cumsum(split_sizes)
        self.data_train = dataset[:split_idx[0]]
        self.data_val = dataset[split_idx[0]:split_idx[1]]
        self.data_test = dataset[split_idx[1]:]


    def get_target_stats(
        self,
        remove_atom_refs: bool = True,
        divide_by_atoms: bool = True
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        atom_refs = self.data_train.atomref(self.target)                        # Atom reference energy

        ys = list()
        for batch in self.train_dataloader(shuffle=False):
            y = batch.y.clone()
            if remove_atom_refs and atom_refs is not None:
                y.index_add_(
                    dim=0, index=batch.batch, source=-atom_refs[batch.z]
                )
            if divide_by_atoms:                                                 # Normalize internal energy by the number of atoms
                _, num_atoms  = torch.unique(batch.batch, return_counts=True)
                y = y / num_atoms.unsqueeze(-1)
            ys.append(y)

        y = torch.cat(ys, dim=0)
        return y.mean(), y.std(), atom_refs


    def train_dataloader(self, shuffle: bool = True) -> DataLoader:
        return DataLoader(
            self.data_train,
            batch_size=self.batch_size_train,
            num_workers=self.num_workers,
            shuffle=shuffle,
            pin_memory=True,
        )


    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_val,
            batch_size=self.batch_size_inference,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )


    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_test,
            batch_size=self.batch_size_inference,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )

## Post-processing module

In [None]:
class AtomwisePostProcessing(nn.Module):
    """
    Post-processing for (QM9) properties that are predicted as sums of atomic
    contributions.
    """
    def __init__(
        self,
        num_outputs: int,
        mean: torch.FloatTensor,
        std: torch.FloatTensor,
        atom_refs: torch.FloatTensor,
    ) -> None:
        """
        Args:
            num_outputs: Integer with the number of model outputs. In most
                cases 1.
            mean: torch.FloatTensor with mean value to shift atomwise
                contributions by.
            std: torch.FloatTensor with standard deviation to scale atomwise
                contributions by.
            atom_refs: torch.FloatTensor of size [num_atom_types, 1] with
                atomic reference values.
        """
        super().__init__()
        self.num_outputs = num_outputs
        self.register_buffer('scale', std)
        self.register_buffer('shift', mean)
        self.atom_refs = nn.Embedding.from_pretrained(atom_refs, freeze=True)


    def forward(
        self,
        atomic_contributions: torch.FloatTensor,
        atoms: torch.LongTensor,
        graph_indexes: torch.LongTensor,
    ) -> torch.FloatTensor:
        """
        Atomwise post-processing operations and atomic sum.

        Args:
            atomic_contributions: torch.FloatTensor of size [num_nodes,
                num_outputs] with each node's contribution to the overall graph
                prediction, i.e., each atom's contribution to the overall
                molecular property prediction.
            atoms: torch.LongTensor of size [num_nodes] with atom type of each
                node in the graph.
            graph_indexes: torch.LongTensor of size [num_nodes] with the graph
                index each node belongs to.

        Returns:
            A torch.FLoatTensor of size [num_graphs, num_outputs] with
            predictions for each graph (molecule).
        """
        num_graphs = torch.unique(graph_indexes).shape[0]

        atomic_contributions = atomic_contributions*self.scale + self.shift
        atomic_contributions = atomic_contributions + self.atom_refs(atoms)

        # Sum contributions for each graph
        output_per_graph = torch.zeros(
            (num_graphs, self.num_outputs),
            device=atomic_contributions.device,
        )
        output_per_graph.index_add_(
            dim=0,
            index=graph_indexes,
            source=atomic_contributions,
        )

        return output_per_graph

## PaiNN

In [None]:
def atom_distances(atom_positions,
                   graph_indexes,
                   cutoff_dist):
    """
    Computes pairwise distances between atoms within each molecule in the batch,
    and filters out pairs beyond the cutoff distance.

    Args:
        atom_positions: torch.FloatTensor of size [num_nodes, 3] with
                euclidean coordinates of each node / atom.
        graph_indexes: torch.LongTensor of size [num_nodes] with the graph
                index each node belongs to.
        cutoff_dist: Euclidean distance threshold for determining whether
                two nodes (atoms) are neighbours.
    """
    edge_index = []
    distances = []

    # Iterate over unique molecules in the batch
    for mol_id in torch.unique(graph_indexes):

        # Mask to filter atoms belonging to the current molecule
        atom_mask = (graph_indexes == mol_id)

        # Extract positions of atoms in this molecule
        pos = atom_positions[atom_mask]

        # Compute pairwise distance matrix for atoms in this molecule
        # torch.cdist computes batched the p-norm distance between each pair of the two collections of row vectors.
        dist_matrix = torch.cdist(pos, pos)

        # Find indices where distance is below cutoff (ignoring self-loops)
        src, dst = torch.where((dist_matrix < cutoff_dist) & (dist_matrix > 0))

        # Store the edges and distances
        edge_index.append(torch.stack([src, dst]))
        distances.append(dist_matrix[src, dst])

    # Concatenate all edges and distances across molecules
    edge_index = torch.cat(edge_index, dim=1) if edge_index else torch.empty(2, 0, dtype=torch.long)
    distances = torch.cat(distances) if distances else torch.empty(0)

    return edge_index, distances


def atom_distiance(atom_positions,
                   graph_indexes,
                   cutoff_dist):

  j, i = graph_indexes
  distance_vec = atom_positions[j] - atom_positions[i]
  edge_distance = distance_vec.norm()

  return edge_distance






def Message(sj, vj, r_ij, num_features):

  # Linear combinations of the atom embeddings.
  s_proj = nn.Sequential(nn.Linear(num_features, num_features),
                     ScaledSiLU(),
                     nn.Linear(num_features, num_features * 3))

  # Linear combinations of the radial basis funtions.
  rbf_proj = nn.Linear(num_rbf_features, num_features * 3)

  phi_W_product = sj * # Linear combinations of RBF

  split_1, split_2, split_3 = phi_W_product #splitted

  vj = sum(vj * split_1) + sum(r_ij * split_3)

  return ds, dv

def Update(sj, vj):

  self.hidden_channels = hidden_channels

  linear_link_U = nn.Linear(hidden_channels, hidden_channels * 2, bias=False)

  linear_link_V = nn.Linear(hidden_channels, hidden_channels * 2, bias=False)

  linear_link_W = nn.Sequential(nn.Linear(hidden_channels * 2, hidden_channels),
                                ScaledSiLU(),
                                nn.Linear(hidden_channels, hidden_channels * 3),
                                )

  vj_dot = (Uvj * Vvj).sum(dim=1) * split

IndentationError: unexpected indent (<ipython-input-2-27c6b005b73d>, line 48)

In [None]:
class MessageBlock(nn.Module):
  def __init__(self, *args, **kwargs)
    super().__init__()

    self.linear_1 = nn.Linear(...)


  def forward(self, *args, **kwargs):
    ... = self.linear_1(...)

    return ...

class PaiNN(nn.Module):
    """
    Polarizable Atom Interaction Neural Network with PyTorch.
    """
    def __init__(
        self,
        num_message_passing_layers: int = 3,
        num_features: int = 128,
        num_outputs: int = 1,
        num_rbf_features: int = 20,
        num_unique_atoms: int = 100,
        cutoff_dist: float = 5.0,
    ) -> None:
        """
        Args:
            num_message_passing_layers: Number of message passing layers in
                the PaiNN model.
            num_features: Size of the node embeddings (scalar features) and
                vector features.
            num_outputs: Number of model outputs. In most cases 1.
            num_rbf_features: Number of radial basis functions to represent
                distances.
            num_unique_atoms: Number of unique atoms in the data that we want
                to learn embeddings for.
            cutoff_dist: Euclidean distance threshold for determining whether
                two nodes (atoms) are neighbours.
        """
        super().__init__()

        self.num_message_passing_layers = num_message_passing_layers
        self.num_features = num_features
        self.num_outputs = num_outputs
        self.num_rbf_features = num_rbf_features
        self.num_unique_atoms = num_unique_atoms
        self.cutoff_dist = cutoff_dist
        self.device = device

        self.embedding_s = nn.Embedding(num_unique_atoms, num_features)
        self.embedding_v = nn.Embedding(3, num_features, sparse=True)

        self.message_1 = MessageBlock(...)
        self.message_2 = MessageBlock(..)
        self.update = Update()
        self.output = Output()

    def forward(
        self,
        atoms: torch.LongTensor,
        atom_positions: torch.FloatTensor,
        graph_indexes: torch.LongTensor,
    ) -> torch.FloatTensor:
        """
        Forward pass of PaiNN. Includes the readout network highlighted in blue
        in Figure 2 in (Schütt et al., 2021) with normal linear layers which is
        used for predicting properties as sums of atomic contributions. The
        post-processing and final sum is perfomed with
        src.models.AtomwisePostProcessing.

        Args:
            atoms: torch.LongTensor of size [num_nodes] with atom type of each
                node in the graph.
            atom_positions: torch.FloatTensor of size [num_nodes, 3] with
                euclidean coordinates of each node / atom.
            graph_indexes: torch.LongTensor of size [num_nodes] with the graph
                index each node belongs to.

        Returns:
            A torch.FloatTensor of size [num_nodes, num_outputs] with atomic
            contributions to the overall molecular property prediction.
        """
        # ----------------------------------------------------------------------
        # Initialization

        s = self.embedding_s(atoms)
        v = self.embedding_v(atoms)

        # ----------------------------------------------------------------------

        # Connections???

        # ----------------------------------------------------------------------
        # Message and update

        for message, update in zip(self.message_blocks, self.update_blocks):
            messa
        for i in range(3):

          # Message
          ds, dv = self.message(s, v, edge_index, edge_rbf, edge_vector)
          s = s + ds
          v = v + dv

          # Update
          ds, dv = self.update(s, v)
          s = s + ds
          v = v + dv

        # ----------------------------------------------------------------------

        # Output step

        # ----------------------------------------------------------------------


        # Final output
        return output

## Hyperparameters

In [4]:
def cli(args: list = []):
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', default=0)

    # Data
    parser.add_argument('--target', default=7, type=int) # 7 => Internal energy at 0K
    parser.add_argument('--data_dir', default='data/', type=str)
    parser.add_argument('--batch_size_train', default=100, type=int)
    parser.add_argument('--batch_size_inference', default=1000, type=int)
    parser.add_argument('--num_workers', default=0, type=int)
    parser.add_argument('--splits', nargs=3, default=[110000, 10000, 10831], type=int) # [num_train, num_val, num_test]
    parser.add_argument('--subset_size', default=None, type=int)

    # Model
    parser.add_argument('--num_message_passing_layers', default=3, type=int)
    parser.add_argument('--num_features', default=128, type=int)
    parser.add_argument('--num_outputs', default=1, type=int)
    parser.add_argument('--num_rbf_features', default=20, type=int)
    parser.add_argument('--num_unique_atoms', default=100, type=int)
    parser.add_argument('--cutoff_dist', default=5.0, type=float)

    # Training
    parser.add_argument('--lr', default=5e-4, type=float)
    parser.add_argument('--weight_decay', default=0.01, type=float)
    parser.add_argument('--num_epochs', default=1000, type=int)

    args = parser.parse_args(args=args)
    return args

## Training and testing

In [5]:
# Specify non-default arguments in this list.
args = []
args = cli(args)
seed_everything(args.seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load and prepare data from the QM9 data set.
dm = QM9DataModule(
    target=args.target,
    data_dir=args.data_dir,
    batch_size_train=args.batch_size_train,
    batch_size_inference=args.batch_size_inference,
    num_workers=args.num_workers,
    splits=args.splits,
    seed=args.seed,
    subset_size=args.subset_size,
)
dm.prepare_data()
dm.setup()

# Calculate target statistics.
y_mean, y_std, atom_refs = dm.get_target_stats(
    remove_atom_refs=True, divide_by_atoms=True
)

INFO:lightning_fabric.utilities.seed:Seed set to 0
Downloading https://data.pyg.org/datasets/qm9_v3.zip
Extracting data/raw/qm9_v3.zip
Processing...
Using a pre-processed version of the dataset. Please install 'rdkit' to alternatively process the raw data.
Done!


In [None]:

# Initialize the model.
painn = PaiNN(
    num_message_passing_layers=args.num_message_passing_layers,     # 3
    num_features=args.num_features,                                 # 128
    num_outputs=args.num_outputs,                                   # 1
    num_rbf_features=args.num_rbf_features,
    num_unique_atoms=args.num_unique_atoms,
    cutoff_dist=args.cutoff_dist,                                   # 5
)

post_processing = AtomwisePostProcessing(
    args.num_outputs, y_mean, y_std, atom_refs
)

painn.to(device)
post_processing.to(device)

# Define optimizer.
optimizer = torch.optim.AdamW(
    painn.parameters(),
    lr=args.lr,
    weight_decay=args.weight_decay,
)

# Train the model.
painn.train()
pbar = trange(args.num_epochs)
for epoch in pbar:

    loss_epoch = 0.
    for batch in dm.train_dataloader():
        batch = batch.to(device)

        atomic_contributions = painn(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch
        )
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )
        loss_step = F.mse_loss(preds, batch.y, reduction='sum')

        loss = loss_step / len(batch.y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_epoch += loss_step.detach().item()
    loss_epoch /= len(dm.data_train)
    pbar.set_postfix_str(f'Train loss: {loss_epoch:.3e}')

mae = 0
painn.eval()
with torch.no_grad():
    for batch in dm.test_dataloader():
        batch = batch.to(device)

        atomic_contributions = painn(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch,
        )
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )
        mae += F.l1_loss(preds, batch.y, reduction='sum')

mae /= len(dm.data_test)
unit_conversion = dm.unit_conversion[args.target]
print(f'Test MAE: {unit_conversion(mae):.3f}')

INFO:lightning_fabric.utilities.seed:Seed set to 0
  0%|          | 0/1000 [00:00<?, ?it/s]


TypeError: MessagePassingLayer.forward() missing 1 required positional argument: 'edge_index'

In [6]:
# Check sizes of train, validation, and test splits:
print("Training set size:", len(dm.data_train))
print("Validation set size:", len(dm.data_val))
print("Test set size:", len(dm.data_test))

# View the first sample in the training dataset:
sample = dm.data_train[0]
print("Sample features:", sample)

# Access individual attributes of the sample which we will use:
print("Atom type (z):", sample.z)                       # Atom type for each node in the graph
print("Atom position (pos):", sample.pos)               # Atom position for each node in the graph
print("Edge indices (edge_index):", sample.edge_index)  # Connectivity info between atoms
print("Target properties (y):", sample.y)               # Target property (energy)

# Print the mean and standard deviation for the target property
print("Target mean:", y_mean)
print("Target standard deviation:", y_std)

# Print atom reference values (standardized contributions of individual atoms to internal energy)
#print("Atom reference values:", atom_refs)

Training set size: 110000
Validation set size: 10000
Test set size: 10831
Sample features: Data(x=[13, 11], edge_index=[2, 26], edge_attr=[26, 4], y=[1, 1], pos=[13, 3], idx=[1], name='gdb_2329', z=[13])
Atom type (z): tensor([8, 6, 6, 6, 8, 6, 8, 1, 1, 1, 1, 1, 1])
Atom position (pos): tensor([[-0.4970,  1.2608, -0.4083],
        [-0.2214, -0.0731, -0.1197],
        [-0.3092, -0.6333,  1.2750],
        [-1.3466, -1.0139,  0.2482],
        [-2.6112, -0.4155,  0.4511],
        [-3.2778, -0.0487, -0.6749],
        [-4.3477,  0.4762, -0.6321],
        [-1.0897,  1.6111,  0.2655],
        [ 0.5573, -0.4431, -0.7776],
        [ 0.4061, -1.3830,  1.5939],
        [-0.6500,  0.0318,  2.0625],
        [-1.3897, -2.0286, -0.1362],
        [-2.7250, -0.2967, -1.5982]])
Edge indices (edge_index): tensor([[ 0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3,  4,  4,  5,  5,
          5,  6,  7,  8,  9, 10, 11, 12],
        [ 1,  7,  0,  2,  3,  8,  1,  3,  9, 10,  1,  2,  4, 11,  3,  5,  4,  6,

In [77]:
sample = dm.data_train[0:2]


def atom_vector(atom_positions,
                graph_indexes,
                cutoff_dist):

  distance_vec = torch.zeros(graph_indexes.size()[1], 3)
  k = 0
  #print(distance_vec)
  print(distance_vec[k])

  for i in graph_indexes[0].tolist():
    for j in graph_indexes[1].tolist():
      r_ij = atom_positions[j] - atom_positions[i]
      #print(r_ij)
      distance_vec[k] = r_ij
      k = k + 1

  return distance_vec
#edge_distance[i, j, 3] = distance_vec.norm()


#z = torch.zeros((10,1))
#torch.cat((your_tensor,z),1)

#return edge_distance


atom_vector(sample.pos,
            sample.edge_index,
            2)

#compute_pairwise_distances(sample.pos,
                           #torch.ones(sample.z.size()),
                           #2)

tensor([0., 0., 0.])


IndexError: index 64 is out of bounds for dimension 0 with size 64

In [73]:
distance_vec = torch.zeros(10, 3)

distance_vec[1] = torch.ones(3)

distance_vec

tensor([[0., 0., 0.],
        [1., 1., 1.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [59]:
sample = dm.data_train[0:2]
for i in sample.edge_index.tolist():
  print(i)

sample.edge_index[0].tolist()

[0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 5, 5, 5, 6, 7, 8, 9, 10, 11, 12, 13, 13, 13, 13, 14, 14, 14, 14, 15, 15, 15, 15, 16, 16, 17, 17, 17, 17, 18, 18, 18, 18, 19, 19, 19, 20, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
[1, 7, 0, 2, 3, 8, 1, 3, 9, 10, 1, 2, 4, 11, 3, 5, 4, 6, 12, 5, 0, 1, 2, 2, 3, 5, 14, 22, 23, 24, 13, 15, 25, 26, 14, 16, 17, 27, 15, 17, 15, 16, 18, 28, 17, 19, 20, 29, 18, 30, 31, 18, 21, 20, 13, 13, 13, 14, 14, 15, 17, 18, 19, 19]


[0,
 0,
 1,
 1,
 1,
 1,
 2,
 2,
 2,
 2,
 3,
 3,
 3,
 3,
 4,
 4,
 5,
 5,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 13,
 13,
 13,
 14,
 14,
 14,
 14,
 15,
 15,
 15,
 15,
 16,
 16,
 17,
 17,
 17,
 17,
 18,
 18,
 18,
 18,
 19,
 19,
 19,
 20,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31]