diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index db059fe9..5ad4309c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -88,6 +88,19 @@ jobs: - name: Set up uv uses: astral-sh/setup-uv@v2 + - name: Install HuggingFace Hub CLI + run: uv pip install huggingface_hub --system + + - name: HuggingFace Hub Login + env: + HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} + run: | + if [ -n "$HF_TOKEN" ]; then + huggingface-cli login --token "$HF_TOKEN" + else + echo "HF_TOKEN is not set. Skipping login." + fi + - name: Install fairchem repository and dependencies if: ${{ matrix.model.name == 'fairchem' }} run: | diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index b2af7d35..d83cdf80 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -1,3 +1,5 @@ +import os + import pytest import torch @@ -11,6 +13,7 @@ try: from fairchem.core import OCPCalculator from fairchem.core.models.model_registry import model_name_to_local_file + from huggingface_hub.utils._auth import get_token from torch_sim.models.fairchem import FairChemModel @@ -19,55 +22,87 @@ @pytest.fixture(scope="session") -def model_path(tmp_path_factory: pytest.TempPathFactory) -> str: +def model_path_oc20(tmp_path_factory: pytest.TempPathFactory) -> str: tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints") - return model_name_to_local_file( - "EquiformerV2-31M-S2EF-OC20-All+MD", local_cache=str(tmp_path) + model_name = "EquiformerV2-31M-S2EF-OC20-All+MD" + return model_name_to_local_file(model_name, local_cache=str(tmp_path)) + + +@pytest.fixture +def eqv2_oc20_model_pbc(model_path_oc20: str, device: torch.device) -> FairChemModel: + cpu = device.type == "cpu" + return FairChemModel( + model=model_path_oc20, + cpu=cpu, + seed=0, + pbc=True, ) @pytest.fixture -def fairchem_model(model_path: str, device: torch.device) -> FairChemModel: +def eqv2_oc20_model_non_pbc(model_path_oc20: str, device: torch.device) -> FairChemModel: cpu = device.type == "cpu" return FairChemModel( - model=model_path, + model=model_path_oc20, cpu=cpu, seed=0, + pbc=False, ) +if get_token(): + + @pytest.fixture(scope="session") + def model_path_omat24(tmp_path_factory: pytest.TempPathFactory) -> str: + tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints") + model_name = "EquiformerV2-31M-OMAT24-MP-sAlex" + return model_name_to_local_file(model_name, local_cache=str(tmp_path)) + + @pytest.fixture + def eqv2_omat24_model_pbc( + model_path_omat24: str, device: torch.device + ) -> FairChemModel: + cpu = device.type == "cpu" + return FairChemModel( + model=model_path_omat24, + cpu=cpu, + seed=0, + pbc=True, + ) + + @pytest.fixture -def ocp_calculator(model_path: str) -> OCPCalculator: - return OCPCalculator(checkpoint_path=model_path, cpu=False, seed=0) +def ocp_calculator(model_path_oc20: str) -> OCPCalculator: + return OCPCalculator(checkpoint_path=model_path_oc20, cpu=False, seed=0) -test_fairchem_ocp_consistency = make_model_calculator_consistency_test( +test_fairchem_ocp_consistency_pbc = make_model_calculator_consistency_test( test_name="fairchem_ocp", - model_fixture_name="fairchem_model", + model_fixture_name="eqv2_oc20_model_pbc", calculator_fixture_name="ocp_calculator", - sim_state_names=consistency_test_simstate_fixtures, - rtol=5e-4, # NOTE: fairchem doesn't pass at the 1e-5 level used for other models + sim_state_names=consistency_test_simstate_fixtures[:-1], + rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models atol=5e-4, ) +test_fairchem_non_pbc_benzene = make_model_calculator_consistency_test( + test_name="fairchem_non_pbc_benzene", + model_fixture_name="eqv2_oc20_model_non_pbc", + calculator_fixture_name="ocp_calculator", + sim_state_names=["benzene_sim_state"], + rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models + atol=5e-4, +) -# fairchem batching is broken on CPU, do not replicate this skipping -# logic in other models tests -# @pytest.mark.skipif( -# not torch.cuda.is_available(), -# reason="Batching does not work properly on CPU for FAIRchem", -# ) -# def test_validate_model_outputs( -# fairchem_model: FairChemModel, device: torch.device -# ) -> None: -# validate_model_outputs(fairchem_model, device, torch.float32) +# Skip this test due to issues with how the older models +# handled supercells (see related issue here: https://github.com/FAIR-Chem/fairchem/issues/428) test_fairchem_ocp_model_outputs = pytest.mark.skipif( - not torch.cuda.is_available(), - reason="Batching does not work properly on CPU for FAIRchem", + os.environ.get("HF_TOKEN") is None, + reason="Issues in graph construction of older models", )( make_validate_model_outputs_test( - model_fixture_name="fairchem_model", + model_fixture_name="eqv2_omat24_model_pbc", ) ) diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 12b31500..962f7fa6 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -76,7 +76,6 @@ class FairChemModel(torch.nn.Module, ModelInterface): Attributes: neighbor_list_fn (Callable | None): Function to compute neighbor lists - r_max (float): Maximum cutoff radius for atomic interactions in Ångström config (dict): Complete model configuration dictionary trainer: FairChem trainer object that contains the model data_object (Batch): Data object containing system information @@ -108,9 +107,10 @@ def __init__( # noqa: C901, PLR0915 trainer: str | None = None, cpu: bool = False, seed: int | None = None, - r_max: float | None = None, # noqa: ARG002 dtype: torch.dtype | None = None, compute_stress: bool = False, + pbc: bool = True, + disable_amp: bool = True, ) -> None: """Initialize the FairChemModel with specified configuration. @@ -128,10 +128,10 @@ def __init__( # noqa: C901, PLR0915 trainer (str | None): Name of trainer class to use cpu (bool): Whether to use CPU instead of GPU for computation seed (int | None): Random seed for reproducibility - r_max (float | None): Maximum cutoff radius (overrides model default) dtype (torch.dtype | None): Data type to use for computation compute_stress (bool): Whether to compute stress tensor - + pbc (bool): Whether to use periodic boundary conditions + disable_amp (bool): Whether to disable AMP Raises: RuntimeError: If both model_name and model are specified NotImplementedError: If local_cache is not set when model_name is used @@ -150,6 +150,7 @@ def __init__( # noqa: C901, PLR0915 self._compute_stress = compute_stress self._compute_forces = True self._memory_scales_with = "n_atoms" + self.pbc = pbc if model_name is not None: if model is not None: @@ -215,6 +216,7 @@ def __init__( # noqa: C901, PLR0915 ) if "backbone" in config["model"]: + config["model"]["backbone"]["use_pbc"] = pbc config["model"]["backbone"]["use_pbc_single"] = False if dtype is not None: try: @@ -224,14 +226,19 @@ def __init__( # noqa: C901, PLR0915 {"dtype": _DTYPE_DICT[dtype]} ) except KeyError: - print("dtype not found in backbone, using default float32") + print( + "WARNING: dtype not found in backbone, using default model dtype" + ) else: + config["model"]["use_pbc"] = pbc config["model"]["use_pbc_single"] = False if dtype is not None: try: config["model"].update({"dtype": _DTYPE_DICT[dtype]}) except KeyError: - print("dtype not found in backbone, using default dtype") + print( + "WARNING: dtype not found in backbone, using default model dtype" + ) ### backwards compatibility with OCP v<2.0 config = update_config(config) @@ -257,8 +264,6 @@ def __init__( # noqa: C901, PLR0915 inference_only=True, ) - self.trainer.model = self.trainer.model.eval() - if dtype is not None: # Convert model parameters to specified dtype self.trainer.model = self.trainer.model.to(dtype=self.dtype) @@ -275,6 +280,9 @@ def __init__( # noqa: C901, PLR0915 else: self.trainer.set_seed(seed) + if disable_amp: + self.trainer.scaler = None + self.implemented_properties = list(self.config["outputs"]) self._device = self.trainer.device @@ -335,6 +343,12 @@ def forward(self, state: SimState | StateDict) -> dict: if state.batch is None: state.batch = torch.zeros(state.positions.shape[0], dtype=torch.int) + if self.pbc != state.pbc: + raise ValueError( + "PBC mismatch between model and state. " + "For FairChemModel PBC needs to be defined in the model class." + ) + natoms = torch.bincount(state.batch) pbc = torch.tensor( [state.pbc, state.pbc, state.pbc] * len(natoms), dtype=torch.bool @@ -350,9 +364,9 @@ def forward(self, state: SimState | StateDict) -> dict: pbc=pbc, ) - if self._dtype is not None: - self.data_object.pos = self.data_object.pos.to(self._dtype) - self.data_object.cell = self.data_object.cell.to(self._dtype) + if self.dtype is not None: + self.data_object.pos = self.data_object.pos.to(self.dtype) + self.data_object.cell = self.data_object.cell.to(self.dtype) predictions = self.trainer.predict( self.data_object, per_image=False, disable_tqdm=True diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index 6f6ed5b5..41d55df3 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -154,15 +154,15 @@ def __init__( model if isinstance(model, GraphPESModel) else load_model(model) # type: ignore[arg-type] ), ) - self._gp_model = _model.to(device=self._device, dtype=self._dtype) + self._gp_model = _model.to(device=self.device, dtype=self.dtype) self._compute_forces = compute_forces self._compute_stress = compute_stress self._properties: list[PropertyKey] = ["energy"] - if self._compute_forces: + if self.compute_forces: self._properties.append("forces") - if self._compute_stress: + if self.compute_stress: self._properties.append("stress") if self._gp_model.cutoff.item() < 0.5: diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index 25af4b3c..85f691c1 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -138,11 +138,9 @@ def __init__( self.use_neighbor_list = use_neighbor_list # Convert parameters to tensors - self.sigma = torch.tensor(sigma, dtype=dtype, device=self._device) - self.cutoff = torch.tensor( - cutoff or 2.5 * sigma, dtype=dtype, device=self._device - ) - self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self._device) + self.sigma = torch.tensor(sigma, dtype=dtype, device=self.device) + self.cutoff = torch.tensor(cutoff or 2.5 * sigma, dtype=dtype, device=self.device) + self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self.device) def unbatched_forward( self, @@ -209,7 +207,7 @@ def unbatched_forward( pbc=pbc, ) # Mask out self-interactions - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self._device) + mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) # Apply cutoff mask = distances < self.cutoff @@ -233,14 +231,14 @@ def unbatched_forward( if self.per_atom_energies: atom_energies = torch.zeros( - positions.shape[0], dtype=self._dtype, device=self._device + positions.shape[0], dtype=self.dtype, device=self.device ) # Each atom gets half of the pair energy atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) results["energies"] = atom_energies - if self._compute_forces or self._compute_stress: + if self.compute_forces or self.compute_stress: # Calculate forces and apply cutoff pair_forces = lennard_jones_pair_force( distances, sigma=self.sigma, epsilon=self.epsilon @@ -250,7 +248,7 @@ def unbatched_forward( # Project forces along displacement vectors force_vectors = (pair_forces / distances)[:, None] * dr_vec - if self._compute_forces: + if self.compute_forces: # Initialize forces tensor forces = torch.zeros_like(positions) # Add force contributions (f_ij on i, -f_ij on j) @@ -258,7 +256,7 @@ def unbatched_forward( forces.index_add_(0, mapping[1], force_vectors) results["forces"] = forces - if self._compute_stress and cell is not None: + if self.compute_stress and cell is not None: # Compute stress tensor stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(cell)) @@ -268,8 +266,8 @@ def unbatched_forward( if self.per_atom_stresses: atom_stresses = torch.zeros( (state.positions.shape[0], 3, 3), - dtype=self._dtype, - device=self._device, + dtype=self.dtype, + device=self.device, ) atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 64c4338c..3e02f9f0 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -160,8 +160,8 @@ def __init__( self.model = model.to(self._device) self.model = self.model.eval() - if self._dtype is not None: - self.model = self.model.to(dtype=self._dtype) + if self.dtype is not None: + self.model = self.model.to(dtype=self.dtype) if enable_cueq: print("Converting models to CuEq for acceleration") @@ -334,8 +334,8 @@ def forward( # noqa: C901 unit_shifts=unit_shifts, shifts=shifts_list, ), - compute_force=self._compute_forces, - compute_stress=self._compute_stress, + compute_force=self.compute_forces, + compute_stress=self.compute_stress, ) results = {} @@ -348,13 +348,13 @@ def forward( # noqa: C901 results["energy"] = torch.zeros(self.n_systems, device=self.device) # Process forces - if self._compute_forces: + if self.compute_forces: forces = out["forces"] if forces is not None: results["forces"] = forces.detach() # Process stress - if self._compute_stress: + if self.compute_stress: stress = out["stress"] if stress is not None: results["stress"] = stress.detach() diff --git a/torch_sim/models/mattersim.py b/torch_sim/models/mattersim.py index 1aa89932..e431ee96 100644 --- a/torch_sim/models/mattersim.py +++ b/torch_sim/models/mattersim.py @@ -85,8 +85,8 @@ def __init__( self.model = model.to(self._device) self.model = self.model.eval() - if self._dtype is not None: - self.model = self.model.to(dtype=self._dtype) + if self.dtype is not None: + self.model = self.model.to(dtype=self.dtype) model_args = self.model.model.model_args self.two_body_cutoff = model_args["cutoff"] diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index c0fda40b..357ab86f 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -142,12 +142,12 @@ def __init__( self._per_atom_stresses = per_atom_stresses self.use_neighbor_list = use_neighbor_list # Convert parameters to tensors - self.sigma = torch.tensor(sigma, dtype=self._dtype, device=self._device) + self.sigma = torch.tensor(sigma, dtype=self.dtype, device=self.device) self.cutoff = torch.tensor( - cutoff or 2.5 * sigma, dtype=self._dtype, device=self._device + cutoff or 2.5 * sigma, dtype=self.dtype, device=self.device ) - self.epsilon = torch.tensor(epsilon, dtype=self._dtype, device=self._device) - self.alpha = torch.tensor(alpha, dtype=self._dtype, device=self._device) + self.epsilon = torch.tensor(epsilon, dtype=self.dtype, device=self.device) + self.alpha = torch.tensor(alpha, dtype=self.dtype, device=self.device) def unbatched_forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: """Compute Morse potential properties for a single unbatched system. @@ -205,7 +205,7 @@ def unbatched_forward(self, state: SimState | StateDict) -> dict[str, torch.Tens cell=cell, pbc=pbc, ) - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self._device) + mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) mask = distances < self.cutoff i, j = torch.where(mask) @@ -225,7 +225,7 @@ def unbatched_forward(self, state: SimState | StateDict) -> dict[str, torch.Tens if self._per_atom_energies: atom_energies = torch.zeros( - positions.shape[0], dtype=self._dtype, device=self._device + positions.shape[0], dtype=self.dtype, device=self.device ) atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) @@ -239,13 +239,13 @@ def unbatched_forward(self, state: SimState | StateDict) -> dict[str, torch.Tens force_vectors = (pair_forces / distances)[:, None] * dr_vec - if self._compute_forces: + if self.compute_forces: forces = torch.zeros_like(state.positions) forces.index_add_(0, mapping[0], -force_vectors) forces.index_add_(0, mapping[1], force_vectors) results["forces"] = forces - if self._compute_stress and state.cell is not None: + if self.compute_stress and state.cell is not None: stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(state.cell)) @@ -254,8 +254,8 @@ def unbatched_forward(self, state: SimState | StateDict) -> dict[str, torch.Tens if self._per_atom_stresses: atom_stresses = torch.zeros( (state.positions.shape[0], 3, 3), - dtype=self._dtype, - device=self._device, + dtype=self.dtype, + device=self.device, ) atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index adab5895..49c7f6ea 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -331,8 +331,8 @@ def __init__( self.model = model.to(self._device) self.model = self.model.eval() - if self._dtype is not None: - self.model = self.model.to(dtype=self._dtype) + if self.dtype is not None: + self.model = self.model.to(dtype=self.dtype) # Determine if the model is conservative model_is_conservative = hasattr(self.model, "grad_forces_name") @@ -397,7 +397,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: max_num_neighbors=self._max_num_neighbors, edge_method=self._edge_method, half_supercell=half_supercell, - device=self._device, + device=self.device, ) # Run forward pass diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index d82d06be..fb907689 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -126,8 +126,8 @@ def __init__( self.model = model.to(self._device) self.model = self.model.eval() - if self._dtype is not None: - self.model = self.model.to(dtype=self._dtype) + if self.dtype is not None: + self.model = self.model.to(dtype=self.dtype) self.implemented_properties = [ "energy", diff --git a/torch_sim/unbatched/models/lennard_jones.py b/torch_sim/unbatched/models/lennard_jones.py index 8b1e5758..a3a1c42d 100644 --- a/torch_sim/unbatched/models/lennard_jones.py +++ b/torch_sim/unbatched/models/lennard_jones.py @@ -141,11 +141,9 @@ def __init__( self.use_neighbor_list = use_neighbor_list # Convert parameters to tensors - self.sigma = torch.tensor(sigma, dtype=dtype, device=self._device) - self.cutoff = torch.tensor( - cutoff or 2.5 * sigma, dtype=dtype, device=self._device - ) - self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self._device) + self.sigma = torch.tensor(sigma, dtype=dtype, device=self.device) + self.cutoff = torch.tensor(cutoff or 2.5 * sigma, dtype=dtype, device=self.device) + self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self.device) def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: """Compute energies and forces. @@ -192,7 +190,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: pbc=pbc, ) # Mask out self-interactions - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self._device) + mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) # Apply cutoff mask = distances < self.cutoff @@ -216,14 +214,14 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: if self._per_atom_energies: atom_energies = torch.zeros( - positions.shape[0], dtype=self._dtype, device=self._device + positions.shape[0], dtype=self.dtype, device=self.device ) # Each atom gets half of the pair energy atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) results["energies"] = atom_energies - if self._compute_forces or self._compute_stress: + if self.compute_forces or self.compute_stress: # Calculate forces and apply cutoff pair_forces = lennard_jones_pair_force( distances, sigma=self.sigma, epsilon=self.epsilon @@ -233,7 +231,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: # Project forces along displacement vectors force_vectors = (pair_forces / distances)[:, None] * dr_vec - if self._compute_forces: + if self.compute_forces: # Initialize forces tensor forces = torch.zeros_like(positions) # Add force contributions (f_ij on i, -f_ij on j) @@ -241,7 +239,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: forces.index_add_(0, mapping[1], force_vectors) results["forces"] = forces - if self._compute_stress and cell is not None: + if self.compute_stress and cell is not None: # Compute stress tensor stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(cell)) @@ -251,8 +249,8 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: if self._per_atom_stresses: atom_stresses = torch.zeros( (positions.shape[0], 3, 3), - dtype=self._dtype, - device=self._device, + dtype=self.dtype, + device=self.device, ) atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) diff --git a/torch_sim/unbatched/models/mace.py b/torch_sim/unbatched/models/mace.py index 6842f392..2b87f857 100644 --- a/torch_sim/unbatched/models/mace.py +++ b/torch_sim/unbatched/models/mace.py @@ -226,8 +226,8 @@ def forward( # noqa: C901 unit_shifts=shifts_idx, shifts=shifts, ), - compute_force=self._compute_forces, - compute_stress=self._compute_stress, + compute_force=self.compute_forces, + compute_stress=self.compute_stress, ) energy = out["energy"] @@ -239,11 +239,11 @@ def forward( # noqa: C901 else: results["energy"] = torch.tensor(0.0, device=self.device) - if self._compute_forces: + if self.compute_forces: forces = out["forces"] results["forces"] = forces - if self._compute_stress: + if self.compute_stress: stress = out["stress"].squeeze() results["stress"] = stress diff --git a/torch_sim/unbatched/models/morse.py b/torch_sim/unbatched/models/morse.py index 03fd8b96..2038d622 100644 --- a/torch_sim/unbatched/models/morse.py +++ b/torch_sim/unbatched/models/morse.py @@ -133,12 +133,12 @@ def __init__( self._per_atom_stresses = per_atom_stresses self.use_neighbor_list = use_neighbor_list # Convert parameters to tensors - self.sigma = torch.tensor(sigma, dtype=self._dtype, device=self._device) + self.sigma = torch.tensor(sigma, dtype=self.dtype, device=self.device) self.cutoff = torch.tensor( - cutoff or 2.5 * sigma, dtype=self._dtype, device=self._device + cutoff or 2.5 * sigma, dtype=self.dtype, device=self.device ) - self.epsilon = torch.tensor(epsilon, dtype=self._dtype, device=self._device) - self.alpha = torch.tensor(alpha, dtype=self._dtype, device=self._device) + self.epsilon = torch.tensor(epsilon, dtype=self.dtype, device=self.device) + self.alpha = torch.tensor(alpha, dtype=self.dtype, device=self.device) def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: """Compute energies and forces. @@ -180,7 +180,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: cell=cell, pbc=pbc, ) - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self._device) + mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) mask = distances < self.cutoff i, j = torch.where(mask) @@ -200,7 +200,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: if self._per_atom_energies: atom_energies = torch.zeros( - positions.shape[0], dtype=self._dtype, device=self._device + positions.shape[0], dtype=self.dtype, device=self.device ) atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) @@ -214,13 +214,13 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: force_vectors = (pair_forces / distances)[:, None] * dr_vec - if self._compute_forces: + if self.compute_forces: forces = torch.zeros_like(positions) forces.index_add_(0, mapping[0], -force_vectors) forces.index_add_(0, mapping[1], force_vectors) results["forces"] = forces - if self._compute_stress and cell is not None: + if self.compute_stress and cell is not None: stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(cell)) @@ -229,8 +229,8 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: if self._per_atom_stresses: atom_stresses = torch.zeros( (positions.shape[0], 3, 3), - dtype=self._dtype, - device=self._device, + dtype=self.dtype, + device=self.device, ) atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) diff --git a/torch_sim/unbatched/models/particle_life.py b/torch_sim/unbatched/models/particle_life.py index 6debc982..d2caa5b7 100644 --- a/torch_sim/unbatched/models/particle_life.py +++ b/torch_sim/unbatched/models/particle_life.py @@ -118,11 +118,11 @@ def __init__( self.use_neighbor_list = use_neighbor_list # Convert parameters to tensors - self.sigma = torch.tensor(sigma, dtype=self._dtype, device=self._device) + self.sigma = torch.tensor(sigma, dtype=self.dtype, device=self.device) self.cutoff = torch.tensor( - cutoff or 2.5 * sigma, dtype=self._dtype, device=self._device + cutoff or 2.5 * sigma, dtype=self.dtype, device=self.device ) - self.epsilon = torch.tensor(epsilon, dtype=self._dtype, device=self._device) + self.epsilon = torch.tensor(epsilon, dtype=self.dtype, device=self.device) def forward(self, state: SimState) -> dict[str, torch.Tensor]: """Compute energies and forces. @@ -170,7 +170,7 @@ def forward(self, state: SimState) -> dict[str, torch.Tensor]: pbc=pbc, ) # Mask out self-interactions - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self._device) + mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) # Apply cutoff mask = distances < self.cutoff diff --git a/torch_sim/unbatched/models/soft_sphere.py b/torch_sim/unbatched/models/soft_sphere.py index 3ecca0ed..4ad3b665 100644 --- a/torch_sim/unbatched/models/soft_sphere.py +++ b/torch_sim/unbatched/models/soft_sphere.py @@ -175,7 +175,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: pbc=pbc, ) # Remove self-interactions and apply cutoff - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self._device) + mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) distances = distances.masked_fill(mask, float("inf")) mask = distances < self.cutoff @@ -196,14 +196,14 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: if self._per_atom_energies: # Compute per-atom energy contributions atom_energies = torch.zeros( - positions.shape[0], dtype=self._dtype, device=self._device + positions.shape[0], dtype=self.dtype, device=self.device ) # Each atom gets half of the pair energy atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) results["energies"] = atom_energies - if self._compute_forces or self._compute_stress: + if self.compute_forces or self.compute_stress: # Calculate pair forces pair_forces = soft_sphere_pair_force( distances, sigma=self.sigma, epsilon=self.epsilon, alpha=self.alpha @@ -212,7 +212,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: # Project scalar forces onto displacement vectors force_vectors = (pair_forces / distances)[:, None] * dr_vec - if self._compute_forces: + if self.compute_forces: # Compute atomic forces by accumulating pair contributions forces = torch.zeros_like(positions) # Add force contributions (f_ij on j, -f_ij on i) @@ -220,7 +220,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: forces.index_add_(0, mapping[1], -force_vectors) results["forces"] = forces - if self._compute_stress and cell is not None: + if self.compute_stress and cell is not None: # Compute stress tensor using virial formula stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(cell)) @@ -231,8 +231,8 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: # Compute per-atom stress contributions atom_stresses = torch.zeros( (positions.shape[0], 3, 3), - dtype=self._dtype, - device=self._device, + dtype=self.dtype, + device=self.device, ) atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) @@ -244,7 +244,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: # TODO: Standardize the interface for multi-species models -class UnbatchedSoftSphereMultiModel(torch.nn.Module): +class UnbatchedSoftSphereMultiModel(torch.nn.Module, ModelInterface): """Calculator for soft sphere potential with multiple atomic species. This model implements a multi-species soft sphere potential where the interaction @@ -462,7 +462,7 @@ def forward( atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) results["energies"] = atom_energies - if self._compute_forces or self._compute_stress: + if self.compute_forces or self.compute_stress: # Calculate pair forces pair_forces = soft_sphere_pair_force( distances, sigma=pair_sigmas, epsilon=pair_epsilons, alpha=pair_alphas @@ -471,7 +471,7 @@ def forward( # Project scalar forces onto displacement vectors force_vectors = (pair_forces / distances)[:, None] * dr_vec - if self._compute_forces: + if self.compute_forces: # Compute atomic forces by accumulating pair contributions forces = torch.zeros_like(positions) # Add force contributions (f_ij on j, -f_ij on i) @@ -479,7 +479,7 @@ def forward( forces.index_add_(0, mapping[1], -force_vectors) results["forces"] = forces - if self._compute_stress and cell is not None: + if self.compute_stress and cell is not None: # Compute stress tensor using virial formula stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) volume = torch.abs(torch.linalg.det(cell))