In [1]:
%load_ext autoreload
%autoreload 2
import moldb as mdb
from pony.orm import *

In [2]:
import requests
from pathlib import Path
if not Path('cubic_perovskites.db').exists():
    session = requests.get('https://cmr.fysik.dtu.dk/_downloads/03d2580a2f33d61c6998b803d2d72af0/cubic_perovskites.db')
    with open('cubic_perovskites.db', 'wb') as f:
        f.write(session.content)

In [5]:
!pwd
!rm ./cubic_perovskites.sqlite
db = mdb.Database()
db.bind(provider='sqlite', filename='/workspaces/molcrafts/moldb/example/cubic_perovskites.sqlite', create_db=True)
db.load_ase(
    "cubic_perovskites.db",
    selection="combination",
    table_name="CubicPerovskites"
    extra={"heat_of_formation_all": {"kind": Required(float), "unit": "eV/atom", 'dtype': 'float', "comment": 'Heat of formation per atom'}},
)


/workspaces/molcrafts/moldb/example


In [6]:
db.show_table('NameSpace', n_rows=None)

In [7]:
import molpot as mpot
from moldb import from_orm
from typing import Callable


ase_field_mapping = {
    "numbers": "Z",
    "positions": "xyz",
}


class MolDBDataset(mpot.Dataset):

    def __init__(self, name, db, table_name:str|None=None):
        super().__init__(name)
        self.db = db
        self.name = name
        self.table_name = table_name or name
        self.table = db.entities[self.table_name]

        for i in db.entities[f"{self.table_name}NameSpace"].select():
            alias = {k: v for k, v in i.to_dict().items() if k != "id"}
            alias['name'] = ase_field_mapping.get(alias['name'], alias['name'])
            self.labels.set(**alias)

    def preload(self, selection: Callable | None = None):
        data = []
        if selection is None:
            selection = lambda _: True

        for d in self.table.select(selection):
            data.append(self.to_frame(d))
        return data

    def __len__(self):
        return self.table.select().count()

    def to_frame(self, d):
        return mpot.Frame(
            {
                self.labels[ase_field_mapping.get(k, k)].key[1:]: from_orm(v)
                for k, v in d.to_dict().items()  # promote dataset to top by remove namespace
            }
        )

    def __getitem__(self, idx):
        d = self.table[idx + 1]
        return self.to_frame(d)


perovskites = MolDBDataset("cubic_perovskites", db)
frames = perovskites.preload()

KeyError: 'cubic_perovskites'

In [127]:
frames[0].keys()

_StringKeys(dict_keys(['heat_of_formation_all', 'id', 'unique_id', 'ctime', 'mtime', 'user', 'atoms', 'pbc', 'cell']))

In [128]:
frame = frames[0]
frame['atoms']

TensorDict(
    fields={
        Z: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int32, is_shared=False),
        xyz: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float64, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

In [129]:
dl = mpot.DataLoader(perovskites, batch_size=10)
for d in dl:
    print(d)
    break

AttributeError: module 'molpot.alias' has no attribute 'atom_batch_mask'