In [60]:
import pytest
from matsciml.datasets.lips import LiPSDataset, lips_devset
from matsciml.datasets import transforms
from matsciml.models.base import MaceEnergyForceTask
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.models.pyg.mace.modules.blocks import *
from matsciml.models.pyg.mace.modules.models import ScaleShiftMACE
from matsciml.models.pyg.mace import data, modules, tools
import pytorch_lightning as pl
from matsciml.models.pyg.mace.modules.utils import compute_mean_std_atomic_inter_energy
from matsciml.models.pyg.mace.tools import (
    atomic_numbers_to_indices,
    to_one_hot)
import e3nn

In [61]:

from matsciml.datasets.transforms import PointCloudToGraphTransform
dm = MatSciMLDataModule.from_devset(
    "LiPSDataset",dset_kwargs={"transforms":[PointCloudToGraphTransform("pyg", cutoff_dist=5.0)]}
)


In [121]:
def to_numpy(t: torch.Tensor) -> np.ndarray:
    return t.cpu().detach().numpy()

def compute_mean_std_atomic_inter_energy_and_avg_num_neighbors(
    data_loader: torch.utils.data.DataLoader,
    atomic_energies: np.ndarray,
) -> Tuple[float, float]:
    atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies)

    avg_atom_inter_es_list = []
    avg_num_neighbors_list=[]
    for batch in data_loader:

        graph = batch.get("graph")
        atomic_numbers: torch.Tensor = getattr(graph, "atomic_numbers")
        z_table=tools.get_atomic_number_table_from_zs(atomic_numbers.numpy())

        indices = atomic_numbers_to_indices(atomic_numbers, z_table=z_table)
        node_attrs = to_one_hot(
            torch.tensor(indices, dtype=torch.long).unsqueeze(-1),
            num_classes=len(z_table))
        node_e0 = atomic_energies_fn(node_attrs)
        graph_e0s = scatter_sum(
            src=node_e0, index=graph.batch, dim=-1, dim_size=graph.num_graphs
        )
        graph_sizes = graph.ptr[1:] - graph.ptr[:-1]
        avg_atom_inter_es_list.append(
            (batch['energy'] - graph_e0s) / graph_sizes
        )  # {[n_graphs], }
        avg_num_neighbors_list.append(graph.edge_index.numel()/len(atomic_numbers))

    avg_atom_inter_es = torch.cat(avg_atom_inter_es_list)  # [total_n_graphs]
    mean = to_numpy(torch.mean(avg_atom_inter_es)).item()
    std = to_numpy(torch.std(avg_atom_inter_es)).item()
    avg_num_neighbors= torch.mean(torch.Tensor(avg_num_neighbors_list))
    return mean, std, avg_num_neighbors

In [122]:
dm.setup()
Train_loader=dm.train_dataloader()
dataset_iter = iter(Train_loader)
batch=next(dataset_iter)


In [123]:
#Atomic Energies table
E0s={1:-13.663181292231226, 3:-216.78673811801755, 6:-1029.2809654211628, 7:-1484.1187695035828, 8:-2042.0330099956639, 15:-1537.0898574856286, 16:-1867.8202267974733}
atomic_numbers=torch.unique(batch['graph']['atomic_numbers']).numpy()
atomic_energies=np.array([E0s[i] for i in atomic_numbers])
atomic_inter_shift,atomic_inter_scale,avg_num_neighbors=compute_mean_std_atomic_inter_energy_and_avg_num_neighbors(Train_loader,atomic_energies)

In [125]:
model_config = dict(
        r_max=5.0,
        num_bessel=8,
        num_polynomial_cutoff=5,
        max_ell=3,
        interaction_cls= RealAgnosticResidualInteractionBlock ,
        num_interactions=2,
        num_elements=3,
        hidden_irreps=e3nn.o3.Irreps('16x0e+16x1o+16x2e'),
        atomic_energies=atomic_energies,
        avg_num_neighbors=avg_num_neighbors,
        atomic_numbers=atomic_numbers,
        correlation=3,
        gate=torch.nn.functional.silu,
        interaction_cls_first=RealAgnosticResidualInteractionBlock,
        MLP_irreps=e3nn.o3.Irreps('16x0e'),
        atomic_inter_scale=atomic_inter_scale,
        atomic_inter_shift=atomic_inter_shift,
        training=True
    )

In [68]:
task = MaceEnergyForceTask(
    encoder_class=ScaleShiftMACE,
    encoder_kwargs=model_config,
    task_keys=["energy","force"],
    output_kwargs={'energy':{
      "block_type": "IdentityOutputBlock",
      "output_dim": 1,
      "hidden_dim": None
      },
      'force':{
      "block_type": "IdentityOutputBlock",
      "output_dim": 3,
      "hidden_dim": None
      }
      },
      loss_coeff={'energy':1.0,'force':1000.0}
)

In [69]:
task

MaceEnergyForceTask(
  (encoder): ScaleShiftMACE(
    (atom_embedding): Embedding(100, 16, padding_idx=0)
    (node_embedding): LinearNodeEmbeddingBlock(
      (linear): Linear(3x0e -> 16x0e | 48 weights)
    )
    (radial_embedding): RadialEmbeddingBlock(
      (bessel_fn): BesselBasis(r_max=5.0, num_basis=8, trainable=False)
      (cutoff_fn): PolynomialCutoff(p=5.0, r_max=5.0)
    )
    (spherical_harmonics): SphericalHarmonics()
    (atomic_energies_fn): AtomicEnergiesBlock(energies=[-216.7867, -1537.0898, -1867.8202])
    (interactions): ModuleList(
      (0): RealAgnosticResidualInteractionBlock(
        (linear_up): Linear(16x0e -> 16x0e | 256 weights)
        (conv_tp): TensorProduct(16x0e x 1x0e+1x1o+1x2e+1x3o -> 16x0e+16x1o+16x2e+16x3o | 64 paths | 64 weights)
        (conv_tp_weights): FullyConnectedNet[8, 64, 64, 64, 64]
        (linear): Linear(16x0e+16x1o+16x2e+16x3o -> 16x0e+16x1o+16x2e+16x3o | 1024 weights)
        (skip_tp): FullyConnectedTensorProduct(16x0e x 3x0e -> 

In [70]:
trainer = pl.Trainer(max_epochs=100,num_sanity_val_steps=2,log_every_n_steps=10)
trainer.fit(task, datamodule=dm)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name         | Type           | Params
------------------------------------------------
0 | encoder      | ScaleShiftMACE | 58.8 K
1 | loss_func    | MSELoss        | 0     
2 | output_heads | ModuleDict     | 0     
------------------------------------------------
58.8 K    Trainable params
0         Non-trainable params
58.8 K    Total params
0.235     Total estimated model params size (MB)


Epoch 0:   8%|▊         | 4/50 [00:14<02:42,  3.52s/it, loss=0.243, v_num=33, train_force=0.205, train_energy=0.0393]