diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index cff515e3..5ddb1d09 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -33,8 +33,6 @@ from orb_models.forcefield import featurization_utilities as feat_util from orb_models.forcefield.atomic_system import SystemConfig from orb_models.forcefield.base import AtomGraphs, _map_concat - from orb_models.forcefield.featurization_utilities import EdgeCreationMethod - from orb_models.forcefield.graph_regressor import GraphRegressor except ImportError as exc: warnings.warn(f"Orb import failed: {traceback.format_exc()}", stacklevel=2) @@ -55,8 +53,8 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: from orb_models.forcefield.conservative_regressor import ( ConservativeForcefieldRegressor, ) + from orb_models.forcefield.direct_regressor import DirectForcefieldRegressor from orb_models.forcefield.featurization_utilities import EdgeCreationMethod - from orb_models.forcefield.graph_regressor import GraphRegressor from torch_sim.typing import StateDict @@ -256,7 +254,8 @@ class OrbModel(ModelInterface): predictions. Attributes: - model (Union[GraphRegressor, ConservativeForcefieldRegressor]): The ORB model + model (Union[DirectForcefieldRegressor, ConservativeForcefieldRegressor]): + The ORB model system_config (SystemConfig): Configuration for the atomic system conservative (bool): Whether to use conservative forces/stresses calculation implemented_properties (list): Properties the model can compute @@ -274,7 +273,7 @@ class OrbModel(ModelInterface): def __init__( self, - model: GraphRegressor | ConservativeForcefieldRegressor | str | Path, + model: DirectForcefieldRegressor | ConservativeForcefieldRegressor | str | Path, *, # force remaining arguments to be keyword-only conservative: bool | None = None, compute_stress: bool = True, @@ -292,7 +291,7 @@ def __init__( Sets up the model parameters for subsequent use in energy and force calculations. Args: - model (Union[GraphRegressor, ConservativeForcefieldRegressor, str, Path]): + model (DirectForcefieldRegressor|ConservativeForcefieldRegressor|str|Path): Either a model object or a path to a saved model conservative (bool | None): Whether to use conservative forces/stresses If None, determined based on model type