In [1]:
%load_ext autoreload
%autoreload 2
import moldb as mdb
import molpot as mpot
import torch
from pony.orm import *

  from torch.distributed.optim import ZeroRedundancyOptimizer


In [2]:
from pprint import pprint
from matplotlib import category
import requests
from pathlib import Path

db_path = Path("cubic_perovskites.db")
if not db_path.exists():
    session = requests.get(
        "https://cmr.fysik.dtu.dk/_downloads/03d2580a2f33d61c6998b803d2d72af0/cubic_perovskites.db"
    )
    with open("cubic_perovskites.db", "wb") as f:
        f.write(session.content)


db = mdb.Database()
mdb_path = "/workspaces/molcrafts/moldb/example/cubic_perovskites.sqlite"
provider = "sqlite"

if Path(mdb_path).exists():
    print("Database already exists")
    db.bind(provider=provider, filename=mdb_path)
else:
    print("Creating database")
    db.bind(provider=provider, filename=mdb_path, create_db=True)
    db.load_ase(
        db_path=db_path,
        selection="combination",
        table_name="CubicPerovskites",
        extra={
            "heat_of_formation_all": {
                "kind": Required(float),
                "name": "Ef",
                "unit": "eV/atom",
                "dtype": "float",
                "shape": [],
                "comment": "Heat of formation",
                "category": "",
            }
        },
    )
structure = db.def_entity('CubicPerovskites', {
    "numbers":  Required(IntArray),
    "positions": Required(FloatArray),
    "cell": Required(FloatArray),
    "pbc": Required(StrArray),
    "Ef": Required(float),
})
db.show_table("CubicPerovskitesNameSpace", n_rows=None)
pprint(structure[1].to_dict())

Database already exists


{'Ef': 1.2,
 'cell': [3.9214237611052916,
          0.0,
          0.0,
          0.0,
          3.9214237611052916,
          0.0,
          0.0,
          0.0,
          3.9214237611052916],
 'id': 1,
 'numbers': [22, 33, 7, 8, 8],
 'pbc': [1, 1, 1],
 'positions': [0.8384829022255794,
               0.0,
               0.0,
               2.1902264842929,
               1.960711880552646,
               1.960711880552646,
               1.8287397655746789,
               0.0,
               1.960711880552646,
               1.7885253110150061,
               1.960711880552646,
               0.0,
               0.06959507274448885,
               1.960711880552646,
               1.960711880552646]}


In [3]:
import molpot as mpot
from typing import Callable

from tqdm import tqdm


class MolDBDataset(mpot.Dataset):

    def __init__(
        self, name, db, table_name: str | None = None, mapping: dict[str, str] = {}
    ):
        super().__init__(name)
        self.db = db
        self.name = name
        self.table_name = table_name or name
        self.table = db.entities[self.table_name]
        self.mapping = mapping

        for i in db.entities[f"{self.table_name}NameSpace"].select():
            alias = {k: v for k, v in i.to_dict().items() if k != "id"}
            alias["name"] = mapping.get(alias["name"], alias["name"])
            self.labels.set(**alias)

        self._data = []

    def preload(self, selection: Callable | None = None):
        data = []
        if selection is None:
            selection = lambda _: True

        for d in tqdm(self.table.select(selection)):
            data.append(self.apply_transforms(self.to_frame(d)))
        self._data = data
        return data

    def __len__(self):
        return self.table.select().count()

    def to_frame(self, d):
        return mpot.Frame(
            {
                self.labels[self.mapping.get(k, k)].key[1:]: self.labels.format(
                    self.mapping.get(k, k), v
                )
                for k, v in d.to_dict().items()
            }
        )

    def __getitem__(self, idx):
        d = self.table[idx + 1]
        frame = self.to_frame(d)
        return self.apply_transforms(frame)


perovskites = MolDBDataset(
    "CubicPerovskites", db, mapping={"positions": "xyz", "numbers": "Z"}
)
perovskites.add_transform(mpot.pipline.nblist.NeighborList(cutoff=5.0))
frames = perovskites.preload()

100%|██████████| 18928/18928 [00:03<00:00, 4911.94it/s]


In [4]:
dl = mpot.DataLoader(perovskites, batch_size=10)
for d in dl:
    print(d)
    break

TensorDict(
    fields={
        Ef: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        atoms: TensorDict(
            fields={
                Z: Tensor(shape=torch.Size([50]), device=cpu, dtype=torch.int32, is_shared=False),
                atom_batch_mask: Tensor(shape=torch.Size([50]), device=cpu, dtype=torch.int64, is_shared=False),
                xyz: Tensor(shape=torch.Size([50, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        cell: Tensor(shape=torch.Size([30, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        id: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
        pairs: TensorDict(
            fields={
                diff: Tensor(shape=torch.Size([100, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                dist: Tensor(shape=torch.Size([100]), device=cpu, dtype=t

In [5]:
pinet = mpot.potential.nnp.PiNet(
    depth=5,
    basis_fn=mpot.potential.nnp.radial.GaussianRBF(10, 4.0),
    cutoff_fn=mpot.potential.nnp.cutoff.CosineCutoff(4.0),
    pp_nodes=[64, 64],
    pi_nodes=[64, 64],
    ii_nodes=[64, 64],
    activation=torch.nn.Tanh(),
)
readout = mpot.potential.nnp.readout.Atomwise(
    n_in=64,
    n_out=1,
    from_key=("pinet", "p1"),
    to_key=("predict", "energy")
)
model = mpot.potential.PotentialSeq("pinet", pinet, readout)

In [14]:
from ignite.handlers import (
    Checkpoint,
    global_step_from_engine,
    TensorboardLogger,
    ProgressBar,
)
from ignite.engine import Events
from ignite.metrics import Accuracy, MeanAbsoluteError
from pathlib import Path

trainer = mpot.PotentialTrainer(
    model,
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
    loss_fn=mpot.loss(torch.nn.MSELoss(), ("predict", "energy"), ("Ef")),
    device="cuda",
    output_transform=lambda x, y, y_pred, loss: y_pred,
)
trainer.add_evaluator(metrics={"mae": MeanAbsoluteError(), "accuracy": Accuracy()})
trainer.add_event(
    "trainer",
    Events.COMPLETED,
    Checkpoint(
        {"model": model, "optimizer": trainer.optimizer},
        save_handler=Path("pinet_perovskites.ckpt"),
        n_saved=1,
        global_step_transform=global_step_from_engine(trainer.trainer),
    ),
)

tb_logger = TensorboardLogger(
    log_dir=Path("pinet_perovskites"),
)

tb_logger.attach_output_handler(
    trainer.trainer,
    event_name=Events.ITERATION_COMPLETED,
    tag="training",
    metric_names=["mae"],
    global_step_transform=global_step_from_engine(trainer.trainer),
)
tb_logger.attach_output_handler(
    trainer.evaluator,
    event_name=Events.ITERATION_COMPLETED,
    tag="validation",
    metric_names=["mae"],
    global_step_transform=global_step_from_engine(trainer.trainer),
)


def output_transform(td):
    return


trainer.add_metric(
    "mae",
    MeanAbsoluteError(
        output_transform=lambda td: {
            "y_pred": td["predict"]["energy"],
            "y": td["Ef"],
        }
    ),
    "trainer",
)
trainer.add_metric(
    "mae",
    MeanAbsoluteError(
        output_transform=lambda td: {
            "y_pred": td["predict"]["energy"],
            "y": td["Ef"],
        }
    ),
    "evaluator",
)
trainer.enable_progressbar()
trainer.run(dl, max_epochs=1)



[1/1893]   0%|           [00:00<?]