Skip to content
Merged
151 changes: 98 additions & 53 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Any

import numpy as np
import pytest
import torch
from ase import Atoms
Expand All @@ -9,6 +10,7 @@
from pymatgen.core import Structure

import torch_sim as ts
from torch_sim.io import atoms_to_state
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.models.mace import MaceModel, MaceUrls
from torch_sim.state import concatenate_states
Expand All @@ -29,6 +31,66 @@ def dtype() -> torch.dtype:
return torch.float64


@pytest.fixture
def unbatched_lj_model(
device: torch.device, dtype: torch.dtype
) -> UnbatchedLennardJonesModel:
"""Create a Lennard-Jones model with reasonable parameters for Ar."""
return UnbatchedLennardJonesModel(
use_neighbor_list=True,
sigma=3.405,
epsilon=0.0104,
device=device,
dtype=dtype,
compute_forces=True,
compute_stress=True,
cutoff=2.5 * 3.405,
)


@pytest.fixture
def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel:
"""Create a Lennard-Jones model with reasonable parameters for Ar."""
return LennardJonesModel(
use_neighbor_list=True,
sigma=3.405,
epsilon=0.0104,
device=device,
dtype=dtype,
compute_forces=True,
compute_stress=True,
cutoff=2.5 * 3.405,
)


@pytest.fixture
def ase_mace_mpa() -> "MACECalculator":
"""Provides an ASE MACECalculator instance using mace_mp."""
from mace.calculators.foundations_models import mace_mp

# Ensure dtype matches the one used in the torchsim fixture (float64)
return mace_mp(model=MaceUrls.mace_mp_small, default_dtype="float64")


@pytest.fixture
def torchsim_mace_mpa() -> MaceModel:
"""Provides a MACE MP model instance for the optimizer tests."""
from mace.calculators.foundations_models import mace_mp

# Use float64 for potentially higher precision needed in optimization
dtype = getattr(torch, dtype_str := "float64")
raw_mace = mace_mp(
model=MaceUrls.mace_mp_small, return_raw_model=True, default_dtype=dtype_str
)
return MaceModel(
model=raw_mace,
device="cpu",
dtype=dtype,
compute_forces=True,
compute_stress=True,
)


@pytest.fixture
def ar_atoms() -> Atoms:
"""Create a face-centered cubic (FCC) Argon structure."""
Expand Down Expand Up @@ -293,60 +355,43 @@ def mixed_double_sim_state(


@pytest.fixture
def unbatched_lj_model(
device: torch.device, dtype: torch.dtype
) -> UnbatchedLennardJonesModel:
"""Create a Lennard-Jones model with reasonable parameters for Ar."""
return UnbatchedLennardJonesModel(
use_neighbor_list=True,
sigma=3.405,
epsilon=0.0104,
device=device,
dtype=dtype,
compute_forces=True,
compute_stress=True,
cutoff=2.5 * 3.405,
)


@pytest.fixture
def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel:
"""Create a Lennard-Jones model with reasonable parameters for Ar."""
return LennardJonesModel(
use_neighbor_list=True,
sigma=3.405,
epsilon=0.0104,
device=device,
dtype=dtype,
compute_forces=True,
compute_stress=True,
cutoff=2.5 * 3.405,
def osn2_sim_state(torchsim_mace_mpa: MaceModel) -> ts.state.SimState:
"""Provides an initial SimState for rhombohedral OsN2."""
# For pymatgen Structure initialization
from pymatgen.core import Lattice, Structure

a = 3.211996
lattice = Lattice.from_parameters(a, a, a, 60, 60, 60)
species = ["Os", "N"]
frac_coords = [[0.75, 0.7501, -0.25], [0, 0, 0]] # Slightly perturbed
structure = Structure(lattice, species, frac_coords, coords_are_cartesian=False)
return ts.initialize_state(
structure, dtype=torchsim_mace_mpa.dtype, device=torchsim_mace_mpa.device
)


@pytest.fixture
def ase_mace_mpa() -> "MACECalculator":
"""Provides an ASE MACECalculator instance using mace_mp."""
from mace.calculators.foundations_models import mace_mp

# Ensure dtype matches the one used in the torchsim fixture (float64)
return mace_mp(model=MaceUrls.mace_mp_small, default_dtype="float64")


@pytest.fixture
def torchsim_mace_mpa() -> MaceModel:
"""Provides a MACE MP model instance for the optimizer tests."""
from mace.calculators.foundations_models import mace_mp

# Use float64 for potentially higher precision needed in optimization
dtype = getattr(torch, dtype_str := "float64")
raw_mace = mace_mp(
model=MaceUrls.mace_mp_small, return_raw_model=True, default_dtype=dtype_str
)
return MaceModel(
model=raw_mace,
device="cpu",
dtype=dtype,
compute_forces=True,
compute_stress=True,
)
def distorted_fcc_al_conventional_sim_state(
torchsim_mace_mpa: MaceModel,
) -> ts.state.SimState:
"""Initial SimState for a slightly distorted FCC Al conventional cell (4 atoms)."""
# Create a standard 4-atom conventional FCC Al cell
atoms_fcc = bulk("Al", crystalstructure="fcc", a=4.05, cubic=True)

# Define a small triclinic strain matrix (deviations from identity)
strain_matrix = np.array([[1.0, 0.05, -0.03], [0.04, 1.0, 0.06], [-0.02, 0.03, 1.0]])

original_cell = atoms_fcc.get_cell()
new_cell = original_cell @ strain_matrix.T # Apply strain
atoms_fcc.set_cell(new_cell, scale_atoms=True)

# Slightly perturb atomic positions to break perfect symmetry after strain
positions = atoms_fcc.get_positions()
np_rng = np.random.default_rng(seed=42)
positions += np_rng.normal(scale=0.01, size=positions.shape)
atoms_fcc.set_positions(positions)

dtype = torchsim_mace_mpa.dtype
device = torchsim_mace_mpa.device
# Convert the ASE Atoms object to SimState (will be a single batch with 4 atoms)
return atoms_to_state(atoms_fcc, device=device, dtype=dtype)
Loading
Loading