In [1]:
from __future__ import annotations

from typing import Any

import ase
import datasets
import numpy as np
from tqdm.auto import tqdm


def process(
    dataset_name: str,
    *,
    id_column: str,
    atomic_numbers_column: str,
    pos_column: str,
    cell_column: str | None,
    pbc_column: str | None,
    extra_columns: list[str] | None = None,
    split: str = "train",
    count: int = 10_000,
):
    ids = set()

    atoms_list: list[ase.Atoms] = []

    i = 0
    pbar = tqdm(total=count)
    printed = False
    for data_dict in datasets.load_dataset(dataset_name, split=split, streaming=True):
        if not printed:
            print({k: type(v) for k, v in data_dict.items()})
            printed = True
        if i >= count:
            break

        if (id := data_dict[id_column]) in ids:
            continue

        try:
            atoms_kwargs: dict[str, Any] = {}
            atoms_kwargs["numbers"] = np.array(
                data_dict[atomic_numbers_column], dtype=int
            )
            atoms_kwargs["positions"] = np.array(
                data_dict[pos_column], dtype=float
            ).reshape(-1, 3)
            if cell_column is not None:
                atoms_kwargs["cell"] = np.array(
                    data_dict[cell_column], dtype=float
                ).reshape(3, 3)
                atoms_kwargs["pbc"] = (
                    True if pbc_column is None else data_dict[pbc_column]
                )

            if extra_columns:
                info: dict[str, Any] = {}
                for col in extra_columns:
                    info[col] = data_dict[col]
                atoms_kwargs["info"] = info

            atoms = ase.Atoms(**atoms_kwargs)
            atoms_list.append(atoms)
        finally:
            i += 1
            pbar.update(1)
            ids.add(id)

    return atoms_list


stats = {}

In [2]:
stats["mptrj"] = process(
    "nimashoghi/mptrj",
    id_column="mp_id",
    atomic_numbers_column="numbers",
    pos_column="positions",
    cell_column="cell",
    pbc_column="pbc",
    extra_columns=[
        "corrected_total_energy",
        "forces",
        "stress",
    ],
)

  0%|          | 0/10000 [00:00<?, ?it/s]

{'numbers': <class 'list'>, 'positions': <class 'list'>, 'forces': <class 'list'>, 'cell': <class 'list'>, 'pbc': <class 'list'>, 'energy': <class 'float'>, 'stress': <class 'list'>, 'e_per_atom_relaxed': <class 'float'>, 'mp_id': <class 'str'>, 'energy_per_atom': <class 'float'>, 'ef_per_atom_relaxed': <class 'float'>, 'corrected_total_energy': <class 'float'>, 'ef_per_atom': <class 'float'>, 'task_id': <class 'str'>, 'calc_id': <class 'int'>, 'ionic_step': <class 'int'>, 'filename': <class 'str'>, 'extxyz_id': <class 'int'>, 'num_atoms': <class 'int'>, 'corrected_total_energy_relaxed': <class 'float'>, 'energy_referenced': <class 'float'>, 'corrected_total_energy_referenced': <class 'float'>, 'corrected_total_energy_relaxed_referenced': <class 'float'>, 'composition': <class 'list'>}


In [3]:
stats["wbm"] = process(
    "nimashoghi/wbm",
    id_column="material_id",
    atomic_numbers_column="atomic_numbers",
    pos_column="cart_pos",
    cell_column="cell",
    pbc_column=None,
    extra_columns=["uncorrected_energy", "e_correction_per_atom_mp2020"],
)

  0%|          | 0/10000 [00:00<?, ?it/s]

{'formula': <class 'str'>, 'n_sites': <class 'float'>, 'volume': <class 'float'>, 'uncorrected_energy': <class 'float'>, 'e_form_per_atom_wbm': <class 'float'>, 'e_above_hull_wbm': <class 'float'>, 'bandgap_pbe': <class 'float'>, 'wyckoff_spglib_initial_structure': <class 'str'>, 'uncorrected_energy_from_cse': <class 'float'>, 'e_correction_per_atom_mp2020': <class 'float'>, 'e_correction_per_atom_mp_legacy': <class 'float'>, 'e_form_per_atom_uncorrected': <class 'float'>, 'e_form_per_atom_mp2020_corrected': <class 'float'>, 'e_above_hull_mp2020_corrected_ppd_mp': <class 'float'>, 'site_stats_fingerprint_init_final_norm_diff': <class 'float'>, 'wyckoff_spglib': <class 'str'>, 'unique_prototype': <class 'bool'>, 'formula_from_cse': <class 'str'>, 'initial_structure': <class 'dict'>, 'id': <class 'str'>, 'material_id': <class 'str'>, 'frac_pos': <class 'list'>, 'cart_pos': <class 'list'>, 'pos': <class 'list'>, 'cell': <class 'list'>, 'num_atoms': <class 'int'>, 'atomic_numbers': <class 

In [4]:
stats["oc20"] = process(
    "nimashoghi/oc20-s2ef",
    id_column="sid",
    atomic_numbers_column="atomic_numbers",
    pos_column="pos",
    cell_column="cell",
    pbc_column=None,
    extra_columns=["energy", "forces"],
    split="2M",
)

  0%|          | 0/10000 [00:00<?, ?it/s]

{'sid': <class 'str'>, 'fid': <class 'str'>, 'reference_energy': <class 'float'>, 'num_atoms': <class 'int'>, 'atomic_numbers': <class 'list'>, 'pos': <class 'list'>, 'energy': <class 'float'>, 'forces': <class 'list'>, 'cell': <class 'list'>, 'fixed': <class 'list'>, 'tags': <class 'list'>}


In [5]:
stats["oc22"] = process(
    "nimashoghi/oc22",
    id_column="sid",
    atomic_numbers_column="atomic_numbers",
    pos_column="pos",
    cell_column="cell",
    pbc_column=None,
    extra_columns=["y", "force"],
)

  0%|          | 0/10000 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/82 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/82 [00:00<?, ?it/s]

{'pos': <class 'list'>, 'cell': <class 'list'>, 'atomic_numbers': <class 'list'>, 'natoms': <class 'int'>, 'fixed': <class 'list'>, 'tags': <class 'list'>, 'sid': <class 'int'>, 'fid': <class 'int'>, 'id': <class 'str'>, 'oc22': <class 'int'>, 'composition': <class 'list'>, 'y': <class 'float'>, 'force': <class 'list'>}


In [7]:
stats["ani1x"] = process(
    "nimashoghi/ani1x",
    id_column="sid",
    atomic_numbers_column="atomic_numbers",
    pos_column="pos",
    cell_column=None,
    pbc_column=None,
    extra_columns=["y", "force"],
)

  0%|          | 0/10000 [00:00<?, ?it/s]

{'y': <class 'float'>, 'pos': <class 'list'>, 'atomic_numbers': <class 'list'>, 'sid': <class 'int'>, 'fid': <class 'int'>, 'natoms': <class 'int'>, 'tags': <class 'list'>, 'force': <class 'list'>, 'pbc': <class 'list'>, 'id': <class 'str'>}


In [9]:
stats["trans1x"] = process(
    "nimashoghi/transition1x",
    id_column="sid",
    atomic_numbers_column="atomic_numbers",
    pos_column="pos",
    cell_column=None,
    pbc_column=None,
    extra_columns=["y", "force"],
    count=9000,
)

  0%|          | 0/9000 [00:00<?, ?it/s]

{'y': <class 'float'>, 'pos': <class 'list'>, 'atomic_numbers': <class 'list'>, 'sid': <class 'int'>, 'fid': <class 'int'>, 'natoms': <class 'int'>, 'tags': <class 'list'>, 'force': <class 'list'>, 'pbc': <class 'list'>, 'id': <class 'str'>}


In [10]:
import pickle

!mkdir -p /mnt/shared/jmp-distributions-all-11-4
with open("/mnt/shared/jmp-distributions-all-11-4/ase_atoms_list.pkl", "wb") as f:
    pickle.dump(stats, f)

  [31m×[0m could not find pixi.toml or pyproject.toml which is configured to use pixi

