# 02456 Molecular Property Prediction

Basic example of how to train the PaiNN model to predict the QM9 property
"internal energy at 0K". This property (and the majority of the other QM9
properties) is computed as a sum of atomic contributions.

In [1]:
import torch
from torch_scatter import scatter

x = torch.tensor([1, 2, 3, 4])
index = torch.tensor([0, 0, 1, 1])  # Group indices
result = scatter(x, index, dim=0, reduce="sum")
print(result)  # Outp


tensor([3, 7])


In [2]:
import torch
import argparse
from tqdm import trange
import torch.nn.functional as F
from pytorch_lightning import seed_everything
from torch_scatter import scatter


## QM9 Datamodule

In [3]:
import numpy as np
import pytorch_lightning as pl
from torch_geometric.data import Data
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from typing import Optional, List, Union, Tuple
from torch_geometric.transforms import BaseTransform


class GetTarget(BaseTransform):
    def __init__(self, target: Optional[int] = None) -> None:
        self.target = [target]


    def forward(self, data: Data) -> Data:
        if self.target is not None:
            data.y = data.y[:, self.target]
        return data


class QM9DataModule(pl.LightningDataModule):

    target_types = ['atomwise' for _ in range(19)]
    target_types[0] = 'dipole_moment'
    target_types[5] = 'electronic_spatial_extent'

    # Specify unit conversions (eV to meV).
    unit_conversion = {
        i: (lambda t: 1000*t) if i not in [0, 1, 5, 11, 16, 17, 18]
        else (lambda t: t)
        for i in range(19)
    }

    def __init__(
        self,
        target: int = 7,
        data_dir: str = 'data/',
        batch_size_train: int = 100,
        batch_size_inference: int = 1000,
        num_workers: int = 0,
        splits: Union[List[int], List[float]] = [110000, 10000, 10831],
        seed: int = 0,
        subset_size: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.target = target
        self.data_dir = data_dir
        self.batch_size_train = batch_size_train
        self.batch_size_inference = batch_size_inference
        self.num_workers = num_workers
        self.splits = splits
        self.seed = seed
        self.subset_size = subset_size

        self.data_train = None
        self.data_val = None
        self.data_test = None


    def prepare_data(self) -> None:
        # Download data
        QM9(root=self.data_dir)


    def setup(self, stage: Optional[str] = None) -> None:
        dataset = QM9(root=self.data_dir, transform=GetTarget(self.target))

        # Shuffle dataset
        rng = np.random.default_rng(seed=self.seed)
        dataset = dataset[rng.permutation(len(dataset))]

        # Subset dataset
        if self.subset_size is not None:
            dataset = dataset[:self.subset_size]
        
        # Split dataset
        if all([type(split) == int for split in self.splits]):
            split_sizes = self.splits
        elif all([type(split) == float for split in self.splits]):
            split_sizes = [int(len(dataset) * prop) for prop in self.splits]

        split_idx = np.cumsum(split_sizes)
        self.data_train = dataset[:split_idx[0]]
        self.data_val = dataset[split_idx[0]:split_idx[1]]
        self.data_test = dataset[split_idx[1]:]


    def get_target_stats(
        self,
        remove_atom_refs: bool = True,
        divide_by_atoms: bool = True
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        atom_refs = self.data_train.atomref(self.target)

        ys = list()
        for batch in self.train_dataloader(shuffle=False):
            y = batch.y.clone()
            if remove_atom_refs and atom_refs is not None:
                y.index_add_(
                    dim=0, index=batch.batch, source=-atom_refs[batch.z]
                )
            if divide_by_atoms:
                _, num_atoms  = torch.unique(batch.batch, return_counts=True)
                y = y / num_atoms.unsqueeze(-1)
            ys.append(y)

        y = torch.cat(ys, dim=0)
        return y.mean(), y.std(), atom_refs


    def train_dataloader(self, shuffle: bool = True) -> DataLoader:
        return DataLoader(
            self.data_train,
            batch_size=self.batch_size_train,
            num_workers=self.num_workers,
            shuffle=shuffle,
            pin_memory=True,
        )


    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_val,
            batch_size=self.batch_size_inference,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )


    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_test,
            batch_size=self.batch_size_inference,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )

Valid values for target [0,18] where 
<br>
0: Dipole moment (μ). <br>
1: Isotropic polarizability (α). <br>
2: HOMO (highest occupied molecular orbital) energy. <br>
3: LUMO (lowest unoccupied molecular orbital) energy. <br>
4: Gap between HOMO and LUMO. <br>
... <br>
18: Internal energy (U) at 298.15K. <br>

In [4]:
dataset = QM9(root='data/', transform=GetTarget(target=7))
print(dataset[0])
print(dataset)

Data(x=[5, 11], edge_index=[2, 8], edge_attr=[8, 4], y=[1, 1], pos=[5, 3], idx=[1], name='gdb_1', z=[5])
QM9(130831)


1. x=[5, 11]
Description: This is the node feature matrix. It contains features for each node in the graph (each atom in the molecule).
The shape [5, 11] means there are 5 nodes (atoms) in this molecule, and each atom has a feature vector of length 11. These features could include various atomic properties such as atomic number, hybridization, etc.
Example: If the molecule has 5 atoms, each atom's feature could represent properties like charge, hybridization state, or other descriptors.

2. edge_index=[2, 8]
Description: This is the edge index that represents the connectivity (bonds) between atoms. It's a sparse matrix in COO (Coordinate) format, where each column represents a directed edge between two atoms.
The shape [2, 8] means there are 8 edges, and each edge is represented by two values (source, target) for each direction. The 2 indicates there are 2 rows (one for the source atom and one for the target atom), and 8 indicates there are 8 edges.
Example: For a molecule with 5 atoms, the edge_index matrix could look like this:

edge_index = [ 0 0 1 2 2 3 3 4 
               1 2 3 3 4 4 0 1 ]

3. edge_attr=[8, 4]
Description: This is the edge attribute matrix, which stores additional information about each edge (bond). The shape [8, 4] means there are 8 edges, and each edge has a feature vector of length 4.
These attributes could represent things like bond type (single, double, triple), bond length, or other chemical properties associated with the bond.
Example: For 8 edges, each with 4 attributes, this could be a matrix representing things like bond length, bond type, or bond angle.

4. y=[1, 1]
Description: This is the target value for the molecular property that is being predicted. The shape [1, 1] indicates that there is a single target property for the entire graph (molecule).
Since you applied the GetTarget(target=7) transform, this corresponds to the 7th property in the QM9 dataset (e.g., y could represent a target like the molecular dipole moment or energy, depending on the dataset).
Example: If you were predicting the dipole moment, y would store the dipole moment value for the molecule.

5. pos=[5, 3]
Description: This is the position matrix for each node (atom) in the graph. The shape [5, 3] means there are 5 atoms, and each atom has a 3D coordinate (x, y, z).
This attribute is typically used in molecular graph representations where the spatial coordinates of atoms are important for geometric properties, especially for 3D molecule-related tasks.
Example: For each atom in the molecule, pos could represent its 3D position in space (e.g., pos[0] could be [x_0, y_0, z_0] for atom 0).

6. idx=[1]
Description: This is a unique identifier for the molecule in the dataset. The shape [1] indicates there is one molecule, and the value 1 is a placeholder index (could represent a unique molecule ID like gdb_1). This helps in tracking different molecules in the dataset.

7. name='gdb_1'
Description: This is the name or ID of the molecule in the dataset. In this case, 'gdb_1' indicates that this molecule is the first molecule in the dataset, corresponding to the name gdb_1 in the QM9 dataset. Example: This could be used for logging or identifying which molecule is being processed.

8. z=[5]
Description: This represents the atomic numbers of the atoms in the graph. The shape [5] means there are 5 atoms, and the list [5] corresponds to the atomic numbers of the atoms (for example, 5 could represent boron if all atoms are of type boron). Example: If this is a molecule of boron (B), the array could look like [5, 5, 5, 5, 5] for each boron atom.






Download and use the data. 'train_loader' is batched with 100 observations.

In [5]:
data_module = QM9DataModule(target=7)
data_module.prepare_data()
data_module.setup()

train_loader = data_module.train_dataloader()
for batch in train_loader:
    print(batch)
    break

DataBatch(x=[1840, 11], edge_index=[2, 3804], edge_attr=[3804, 4], y=[100, 1], pos=[1840, 3], idx=[100], name=[100], z=[1840], batch=[1840], ptr=[101])



### Scalar message function

#### 1. Feature Vector 

\begin{align*}
    x \in \mathbb{R}^{F\times1}
\end{align*}

* E.g., 

\begin{align*}
    x_A = [ 0.5, 1.2 ], x_B = [ 0.8, 0.9 ]
\end{align*}

#### 2. Coordinates

\begin{align*}
    \vec{r} \in \mathbb{R}^{ 1 \times 3 }
\end{align*}

* E.g., 

\begin{align*}
    \vec{r}_A = [ 1.0, 0.0 0.0 ], \ \vec{r}_B = [ 0.0, 1.0 0.0 ]
\end{align*}

#### 4. Vectorial features:

\begin{align*}
    \vec{x} \in \mathbb{R}^{ F \times 3 }
\end{align*}

* E.g., 

\begin{align*}
    \vec{x}_A = 
    \begin{bmatrix}
    1.0 & 0.0 & 0.0 \\
    0.0 & 1.0 & 0.0
    \end{bmatrix}
    , \ 
    \vec{x}_ = 
    \begin{bmatrix}
    0.5 & 0.0 & 0.5 \\
    0.0 & 0.5 & 1.0
    \end{bmatrix}
\end{align*}







## Post-processing module

The process is for predicting the molecuel after processed by a PaiNN model object.

In [6]:
import torch.nn as nn

class AtomwisePostProcessing(nn.Module):
    """
    Post-processing for (QM9) properties that are predicted as sums of atomic
    contributions.
    """
    def __init__(
        self,
        num_outputs: int,
        mean: torch.FloatTensor,
        std: torch.FloatTensor,
        atom_refs: torch.FloatTensor,
    ) -> None:
        """
        Args:
            num_outputs: Integer with the number of model outputs. In most
                cases 1.
            mean: torch.FloatTensor with mean value to shift atomwise
                contributions by.
            std: torch.FloatTensor with standard deviation to scale atomwise
                contributions by.
            atom_refs: torch.FloatTensor of size [num_atom_types, 1] with
                atomic reference values.
        """
        super().__init__()
        self.num_outputs = num_outputs
        self.register_buffer('scale', std)
        self.register_buffer('shift', mean)
        self.atom_refs = nn.Embedding.from_pretrained(atom_refs, freeze=True)


    def forward(
        self,
        atomic_contributions: torch.FloatTensor,
        atoms: torch.LongTensor,
        graph_indexes: torch.LongTensor,
    ) -> torch.FloatTensor:
        """
        Atomwise post-processing operations and atomic sum.

        Args:
            atomic_contributions: torch.FloatTensor of size [num_nodes,
                num_outputs] with each node's contribution to the overall graph
                prediction, i.e., each atom's contribution to the overall
                molecular property prediction.
            atoms: torch.LongTensor of size [num_nodes] with atom type of each
                node in the graph.
            graph_indexes: torch.LongTensor of size [num_nodes] with the graph 
                index each node belongs to.

        Returns:
            A torch.FLoatTensor of size [num_graphs, num_outputs] with
            predictions for each graph (molecule).
        """
        num_graphs = torch.unique(graph_indexes).shape[0]

        atomic_contributions = atomic_contributions*self.scale + self.shift
        atomic_contributions = atomic_contributions + self.atom_refs(atoms)

        # Sum contributions for each graph
        output_per_graph = torch.zeros(
            (num_graphs, self.num_outputs),
            device=atomic_contributions.device,
        )
        output_per_graph.index_add_(
            dim=0,
            index=graph_indexes,
            source=atomic_contributions,
        )

        return output_per_graph

In [7]:
import torch

# Example data
num_atom_types = 5
num_outputs = 1
mean = torch.tensor([0.5])  # Example mean
std = torch.tensor([2.0])   # Example standard deviation
atom_refs = torch.tensor([[0.1], [0.2], [0.3], [0.4], [0.5]])  # Reference values

# Initialize the post-processing module
post_processor = AtomwisePostProcessing(
    num_outputs=num_outputs,
    mean=mean,
    std=std,
    atom_refs=atom_refs,
)
print(post_processor)

AtomwisePostProcessing(
  (atom_refs): Embedding(5, 1)
)


In [8]:
# Atomic contributions for nodes
atomic_contributions = torch.tensor([
    [0.1],  # Atom 0
    [0.2],  # Atom 1
    [0.3],  # Atom 2
    [0.4],  # Atom 3
])
print(atomic_contributions)

# Atom types (corresponding to atom_refs)
atoms = torch.tensor([0, 1, 2, 3])  # 4 atoms, each with a specific type
print(atoms)
# Graph indexes for each atom
graph_indexes = torch.tensor([0, 0, 1, 1])  # First two atoms in graph 0, last two in graph 1
print(graph_indexes)

output = post_processor(
    atomic_contributions=atomic_contributions,
    atoms=atoms,
    graph_indexes=graph_indexes,
)
print(output)



tensor([[0.1000],
        [0.2000],
        [0.3000],
        [0.4000]])
tensor([0, 1, 2, 3])
tensor([0, 0, 1, 1])
tensor([[1.9000],
        [3.1000]])




### 1. Compute Scala Messages

\begin{align*}
    m_{ij} \ \text{or} \ h^n  & = \phi_m (x_i, \ x_j, \ || \vec{d}_{ij} || ) = \mathsf{MLP}( [ x_i, x_j, || \vec{d}_{ij} || ]) \\
           \Rightarrow M_i & = \sum_{j \in \mathcal{N}(i) } m_{ij} x_j \cdot \vec{d}_{ij} \ \text{ aggregation }\\
           \Rightarrow x' & =  \phi_m (x_i, M_i ) \ \text{ update }  \\  
           & = x_i + M_i 
\end{align*}

* e.g. 

\begin{align*}
    m_{A} & = \phi (x_A, \ x_B, || \vec{d}_{AB} || ) \\
           & = \mathsf{MLP}([0.5, 1.2, 0.8, 0.9, \sqrt{2} ]) \\
\end{align*}

##### Note. Displacement magnitude

\begin{align*}
    || \vec{d}_{AB} || & = \vec{r}_{A} - \vec{r}_{B} \\
    & = \sqrt{ (-1)^2 + (1)^2 + (0)^2 } || \\
    & = \sqrt{2}
\end{align*}


We'll start with the following setting for the MLP, 2 laye network, input size 5, hidden size 4 and ouput size 2.

1.3 Linear layer(Linear compbination) 

* e.g., 
\begin{align*}
    h^n = w_n m_i + b_n
\end{align*}

1.4. Initialize weights and biases, typically they are initalized radomly.

e.g.,

\begin{align*}
    h^1 = [0.996,0.802,−0.168,0.009]
\end{align*}

1.5. Apply activation function SiLU

    SiLU(h^1) = ?

1.6 Apply SiLU(h^1) to next connected layer(s)
e.g.
\begin{align*}
    m_{AB} \text{ or } (h^2) = w_2 \text{}(h^1) + b_2
\end{align*}


2. Compute Vectorial Messages

Vectorial messages are just matrix version of the function above.



## PaiNN

In [9]:
class PaiNN(nn.Module):
    """
    Polarizable Atom Interaction Neural Network with PyTorch.
    """
    def __init__(
        self,
        num_message_passing_layers: int = 3,
        num_features: int = 128,
        num_outputs: int = 1,
        num_rbf_features: int = 20,
        num_unique_atoms: int = 100,
        cutoff_dist: float = 5.0,
    ) -> None:
        """
        Args:
            num_message_passing_layers: Number of message passing layers in
                the PaiNN model.
            num_features: Size of the node embeddings (scalar features) and
                vector features.
            num_outputs: Number of model outputs. In most cases 1.
            num_rbf_features: Number of radial basis functions to represent
                distances.
            num_unique_atoms: Number of unique atoms in the data that we want
                to learn embeddings for.
            cutoff_dist: Euclidean distance threshold for determining whether 
                two nodes (atoms) are neighbours.
        """
        super().__init__()
        #raise NotImplementedError

        self.num_features = num_features
        self.num_message_passing_layers = num_message_passing_layers
        self.cutoff_dist = cutoff_dist

        # Atom embedding layer
        self.atom_embedding = nn.Embedding(num_unique_atoms, num_features)
        
        # Radial basis function (RBF) transformation
        self.rbf_transform = nn.Sequential(
            nn.Linear(num_rbf_features, num_features),
            nn.ReLU(),
        )
        
        # Message passing layers
        self.message_passing_layers = nn.ModuleList([
            MessagePassingLayer(num_features, num_rbf_features)
            for _ in range(num_message_passing_layers)
        ])

        # Final readout layer
        self.readout = nn.Sequential(
            nn.Linear(num_features, 64),
            nn.SiLU(),
            nn.Linear(64, num_outputs),
        )

    def forward(
        self,
        atoms: torch.LongTensor,
        atom_positions: torch.FloatTensor,
        graph_indexes: torch.LongTensor,
    ) -> torch.FloatTensor:
        """
        Forward pass of PaiNN. Includes the readout network highlighted in blue
        in Figure 2 in (Schütt et al., 2021) with normal linear layers which is
        used for predicting properties as sums of atomic contributions. The
        post-processing and final sum is perfomed with
        src.models.AtomwisePostProcessing.

        Args:
            atoms: torch.LongTensor of size [num_nodes] with atom type of each
                node in the graph.
            atom_positions: torch.FloatTensor of size [num_nodes, 3] with
                euclidean coordinates of each node / atom.
            graph_indexes: torch.LongTensor of size [num_nodes] with the graph 
                index each node belongs to.

        Returns:
            A torch.FloatTensor of size [num_nodes, num_outputs] with atomic
            contributions to the overall molecular property prediction.
        """
        #raise NotImplementedError
        # Atom embeddings: Shape [num_nodes, num_features]
        x = self.atom_embedding(atoms)

        # Pairwise distances
        d_ij = torch.cdist(atom_positions, atom_positions)  # Shape [num_nodes, num_nodes]

        # Mask out distances beyond cutoff
        adjacency_mask = (d_ij <= self.cutoff_dist).float()

        # Apply radial basis functions to distances
        rbf_features = radial_basis_function_transform(d_ij, self.cutoff_dist)  # Shape [num_nodes, num_nodes, num_rbf_features]
        rbf_features = self.rbf_transform(rbf_features)  # Shape [num_nodes, num_nodes, num_features]

        # Message passing layers
        for layer in self.message_passing_layers:
            x = layer(x, atom_positions, rbf_features, adjacency_mask)

        # Readout: Predict per-atom properties
        atomic_predictions = self.readout(x)  # Shape [num_nodes, num_outputs]

        # Aggregate per-atom contributions for each graph
        # molecular_predictions = torch.zeros(graph_indexes.max() + 1, self.readout[-1].out_features).to(x.device)
        # molecular_predictions.index_add_(
        #     0, graph_indexes, atomic_predictions
        # )  # Aggregate contributions by graph index

        #return molecular_predictions
        return atomic_predictions




class MessagePassingLayer(nn.Module):
    """
    A single message-passing layer for PaiNN.
    """
    def __init__(self, num_features, num_rbf_features):
        super().__init__()
        self.linear = nn.Linear(num_features, num_features)
        self.vector_update = nn.Linear(num_features, num_features)
        self.scalar_update = nn.Linear(num_features, num_features)

    def forward(self, x, atom_positions, rbf_features, adjacency_mask):
        # Aggregate messages
        messages = self.aggregate_messages(x, rbf_features, adjacency_mask)

        # Update scalar features
        scalar_update = self.scalar_update(messages)
        x = x + scalar_update

        # Update vector features
        vector_update = self.vector_update(messages)
        x = x + vector_update

        return x

    def aggregate_messages(self, x, rbf_features, adjacency_mask):
        # Use adjacency mask to aggregate weighted features
        messages = adjacency_mask.unsqueeze(-1) * rbf_features
        return torch.sum(messages, dim=1)


def radial_basis_function_transform(distances, cutoff):
    """
    Compute radial basis function features for distances.
    """
    # Define RBF centers and width
    num_rbf_features = 20
    rbf_centers = torch.linspace(0, cutoff, num_rbf_features).to(distances.device)
    rbf_width = (rbf_centers[1] - rbf_centers[0])

    # Compute RBF values
    rbf_values = torch.exp(-((distances.unsqueeze(-1) - rbf_centers) ** 2) / (2 * rbf_width ** 2))
    return rbf_values

## Hyperparameters

In [10]:
def cli(args: list = []):
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', default=0)

    # Data
    parser.add_argument('--target', default=7, type=int) # 7 => Internal energy at 0K
    parser.add_argument('--data_dir', default='data/', type=str)
    parser.add_argument('--batch_size_train', default=100, type=int)
    parser.add_argument('--batch_size_inference', default=1000, type=int)
    # parser.add_argument('--batch_size_train', default=30, type=int)
    # parser.add_argument('--batch_size_inference', default=10, type=int)
    parser.add_argument('--num_workers', default=0, type=int)
    parser.add_argument('--splits', nargs=3, default=[110000, 10000, 10831], type=int) # [num_train, num_val, num_test]
    parser.add_argument('--subset_size', default=None, type=int)

    # Model
    parser.add_argument('--num_message_passing_layers', default=3, type=int)
    parser.add_argument('--num_features', default=128, type=int)
    parser.add_argument('--num_outputs', default=1, type=int)
    parser.add_argument('--num_rbf_features', default=20, type=int)
    parser.add_argument('--num_unique_atoms', default=100, type=int)
    parser.add_argument('--cutoff_dist', default=5.0, type=float)

    # Training
    parser.add_argument('--lr', default=5e-4, type=float)
    parser.add_argument('--weight_decay', default=0.01, type=float)
    #parser.add_argument('--num_epochs', default=1000, type=int)
    parser.add_argument('--num_epochs', default=10, type=int)

    args = parser.parse_args(args=args)
    return args

## Training and testing

In [None]:
args = [] # Specify non-default arguments in this list
args = cli(args)
seed_everything(args.seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

dm = QM9DataModule(
    target=args.target,
    data_dir=args.data_dir,
    batch_size_train=args.batch_size_train,
    batch_size_inference=args.batch_size_inference,
    num_workers=args.num_workers,
    splits=args.splits,
    seed=args.seed,
    subset_size=args.subset_size,
)
dm.prepare_data()
dm.setup()
y_mean, y_std, atom_refs = dm.get_target_stats(
    remove_atom_refs=True, divide_by_atoms=True
)

painn = PaiNN(
    num_message_passing_layers=args.num_message_passing_layers,
    num_features=args.num_features,
    num_outputs=args.num_outputs, 
    num_rbf_features=args.num_rbf_features,
    num_unique_atoms=args.num_unique_atoms,
    cutoff_dist=args.cutoff_dist,
)
post_processing = AtomwisePostProcessing(
    args.num_outputs, y_mean, y_std, atom_refs
)

painn.to(device)
post_processing.to(device)

optimizer = torch.optim.AdamW(
    painn.parameters(),
    lr=args.lr,
    weight_decay=args.weight_decay,
)

painn.train()
pbar = trange(args.num_epochs)
for epoch in pbar:

    loss_epoch = 0.
    for batch in dm.train_dataloader():
        batch = batch.to(device)

        atomic_contributions = painn(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch
        )
        #print(atomic_contributions)
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )
        loss_step = F.mse_loss(preds, batch.y, reduction='sum')

        loss = loss_step / len(batch.y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_epoch += loss_step.detach().item()
    loss_epoch /= len(dm.data_train)
    pbar.set_postfix_str(f'Train loss: {loss_epoch:.3e}')

mae = 0
painn.eval()
with torch.no_grad():
    for batch in dm.test_dataloader():
        batch = batch.to(device)

        atomic_contributions = painn(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch,
        )
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )
        mae += F.l1_loss(preds, batch.y, reduction='sum')

mae /= len(dm.data_test)
unit_conversion = dm.unit_conversion[args.target]
print(f'Test MAE: {unit_conversion(mae):.3f}')


Seed set to 0


cuda


  0%|          | 0/10 [00:00<?, ?it/s]

In [None]:
print(atomic_contributions.shape)
print(batch.z.shape)

print(batch.batch.shape[0])
print(atomic_contributions.shape[0])
