Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ jobs:
- name: Set up uv
uses: astral-sh/setup-uv@v2

- name: Install HuggingFace Hub CLI
run: uv pip install huggingface_hub --system

- name: HuggingFace Hub Login
env:
HF_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
run: |
if [ -n "$HF_TOKEN" ]; then
huggingface-cli login --token "$HF_TOKEN"
else
echo "HF_TOKEN is not set. Skipping login."
fi

- name: Install fairchem repository and dependencies
if: ${{ matrix.model.name == 'fairchem' }}
run: |
Expand Down
83 changes: 59 additions & 24 deletions tests/models/test_fairchem.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import pytest
import torch

Expand All @@ -11,6 +13,7 @@
try:
from fairchem.core import OCPCalculator
from fairchem.core.models.model_registry import model_name_to_local_file
from huggingface_hub.utils._auth import get_token

from torch_sim.models.fairchem import FairChemModel

Expand All @@ -19,55 +22,87 @@


@pytest.fixture(scope="session")
def model_path(tmp_path_factory: pytest.TempPathFactory) -> str:
def model_path_oc20(tmp_path_factory: pytest.TempPathFactory) -> str:
tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints")
return model_name_to_local_file(
"EquiformerV2-31M-S2EF-OC20-All+MD", local_cache=str(tmp_path)
model_name = "EquiformerV2-31M-S2EF-OC20-All+MD"
return model_name_to_local_file(model_name, local_cache=str(tmp_path))


@pytest.fixture
def eqv2_oc20_model_pbc(model_path_oc20: str, device: torch.device) -> FairChemModel:
cpu = device.type == "cpu"
return FairChemModel(
model=model_path_oc20,
cpu=cpu,
seed=0,
pbc=True,
)


@pytest.fixture
def fairchem_model(model_path: str, device: torch.device) -> FairChemModel:
def eqv2_oc20_model_non_pbc(model_path_oc20: str, device: torch.device) -> FairChemModel:
cpu = device.type == "cpu"
return FairChemModel(
model=model_path,
model=model_path_oc20,
cpu=cpu,
seed=0,
pbc=False,
)


if get_token():

@pytest.fixture(scope="session")
def model_path_omat24(tmp_path_factory: pytest.TempPathFactory) -> str:
tmp_path = tmp_path_factory.mktemp("fairchem_checkpoints")
model_name = "EquiformerV2-31M-OMAT24-MP-sAlex"
return model_name_to_local_file(model_name, local_cache=str(tmp_path))

@pytest.fixture
def eqv2_omat24_model_pbc(
model_path_omat24: str, device: torch.device
) -> FairChemModel:
cpu = device.type == "cpu"
return FairChemModel(
model=model_path_omat24,
cpu=cpu,
seed=0,
pbc=True,
)


@pytest.fixture
def ocp_calculator(model_path: str) -> OCPCalculator:
return OCPCalculator(checkpoint_path=model_path, cpu=False, seed=0)
def ocp_calculator(model_path_oc20: str) -> OCPCalculator:
return OCPCalculator(checkpoint_path=model_path_oc20, cpu=False, seed=0)


test_fairchem_ocp_consistency = make_model_calculator_consistency_test(
test_fairchem_ocp_consistency_pbc = make_model_calculator_consistency_test(
test_name="fairchem_ocp",
model_fixture_name="fairchem_model",
model_fixture_name="eqv2_oc20_model_pbc",
calculator_fixture_name="ocp_calculator",
sim_state_names=consistency_test_simstate_fixtures,
rtol=5e-4, # NOTE: fairchem doesn't pass at the 1e-5 level used for other models
sim_state_names=consistency_test_simstate_fixtures[:-1],
rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models
atol=5e-4,
)

test_fairchem_non_pbc_benzene = make_model_calculator_consistency_test(
test_name="fairchem_non_pbc_benzene",
model_fixture_name="eqv2_oc20_model_non_pbc",
calculator_fixture_name="ocp_calculator",
sim_state_names=["benzene_sim_state"],
rtol=5e-4, # NOTE: EqV2 doesn't pass at the 1e-5 level used for other models
atol=5e-4,
)

# fairchem batching is broken on CPU, do not replicate this skipping
# logic in other models tests
# @pytest.mark.skipif(
# not torch.cuda.is_available(),
# reason="Batching does not work properly on CPU for FAIRchem",
# )
# def test_validate_model_outputs(
# fairchem_model: FairChemModel, device: torch.device
# ) -> None:
# validate_model_outputs(fairchem_model, device, torch.float32)

# Skip this test due to issues with how the older models
# handled supercells (see related issue here: https://github.com/FAIR-Chem/fairchem/issues/428)

test_fairchem_ocp_model_outputs = pytest.mark.skipif(
not torch.cuda.is_available(),
reason="Batching does not work properly on CPU for FAIRchem",
os.environ.get("HF_TOKEN") is None,
reason="Issues in graph construction of older models",
)(
make_validate_model_outputs_test(
model_fixture_name="fairchem_model",
model_fixture_name="eqv2_omat24_model_pbc",
)
)
36 changes: 25 additions & 11 deletions torch_sim/models/fairchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@

Attributes:
neighbor_list_fn (Callable | None): Function to compute neighbor lists
r_max (float): Maximum cutoff radius for atomic interactions in Ångström
config (dict): Complete model configuration dictionary
trainer: FairChem trainer object that contains the model
data_object (Batch): Data object containing system information
Expand Down Expand Up @@ -108,9 +107,10 @@
trainer: str | None = None,
cpu: bool = False,
seed: int | None = None,
r_max: float | None = None, # noqa: ARG002
dtype: torch.dtype | None = None,
compute_stress: bool = False,
pbc: bool = True,
disable_amp: bool = True,
) -> None:
"""Initialize the FairChemModel with specified configuration.

Expand All @@ -128,10 +128,10 @@
trainer (str | None): Name of trainer class to use
cpu (bool): Whether to use CPU instead of GPU for computation
seed (int | None): Random seed for reproducibility
r_max (float | None): Maximum cutoff radius (overrides model default)
dtype (torch.dtype | None): Data type to use for computation
compute_stress (bool): Whether to compute stress tensor

pbc (bool): Whether to use periodic boundary conditions
disable_amp (bool): Whether to disable AMP
Raises:
RuntimeError: If both model_name and model are specified
NotImplementedError: If local_cache is not set when model_name is used
Expand All @@ -150,6 +150,7 @@
self._compute_stress = compute_stress
self._compute_forces = True
self._memory_scales_with = "n_atoms"
self.pbc = pbc

if model_name is not None:
if model is not None:
Expand Down Expand Up @@ -215,6 +216,7 @@
)

if "backbone" in config["model"]:
config["model"]["backbone"]["use_pbc"] = pbc

Check warning on line 219 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L219

Added line #L219 was not covered by tests
config["model"]["backbone"]["use_pbc_single"] = False
if dtype is not None:
try:
Expand All @@ -224,14 +226,19 @@
{"dtype": _DTYPE_DICT[dtype]}
)
except KeyError:
print("dtype not found in backbone, using default float32")
print(

Check warning on line 229 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L229

Added line #L229 was not covered by tests
"WARNING: dtype not found in backbone, using default model dtype"
)
else:
config["model"]["use_pbc"] = pbc
config["model"]["use_pbc_single"] = False
if dtype is not None:
try:
config["model"].update({"dtype": _DTYPE_DICT[dtype]})
except KeyError:
print("dtype not found in backbone, using default dtype")
print(

Check warning on line 239 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L239

Added line #L239 was not covered by tests
"WARNING: dtype not found in backbone, using default model dtype"
)

### backwards compatibility with OCP v<2.0
config = update_config(config)
Expand All @@ -257,8 +264,6 @@
inference_only=True,
)

self.trainer.model = self.trainer.model.eval()

if dtype is not None:
# Convert model parameters to specified dtype
self.trainer.model = self.trainer.model.to(dtype=self.dtype)
Expand All @@ -275,6 +280,9 @@
else:
self.trainer.set_seed(seed)

if disable_amp:
self.trainer.scaler = None

self.implemented_properties = list(self.config["outputs"])

self._device = self.trainer.device
Expand Down Expand Up @@ -335,6 +343,12 @@
if state.batch is None:
state.batch = torch.zeros(state.positions.shape[0], dtype=torch.int)

if self.pbc != state.pbc:
raise ValueError(

Check warning on line 347 in torch_sim/models/fairchem.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/fairchem.py#L347

Added line #L347 was not covered by tests
"PBC mismatch between model and state. "
"For FairChemModel PBC needs to be defined in the model class."
)

natoms = torch.bincount(state.batch)
pbc = torch.tensor(
[state.pbc, state.pbc, state.pbc] * len(natoms), dtype=torch.bool
Expand All @@ -350,9 +364,9 @@
pbc=pbc,
)

if self._dtype is not None:
self.data_object.pos = self.data_object.pos.to(self._dtype)
self.data_object.cell = self.data_object.cell.to(self._dtype)
if self.dtype is not None:
self.data_object.pos = self.data_object.pos.to(self.dtype)
self.data_object.cell = self.data_object.cell.to(self.dtype)

predictions = self.trainer.predict(
self.data_object, per_image=False, disable_tqdm=True
Expand Down
6 changes: 3 additions & 3 deletions torch_sim/models/graphpes.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,15 @@ def __init__(
model if isinstance(model, GraphPESModel) else load_model(model) # type: ignore[arg-type]
),
)
self._gp_model = _model.to(device=self._device, dtype=self._dtype)
self._gp_model = _model.to(device=self.device, dtype=self.dtype)

self._compute_forces = compute_forces
self._compute_stress = compute_stress

self._properties: list[PropertyKey] = ["energy"]
if self._compute_forces:
if self.compute_forces:
self._properties.append("forces")
if self._compute_stress:
if self.compute_stress:
self._properties.append("stress")

if self._gp_model.cutoff.item() < 0.5:
Expand Down
22 changes: 10 additions & 12 deletions torch_sim/models/lennard_jones.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,9 @@
self.use_neighbor_list = use_neighbor_list

# Convert parameters to tensors
self.sigma = torch.tensor(sigma, dtype=dtype, device=self._device)
self.cutoff = torch.tensor(
cutoff or 2.5 * sigma, dtype=dtype, device=self._device
)
self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self._device)
self.sigma = torch.tensor(sigma, dtype=dtype, device=self.device)
self.cutoff = torch.tensor(cutoff or 2.5 * sigma, dtype=dtype, device=self.device)
self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self.device)

def unbatched_forward(
self,
Expand Down Expand Up @@ -209,7 +207,7 @@
pbc=pbc,
)
# Mask out self-interactions
mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self._device)
mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device)

Check warning on line 210 in torch_sim/models/lennard_jones.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/models/lennard_jones.py#L210

Added line #L210 was not covered by tests
distances = distances.masked_fill(mask, float("inf"))
# Apply cutoff
mask = distances < self.cutoff
Expand All @@ -233,14 +231,14 @@

if self.per_atom_energies:
atom_energies = torch.zeros(
positions.shape[0], dtype=self._dtype, device=self._device
positions.shape[0], dtype=self.dtype, device=self.device
)
# Each atom gets half of the pair energy
atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies)
atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies)
results["energies"] = atom_energies

if self._compute_forces or self._compute_stress:
if self.compute_forces or self.compute_stress:
# Calculate forces and apply cutoff
pair_forces = lennard_jones_pair_force(
distances, sigma=self.sigma, epsilon=self.epsilon
Expand All @@ -250,15 +248,15 @@
# Project forces along displacement vectors
force_vectors = (pair_forces / distances)[:, None] * dr_vec

if self._compute_forces:
if self.compute_forces:
# Initialize forces tensor
forces = torch.zeros_like(positions)
# Add force contributions (f_ij on i, -f_ij on j)
forces.index_add_(0, mapping[0], -force_vectors)
forces.index_add_(0, mapping[1], force_vectors)
results["forces"] = forces

if self._compute_stress and cell is not None:
if self.compute_stress and cell is not None:
# Compute stress tensor
stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors)
volume = torch.abs(torch.linalg.det(cell))
Expand All @@ -268,8 +266,8 @@
if self.per_atom_stresses:
atom_stresses = torch.zeros(
(state.positions.shape[0], 3, 3),
dtype=self._dtype,
device=self._device,
dtype=self.dtype,
device=self.device,
)
atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair)
atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair)
Expand Down
12 changes: 6 additions & 6 deletions torch_sim/models/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def __init__(
self.model = model.to(self._device)
self.model = self.model.eval()

if self._dtype is not None:
self.model = self.model.to(dtype=self._dtype)
if self.dtype is not None:
self.model = self.model.to(dtype=self.dtype)

if enable_cueq:
print("Converting models to CuEq for acceleration")
Expand Down Expand Up @@ -334,8 +334,8 @@ def forward( # noqa: C901
unit_shifts=unit_shifts,
shifts=shifts_list,
),
compute_force=self._compute_forces,
compute_stress=self._compute_stress,
compute_force=self.compute_forces,
compute_stress=self.compute_stress,
)

results = {}
Expand All @@ -348,13 +348,13 @@ def forward( # noqa: C901
results["energy"] = torch.zeros(self.n_systems, device=self.device)

# Process forces
if self._compute_forces:
if self.compute_forces:
forces = out["forces"]
if forces is not None:
results["forces"] = forces.detach()

# Process stress
if self._compute_stress:
if self.compute_stress:
stress = out["stress"]
if stress is not None:
results["stress"] = stress.detach()
Expand Down
4 changes: 2 additions & 2 deletions torch_sim/models/mattersim.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def __init__(
self.model = model.to(self._device)
self.model = self.model.eval()

if self._dtype is not None:
self.model = self.model.to(dtype=self._dtype)
if self.dtype is not None:
self.model = self.model.to(dtype=self.dtype)

model_args = self.model.model.model_args
self.two_body_cutoff = model_args["cutoff"]
Expand Down
Loading
Loading