Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 39 additions & 38 deletions tests/models/test_orb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

try:
from orb_models.forcefield import pretrained
from orb_models.forcefield.atomic_system import SystemConfig
from orb_models.forcefield.calculator import ORBCalculator

from torch_sim.models.orb import OrbModel
Expand All @@ -19,63 +18,65 @@


@pytest.fixture
def dtype() -> torch.dtype:
"""Fixture to provide the default dtype for testing."""
return torch.float32


@pytest.fixture
def pretrained_orb_model(device: torch.device):
"""Load a pretrained ORB model for testing."""
return pretrained.orb_v2(device=device)
def orbv3_conservative_inf_omat_model(device: torch.device) -> OrbModel:
orbff = pretrained.orb_v3_conservative_inf_omat(
device=device,
precision="float32-high",
)
return OrbModel(model=orbff, device=device)


@pytest.fixture
def orb_model(pretrained_orb_model: torch.nn.Module, device: torch.device) -> OrbModel:
"""Create an OrbModel wrapper for the pretrained model."""
return OrbModel(
model=pretrained_orb_model,
def orbv3_direct_20_omat_model(device: torch.device) -> OrbModel:
orbff = pretrained.orb_v3_direct_20_omat(
device=device,
system_config=SystemConfig(radius=6.0, max_num_neighbors=20),
precision="float32-high",
)
return OrbModel(model=orbff, device=device)


@pytest.fixture
def orb_calculator(
pretrained_orb_model: torch.nn.Module, device: torch.device
) -> ORBCalculator:
def orbv3_conservative_inf_omat_calculator(device: torch.device) -> ORBCalculator:
"""Create an ORBCalculator for the pretrained model."""
return ORBCalculator(
model=pretrained_orb_model,
system_config=SystemConfig(radius=6.0, max_num_neighbors=20),
orbff = pretrained.orb_v3_conservative_inf_omat(
device=device,
precision="float32-high",
)
return ORBCalculator(model=orbff, device=device)


def test_orb_initialization(
pretrained_orb_model: torch.nn.Module, device: torch.device
) -> None:
"""Test that the ORB model initializes correctly."""
model = OrbModel(
model=pretrained_orb_model,
@pytest.fixture
def orbv3_direct_20_omat_calculator(device: torch.device) -> ORBCalculator:
"""Create an ORBCalculator for the pretrained model."""
orbff = pretrained.orb_v3_direct_20_omat(
device=device,
precision="float32-high",
)
# Check that properties were set correctly
assert "energy" in model.implemented_properties
assert "forces" in model.implemented_properties
assert model._device == device # noqa: SLF001
return ORBCalculator(model=orbff, device=device)


test_orb_conservative_consistency = make_model_calculator_consistency_test(
test_name="orbv3_conservative_inf_omat",
model_fixture_name="orbv3_conservative_inf_omat_model",
calculator_fixture_name="orbv3_conservative_inf_omat_calculator",
sim_state_names=consistency_test_simstate_fixtures,
atol=5e-4,
rtol=5e-4,
)

test_orb_consistency = make_model_calculator_consistency_test(
test_name="orb",
model_fixture_name="orb_model",
calculator_fixture_name="orb_calculator",
test_orb_direct_consistency = make_model_calculator_consistency_test(
test_name="orbv3_direct_20_omat",
model_fixture_name="orbv3_direct_20_omat_model",
calculator_fixture_name="orbv3_direct_20_omat_calculator",
sim_state_names=consistency_test_simstate_fixtures,
atol=5e-4,
rtol=1,
rtol=5e-4,
)

test_validate_conservative_model_outputs = make_validate_model_outputs_test(
model_fixture_name="orbv3_conservative_inf_omat_model",
)

test_validate_model_outputs = make_validate_model_outputs_test(
model_fixture_name="orb_model",
test_validate_direct_model_outputs = make_validate_model_outputs_test(
model_fixture_name="orbv3_direct_20_omat_model",
)
19 changes: 5 additions & 14 deletions torch_sim/models/orb.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,17 +312,15 @@
self._device = torch.device(self._device)

self._dtype = dtype
self._memory_scales_with = "n_atoms"
self._compute_stress = compute_stress
self._compute_forces = compute_forces

# Set up system configuration
self.system_config = system_config or SystemConfig(
radius=6.0, max_num_neighbors=20
)
self.system_config = system_config or model.system_config

Check warning on line 319 in torch_sim/models/orb.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/orb.py#L319

Added line #L319 was not covered by tests
self._max_num_neighbors = max_num_neighbors
self._edge_method = edge_method
self._half_supercell = half_supercell
self.conservative = conservative

Check warning on line 323 in torch_sim/models/orb.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/orb.py#L323

Added line #L323 was not covered by tests

# Load model if path is provided
if isinstance(model, str | Path):
Expand All @@ -336,9 +334,8 @@

# Determine if the model is conservative
model_is_conservative = hasattr(self.model, "grad_forces_name")
self.conservative = (
conservative if conservative is not None else model_is_conservative
)
if self.conservative is None:
self.conservative = model_is_conservative

Check warning on line 338 in torch_sim/models/orb.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/orb.py#L337-L338

Added lines #L337 - L338 were not covered by tests

if self.conservative and not model_is_conservative:
raise ValueError(
Expand All @@ -351,10 +348,7 @@

# Add forces and stress to implemented properties if conservative model
if self.conservative:
if "forces" not in self.implemented_properties:
self.implemented_properties.append("forces")
if compute_stress and "stress" not in self.implemented_properties:
self.implemented_properties.append("stress")
self.implemented_properties.extend(["forces", "stress"])

Check warning on line 351 in torch_sim/models/orb.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/orb.py#L351

Added line #L351 was not covered by tests

def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]:
"""Perform forward pass to compute energies, forces, and other properties.
Expand Down Expand Up @@ -417,13 +411,10 @@
results[prop] = predictions[_property].squeeze()

if self.conservative:
results["direct_forces"] = results["forces"]
results["direct_stress"] = results["stress"]
results["forces"] = results[self.model.grad_forces_name]
results["stress"] = results[self.model.grad_stress_name]

if "stress" in results and results["stress"].shape[-1] == 6:
# TODO: is there a point to converting the direct stress if conservative?
results["stress"] = voigt_6_to_full_3x3_stress(results["stress"])

return results