From 6c3551ec250eddb9870369ae434986c51cf2642a Mon Sep 17 00:00:00 2001 From: Adeesh Kolluru Date: Wed, 9 Apr 2025 14:32:27 -0700 Subject: [PATCH 01/10] fix: revert a few changes and add a proper fix for pbc handling with caution --- torch_sim/models/fairchem.py | 40 ++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index e2ce6bc1..6d4980af 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -76,7 +76,6 @@ class FairChemModel(torch.nn.Module, ModelInterface): Attributes: neighbor_list_fn (Callable | None): Function to compute neighbor lists - r_max (float): Maximum cutoff radius for atomic interactions in Ångström config (dict): Complete model configuration dictionary trainer: FairChem trainer object that contains the model data_object (Batch): Data object containing system information @@ -108,9 +107,9 @@ def __init__( # noqa: C901, PLR0915 trainer: str | None = None, cpu: bool = False, seed: int | None = None, - r_max: float | None = None, # noqa: ARG002 dtype: torch.dtype | None = None, compute_stress: bool = False, + pbc: bool = True, ) -> None: """Initialize the FairChemModel with specified configuration. @@ -128,9 +127,9 @@ def __init__( # noqa: C901, PLR0915 trainer (str | None): Name of trainer class to use cpu (bool): Whether to use CPU instead of GPU for computation seed (int | None): Random seed for reproducibility - r_max (float | None): Maximum cutoff radius (overrides model default) dtype (torch.dtype | None): Data type to use for computation compute_stress (bool): Whether to compute stress tensor + pbc (bool): Whether to use periodic boundary conditions Raises: RuntimeError: If both model_name and model are specified @@ -150,6 +149,7 @@ def __init__( # noqa: C901, PLR0915 self._compute_stress = compute_stress self._compute_forces = True self._memory_scales_with = "n_atoms" + self.pbc = pbc if model_name is not None: if model is not None: @@ -215,6 +215,14 @@ def __init__( # noqa: C901, PLR0915 ) if "backbone" in config["model"]: + if config["model"]["backbone"]["use_pbc"] != pbc: + print( + f"WARNING: PBC mismatch between model and state. " + "The model loaded was trained with" + f"PBC={config['model']['backbone']['use_pbc']} " + f"and you are using PBC={pbc}." + ) + config["model"]["backbone"]["use_pbc"] = pbc config["model"]["backbone"]["use_pbc_single"] = False if dtype is not None: try: @@ -224,14 +232,26 @@ def __init__( # noqa: C901, PLR0915 {"dtype": _DTYPE_DICT[dtype]} ) except KeyError: - print("dtype not found in backbone, using default float32") + print( + "WARNING: dtype not found in backbone, using default model dtype" + ) else: + if config["model"]["use_pbc"] != pbc: + print( + f"WARNING: PBC mismatch between model and state. " + f"The model loaded was trained with" + f"PBC={config['model']['use_pbc']} " + f"and you are using PBC={pbc}." + ) + config["model"]["use_pbc"] = pbc config["model"]["use_pbc_single"] = False if dtype is not None: try: config["model"].update({"dtype": _DTYPE_DICT[dtype]}) except KeyError: - print("dtype not found in backbone, using default dtype") + print( + "WARNING: dtype not found in backbone, using default model dtype" + ) ### backwards compatibility with OCP v<2.0 config = update_config(config) @@ -257,11 +277,9 @@ def __init__( # noqa: C901, PLR0915 inference_only=True, ) - self.trainer.model = self.trainer.model.eval() - if dtype is not None: # Convert model parameters to specified dtype - self.trainer.model = self.trainer.model.to(dtype=self.dtype) + self.trainer.model = self.trainer.model.to(dtype=self._dtype) if model is not None: self.load_checkpoint(checkpoint_path=model, checkpoint=checkpoint) @@ -338,6 +356,12 @@ def forward(self, state: SimState | StateDict) -> dict: if state.batch is None: state.batch = torch.zeros(state.positions.shape[0], dtype=torch.int) + if self.pbc != state.pbc: + raise ValueError( + "PBC mismatch between model and state. " + "For FairChemModel PBC needs to be defined in the model class." + ) + natoms = torch.bincount(state.batch) pbc = torch.tensor( [state.pbc, state.pbc, state.pbc] * len(natoms), dtype=torch.bool From bdb00c1f404d77dc0b516411fa3a3bbb7178a0b9 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 10 Apr 2025 18:26:17 -0400 Subject: [PATCH 02/10] test: remove benezene from fairchem tests. --- tests/models/test_fairchem.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index b2af7d35..30f5f289 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -27,12 +27,13 @@ def model_path(tmp_path_factory: pytest.TempPathFactory) -> str: @pytest.fixture -def fairchem_model(model_path: str, device: torch.device) -> FairChemModel: +def fairchem_model_pbc(model_path: str, device: torch.device) -> FairChemModel: cpu = device.type == "cpu" return FairChemModel( model=model_path, cpu=cpu, seed=0, + pbc=True, ) @@ -41,28 +42,20 @@ def ocp_calculator(model_path: str) -> OCPCalculator: return OCPCalculator(checkpoint_path=model_path, cpu=False, seed=0) -test_fairchem_ocp_consistency = make_model_calculator_consistency_test( +test_fairchem_ocp_consistency_pbc = make_model_calculator_consistency_test( test_name="fairchem_ocp", - model_fixture_name="fairchem_model", + model_fixture_name="fairchem_model_pbc", calculator_fixture_name="ocp_calculator", - sim_state_names=consistency_test_simstate_fixtures, + sim_state_names=consistency_test_simstate_fixtures[:-1], rtol=5e-4, # NOTE: fairchem doesn't pass at the 1e-5 level used for other models atol=5e-4, ) +# TODO: add test for non-PBC model # fairchem batching is broken on CPU, do not replicate this skipping -# logic in other models tests -# @pytest.mark.skipif( -# not torch.cuda.is_available(), -# reason="Batching does not work properly on CPU for FAIRchem", -# ) -# def test_validate_model_outputs( -# fairchem_model: FairChemModel, device: torch.device -# ) -> None: -# validate_model_outputs(fairchem_model, device, torch.float32) - - +# logic in other models tests. This is due to issues with how the models +# handle supercells (see related issue here: https://github.com/FAIR-Chem/fairchem/issues/428) test_fairchem_ocp_model_outputs = pytest.mark.skipif( not torch.cuda.is_available(), reason="Batching does not work properly on CPU for FAIRchem", From 73268bfea789b8aca73685ecd049c18c0fac62fc Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 10 Apr 2025 19:28:08 -0400 Subject: [PATCH 03/10] nit: be consistent about using dtype vs _dtype and device vs _device in models --- torch_sim/models/fairchem.py | 8 ++++---- torch_sim/models/graphpes.py | 2 +- torch_sim/models/lennard_jones.py | 16 +++++++--------- torch_sim/models/mace.py | 4 ++-- torch_sim/models/mattersim.py | 4 ++-- torch_sim/models/morse.py | 16 ++++++++-------- torch_sim/models/orb.py | 6 +++--- torch_sim/models/sevennet.py | 4 ++-- torch_sim/unbatched/models/lennard_jones.py | 16 +++++++--------- torch_sim/unbatched/models/morse.py | 16 ++++++++-------- torch_sim/unbatched/models/particle_life.py | 8 ++++---- torch_sim/unbatched/models/soft_sphere.py | 8 ++++---- 12 files changed, 52 insertions(+), 56 deletions(-) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index f6d0ad94..24d2f884 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -279,7 +279,7 @@ def __init__( # noqa: C901, PLR0915 if dtype is not None: # Convert model parameters to specified dtype - self.trainer.model = self.trainer.model.to(dtype=self._dtype) + self.trainer.model = self.trainer.model.to(dtype=self.dtype) if model is not None: self.load_checkpoint(checkpoint_path=model, checkpoint=checkpoint) @@ -374,9 +374,9 @@ def forward(self, state: SimState | StateDict) -> dict: pbc=pbc, ) - if self._dtype is not None: - self.data_object.pos = self.data_object.pos.to(self._dtype) - self.data_object.cell = self.data_object.cell.to(self._dtype) + if self.dtype is not None: + self.data_object.pos = self.data_object.pos.to(self.dtype) + self.data_object.cell = self.data_object.cell.to(self.dtype) predictions = self.trainer.predict( self.data_object, per_image=False, disable_tqdm=True diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index 6f6ed5b5..e2d42da3 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -154,7 +154,7 @@ def __init__( model if isinstance(model, GraphPESModel) else load_model(model) # type: ignore[arg-type] ), ) - self._gp_model = _model.to(device=self._device, dtype=self._dtype) + self._gp_model = _model.to(device=self.device, dtype=self.dtype) self._compute_forces = compute_forces self._compute_stress = compute_stress diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index 25af4b3c..0e34b67f 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -138,11 +138,9 @@ def __init__( self.use_neighbor_list = use_neighbor_list # Convert parameters to tensors - self.sigma = torch.tensor(sigma, dtype=dtype, device=self._device) - self.cutoff = torch.tensor( - cutoff or 2.5 * sigma, dtype=dtype, device=self._device - ) - self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self._device) + self.sigma = torch.tensor(sigma, dtype=dtype, device=self.device) + self.cutoff = torch.tensor(cutoff or 2.5 * sigma, dtype=dtype, device=self.device) + self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self.device) def unbatched_forward( self, @@ -209,7 +207,7 @@ def unbatched_forward( pbc=pbc, ) # Mask out self-interactions - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self._device) + mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) # Apply cutoff mask = distances < self.cutoff @@ -233,7 +231,7 @@ def unbatched_forward( if self.per_atom_energies: atom_energies = torch.zeros( - positions.shape[0], dtype=self._dtype, device=self._device + positions.shape[0], dtype=self.dtype, device=self.device ) # Each atom gets half of the pair energy atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) @@ -268,8 +266,8 @@ def unbatched_forward( if self.per_atom_stresses: atom_stresses = torch.zeros( (state.positions.shape[0], 3, 3), - dtype=self._dtype, - device=self._device, + dtype=self.dtype, + device=self.device, ) atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 64c4338c..7c21fa15 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -160,8 +160,8 @@ def __init__( self.model = model.to(self._device) self.model = self.model.eval() - if self._dtype is not None: - self.model = self.model.to(dtype=self._dtype) + if self.dtype is not None: + self.model = self.model.to(dtype=self.dtype) if enable_cueq: print("Converting models to CuEq for acceleration") diff --git a/torch_sim/models/mattersim.py b/torch_sim/models/mattersim.py index 1aa89932..e431ee96 100644 --- a/torch_sim/models/mattersim.py +++ b/torch_sim/models/mattersim.py @@ -85,8 +85,8 @@ def __init__( self.model = model.to(self._device) self.model = self.model.eval() - if self._dtype is not None: - self.model = self.model.to(dtype=self._dtype) + if self.dtype is not None: + self.model = self.model.to(dtype=self.dtype) model_args = self.model.model.model_args self.two_body_cutoff = model_args["cutoff"] diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index c0fda40b..8cde515d 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -142,12 +142,12 @@ def __init__( self._per_atom_stresses = per_atom_stresses self.use_neighbor_list = use_neighbor_list # Convert parameters to tensors - self.sigma = torch.tensor(sigma, dtype=self._dtype, device=self._device) + self.sigma = torch.tensor(sigma, dtype=self.dtype, device=self.device) self.cutoff = torch.tensor( - cutoff or 2.5 * sigma, dtype=self._dtype, device=self._device + cutoff or 2.5 * sigma, dtype=self.dtype, device=self.device ) - self.epsilon = torch.tensor(epsilon, dtype=self._dtype, device=self._device) - self.alpha = torch.tensor(alpha, dtype=self._dtype, device=self._device) + self.epsilon = torch.tensor(epsilon, dtype=self.dtype, device=self.device) + self.alpha = torch.tensor(alpha, dtype=self.dtype, device=self.device) def unbatched_forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: """Compute Morse potential properties for a single unbatched system. @@ -205,7 +205,7 @@ def unbatched_forward(self, state: SimState | StateDict) -> dict[str, torch.Tens cell=cell, pbc=pbc, ) - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self._device) + mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) mask = distances < self.cutoff i, j = torch.where(mask) @@ -225,7 +225,7 @@ def unbatched_forward(self, state: SimState | StateDict) -> dict[str, torch.Tens if self._per_atom_energies: atom_energies = torch.zeros( - positions.shape[0], dtype=self._dtype, device=self._device + positions.shape[0], dtype=self.dtype, device=self.device ) atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) @@ -254,8 +254,8 @@ def unbatched_forward(self, state: SimState | StateDict) -> dict[str, torch.Tens if self._per_atom_stresses: atom_stresses = torch.zeros( (state.positions.shape[0], 3, 3), - dtype=self._dtype, - device=self._device, + dtype=self.dtype, + device=self.device, ) atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index adab5895..49c7f6ea 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -331,8 +331,8 @@ def __init__( self.model = model.to(self._device) self.model = self.model.eval() - if self._dtype is not None: - self.model = self.model.to(dtype=self._dtype) + if self.dtype is not None: + self.model = self.model.to(dtype=self.dtype) # Determine if the model is conservative model_is_conservative = hasattr(self.model, "grad_forces_name") @@ -397,7 +397,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: max_num_neighbors=self._max_num_neighbors, edge_method=self._edge_method, half_supercell=half_supercell, - device=self._device, + device=self.device, ) # Run forward pass diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index d82d06be..fb907689 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -126,8 +126,8 @@ def __init__( self.model = model.to(self._device) self.model = self.model.eval() - if self._dtype is not None: - self.model = self.model.to(dtype=self._dtype) + if self.dtype is not None: + self.model = self.model.to(dtype=self.dtype) self.implemented_properties = [ "energy", diff --git a/torch_sim/unbatched/models/lennard_jones.py b/torch_sim/unbatched/models/lennard_jones.py index 8b1e5758..d129ec5d 100644 --- a/torch_sim/unbatched/models/lennard_jones.py +++ b/torch_sim/unbatched/models/lennard_jones.py @@ -141,11 +141,9 @@ def __init__( self.use_neighbor_list = use_neighbor_list # Convert parameters to tensors - self.sigma = torch.tensor(sigma, dtype=dtype, device=self._device) - self.cutoff = torch.tensor( - cutoff or 2.5 * sigma, dtype=dtype, device=self._device - ) - self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self._device) + self.sigma = torch.tensor(sigma, dtype=dtype, device=self.device) + self.cutoff = torch.tensor(cutoff or 2.5 * sigma, dtype=dtype, device=self.device) + self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self.device) def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: """Compute energies and forces. @@ -192,7 +190,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: pbc=pbc, ) # Mask out self-interactions - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self._device) + mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) # Apply cutoff mask = distances < self.cutoff @@ -216,7 +214,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: if self._per_atom_energies: atom_energies = torch.zeros( - positions.shape[0], dtype=self._dtype, device=self._device + positions.shape[0], dtype=self.dtype, device=self.device ) # Each atom gets half of the pair energy atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) @@ -251,8 +249,8 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: if self._per_atom_stresses: atom_stresses = torch.zeros( (positions.shape[0], 3, 3), - dtype=self._dtype, - device=self._device, + dtype=self.dtype, + device=self.device, ) atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) diff --git a/torch_sim/unbatched/models/morse.py b/torch_sim/unbatched/models/morse.py index 03fd8b96..1b917585 100644 --- a/torch_sim/unbatched/models/morse.py +++ b/torch_sim/unbatched/models/morse.py @@ -133,12 +133,12 @@ def __init__( self._per_atom_stresses = per_atom_stresses self.use_neighbor_list = use_neighbor_list # Convert parameters to tensors - self.sigma = torch.tensor(sigma, dtype=self._dtype, device=self._device) + self.sigma = torch.tensor(sigma, dtype=self.dtype, device=self.device) self.cutoff = torch.tensor( - cutoff or 2.5 * sigma, dtype=self._dtype, device=self._device + cutoff or 2.5 * sigma, dtype=self.dtype, device=self.device ) - self.epsilon = torch.tensor(epsilon, dtype=self._dtype, device=self._device) - self.alpha = torch.tensor(alpha, dtype=self._dtype, device=self._device) + self.epsilon = torch.tensor(epsilon, dtype=self.dtype, device=self.device) + self.alpha = torch.tensor(alpha, dtype=self.dtype, device=self.device) def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: """Compute energies and forces. @@ -180,7 +180,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: cell=cell, pbc=pbc, ) - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self._device) + mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) mask = distances < self.cutoff i, j = torch.where(mask) @@ -200,7 +200,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: if self._per_atom_energies: atom_energies = torch.zeros( - positions.shape[0], dtype=self._dtype, device=self._device + positions.shape[0], dtype=self.dtype, device=self.device ) atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) @@ -229,8 +229,8 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: if self._per_atom_stresses: atom_stresses = torch.zeros( (positions.shape[0], 3, 3), - dtype=self._dtype, - device=self._device, + dtype=self.dtype, + device=self.device, ) atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) diff --git a/torch_sim/unbatched/models/particle_life.py b/torch_sim/unbatched/models/particle_life.py index 6debc982..d2caa5b7 100644 --- a/torch_sim/unbatched/models/particle_life.py +++ b/torch_sim/unbatched/models/particle_life.py @@ -118,11 +118,11 @@ def __init__( self.use_neighbor_list = use_neighbor_list # Convert parameters to tensors - self.sigma = torch.tensor(sigma, dtype=self._dtype, device=self._device) + self.sigma = torch.tensor(sigma, dtype=self.dtype, device=self.device) self.cutoff = torch.tensor( - cutoff or 2.5 * sigma, dtype=self._dtype, device=self._device + cutoff or 2.5 * sigma, dtype=self.dtype, device=self.device ) - self.epsilon = torch.tensor(epsilon, dtype=self._dtype, device=self._device) + self.epsilon = torch.tensor(epsilon, dtype=self.dtype, device=self.device) def forward(self, state: SimState) -> dict[str, torch.Tensor]: """Compute energies and forces. @@ -170,7 +170,7 @@ def forward(self, state: SimState) -> dict[str, torch.Tensor]: pbc=pbc, ) # Mask out self-interactions - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self._device) + mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) # Apply cutoff mask = distances < self.cutoff diff --git a/torch_sim/unbatched/models/soft_sphere.py b/torch_sim/unbatched/models/soft_sphere.py index 3ecca0ed..64ab8dad 100644 --- a/torch_sim/unbatched/models/soft_sphere.py +++ b/torch_sim/unbatched/models/soft_sphere.py @@ -175,7 +175,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: pbc=pbc, ) # Remove self-interactions and apply cutoff - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self._device) + mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) mask = distances < self.cutoff @@ -196,7 +196,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: if self._per_atom_energies: # Compute per-atom energy contributions atom_energies = torch.zeros( - positions.shape[0], dtype=self._dtype, device=self._device + positions.shape[0], dtype=self.dtype, device=self.device ) # Each atom gets half of the pair energy atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) @@ -231,8 +231,8 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: # Compute per-atom stress contributions atom_stresses = torch.zeros( (positions.shape[0], 3, 3), - dtype=self._dtype, - device=self._device, + dtype=self.dtype, + device=self.device, ) atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) From faf6391e5358c9cafd9cf8da381ce99aff7f86c7 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 10 Apr 2025 19:31:53 -0400 Subject: [PATCH 04/10] nit: be consistent about using compute_stress and compute_forces vs _compute_stress and _compute_forces --- torch_sim/models/graphpes.py | 4 ++-- torch_sim/models/lennard_jones.py | 6 +++--- torch_sim/models/mace.py | 8 ++++---- torch_sim/models/morse.py | 4 ++-- torch_sim/unbatched/models/lennard_jones.py | 6 +++--- torch_sim/unbatched/models/mace.py | 8 ++++---- torch_sim/unbatched/models/morse.py | 4 ++-- torch_sim/unbatched/models/soft_sphere.py | 14 +++++++------- 8 files changed, 27 insertions(+), 27 deletions(-) diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index e2d42da3..41d55df3 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -160,9 +160,9 @@ def __init__( self._compute_stress = compute_stress self._properties: list[PropertyKey] = ["energy"] - if self._compute_forces: + if self.compute_forces: self._properties.append("forces") - if self._compute_stress: + if self.compute_stress: self._properties.append("stress") if self._gp_model.cutoff.item() < 0.5: diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index 0e34b67f..85f691c1 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -238,7 +238,7 @@ def unbatched_forward( atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) results["energies"] = atom_energies - if self._compute_forces or self._compute_stress: + if self.compute_forces or self.compute_stress: # Calculate forces and apply cutoff pair_forces = lennard_jones_pair_force( distances, sigma=self.sigma, epsilon=self.epsilon @@ -248,7 +248,7 @@ def unbatched_forward( # Project forces along displacement vectors force_vectors = (pair_forces / distances)[:, None] * dr_vec - if self._compute_forces: + if self.compute_forces: # Initialize forces tensor forces = torch.zeros_like(positions) # Add force contributions (f_ij on i, -f_ij on j) @@ -256,7 +256,7 @@ def unbatched_forward( forces.index_add_(0, mapping[1], force_vectors) results["forces"] = forces - if self._compute_stress and cell is not None: + if self.compute_stress and cell is not None: # Compute stress tensor stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(cell)) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 7c21fa15..3e02f9f0 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -334,8 +334,8 @@ def forward( # noqa: C901 unit_shifts=unit_shifts, shifts=shifts_list, ), - compute_force=self._compute_forces, - compute_stress=self._compute_stress, + compute_force=self.compute_forces, + compute_stress=self.compute_stress, ) results = {} @@ -348,13 +348,13 @@ def forward( # noqa: C901 results["energy"] = torch.zeros(self.n_systems, device=self.device) # Process forces - if self._compute_forces: + if self.compute_forces: forces = out["forces"] if forces is not None: results["forces"] = forces.detach() # Process stress - if self._compute_stress: + if self.compute_stress: stress = out["stress"] if stress is not None: results["stress"] = stress.detach() diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index 8cde515d..357ab86f 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -239,13 +239,13 @@ def unbatched_forward(self, state: SimState | StateDict) -> dict[str, torch.Tens force_vectors = (pair_forces / distances)[:, None] * dr_vec - if self._compute_forces: + if self.compute_forces: forces = torch.zeros_like(state.positions) forces.index_add_(0, mapping[0], -force_vectors) forces.index_add_(0, mapping[1], force_vectors) results["forces"] = forces - if self._compute_stress and state.cell is not None: + if self.compute_stress and state.cell is not None: stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(state.cell)) diff --git a/torch_sim/unbatched/models/lennard_jones.py b/torch_sim/unbatched/models/lennard_jones.py index d129ec5d..a3a1c42d 100644 --- a/torch_sim/unbatched/models/lennard_jones.py +++ b/torch_sim/unbatched/models/lennard_jones.py @@ -221,7 +221,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) results["energies"] = atom_energies - if self._compute_forces or self._compute_stress: + if self.compute_forces or self.compute_stress: # Calculate forces and apply cutoff pair_forces = lennard_jones_pair_force( distances, sigma=self.sigma, epsilon=self.epsilon @@ -231,7 +231,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: # Project forces along displacement vectors force_vectors = (pair_forces / distances)[:, None] * dr_vec - if self._compute_forces: + if self.compute_forces: # Initialize forces tensor forces = torch.zeros_like(positions) # Add force contributions (f_ij on i, -f_ij on j) @@ -239,7 +239,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: forces.index_add_(0, mapping[1], force_vectors) results["forces"] = forces - if self._compute_stress and cell is not None: + if self.compute_stress and cell is not None: # Compute stress tensor stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(cell)) diff --git a/torch_sim/unbatched/models/mace.py b/torch_sim/unbatched/models/mace.py index 6842f392..2b87f857 100644 --- a/torch_sim/unbatched/models/mace.py +++ b/torch_sim/unbatched/models/mace.py @@ -226,8 +226,8 @@ def forward( # noqa: C901 unit_shifts=shifts_idx, shifts=shifts, ), - compute_force=self._compute_forces, - compute_stress=self._compute_stress, + compute_force=self.compute_forces, + compute_stress=self.compute_stress, ) energy = out["energy"] @@ -239,11 +239,11 @@ def forward( # noqa: C901 else: results["energy"] = torch.tensor(0.0, device=self.device) - if self._compute_forces: + if self.compute_forces: forces = out["forces"] results["forces"] = forces - if self._compute_stress: + if self.compute_stress: stress = out["stress"].squeeze() results["stress"] = stress diff --git a/torch_sim/unbatched/models/morse.py b/torch_sim/unbatched/models/morse.py index 1b917585..2038d622 100644 --- a/torch_sim/unbatched/models/morse.py +++ b/torch_sim/unbatched/models/morse.py @@ -214,13 +214,13 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: force_vectors = (pair_forces / distances)[:, None] * dr_vec - if self._compute_forces: + if self.compute_forces: forces = torch.zeros_like(positions) forces.index_add_(0, mapping[0], -force_vectors) forces.index_add_(0, mapping[1], force_vectors) results["forces"] = forces - if self._compute_stress and cell is not None: + if self.compute_stress and cell is not None: stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(cell)) diff --git a/torch_sim/unbatched/models/soft_sphere.py b/torch_sim/unbatched/models/soft_sphere.py index 64ab8dad..4ad3b665 100644 --- a/torch_sim/unbatched/models/soft_sphere.py +++ b/torch_sim/unbatched/models/soft_sphere.py @@ -203,7 +203,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) results["energies"] = atom_energies - if self._compute_forces or self._compute_stress: + if self.compute_forces or self.compute_stress: # Calculate pair forces pair_forces = soft_sphere_pair_force( distances, sigma=self.sigma, epsilon=self.epsilon, alpha=self.alpha @@ -212,7 +212,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: # Project scalar forces onto displacement vectors force_vectors = (pair_forces / distances)[:, None] * dr_vec - if self._compute_forces: + if self.compute_forces: # Compute atomic forces by accumulating pair contributions forces = torch.zeros_like(positions) # Add force contributions (f_ij on j, -f_ij on i) @@ -220,7 +220,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: forces.index_add_(0, mapping[1], -force_vectors) results["forces"] = forces - if self._compute_stress and cell is not None: + if self.compute_stress and cell is not None: # Compute stress tensor using virial formula stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(cell)) @@ -244,7 +244,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: # TODO: Standardize the interface for multi-species models -class UnbatchedSoftSphereMultiModel(torch.nn.Module): +class UnbatchedSoftSphereMultiModel(torch.nn.Module, ModelInterface): """Calculator for soft sphere potential with multiple atomic species. This model implements a multi-species soft sphere potential where the interaction @@ -462,7 +462,7 @@ def forward( atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) results["energies"] = atom_energies - if self._compute_forces or self._compute_stress: + if self.compute_forces or self.compute_stress: # Calculate pair forces pair_forces = soft_sphere_pair_force( distances, sigma=pair_sigmas, epsilon=pair_epsilons, alpha=pair_alphas @@ -471,7 +471,7 @@ def forward( # Project scalar forces onto displacement vectors force_vectors = (pair_forces / distances)[:, None] * dr_vec - if self._compute_forces: + if self.compute_forces: # Compute atomic forces by accumulating pair contributions forces = torch.zeros_like(positions) # Add force contributions (f_ij on j, -f_ij on i) @@ -479,7 +479,7 @@ def forward( forces.index_add_(0, mapping[1], -force_vectors) results["forces"] = forces - if self._compute_stress and cell is not None: + if self.compute_stress and cell is not None: # Compute stress tensor using virial formula stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(cell)) From 336d0aca4fe466b3eaf6eccf1a4471e07bfcb66c Mon Sep 17 00:00:00 2001 From: Adeesh Kolluru Date: Fri, 11 Apr 2025 08:50:19 -0700 Subject: [PATCH 05/10] adds disable_amp, fixes previous use_pbc, adds omat24 model test --- .github/workflows/test.yml | 10 ++++++++ tests/models/test_fairchem.py | 48 +++++++++++++++++++++++++++-------- torch_sim/models/fairchem.py | 20 ++++----------- 3 files changed, 52 insertions(+), 26 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index db059fe9..0ac21a58 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -88,6 +88,16 @@ jobs: - name: Set up uv uses: astral-sh/setup-uv@v2 + - name: HuggingFace Hub Login + env: + HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} + run: | + if [ -n "$HF_TOKEN" ]; then + huggingface-cli login --token "$HF_TOKEN" + else + echo "HF_TOKEN is not set. Skipping login." + fi + - name: Install fairchem repository and dependencies if: ${{ matrix.model.name == 'fairchem' }} run: | diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index 30f5f289..fc3488bf 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -1,3 +1,5 @@ +import os + import pytest import torch @@ -21,9 +23,15 @@ @pytest.fixture(scope="session") def model_path(tmp_path_factory: pytest.TempPathFactory) -> str: tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints") - return model_name_to_local_file( - "EquiformerV2-31M-S2EF-OC20-All+MD", local_cache=str(tmp_path) - ) + from huggingface_hub.utils._auth import get_token + + if get_token(): + # To test OMat24 trained models, you need to set HF_TOKEN env variable. + model_name = "EquiformerV2-31M-OMAT24-MP-sAlex" + else: + model_name = "EquiformerV2-31M-S2EF-OC20-All+MD" + + return model_name_to_local_file(model_name, local_cache=str(tmp_path)) @pytest.fixture @@ -37,6 +45,17 @@ def fairchem_model_pbc(model_path: str, device: torch.device) -> FairChemModel: ) +@pytest.fixture +def fairchem_model_non_pbc(model_path: str, device: torch.device) -> FairChemModel: + cpu = device.type == "cpu" + return FairChemModel( + model=model_path, + cpu=cpu, + seed=0, + pbc=False, + ) + + @pytest.fixture def ocp_calculator(model_path: str) -> OCPCalculator: return OCPCalculator(checkpoint_path=model_path, cpu=False, seed=0) @@ -47,20 +66,27 @@ def ocp_calculator(model_path: str) -> OCPCalculator: model_fixture_name="fairchem_model_pbc", calculator_fixture_name="ocp_calculator", sim_state_names=consistency_test_simstate_fixtures[:-1], - rtol=5e-4, # NOTE: fairchem doesn't pass at the 1e-5 level used for other models + rtol=5e-4, # NOTE: EqV2-OC20 doesn't pass at the 1e-5 level used for other models + atol=5e-4, +) + +test_fairchem_non_pbc_benzene = make_model_calculator_consistency_test( + test_name="fairchem_non_pbc_benzene", + model_fixture_name="fairchem_model_non_pbc", + calculator_fixture_name="ocp_calculator", + sim_state_names=["benzene_sim_state"], + rtol=5e-4, # NOTE: EqV2-OC20 doesn't pass at the 1e-5 level used for other models atol=5e-4, ) -# TODO: add test for non-PBC model +# Skip this test due to issues with how the older models +# handled supercells (see related issue here: https://github.com/FAIR-Chem/fairchem/issues/428) -# fairchem batching is broken on CPU, do not replicate this skipping -# logic in other models tests. This is due to issues with how the models -# handle supercells (see related issue here: https://github.com/FAIR-Chem/fairchem/issues/428) test_fairchem_ocp_model_outputs = pytest.mark.skipif( - not torch.cuda.is_available(), - reason="Batching does not work properly on CPU for FAIRchem", + os.environ.get("HF_TOKEN") is None, + reason="Issues in graph construction of older models", )( make_validate_model_outputs_test( - model_fixture_name="fairchem_model", + model_fixture_name="fairchem_model_pbc", ) ) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 24d2f884..962f7fa6 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -110,6 +110,7 @@ def __init__( # noqa: C901, PLR0915 dtype: torch.dtype | None = None, compute_stress: bool = False, pbc: bool = True, + disable_amp: bool = True, ) -> None: """Initialize the FairChemModel with specified configuration. @@ -130,7 +131,7 @@ def __init__( # noqa: C901, PLR0915 dtype (torch.dtype | None): Data type to use for computation compute_stress (bool): Whether to compute stress tensor pbc (bool): Whether to use periodic boundary conditions - + disable_amp (bool): Whether to disable AMP Raises: RuntimeError: If both model_name and model are specified NotImplementedError: If local_cache is not set when model_name is used @@ -215,13 +216,6 @@ def __init__( # noqa: C901, PLR0915 ) if "backbone" in config["model"]: - if config["model"]["backbone"]["use_pbc"] != pbc: - print( - f"WARNING: PBC mismatch between model and state. " - "The model loaded was trained with" - f"PBC={config['model']['backbone']['use_pbc']} " - f"and you are using PBC={pbc}." - ) config["model"]["backbone"]["use_pbc"] = pbc config["model"]["backbone"]["use_pbc_single"] = False if dtype is not None: @@ -236,13 +230,6 @@ def __init__( # noqa: C901, PLR0915 "WARNING: dtype not found in backbone, using default model dtype" ) else: - if config["model"]["use_pbc"] != pbc: - print( - f"WARNING: PBC mismatch between model and state. " - f"The model loaded was trained with" - f"PBC={config['model']['use_pbc']} " - f"and you are using PBC={pbc}." - ) config["model"]["use_pbc"] = pbc config["model"]["use_pbc_single"] = False if dtype is not None: @@ -293,6 +280,9 @@ def __init__( # noqa: C901, PLR0915 else: self.trainer.set_seed(seed) + if disable_amp: + self.trainer.scaler = None + self.implemented_properties = list(self.config["outputs"]) self._device = self.trainer.device From 446acf273cc8768d91b751fa5ed15a087d3fbb78 Mon Sep 17 00:00:00 2001 From: Adeesh Kolluru Date: Fri, 11 Apr 2025 11:14:34 -0700 Subject: [PATCH 06/10] download huggingface_hub in tests --- .github/workflows/test.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0ac21a58..5ad4309c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -88,6 +88,9 @@ jobs: - name: Set up uv uses: astral-sh/setup-uv@v2 + - name: Install HuggingFace Hub CLI + run: uv pip install huggingface_hub --system + - name: HuggingFace Hub Login env: HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} From 1c7430b237b2fb440acfb5e683d04a297dc220ab Mon Sep 17 00:00:00 2001 From: Adeesh Kolluru Date: Fri, 11 Apr 2025 11:38:49 -0700 Subject: [PATCH 07/10] increase the tolerance --- tests/models/test_fairchem.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index fc3488bf..48ddb5b3 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -66,8 +66,8 @@ def ocp_calculator(model_path: str) -> OCPCalculator: model_fixture_name="fairchem_model_pbc", calculator_fixture_name="ocp_calculator", sim_state_names=consistency_test_simstate_fixtures[:-1], - rtol=5e-4, # NOTE: EqV2-OC20 doesn't pass at the 1e-5 level used for other models - atol=5e-4, + rtol=1e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models + atol=1e-4, ) test_fairchem_non_pbc_benzene = make_model_calculator_consistency_test( From c4e87361cb83c134f76ce7064f8f0bcea6e0fbe3 Mon Sep 17 00:00:00 2001 From: Adeesh Kolluru Date: Fri, 11 Apr 2025 11:43:07 -0700 Subject: [PATCH 08/10] increase the tolerance --- tests/models/test_fairchem.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index 48ddb5b3..2bf9856b 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -66,8 +66,8 @@ def ocp_calculator(model_path: str) -> OCPCalculator: model_fixture_name="fairchem_model_pbc", calculator_fixture_name="ocp_calculator", sim_state_names=consistency_test_simstate_fixtures[:-1], - rtol=1e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models - atol=1e-4, + rtol=5e-3, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models + atol=5e-3, ) test_fairchem_non_pbc_benzene = make_model_calculator_consistency_test( From fbcd389d506ec2626c13e23cedb6cbc290feb63b Mon Sep 17 00:00:00 2001 From: Adeesh Kolluru Date: Fri, 11 Apr 2025 11:57:51 -0700 Subject: [PATCH 09/10] test larger machine for CI --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5ad4309c..79212e9c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -55,7 +55,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-14] + os: [ubuntu-latest-8-cores, macos-14] version: - { python: "3.11", resolution: highest } - { python: "3.12", resolution: lowest-direct } From a295a25d09f2ab46161afa5f7b8e1840da8640c3 Mon Sep 17 00:00:00 2001 From: Adeesh Kolluru Date: Fri, 11 Apr 2025 12:44:38 -0700 Subject: [PATCH 10/10] revert to smaller machine, use omat24 model only for batch test --- .github/workflows/test.yml | 2 +- tests/models/test_fairchem.py | 58 ++++++++++++++++++++++------------- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 79212e9c..5ad4309c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -55,7 +55,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest-8-cores, macos-14] + os: [ubuntu-latest, macos-14] version: - { python: "3.11", resolution: highest } - { python: "3.12", resolution: lowest-direct } diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index 2bf9856b..d83cdf80 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -13,6 +13,7 @@ try: from fairchem.core import OCPCalculator from fairchem.core.models.model_registry import model_name_to_local_file + from huggingface_hub.utils._auth import get_token from torch_sim.models.fairchem import FairChemModel @@ -21,24 +22,17 @@ @pytest.fixture(scope="session") -def model_path(tmp_path_factory: pytest.TempPathFactory) -> str: +def model_path_oc20(tmp_path_factory: pytest.TempPathFactory) -> str: tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints") - from huggingface_hub.utils._auth import get_token - - if get_token(): - # To test OMat24 trained models, you need to set HF_TOKEN env variable. - model_name = "EquiformerV2-31M-OMAT24-MP-sAlex" - else: - model_name = "EquiformerV2-31M-S2EF-OC20-All+MD" - + model_name = "EquiformerV2-31M-S2EF-OC20-All+MD" return model_name_to_local_file(model_name, local_cache=str(tmp_path)) @pytest.fixture -def fairchem_model_pbc(model_path: str, device: torch.device) -> FairChemModel: +def eqv2_oc20_model_pbc(model_path_oc20: str, device: torch.device) -> FairChemModel: cpu = device.type == "cpu" return FairChemModel( - model=model_path, + model=model_path_oc20, cpu=cpu, seed=0, pbc=True, @@ -46,39 +40,61 @@ def fairchem_model_pbc(model_path: str, device: torch.device) -> FairChemModel: @pytest.fixture -def fairchem_model_non_pbc(model_path: str, device: torch.device) -> FairChemModel: +def eqv2_oc20_model_non_pbc(model_path_oc20: str, device: torch.device) -> FairChemModel: cpu = device.type == "cpu" return FairChemModel( - model=model_path, + model=model_path_oc20, cpu=cpu, seed=0, pbc=False, ) +if get_token(): + + @pytest.fixture(scope="session") + def model_path_omat24(tmp_path_factory: pytest.TempPathFactory) -> str: + tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints") + model_name = "EquiformerV2-31M-OMAT24-MP-sAlex" + return model_name_to_local_file(model_name, local_cache=str(tmp_path)) + + @pytest.fixture + def eqv2_omat24_model_pbc( + model_path_omat24: str, device: torch.device + ) -> FairChemModel: + cpu = device.type == "cpu" + return FairChemModel( + model=model_path_omat24, + cpu=cpu, + seed=0, + pbc=True, + ) + + @pytest.fixture -def ocp_calculator(model_path: str) -> OCPCalculator: - return OCPCalculator(checkpoint_path=model_path, cpu=False, seed=0) +def ocp_calculator(model_path_oc20: str) -> OCPCalculator: + return OCPCalculator(checkpoint_path=model_path_oc20, cpu=False, seed=0) test_fairchem_ocp_consistency_pbc = make_model_calculator_consistency_test( test_name="fairchem_ocp", - model_fixture_name="fairchem_model_pbc", + model_fixture_name="eqv2_oc20_model_pbc", calculator_fixture_name="ocp_calculator", sim_state_names=consistency_test_simstate_fixtures[:-1], - rtol=5e-3, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models - atol=5e-3, + rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models + atol=5e-4, ) test_fairchem_non_pbc_benzene = make_model_calculator_consistency_test( test_name="fairchem_non_pbc_benzene", - model_fixture_name="fairchem_model_non_pbc", + model_fixture_name="eqv2_oc20_model_non_pbc", calculator_fixture_name="ocp_calculator", sim_state_names=["benzene_sim_state"], - rtol=5e-4, # NOTE: EqV2-OC20 doesn't pass at the 1e-5 level used for other models + rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models atol=5e-4, ) + # Skip this test due to issues with how the older models # handled supercells (see related issue here: https://github.com/FAIR-Chem/fairchem/issues/428) @@ -87,6 +103,6 @@ def ocp_calculator(model_path: str) -> OCPCalculator: reason="Issues in graph construction of older models", )( make_validate_model_outputs_test( - model_fixture_name="fairchem_model_pbc", + model_fixture_name="eqv2_omat24_model_pbc", ) )