In [1]:
import pytest
from matsciml.datasets import lips
from matsciml.datasets.lips import LiPSDataset, lips_devset
from matsciml.datasets import transforms
import pytorch_lightning as pl
from matsciml.models.pyg import EGNN
from matsciml.models.base import ScalarRegressionTask
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.datasets.transforms import DistancesTransform

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# dset = LiPSDataset(lips_devset)
# sample = dset[10]
# construct a scalar regression task with SchNet encoder
#task = ScalarRegressionTask(
#     encoder_class=EGNN,
#     encoder_kwargs={"hidden_dim": 128, "output_dim": 64},
#     task_keys=["energy"],
# )
# task = ScalarRegressionTask(
#     encoder_class=EGNN,
#     # kwargs to be passed into the creation of SchNet model
#     encoder_kwargs={
#         "encoder_only": True,
#         "hidden_dim": 128,
#         "output__dim": 1,
#     },
#     # which keys to use as targets
#     task_keys=["energy"],
# )
# # Use IS2RE devset to test workflow

In [3]:
from matsciml.datasets.transforms import PointCloudToGraphTransform
# SchNet uses RBFs, and expects edge features corresponding to atom-atom distances
dm = MatSciMLDataModule.from_devset(
    "LiPSDataset",dset_kwargs={"transforms":[PointCloudToGraphTransform("pyg", cutoff_dist=5.0)]}
)
# run a quick training loop
# trainer = pl.Trainer(fast_dev_run=1000)
# trainer.fit(task, datamodule=dm)


In [4]:
from matsciml.models.pyg.mace.modules.blocks import *
from matsciml.models.pyg.mace.modules.models import ScaleShiftMACE

In [5]:
import matsciml
import e3nn

In [6]:
model_config = dict(
        r_max=5.0,
        num_bessel=8,
        num_polynomial_cutoff=5,
        max_ell=3,
        interaction_cls= RealAgnosticResidualInteractionBlock ,
        num_interactions=2,
        num_elements=3,
        hidden_irreps=e3nn.o3.Irreps('16x0e+16x1o+16x2e'),
        atomic_energies=torch.Tensor([-13.663181292231226, -1029.2809654211628, -2042.0330099956639]),
        avg_num_neighbors=14.38,
        atomic_numbers=(1,6,8),
    )

In [12]:
model=ScaleShiftMACE(
            **model_config,
            correlation=3,
            gate=torch.nn.functional.silu,
            interaction_cls_first=RealAgnosticResidualInteractionBlock,
            MLP_irreps=e3nn.o3.Irreps('16x0e'),
            atomic_inter_scale=1.0,
            atomic_inter_shift=0.0,
        )

In [13]:
model

ScaleShiftMACE(
  (atom_embedding): Embedding(100, 16, padding_idx=0)
  (node_embedding): LinearNodeEmbeddingBlock(
    (linear): Linear(3x0e -> 16x0e | 48 weights)
  )
  (radial_embedding): RadialEmbeddingBlock(
    (bessel_fn): BesselBasis(r_max=5.0, num_basis=8, trainable=False)
    (cutoff_fn): PolynomialCutoff(p=5.0, r_max=5.0)
  )
  (spherical_harmonics): SphericalHarmonics()
  (atomic_energies_fn): AtomicEnergiesBlock(energies=[-13.6632, -1029.2810, -2042.0330])
  (interactions): ModuleList(
    (0): RealAgnosticResidualInteractionBlock(
      (linear_up): Linear(16x0e -> 16x0e | 256 weights)
      (conv_tp): TensorProduct(16x0e x 1x0e+1x1o+1x2e+1x3o -> 16x0e+16x1o+16x2e+16x3o | 64 paths | 64 weights)
      (conv_tp_weights): FullyConnectedNet[8, 64, 64, 64, 64]
      (linear): Linear(16x0e+16x1o+16x2e+16x3o -> 16x0e+16x1o+16x2e+16x3o | 1024 weights)
      (skip_tp): FullyConnectedTensorProduct(16x0e x 3x0e -> 16x0e+16x1o+16x2e | 768 paths | 768 weights)
      (reshape): reshape

In [17]:
dm.setup()
Train_loader=dm.train_dataloader()
dataset_iter = iter(Train_loader)
batch=next(dataset_iter)

In [18]:
batch['graph']

DataBatch(edge_index=[2, 4707], pos=[664, 3], atomic_numbers=[664], force=[664, 3], batch=[664], ptr=[9])

In [19]:
inputs=model.read_batch(batch)
outputs=model._forward(**inputs)

In [21]:
outputs.keys()

dict_keys(['energy', 'node_energy', 'interaction_energy', 'forces', 'virials', 'stress', 'displacement', 'node_feats', 'edge_feats'])

In [26]:
outputs['node_feats'].shape

torch.Size([664, 16])