In [1]:
from typing import Literal, Callable, Any
from math import ceil

import torch
from torch import nn
from torch.utils.data import random_split
from torch.optim import AdamW
import e3nn
from e3nn import o3
from torch_scatter import scatter
from torch_geometric.data import Data as PyGGraph
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
import pytorch_lightning as pl

from equitriton.sph_harm.direct import triton_spherical_harmonic
from equitriton.utils import spherical_harmonics_irreps

In [2]:
seed = torch.manual_seed(215162)

In [11]:
class AtomEmbedding(nn.Module):
    def __init__(self, num_atoms: int, atom_dim: int):
        super().__init__()
        self.embedding = nn.Embedding(num_atoms, atom_dim, padding_idx=0)

    def forward(self, atomic_numbers: torch.LongTensor) -> torch.Tensor:
        return self.embedding(atomic_numbers)


class EdgeEmbedding(nn.Module):
    def __init__(self, num_basis: int, radius_cutoff: float = 6.0, **kwargs):
        """
        This module embeds edges in a graph with an EdgeEmbedding object.

        Parameters
        ----------
        num_basis : int, optional
            The number of basis functions. Defaults to 1.
        radius_cutoff : float, optional
            The maximum radius up to which basis functions are defined. Defaults to 6.0.

        Optional kwargs
        ---------------
        basis : str, optional
            The type of basis function to use. Defaults to 'bessel'.
        start : float, optional
            The starting point in the distance grid used in the radial basis.
        cutoff : bool, optional
            Whether or not to apply a cutoff to the basis functions.

        Returns
        -------
        torch.Tensor
            A tensor representing the embedding of edges with shape (num_edges, num_basis).

        Examples
        --------
        >>> # Define an instance of EdgeEmbedding with 4 basis functions and a radius cutoff of 10.
        >>> embedder = EdgeEmbedding(num_basis=4, radius_cutoff=10.0)
        """
        super().__init__()
        kwargs.setdefault("basis", "bessel")
        kwargs.setdefault("start", 0.0)
        kwargs.setdefault("cutoff", True)
        self.num_basis = num_basis
        self.radius_cutoff = radius_cutoff
        self.basis_kwargs = kwargs

    def forward(self, distances: torch.Tensor) -> torch.Tensor:
        basis_funcs = e3nn.math.soft_one_hot_linspace(
            distances,
            number=self.num_basis,
            end=self.radius_cutoff,
            **self.basis_kwargs,
        )
        return basis_funcs * self.num_basis**0.5


class SphericalHarmonicEmbedding(nn.Module):
    def __init__(
        self,
        l_values: list[int],
        normalize: bool = True,
        normalization: Literal["norm", "integral", "component"] = "integral",
    ):
        """
        Projects cartesian positions onto spherical harmonic functions.
        """
        super().__init__()
        self.l_values = list(sorted(l_values))
        self.irreps = spherical_harmonics_irreps(self.l_values, num_feat=1)
        self.normalize = normalize
        self.normalization = normalization

    def forward(self, coords: torch.Tensor) -> torch.Tensor:
        outputs = [triton_spherical_harmonic(l, coords) for l in self.l_values]
        return torch.cat(outputs, dim=-1)


class InteractionBlock(nn.Module):
    def __init__(
        self,
        atomic_dim: int | o3.Irreps,
        l_values: int,
        edge_dim: int,
        hidden_dim: int,
        radius_cutoff: float,
        degree_norm: float,
        edge_kwargs: dict[str, Any] = {},
        sph_harm_kwargs: dict[str, Any] = {},
        activation: Callable = nn.functional.silu,
    ):
        """
        A module that combines radial basis with spherical harmonics to
        describe molecular interactions.

        Parameters
        ----------
        atomic_dim : int | o3.Irreps
            Dimension of the atomic features. If int, it is treated as a
            single irreducible representation.
        l_values : int
            Values of the spherical harmonic order.
        edge_dim : int
            Dimension of the edge features.
        hidden_dim : int
            Hidden dimension for the fully connected network.
        radius_cutoff : float
            Cutoff radius for the radial basis.
        degree_norm : float
            Normalization factor for the degree of the graph.
        edge_kwargs : dict[str, Any], optional
            Keyword arguments for the EdgeEmbedding module. Defaults to {}.
        sph_harm_kwargs : dict[str, Any], optional
            Keyword arguments for the SphericalHarmonicEmbedding module.
            Defaults to {}.
        activation : Callable, optional
            Activation function for the fully connected network. Defaults to
            nn.functional.silu.

        Notes
        -----
        The `degree_norm` attribute is set as a property and effectively
        represents the average number of neighbors in other models.

        Examples
        --------
        >>> block = InteractionBlock(atomic_dim=8, l_values=[0, 1],
            edge_dim=16, hidden_dim=32)
        >>> block.sph_irreps
        ['1x0e', '2x0e']
        """

        super().__init__()
        # this is effectively the average number of neighbors in other models
        self.degree_norm = degree_norm
        # treat atom features as invariant
        if isinstance(atomic_dim, int):
            atomic_irreps = f"{atomic_dim}x0e"
        else:
            atomic_irreps = atomic_dim
        self.atomic_irreps = atomic_irreps
        self.l_values = list(sorted(l_values))
        # these two attributes are similar but different: the former is used for describing
        # the basis itself, and the latter is for actually specifying the weights
        self.sph_irreps = spherical_harmonics_irreps(self.l_values, num_feat=1)
        self.output_irreps = spherical_harmonics_irreps(
            self.l_values, num_feat=hidden_dim
        )
        # tensor product is the final bit the combines the radial basis with the spherical
        # harmonics
        self.tensor_product = o3.FullyConnectedTensorProduct(
            self.atomic_irreps,
            self.sph_irreps,
            self.output_irreps,
            shared_weights=False,
        )
        self.edge_basis = EdgeEmbedding(edge_dim, radius_cutoff, **edge_kwargs)
        self.spherical_harmonics = SphericalHarmonicEmbedding(
            l_values, **sph_harm_kwargs
        )
        self.fc = e3nn.nn.FullyConnectedNet(
            [edge_dim, hidden_dim, self.tensor_product.weight_numel], activation
        )

    @property
    def num_projections(self) -> int:
        """Returns the expected number of projections."""
        return sum([2 * l + 1 for l in self.l_values])

    @property
    def output_dim(self) -> int:
        """Returns the dimensionality of the output."""
        return self.output_irreps.dim

    def forward(
        self,
        atomic_features: torch.Tensor,
        coords: torch.Tensor,
        edge_index: torch.LongTensor,
    ) -> torch.Tensor:
        """
        High-level description:

        1. Project cartesian coordinates onto spherical harmonic basis
        2. Project interatomic distances onto radial (bessel) basis
        3. Transform radial basis functions with learnable weights
        4. Compute tensor product between scalar atom features and spherical harmonic basis
        5. Update node features
        """
        edge_dist = coords[edge_index[0]] - coords[edge_index[1]]
        sph_harm = self.spherical_harmonics(edge_dist)
        # calculate atomic distances, embed, and transform them
        edge_basis = self.edge_basis(edge_dist.norm(dim=-1))
        edge_z = self.fc(edge_basis)
        # compute tensor product
        messages = self.tensor_product(atomic_features[edge_index[0]], sph_harm, edge_z)
        # update node features
        hidden_feats = (
            scatter(messages, edge_index[1], dim=0, dim_size=atomic_features.size(0))
            / self.degree_norm
        )
        return hidden_feats


class ScalarReadoutLayer(nn.Module):
    def __init__(self, hidden_irreps: o3.Irreps, output_dim: int):
        super().__init__()
        self.hidden_irreps = hidden_irreps
        self.output_irreps = o3.Irreps(f"{output_dim}x0e")
        self.output_layer = o3.Linear(
            irreps_in=hidden_irreps, irreps_out=self.output_irreps
        )

    def forward(self, node_feats: torch.Tensor) -> torch.Tensor:
        return self.output_layer(node_feats)


class EquiTritonModel(nn.Module):
    def __init__(
        self,
        initial_atom_dim: int,
        num_layers: int,
        output_dim: int,
        l_values: int,
        edge_dim: int,
        hidden_dim: int,
        radius_cutoff: float,
        degree_norm: float,
        edge_kwargs: dict[str, Any] = {},
        sph_harm_kwargs: dict[str, Any] = {},
        activation: Callable = nn.functional.silu,
        num_atoms: int = 100,
        skip_connections: bool = True,
    ):
        """
        A neural network model designed for processing molecular graphs.

        This class implements a hierarchical architecture with multiple interaction blocks,
        allowing for efficient and scalable processing of large molecular datasets.

        Parameters:
            initial_atom_dim (int): The dimensionality of the atomic embeddings.
            num_layers (int): The number of convolutional layers to use.
            output_dim (int): The dimensionality of the final scalar features.
            l_values (int): A list of spherical harmonics order to consider.
            edge_dim (int): The dimensionality of the edge features.
            hidden_dim (int): The dimensionality of the hidden state in each interaction block.
            radius_cutoff (float): The cutoff distance for radial basis functions.
            degree_norm (float): The normalization constant for edge features. Typically square root of the average degree.
            edge_kwargs (dict[str, Any], optional): Keyword arguments to pass to the InteractionBlock. Defaults to {}.
            sph_harm_kwargs (dict[str, Any], optional): Keyword arguments to pass to the InteractionBlock. Defaults to {}.
            activation (Callable, optional): The activation function to use in each interaction block. Defaults to nn.functional.silu.
            num_atoms (int, optional): The number of atoms in the embedding table (i.e. unique elements). Defaults to 100.
            skip_connections (bool, optional): Whether to enable residual connections between layers. Defaults to True.

        Returns:
            tuple[torch.Tensor, torch.Tensor]: A tuple containing the graph-level scalar features and the node features.

        Examples:
            >>> model = EquiTritonModel(...)
            >>> graph = PyGGraph(...).to(device="cuda")
            >>> graph_z, z = model(graph)

        Note: This class uses PyTorch Geometric's Graph data structure and assumes that the input graph has already been processed using a suitable preprocessing step.
        """
        super().__init__()
        self.atomic_embedding = AtomEmbedding(num_atoms, initial_atom_dim)
        self.initial_layer = InteractionBlock(
            initial_atom_dim,
            l_values,
            edge_dim,
            hidden_dim,
            radius_cutoff,
            degree_norm,
            edge_kwargs,
            sph_harm_kwargs,
            activation,
        )
        self.conv_layers = nn.ModuleDict()
        for layer_index in range(num_layers + 1):
            self.conv_layers[f"conv_{layer_index}"] = InteractionBlock(
                self.initial_layer.output_dim,
                l_values,
                edge_dim,
                hidden_dim,
                radius_cutoff,
                degree_norm,
                edge_kwargs,
                sph_harm_kwargs,
                activation,
            )
        self.scalar_readout = ScalarReadoutLayer(
            self.initial_layer.output_irreps, output_dim
        )
        self.skip_connections = skip_connections
        self.output_dim = output_dim

    def forward(self, graph: PyGGraph) -> tuple[torch.Tensor, torch.Tensor]:
        # determine if the graph is batched or not
        is_batched = hasattr(graph, "ptr")
        # get atom embeddings
        atom_z = self.atomic_embedding(graph.z)  # [nodes, initial_atom_dim]
        # first message passing step
        z = self.initial_layer(atom_z, graph.pos, graph.edge_index)
        outputs = {}
        for layer_name, layer in self.conv_layers.items():
            new_z = layer(z, graph.pos, graph.edge_index)
            # add residual connections
            if self.skip_connections and new_z.shape == z.shape:
                new_z += z
            z = new_z
            outputs[layer_name] = z
        # map final output as scalars
        z = self.scalar_readout(z)
        # latest node features are in z; we generate graph-level scalar features
        # by doing a scatter add
        if is_batched:
            graph_z = scatter(z, graph.batch, dim=0, dim_size=graph.batch_size)
        else:
            # for a single graph, just sum up the node features
            graph_z = z.sum(dim=0, keepdims=True)
        return graph_z, z

In [17]:
def make_fake_graph(
    num_nodes: int,
    num_edges: int,
    coord_scale: float = 1.0,
    max_atomic_number: int = 100,
    device="cuda",
):
    coords = torch.rand(num_nodes, 3, device=device) * coord_scale
    edge_index = torch.randint(0, high=num_nodes, size=(2, num_edges), device=device)
    atomic_numbers = torch.randint(
        0, max_atomic_number, size=(num_nodes,), device=device
    )
    return coords, edge_index, atomic_numbers

In [18]:
edge_embedder = EdgeEmbedding(num_basis=10)

In [19]:
coords, edge_index, atomic_numbers = make_fake_graph(
    16,
    12,
)

In [20]:
# coords = torch.ones_like(coords, requires_grad=True)

In [21]:
atom_embedder = AtomEmbedding(100, 64).to("cuda")
atom_z = atom_embedder(atomic_numbers)

In [22]:
edge_index.shape

torch.Size([2, 12])

In [50]:
layer = InteractionBlock(
    64, [2, 3, 4, 5], 10, 32, radius_cutoff=6.0, degree_norm=17**0.5
).to("cuda")



In [51]:
next_layer = InteractionBlock(
    layer.output_irreps, [2, 3, 4, 5], 10, 32, radius_cutoff=6.0, degree_norm=17**0.5
).to("cuda")



In [52]:
layer

InteractionBlock(
  (tensor_product): FullyConnectedTensorProduct(64x0e x 1x2e+1x3o+1x4e+1x5o -> 32x2e+32x3o+32x4e+32x5o | 8192 paths | 8192 weights)
  (edge_basis): EdgeEmbedding()
  (spherical_harmonics): SphericalHarmonicEmbedding()
  (fc): FullyConnectedNet[10, 32, 8192]
)

In [53]:
o = layer(atom_z, coords, edge_index)

In [67]:
o

tensor([[ 0.1235,  0.0030,  0.0861,  ...,  0.3042,  0.1641, -0.0114],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0167, -0.0017,  0.0098,  ..., -0.0773, -0.0233,  0.0344],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0578,  0.0226, -0.0362,  ..., -0.0003,  0.0016,  0.0056],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0', grad_fn=<DivBackward0>)

In [37]:
p = next_layer(o, coords, edge_index)

In [38]:
o.shape

torch.Size([16, 1184])

In [39]:
o + p

tensor([[ 0.3229, -0.0011,  0.2210,  ...,  0.3184,  0.1664, -0.0134],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.7950,  0.0826, -0.4642,  ..., -0.1652, -0.0498,  0.0736],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.1217, -0.0898, -0.0128,  ..., -0.0215,  0.0755, -0.0243],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0', grad_fn=<AddBackward0>)

## Dataset definition

In [23]:
class LightningQM9(pl.LightningDataModule):
    def __init__(
        self,
        root_path: str = "./qm9_data",
        batch_size: int = 16,
        train_frac: float = 0.8,
        val_frac: float = 0.1,
        num_workers: int = 0,
    ):
        """
        Custom data module for QM9 dataset.

        Parameters
        ----------
        root_path : str, optional (default: "./qm9_data")
            Path to the QM9 dataset.
        batch_size : int, optional (default: 16)
            Number of samples in each mini-batch.
        train_frac : float, optional (default: 0.8)
            Fraction of data used for training.
        val_frac : float, optional (default: 0.1)
            Fraction of data used for validation.
        num_workers : int, optional (default: 0)
            Number of worker processes to use for loading data.

        Examples
        --------
        >>> dm = LightningQM9(root_path="/path/to/qm9_data", batch_size=32)

        Attributes
        ----------
        dataset : QM9
            Loaded QM9 dataset.
        hparams : dict
            Hyperparameters of the data module.

        Methods
        -------
        setup(stage: str)
            Setup data splits for training, validation and testing.
        train_dataloader()
            Returns a DataLoader instance for training data.
        val_dataloader()
            Returns a DataLoader instance for validation data.
        test_dataloader()
            Returns a DataLoader instance for testing data.
        """
        super().__init__()
        self.dataset = QM9(root_path)
        self.save_hyperparameters()

    def setup(self, stage: str):
        hparams = self.hparams
        num_samples = len(self.dataset)
        num_train = int(num_samples * hparams["train_frac"])
        num_val = int(num_samples * hparams["val_frac"])
        num_test = ceil(
            num_samples * (1 - (hparams["train_frac"] + hparams["val_frac"]))
        )
        # generate random splits
        train_split, val_split, test_split = random_split(
            self.dataset, lengths=[num_train, num_val, num_test]
        )
        self.splits = {"train": train_split, "val": val_split, "test": test_split}

    def train_dataloader(self):
        return DataLoader(
            self.splits["train"],
            batch_size=self.hparams["batch_size"],
            shuffle=True,
            num_workers=self.hparams["num_workers"],
        )

    def val_dataloader(self):
        return DataLoader(
            self.splits["val"],
            batch_size=self.hparams["batch_size"],
            shuffle=False,
            num_workers=self.hparams["num_workers"],
        )

    def test_dataloader(self):
        return DataLoader(
            self.splits["test"],
            batch_size=self.hparams["batch_size"],
            shuffle=False,
            num_workers=self.hparams["num_workers"],
        )

## Loss and Lightning module

Model trains optionally with a loss target that Nequip and MACE uses, which is the atom-weighted MSE. For now we're only using a single target, but can expand to use the full QM9 set of targets too.

In [33]:
class AtomWeightedMSE(nn.Module):
    """
    Calculates the mean-squared-error between predicted and targets,
    weighted by the number of atoms within each graph.

    From matsciml
    """

    def forward(
        self,
        input: torch.Tensor,
        target: torch.Tensor,
        atoms_per_graph: torch.Tensor,
    ) -> torch.Tensor:
        if atoms_per_graph.size(0) != target.size(0):
            raise RuntimeError(
                "Dimensions for atom-weighted loss do not match:"
                f" expected atoms_per_graph to have {target.size(0)} elements; got {atoms_per_graph.size(0)}."
                "This loss is intended to be applied to scalar targets only."
            )
        # check to make sure we are broad casting correctly
        if (input.ndim != target.ndim) and target.size(-1) == 1:
            input.unsqueeze_(-1)
        # for N-d targets, we might want to keep unsqueezing
        while atoms_per_graph.ndim < target.ndim:
            atoms_per_graph.unsqueeze_(-1)
        # ensures that atoms_per_graph is type cast correctly
        squared_error = ((input - target) / atoms_per_graph.to(input.dtype)) ** 2.0
        return squared_error.mean()


class EquiTritonLitModule(pl.LightningModule):
    def __init__(
        self,
        model_class: type,
        model_kwargs,
        e_mean: float,
        e_std: float,
        lr: float = 1e-3,
        weight_decay: float = 0.0,
        atom_weighted_loss: bool = True,
    ):
        """
        Initializes the EquiTritonLitModule clas.

        Parameters
        ----------
        model_class : type
            Th class of the model to be used.
        model_kwargs : dict
            Keyword argument for the model initialization.
        e_mean : float
            The mean of the energy values.
        e_std : float
            The standard deviation of the energy values.
        lr : float, optional
            The learning rate (default is 1e-3) for AdamW.
        weight_decay : float, optional
            Weight decay value (default is 0.0).
        atom_weighted_loss : bool, optional
            Whether to use atom-weighted loss or not (default is True).
        """
        super().__init__()
        self.model = model_class(**model_kwargs)
        if atom_weighted_loss:
            self.loss = AtomWeightedMSE()
        else:
            self.loss = nn.MSELoss()
        self.output_head = nn.Linear(self.model.output_dim, 1)
        self.save_hyperparameters()

    def configure_optimizers(self):
        return AdamW(
            self.parameters(),
            lr=self.hparams["lr"],
            weight_decay=self.hparams["weight_decay"],
        )

    def step(self, graph: PyGGraph, stage: Literal["train", "test", "val"]):
        """
        Performs a single step of the training, validation or testing
        process.

        Parameters
        ----------
        graph : PyGGraph
            The input graph.
        stage : Literal["train", "test", "val"]
            The current stage (training, testing or validation).

        Returns
        -------
        loss : float
            The calculated loss value.
        """
        g_z, z = self.model(graph)
        pred_energy = self.output_head(g_z)
        target_energy = graph.y[:, 12].unsqueeze(-1)
        norm_energy = (target_energy - self.hparams["e_mean"]) / self.hparams["e_std"]
        if self.hparams["atom_weighted_loss"]:
            loss = self.loss(pred_energy, norm_energy, torch.diff(graph.ptr))
        else:
            loss = self.loss(pred_energy, norm_energy)
        batch_size = getattr(graph, "batch_size", 1)
        self.log(
            f"{stage}_loss", loss, prog_bar=True, on_step=True, batch_size=batch_size
        )
        return loss

    def training_step(self, batch):
        loss = self.step(batch, "train")
        return loss

    def validation_step(self, batch):
        loss = self.step(batch, "val")
        return loss

    def test_step(self, batch):
        loss = self.step(batch, "test")
        return loss

In [27]:
dm = LightningQM9("./qm9_data/", batch_size=64)
dm.setup("fit")

train_loader = dm.train_dataloader()

  if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
  if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
  return torch.load(f, map_location)


In [28]:
values = torch.cat([sample.y[:, 12] for sample in dm.dataset])
e_mean = values.mean()
e_std = values.std()

In [40]:
lit_mod = EquiTritonLitModule(
    EquiTritonModel,
    model_kwargs={
        "initial_atom_dim": 64,
        "num_layers": 3,
        "output_dim": 48,
        "l_values": [0, 1, 2, 4, 6, 8],
        "edge_dim": 10,
        "hidden_dim": 16,
        "radius_cutoff": 6.0,
        "degree_norm": 37.5**0.5,
    },
    e_mean=e_mean,
    e_std=e_std,
    atom_weighted_loss=False,
)



In [41]:
trainer = pl.Trainer(max_epochs=30, accelerator="gpu")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [42]:
trainer.fit(lit_mod, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type            | Params
------------------------------------------------
0 | model       | EquiTritonModel | 4.8 M 
1 | loss        | MSELoss         | 0     
2 | output_head | Linear          | 49    
------------------------------------------------
4.8 M     Trainable params
0         Non-trainable params
4.8 M     Total params
19.300    Total estimated model params size (MB)


Sanity Checking: |                                                                                            …

/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=27` in the `DataLoader` to improve performance.
/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=27` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

/home/kelvin/miniforge3/envs/equitriton/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
