Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 34 additions & 22 deletions torch_sim/models/sevennet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import traceback
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any

import torch
Expand All @@ -27,6 +28,7 @@
import torch
from sevenn.atom_graph_data import AtomGraphData
from sevenn.calculator import torch_script_type
from sevenn.util import load_checkpoint
from torch_geometric.loader.dataloader import Collater

except ImportError as exc:
Expand All @@ -44,6 +46,27 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None:
raise err


def _validate(model: AtomGraphSequential, modal: str) -> None:
if not model.type_map:
raise ValueError("type_map is missing")

if model.cutoff == 0.0:
raise ValueError("Model cutoff seems not initialized")

modal_map = model.modal_map
if modal_map:
modal_ava = list(modal_map)
if not modal:
raise ValueError(f"modal argument missing (avail: {modal_ava})")
if modal not in modal_ava:
raise ValueError(f"unknown modal {modal} (not in {modal_ava})")
elif not model.modal_map and modal:
warnings.warn(
f"modal={modal} is ignored as model has no modal_map",
stacklevel=2,
)


class SevenNetModel(ModelInterface):
"""Computes atomistic energies, forces and stresses using an SevenNet model.

Expand All @@ -59,7 +82,7 @@ class SevenNetModel(ModelInterface):

def __init__(
self,
model: AtomGraphSequential,
model: AtomGraphSequential | str | Path,
*, # force remaining arguments to be keyword-only
modal: str | None = None,
neighbor_list_fn: Callable = vesin_nl_ts,
Expand All @@ -72,7 +95,9 @@ def __init__(
Sets up the model parameters for subsequent use in energy and force calculations.

Args:
model (AtomGraphSequential): The SevenNet model to wrap.
model (str | Path | AtomGraphSequential): The SevenNet model to wrap.
Accepts either 1) a path to a checkpoint file, 2) a model instance,
or 3) a pretrained model name.
modal (str | None): modal (fidelity) if given model is multi-modal model.
for 7net-mf-ompa, it should be one of 'mpa' (MPtrj + sAlex) or 'omat24'
(OMat24).
Expand Down Expand Up @@ -103,39 +128,26 @@ def __init__(
stacklevel=2,
)

if not model.type_map:
raise ValueError("type_map is missing")
if isinstance(model, (str, Path)):
cp = load_checkpoint(model)
model = cp.build_model()

_validate(model, modal)

model.eval_type_map = torch.tensor(data=True)

self._dtype = dtype
self._memory_scales_with = "n_atoms_x_density"
self._compute_stress = True
self._compute_forces = True

if model.cutoff == 0.0:
raise ValueError("Model cutoff seems not initialized")

model.set_is_batch_data(True)
model_loaded = model
self.cutoff = torch.tensor(model.cutoff)
self.neighbor_list_fn = neighbor_list_fn

self.model = model_loaded

self.modal = None
modal_map = self.model.modal_map
if modal_map:
modal_ava = list(modal_map)
if not modal:
raise ValueError(f"modal argument missing (avail: {modal_ava})")
if modal not in modal_ava:
raise ValueError(f"unknown modal {modal} (not in {modal_ava})")
self.modal = modal
elif not self.model.modal_map and modal:
warnings.warn(
f"modal={modal} is ignored as model has no modal_map",
stacklevel=2,
)
self.modal = modal

self.model = model.to(self._device)
self.model = self.model.eval()
Expand Down