diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 6a5d81b2..90c56892 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -4,6 +4,7 @@ import traceback import warnings +from pathlib import Path from typing import TYPE_CHECKING, Any import torch @@ -27,6 +28,7 @@ import torch from sevenn.atom_graph_data import AtomGraphData from sevenn.calculator import torch_script_type + from sevenn.util import load_checkpoint from torch_geometric.loader.dataloader import Collater except ImportError as exc: @@ -44,6 +46,27 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: raise err +def _validate(model: AtomGraphSequential, modal: str) -> None: + if not model.type_map: + raise ValueError("type_map is missing") + + if model.cutoff == 0.0: + raise ValueError("Model cutoff seems not initialized") + + modal_map = model.modal_map + if modal_map: + modal_ava = list(modal_map) + if not modal: + raise ValueError(f"modal argument missing (avail: {modal_ava})") + if modal not in modal_ava: + raise ValueError(f"unknown modal {modal} (not in {modal_ava})") + elif not model.modal_map and modal: + warnings.warn( + f"modal={modal} is ignored as model has no modal_map", + stacklevel=2, + ) + + class SevenNetModel(ModelInterface): """Computes atomistic energies, forces and stresses using an SevenNet model. @@ -59,7 +82,7 @@ class SevenNetModel(ModelInterface): def __init__( self, - model: AtomGraphSequential, + model: AtomGraphSequential | str | Path, *, # force remaining arguments to be keyword-only modal: str | None = None, neighbor_list_fn: Callable = vesin_nl_ts, @@ -72,7 +95,9 @@ def __init__( Sets up the model parameters for subsequent use in energy and force calculations. Args: - model (AtomGraphSequential): The SevenNet model to wrap. + model (str | Path | AtomGraphSequential): The SevenNet model to wrap. + Accepts either 1) a path to a checkpoint file, 2) a model instance, + or 3) a pretrained model name. modal (str | None): modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa, it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24). @@ -103,8 +128,12 @@ def __init__( stacklevel=2, ) - if not model.type_map: - raise ValueError("type_map is missing") + if isinstance(model, (str, Path)): + cp = load_checkpoint(model) + model = cp.build_model() + + _validate(model, modal) + model.eval_type_map = torch.tensor(data=True) self._dtype = dtype @@ -112,30 +141,13 @@ def __init__( self._compute_stress = True self._compute_forces = True - if model.cutoff == 0.0: - raise ValueError("Model cutoff seems not initialized") - model.set_is_batch_data(True) model_loaded = model self.cutoff = torch.tensor(model.cutoff) self.neighbor_list_fn = neighbor_list_fn self.model = model_loaded - - self.modal = None - modal_map = self.model.modal_map - if modal_map: - modal_ava = list(modal_map) - if not modal: - raise ValueError(f"modal argument missing (avail: {modal_ava})") - if modal not in modal_ava: - raise ValueError(f"unknown modal {modal} (not in {modal_ava})") - self.modal = modal - elif not self.model.modal_map and modal: - warnings.warn( - f"modal={modal} is ignored as model has no modal_map", - stacklevel=2, - ) + self.modal = modal self.model = model.to(self._device) self.model = self.model.eval()