In [1]:
import pytest
from matsciml.datasets import lips
from matsciml.datasets.lips import LiPSDataset, lips_devset
from matsciml.datasets import transforms
import pytorch_lightning as pl
from matsciml.models.pyg import EGNN
from matsciml.models.base import ScalarRegressionTask
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.datasets.transforms import DistancesTransform

from matsciml.models.pyg.mace.modules.blocks import *
from matsciml.models.pyg.mace.modules.models import ScaleShiftMACE

import matsciml
import e3nn

In [10]:
from matsciml.datasets.transforms import PointCloudToGraphTransform
# SchNet uses RBFs, and expects edge features corresponding to atom-atom distances
dm = MatSciMLDataModule.from_devset(
    "LiPSDataset",dset_kwargs={"transforms":[PointCloudToGraphTransform("pyg", cutoff_dist=5.0)]}
)
# run a quick training loop
# trainer = pl.Trainer(fast_dev_run=1000)
# trainer.fit(task, datamodule=dm)


In [11]:
model_config = dict(
        r_max=5.0,
        num_bessel=8,
        num_polynomial_cutoff=5,
        max_ell=3,
        interaction_cls= RealAgnosticResidualInteractionBlock ,
        num_interactions=2,
        num_elements=3,
        hidden_irreps=e3nn.o3.Irreps('16x0e+16x1o+16x2e'),
        atomic_energies=torch.Tensor([-13.663181292231226, -1029.2809654211628, -2042.0330099956639]),
        avg_num_neighbors=14.38,
        atomic_numbers=(1,6,8),
        correlation=3,
        gate=torch.nn.functional.silu,
        interaction_cls_first=RealAgnosticResidualInteractionBlock,
        MLP_irreps=e3nn.o3.Irreps('16x0e'),
        atomic_inter_scale=1.0,
        atomic_inter_shift=0.0
    )

In [12]:
task = ScalarRegressionTask(
    encoder_class=ScaleShiftMACE,
    # kwargs to be passed into the creation of SchNet model
    encoder_kwargs=model_config,
    # which keys to use as targets
    task_keys=["energy"],
    output_kwargs={
      "block_type": "IrrepOutputBlock",
      "input_dim": "0e", 
      "hidden_dim": "30x0e + 10x1e", 
      "output_dim": "0e",
      "activation": ["torch.nn.SiLU", None],
      "residual":False
    }
)

In [13]:
task

ScalarRegressionTask(
  (encoder): ScaleShiftMACE(
    (atom_embedding): Embedding(100, 16, padding_idx=0)
    (node_embedding): LinearNodeEmbeddingBlock(
      (linear): Linear(3x0e -> 16x0e | 48 weights)
    )
    (radial_embedding): RadialEmbeddingBlock(
      (bessel_fn): BesselBasis(r_max=5.0, num_basis=8, trainable=False)
      (cutoff_fn): PolynomialCutoff(p=5.0, r_max=5.0)
    )
    (spherical_harmonics): SphericalHarmonics()
    (atomic_energies_fn): AtomicEnergiesBlock(energies=[-13.6632, -1029.2810, -2042.0330])
    (interactions): ModuleList(
      (0): RealAgnosticResidualInteractionBlock(
        (linear_up): Linear(16x0e -> 16x0e | 256 weights)
        (conv_tp): TensorProduct(16x0e x 1x0e+1x1o+1x2e+1x3o -> 16x0e+16x1o+16x2e+16x3o | 64 paths | 64 weights)
        (conv_tp_weights): FullyConnectedNet[8, 64, 64, 64, 64]
        (linear): Linear(16x0e+16x1o+16x2e+16x3o -> 16x0e+16x1o+16x2e+16x3o | 1024 weights)
        (skip_tp): FullyConnectedTensorProduct(16x0e x 3x0e -> 

In [14]:
trainer = pl.Trainer(fast_dev_run=1000)
trainer.fit(task, datamodule=dm)

GPU available: True (cuda), used: False


TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Running in `fast_dev_run` mode: will run the requested loop using 1000 batch(es). Logging and checkpointing is suppressed.

  | Name         | Type           | Params
------------------------------------------------
0 | encoder      | ScaleShiftMACE | 58.8 K
1 | loss_func    | MSELoss        | 0     
2 | output_heads | ModuleDict     | 1.1 K 
------------------------------------------------
59.9 K    Trainable params
0         Non-trainable params
59.9 K    Total params
0.240     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Epoch 0:   0%|          | 0/50 [00:00<?, ?it/s] 1)Input data shapes
cell torch.Size([24, 3])
energy torch.Size([8, 1])
ptr torch.Size([9])
batch torch.Size([664])
edge_feats None
graph_feats None
positions torch.Size([664, 3])
forces torch.Size([664, 3])
charges torch.Size([664])
energy_weight torch.Size([8])
forces_weight torch.Size([8])
stress torch.Size([8, 3, 3])
stress_weights torch.Size([8])
virials torch.Size([8, 3, 3])
virials_weights torch.Size([8])
weights torch.Size([8])
node_attrs torch.Size([664, 3])
shifts torch.Size([16732, 3])
unit_shifts torch.Size([16732, 3])
edge_index torch.Size([2, 16732])
2)Node e0 torch.Size([664])
3)e0(After scatter sum over nodes) torch.Size([8])
4)Layer 0: Node feats torch.Size([664, 16])
5)rij,|rij| torch.Size([16732, 3]) torch.Size([16732, 1])
6)Layer 0: edge_attrs(Sph harm), edge_feats(Radial basis) torch.Size([16732, 16]) torch.Size([16732, 8])
7)Layer 1: Node feats(After interaction),sc torch.Size([664, 16, 16]) torch.Size([664, 144])
7)L

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
model=ScaleShiftMACE(**model_config)

In [None]:
dm.setup()
Train_loader=dm.train_dataloader()
dataset_iter = iter(Train_loader)
batch=next(dataset_iter)

In [None]:
model(batch)

1)Input data shapes
cell torch.Size([24, 3])
energy torch.Size([8, 1])
ptr torch.Size([9])
batch torch.Size([664])
edge_feats None
graph_feats None
positions torch.Size([664, 3])
forces torch.Size([664, 3])
charges torch.Size([664])
energy_weight torch.Size([8])
forces_weight torch.Size([8])
stress torch.Size([8, 3, 3])
stress_weights torch.Size([8])
virials torch.Size([8, 3, 3])
virials_weights torch.Size([8])
weights torch.Size([8])
node_attrs torch.Size([664, 3])
shifts torch.Size([16646, 3])
unit_shifts torch.Size([16646, 3])
edge_index torch.Size([2, 16646])
2)Node e0 torch.Size([664])
3)e0(After scatter sum over nodes) torch.Size([8])
4)Layer 0: Node feats torch.Size([664, 16])
5)rij,|rij| torch.Size([16646, 3]) torch.Size([16646, 1])
6)Layer 0: edge_attrs(Sph harm), edge_feats(Radial basis) torch.Size([16646, 16]) torch.Size([16646, 8])
7)Layer 1: Node feats(After interaction),sc torch.Size([664, 16, 16]) torch.Size([664, 144])
7)Layer 1: Node feats(After product) torch.Size([66

Embeddings(system_embedding=tensor([[-102563.4453],
        [-102563.9375],
        [-102563.9375],
        [-102563.9375],
        [-102563.9375],
        [-102563.9375],
        [-102563.9375],
        [-102563.9375]], grad_fn=<ReshapeAliasBackward0>), point_embedding=tensor([[ 0.0608,  0.3205, -0.0511,  ...,  0.1473, -0.0543,  0.0138],
        [ 0.0801,  0.3038, -0.0524,  ...,  0.1366, -0.0654,  0.0398],
        [ 0.0267,  0.3029, -0.0432,  ...,  0.1472,  0.0092,  0.0188],
        ...,
        [-0.1004,  0.1576, -0.3078,  ...,  0.1584, -0.2943, -0.4481],
        [-0.1004,  0.1576, -0.3078,  ...,  0.1584, -0.2943, -0.4481],
        [-0.1004,  0.1576, -0.3078,  ...,  0.1584, -0.2943, -0.4481]],
       grad_fn=<AddBackward0>), reduction=None, reduction_kwargs={})

In [None]:
batch['graph']

In [None]:
task(batch)

In [None]:
inputs=model.read_batch(batch)

In [None]:
outputs=model._forward(**inputs)

In [None]:
out

In [None]:
outputs