diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index a12b87c1..038aa5e7 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -10,7 +10,6 @@ try: from orb_models.forcefield import pretrained - from orb_models.forcefield.atomic_system import SystemConfig from orb_models.forcefield.calculator import ORBCalculator from torch_sim.models.orb import OrbModel @@ -19,63 +18,65 @@ @pytest.fixture -def dtype() -> torch.dtype: - """Fixture to provide the default dtype for testing.""" - return torch.float32 - - -@pytest.fixture -def pretrained_orb_model(device: torch.device): - """Load a pretrained ORB model for testing.""" - return pretrained.orb_v2(device=device) +def orbv3_conservative_inf_omat_model(device: torch.device) -> OrbModel: + orbff = pretrained.orb_v3_conservative_inf_omat( + device=device, + precision="float32-high", + ) + return OrbModel(model=orbff, device=device) @pytest.fixture -def orb_model(pretrained_orb_model: torch.nn.Module, device: torch.device) -> OrbModel: - """Create an OrbModel wrapper for the pretrained model.""" - return OrbModel( - model=pretrained_orb_model, +def orbv3_direct_20_omat_model(device: torch.device) -> OrbModel: + orbff = pretrained.orb_v3_direct_20_omat( device=device, - system_config=SystemConfig(radius=6.0, max_num_neighbors=20), + precision="float32-high", ) + return OrbModel(model=orbff, device=device) @pytest.fixture -def orb_calculator( - pretrained_orb_model: torch.nn.Module, device: torch.device -) -> ORBCalculator: +def orbv3_conservative_inf_omat_calculator(device: torch.device) -> ORBCalculator: """Create an ORBCalculator for the pretrained model.""" - return ORBCalculator( - model=pretrained_orb_model, - system_config=SystemConfig(radius=6.0, max_num_neighbors=20), + orbff = pretrained.orb_v3_conservative_inf_omat( device=device, + precision="float32-high", ) + return ORBCalculator(model=orbff, device=device) -def test_orb_initialization( - pretrained_orb_model: torch.nn.Module, device: torch.device -) -> None: - """Test that the ORB model initializes correctly.""" - model = OrbModel( - model=pretrained_orb_model, +@pytest.fixture +def orbv3_direct_20_omat_calculator(device: torch.device) -> ORBCalculator: + """Create an ORBCalculator for the pretrained model.""" + orbff = pretrained.orb_v3_direct_20_omat( device=device, + precision="float32-high", ) - # Check that properties were set correctly - assert "energy" in model.implemented_properties - assert "forces" in model.implemented_properties - assert model._device == device # noqa: SLF001 + return ORBCalculator(model=orbff, device=device) + +test_orb_conservative_consistency = make_model_calculator_consistency_test( + test_name="orbv3_conservative_inf_omat", + model_fixture_name="orbv3_conservative_inf_omat_model", + calculator_fixture_name="orbv3_conservative_inf_omat_calculator", + sim_state_names=consistency_test_simstate_fixtures, + atol=5e-4, + rtol=5e-4, +) -test_orb_consistency = make_model_calculator_consistency_test( - test_name="orb", - model_fixture_name="orb_model", - calculator_fixture_name="orb_calculator", +test_orb_direct_consistency = make_model_calculator_consistency_test( + test_name="orbv3_direct_20_omat", + model_fixture_name="orbv3_direct_20_omat_model", + calculator_fixture_name="orbv3_direct_20_omat_calculator", sim_state_names=consistency_test_simstate_fixtures, atol=5e-4, - rtol=1, + rtol=5e-4, ) +test_validate_conservative_model_outputs = make_validate_model_outputs_test( + model_fixture_name="orbv3_conservative_inf_omat_model", +) -test_validate_model_outputs = make_validate_model_outputs_test( - model_fixture_name="orb_model", +test_validate_direct_model_outputs = make_validate_model_outputs_test( + model_fixture_name="orbv3_direct_20_omat_model", ) diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 49c7f6ea..6a1b3066 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -312,17 +312,15 @@ def __init__( self._device = torch.device(self._device) self._dtype = dtype - self._memory_scales_with = "n_atoms" self._compute_stress = compute_stress self._compute_forces = compute_forces # Set up system configuration - self.system_config = system_config or SystemConfig( - radius=6.0, max_num_neighbors=20 - ) + self.system_config = system_config or model.system_config self._max_num_neighbors = max_num_neighbors self._edge_method = edge_method self._half_supercell = half_supercell + self.conservative = conservative # Load model if path is provided if isinstance(model, str | Path): @@ -336,9 +334,8 @@ def __init__( # Determine if the model is conservative model_is_conservative = hasattr(self.model, "grad_forces_name") - self.conservative = ( - conservative if conservative is not None else model_is_conservative - ) + if self.conservative is None: + self.conservative = model_is_conservative if self.conservative and not model_is_conservative: raise ValueError( @@ -351,10 +348,7 @@ def __init__( # Add forces and stress to implemented properties if conservative model if self.conservative: - if "forces" not in self.implemented_properties: - self.implemented_properties.append("forces") - if compute_stress and "stress" not in self.implemented_properties: - self.implemented_properties.append("stress") + self.implemented_properties.extend(["forces", "stress"]) def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: """Perform forward pass to compute energies, forces, and other properties. @@ -417,13 +411,10 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: results[prop] = predictions[_property].squeeze() if self.conservative: - results["direct_forces"] = results["forces"] - results["direct_stress"] = results["stress"] results["forces"] = results[self.model.grad_forces_name] results["stress"] = results[self.model.grad_stress_name] if "stress" in results and results["stress"].shape[-1] == 6: - # TODO: is there a point to converting the direct stress if conservative? results["stress"] = voigt_6_to_full_3x3_stress(results["stress"]) return results