Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions torch_sim/models/orb.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
from orb_models.forcefield import featurization_utilities as feat_util
from orb_models.forcefield.atomic_system import SystemConfig
from orb_models.forcefield.base import AtomGraphs, _map_concat
from orb_models.forcefield.featurization_utilities import EdgeCreationMethod
from orb_models.forcefield.graph_regressor import GraphRegressor

except ImportError as exc:
warnings.warn(f"Orb import failed: {traceback.format_exc()}", stacklevel=2)
Expand All @@ -55,8 +53,8 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None:
from orb_models.forcefield.conservative_regressor import (
ConservativeForcefieldRegressor,
)
from orb_models.forcefield.direct_regressor import DirectForcefieldRegressor
from orb_models.forcefield.featurization_utilities import EdgeCreationMethod
from orb_models.forcefield.graph_regressor import GraphRegressor

from torch_sim.typing import StateDict

Expand Down Expand Up @@ -256,7 +254,8 @@ class OrbModel(ModelInterface):
predictions.

Attributes:
model (Union[GraphRegressor, ConservativeForcefieldRegressor]): The ORB model
model (Union[DirectForcefieldRegressor, ConservativeForcefieldRegressor]):
The ORB model
system_config (SystemConfig): Configuration for the atomic system
conservative (bool): Whether to use conservative forces/stresses calculation
implemented_properties (list): Properties the model can compute
Expand All @@ -274,7 +273,7 @@ class OrbModel(ModelInterface):

def __init__(
self,
model: GraphRegressor | ConservativeForcefieldRegressor | str | Path,
model: DirectForcefieldRegressor | ConservativeForcefieldRegressor | str | Path,
*, # force remaining arguments to be keyword-only
conservative: bool | None = None,
compute_stress: bool = True,
Expand All @@ -292,7 +291,7 @@ def __init__(
Sets up the model parameters for subsequent use in energy and force calculations.

Args:
model (Union[GraphRegressor, ConservativeForcefieldRegressor, str, Path]):
model (DirectForcefieldRegressor|ConservativeForcefieldRegressor|str|Path):
Either a model object or a path to a saved model
conservative (bool | None): Whether to use conservative forces/stresses
If None, determined based on model type
Expand Down
Loading