diff --git a/tests/conftest.py b/tests/conftest.py index 699e9505..906eca08 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Any +import numpy as np import pytest import torch from ase import Atoms @@ -9,6 +10,7 @@ from pymatgen.core import Structure import torch_sim as ts +from torch_sim.io import atoms_to_state from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.models.mace import MaceModel, MaceUrls from torch_sim.state import concatenate_states @@ -29,6 +31,66 @@ def dtype() -> torch.dtype: return torch.float64 +@pytest.fixture +def unbatched_lj_model( + device: torch.device, dtype: torch.dtype +) -> UnbatchedLennardJonesModel: + """Create a Lennard-Jones model with reasonable parameters for Ar.""" + return UnbatchedLennardJonesModel( + use_neighbor_list=True, + sigma=3.405, + epsilon=0.0104, + device=device, + dtype=dtype, + compute_forces=True, + compute_stress=True, + cutoff=2.5 * 3.405, + ) + + +@pytest.fixture +def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel: + """Create a Lennard-Jones model with reasonable parameters for Ar.""" + return LennardJonesModel( + use_neighbor_list=True, + sigma=3.405, + epsilon=0.0104, + device=device, + dtype=dtype, + compute_forces=True, + compute_stress=True, + cutoff=2.5 * 3.405, + ) + + +@pytest.fixture +def ase_mace_mpa() -> "MACECalculator": + """Provides an ASE MACECalculator instance using mace_mp.""" + from mace.calculators.foundations_models import mace_mp + + # Ensure dtype matches the one used in the torchsim fixture (float64) + return mace_mp(model=MaceUrls.mace_mp_small, default_dtype="float64") + + +@pytest.fixture +def torchsim_mace_mpa() -> MaceModel: + """Provides a MACE MP model instance for the optimizer tests.""" + from mace.calculators.foundations_models import mace_mp + + # Use float64 for potentially higher precision needed in optimization + dtype = getattr(torch, dtype_str := "float64") + raw_mace = mace_mp( + model=MaceUrls.mace_mp_small, return_raw_model=True, default_dtype=dtype_str + ) + return MaceModel( + model=raw_mace, + device="cpu", + dtype=dtype, + compute_forces=True, + compute_stress=True, + ) + + @pytest.fixture def ar_atoms() -> Atoms: """Create a face-centered cubic (FCC) Argon structure.""" @@ -293,60 +355,43 @@ def mixed_double_sim_state( @pytest.fixture -def unbatched_lj_model( - device: torch.device, dtype: torch.dtype -) -> UnbatchedLennardJonesModel: - """Create a Lennard-Jones model with reasonable parameters for Ar.""" - return UnbatchedLennardJonesModel( - use_neighbor_list=True, - sigma=3.405, - epsilon=0.0104, - device=device, - dtype=dtype, - compute_forces=True, - compute_stress=True, - cutoff=2.5 * 3.405, - ) - - -@pytest.fixture -def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel: - """Create a Lennard-Jones model with reasonable parameters for Ar.""" - return LennardJonesModel( - use_neighbor_list=True, - sigma=3.405, - epsilon=0.0104, - device=device, - dtype=dtype, - compute_forces=True, - compute_stress=True, - cutoff=2.5 * 3.405, +def osn2_sim_state(torchsim_mace_mpa: MaceModel) -> ts.state.SimState: + """Provides an initial SimState for rhombohedral OsN2.""" + # For pymatgen Structure initialization + from pymatgen.core import Lattice, Structure + + a = 3.211996 + lattice = Lattice.from_parameters(a, a, a, 60, 60, 60) + species = ["Os", "N"] + frac_coords = [[0.75, 0.7501, -0.25], [0, 0, 0]] # Slightly perturbed + structure = Structure(lattice, species, frac_coords, coords_are_cartesian=False) + return ts.initialize_state( + structure, dtype=torchsim_mace_mpa.dtype, device=torchsim_mace_mpa.device ) @pytest.fixture -def ase_mace_mpa() -> "MACECalculator": - """Provides an ASE MACECalculator instance using mace_mp.""" - from mace.calculators.foundations_models import mace_mp - - # Ensure dtype matches the one used in the torchsim fixture (float64) - return mace_mp(model=MaceUrls.mace_mp_small, default_dtype="float64") - - -@pytest.fixture -def torchsim_mace_mpa() -> MaceModel: - """Provides a MACE MP model instance for the optimizer tests.""" - from mace.calculators.foundations_models import mace_mp - - # Use float64 for potentially higher precision needed in optimization - dtype = getattr(torch, dtype_str := "float64") - raw_mace = mace_mp( - model=MaceUrls.mace_mp_small, return_raw_model=True, default_dtype=dtype_str - ) - return MaceModel( - model=raw_mace, - device="cpu", - dtype=dtype, - compute_forces=True, - compute_stress=True, - ) +def distorted_fcc_al_conventional_sim_state( + torchsim_mace_mpa: MaceModel, +) -> ts.state.SimState: + """Initial SimState for a slightly distorted FCC Al conventional cell (4 atoms).""" + # Create a standard 4-atom conventional FCC Al cell + atoms_fcc = bulk("Al", crystalstructure="fcc", a=4.05, cubic=True) + + # Define a small triclinic strain matrix (deviations from identity) + strain_matrix = np.array([[1.0, 0.05, -0.03], [0.04, 1.0, 0.06], [-0.02, 0.03, 1.0]]) + + original_cell = atoms_fcc.get_cell() + new_cell = original_cell @ strain_matrix.T # Apply strain + atoms_fcc.set_cell(new_cell, scale_atoms=True) + + # Slightly perturb atomic positions to break perfect symmetry after strain + positions = atoms_fcc.get_positions() + np_rng = np.random.default_rng(seed=42) + positions += np_rng.normal(scale=0.01, size=positions.shape) + atoms_fcc.set_positions(positions) + + dtype = torchsim_mace_mpa.dtype + device = torchsim_mace_mpa.device + # Convert the ASE Atoms object to SimState (will be a single batch with 4 atoms) + return atoms_to_state(atoms_fcc, device=device, dtype=dtype) diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py index 42d9978b..e52a82d1 100644 --- a/tests/test_optimizers_vs_ase.py +++ b/tests/test_optimizers_vs_ase.py @@ -1,126 +1,229 @@ import copy -import functools +from typing import TYPE_CHECKING, Any import pytest import torch -from ase.filters import FrechetCellFilter +from ase.filters import FrechetCellFilter, UnitCellFilter from ase.optimize import FIRE import torch_sim as ts from torch_sim.io import state_to_atoms from torch_sim.models.mace import MaceModel -from torch_sim.optimizers import frechet_cell_fire +from torch_sim.optimizers import frechet_cell_fire, unit_cell_fire -try: +if TYPE_CHECKING: from mace.calculators import MACECalculator -except ImportError: - MACECalculator = None -@pytest.mark.skipif(MACECalculator is None, reason="MACECalculator not installed") -def test_torchsim_frechet_cell_fire_vs_ase_mace( - rattled_sio2_sim_state: ts.state.SimState, +def _run_and_compare_optimizers( + initial_sim_state_fixture: ts.state.SimState, torchsim_mace_mpa: MaceModel, - ase_mace_mpa: MACECalculator, + ase_mace_mpa: "MACECalculator", + torch_sim_optimizer_type: str, + ase_filter_class: Any, + checkpoints: list[int], + force_tol: float, + tolerances: dict[str, float], + test_id_prefix: str, ) -> None: - """Compare torch-sim's Frechet Cell FIRE optimizer with ASE's FIRE + FrechetCellFilter - using MACE-MPA-0. - - This test ensures that the custom Frechet Cell FIRE implementation behaves comparably - to the established ASE equivalent when using a MACE force field. - It checks for consistency in final energies, forces, positions, and cell parameters. - """ - # Use float64 for consistency with the MACE model fixture and for precision + """Run and compare optimizations between torch-sim and ASE.""" + pytest.importorskip("mace") dtype = torch.float64 - device = torchsim_mace_mpa.device # Use device from the model + device = torchsim_mace_mpa.device - # --- Setup Initial State with float64 --- - # Deepcopy to avoid modifying the fixture state for other tests - initial_state = copy.deepcopy(rattled_sio2_sim_state).to(dtype=dtype, device=device) - - # Ensure grads are enabled for both positions and cell for optimization - initial_state.positions = initial_state.positions.detach().requires_grad_( - requires_grad=True + ts_current_system_state = copy.deepcopy(initial_sim_state_fixture).to( + dtype=dtype, device=device ) - initial_state.cell = initial_state.cell.detach().requires_grad_(requires_grad=True) - - n_steps = 20 # Number of optimization steps - force_tol = 0.02 # Convergence criterion for forces - - # --- Run torch-sim Frechet Cell FIRE with MACE model --- - # Use functools.partial to set md_flavor for the frechet_cell_fire optimizer - torch_sim_optimizer = functools.partial(frechet_cell_fire, md_flavor="ase_fire") - - custom_opt_state = ts.optimize( - system=initial_state, - model=torchsim_mace_mpa, - optimizer=torch_sim_optimizer, - max_steps=n_steps, - convergence_fn=ts.generate_force_convergence_fn(force_tol=force_tol), + ts_current_system_state.positions = ( + ts_current_system_state.positions.detach().requires_grad_() ) - - # --- Setup ASE System with native MACE calculator --- - # Convert initial SimState to ASE Atoms object - ase_atoms = state_to_atoms(initial_state)[0] # state_to_atoms returns a list - ase_atoms.calc = ase_mace_mpa # Assign the MACE calculator - - # --- Run ASE FIRE with FrechetCellFilter --- - # Apply FrechetCellFilter for cell optimization - filtered_ase_atoms = FrechetCellFilter(ase_atoms) - ase_optimizer = FIRE(filtered_ase_atoms) - - # Run ASE optimization - ase_optimizer.run(fmax=force_tol, steps=n_steps) - - # --- Compare Results --- - final_custom_energy = custom_opt_state.energy.item() - final_custom_forces_max = torch.norm(custom_opt_state.forces, dim=-1).max().item() - final_custom_positions = custom_opt_state.positions.detach() - # Ensure cell is in row vector format and squeezed for comparison - final_custom_cell = custom_opt_state.row_vector_cell.squeeze(0).detach() - - final_ase_energy = ase_atoms.get_potential_energy() - ase_forces_raw = ase_atoms.get_forces() - if ase_forces_raw is not None: - final_ase_forces = torch.tensor(ase_forces_raw, device=device, dtype=dtype) - final_ase_forces_max = torch.norm(final_ase_forces, dim=-1).max().item() - else: - # Should not happen if calculator ran and produced forces - final_ase_forces_max = float("nan") - - final_ase_positions = torch.tensor( - ase_atoms.get_positions(), device=device, dtype=dtype - ) - final_ase_cell = torch.tensor(ase_atoms.get_cell(), device=device, dtype=dtype) - - # Compare energies (looser tolerance for ML potentials due to potential minor - # numerical differences) - energy_diff = abs(final_custom_energy - final_ase_energy) - assert energy_diff < 5e-2, ( - f"Final energies differ significantly after {n_steps} steps: " - f"torch-sim={final_custom_energy:.6f}, ASE={final_ase_energy:.6f}, " - f"Diff={energy_diff:.2e}" + ts_current_system_state.cell = ts_current_system_state.cell.detach().requires_grad_() + ts_optimizer_state = None + + optimizer_builders = { + "frechet": frechet_cell_fire, + "unit_cell": unit_cell_fire, + } + if torch_sim_optimizer_type not in optimizer_builders: + raise ValueError(f"Unknown torch_sim_optimizer_type: {torch_sim_optimizer_type}") + ts_optimizer_builder = optimizer_builders[torch_sim_optimizer_type] + + optimizer_callable_for_ts_optimize = lambda model, **_kwargs: ts_optimizer_builder( # noqa: E731 + model, md_flavor="ase_fire" ) - # Report forces for diagnostics - print( - f"Max Force ({n_steps} steps): torch-sim={final_custom_forces_max:.4f}, " - f"ASE={final_ase_forces_max:.4f}" - ) - - # Compare positions (average displacement, looser tolerance) - avg_displacement = ( - torch.norm(final_custom_positions - final_ase_positions, dim=-1).mean().item() - ) - assert avg_displacement < 1.0, ( - f"Final positions differ significantly (avg displacement: {avg_displacement:.4f})" - ) - - # Compare cell matrices (Frobenius norm, looser tolerance) - cell_diff = torch.norm(final_custom_cell - final_ase_cell).item() - assert cell_diff < 1.0, ( - f"Final cell matrices differ significantly (Frobenius norm: {cell_diff:.4f})" - f"\nTorch-sim Cell:\n{final_custom_cell}" - f"\nASE Cell:\n{final_ase_cell}" + ase_atoms_for_run = state_to_atoms( + copy.deepcopy(initial_sim_state_fixture).to(dtype=dtype, device=device) + )[0] + ase_atoms_for_run.calc = ase_mace_mpa + filtered_ase_atoms_for_run = ase_filter_class(ase_atoms_for_run) + ase_optimizer = FIRE(filtered_ase_atoms_for_run, logfile=None) + + last_checkpoint_step_count = 0 + convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol) + + for checkpoint_step in checkpoints: + steps_for_current_segment = checkpoint_step - last_checkpoint_step_count + + if steps_for_current_segment > 0: + # Ensure requires_grad is set for the input to ts.optimize + # ts.optimize is expected to return a state suitable for further optimization + # if optimizer_state is passed. + ts_current_system_state.positions = ( + ts_current_system_state.positions.detach().requires_grad_() + ) + ts_current_system_state.cell = ( + ts_current_system_state.cell.detach().requires_grad_() + ) + new_ts_state_and_optimizer_state = ts.optimize( + system=ts_current_system_state, + model=torchsim_mace_mpa, + optimizer=optimizer_callable_for_ts_optimize, + max_steps=steps_for_current_segment, + convergence_fn=convergence_fn, + optimizer_state=ts_optimizer_state, + ) + ts_current_system_state = new_ts_state_and_optimizer_state + ts_optimizer_state = new_ts_state_and_optimizer_state + + ase_optimizer.run(fmax=force_tol, steps=steps_for_current_segment) + + current_test_id = f"{test_id_prefix} (Step {checkpoint_step})" + + final_custom_energy = ts_current_system_state.energy.item() + final_custom_forces_max = ( + torch.norm(ts_current_system_state.forces, dim=-1).max().item() + ) + final_custom_positions = ts_current_system_state.positions.detach() + final_custom_cell = ts_current_system_state.row_vector_cell.squeeze(0).detach() + + final_ase_atoms = filtered_ase_atoms_for_run.atoms + final_ase_energy = final_ase_atoms.get_potential_energy() + ase_forces_raw = final_ase_atoms.get_forces() + final_ase_forces_max = torch.norm( + torch.tensor(ase_forces_raw, device=device, dtype=dtype), dim=-1 + ).max() + final_ase_positions = torch.tensor( + final_ase_atoms.get_positions(), device=device, dtype=dtype + ) + final_ase_cell = torch.tensor( + final_ase_atoms.get_cell(), device=device, dtype=dtype + ) + + energy_diff = abs(final_custom_energy - final_ase_energy) + assert energy_diff < tolerances["energy"], ( + f"{current_test_id}: Final energies differ significantly: " + f"torch-sim={final_custom_energy:.6f}, ASE={final_ase_energy:.6f}, " + f"Diff={energy_diff:.2e}" + ) + + avg_displacement = ( + torch.norm(final_custom_positions - final_ase_positions, dim=-1).mean().item() + ) + assert avg_displacement < tolerances["pos"], ( + f"{current_test_id}: Final positions differ ({avg_displacement=:.4f})" + ) + + cell_diff = torch.norm(final_custom_cell - final_ase_cell).item() + assert cell_diff < tolerances["cell"], ( + f"{current_test_id}: Final cell matrices differ (Frobenius norm: " + f"{cell_diff:.4f})\nTorch-sim Cell:\n{final_custom_cell}" + f"\nASE Cell:\n{final_ase_cell}" + ) + + force_max_diff = abs(final_custom_forces_max - final_ase_forces_max) + assert force_max_diff < tolerances["force_max"], ( + f"{current_test_id}: Max forces differ significantly: " + f"torch-sim={final_custom_forces_max:.4f}, ASE={final_ase_forces_max:.4f}, " + f"Diff={force_max_diff:.2e}" + ) + + last_checkpoint_step_count = checkpoint_step + + +@pytest.mark.parametrize( + ( + "sim_state_fixture_name", + "torch_sim_optimizer_type", + "ase_filter_class", + "checkpoints", + "force_tol", + "tolerances", + "test_id_prefix", + ), + [ + ( + "rattled_sio2_sim_state", + "frechet", + FrechetCellFilter, + [33, 66, 100], + 0.02, + {"energy": 1e-2, "pos": 1.5e-2, "cell": 1.8e-2, "force_max": 1.5e-1}, + "SiO2 (Frechet)", + ), + ( + "osn2_sim_state", + "frechet", + FrechetCellFilter, + [16, 33, 50], + 0.02, + {"energy": 1e-4, "pos": 1e-3, "cell": 1.8e-3, "force_max": 5e-2}, + "OsN2 (Frechet)", + ), + ( + "distorted_fcc_al_conventional_sim_state", + "frechet", + FrechetCellFilter, + [33, 66, 100], + 0.01, + {"energy": 1e-2, "pos": 5e-3, "cell": 2e-2, "force_max": 5e-2}, + "Triclinic Al (Frechet)", + ), + ( + "distorted_fcc_al_conventional_sim_state", + "unit_cell", + UnitCellFilter, + [33, 66, 100], + 0.01, + {"energy": 1e-2, "pos": 3e-2, "cell": 1e-1, "force_max": 5e-2}, + "Triclinic Al (UnitCell)", + ), + ( + "rattled_sio2_sim_state", + "unit_cell", + UnitCellFilter, + [33, 66, 100], + 0.02, + {"energy": 1.5e-2, "pos": 2.5e-2, "cell": 5e-2, "force_max": 0.25}, + "SiO2 (UnitCell)", + ), + ], +) +def test_optimizer_vs_ase_parametrized( + sim_state_fixture_name: str, + torch_sim_optimizer_type: str, + ase_filter_class: Any, + checkpoints: list[int], + force_tol: float, + tolerances: dict[str, float], + test_id_prefix: str, + torchsim_mace_mpa: MaceModel, + ase_mace_mpa: "MACECalculator", + request: pytest.FixtureRequest, +) -> None: + """Compare torch-sim optimizers with ASE FIRE and relevant filters at multiple + checkpoints.""" + initial_sim_state_fixture = request.getfixturevalue(sim_state_fixture_name) + + _run_and_compare_optimizers( + initial_sim_state_fixture=initial_sim_state_fixture, + torchsim_mace_mpa=torchsim_mace_mpa, + ase_mace_mpa=ase_mace_mpa, + torch_sim_optimizer_type=torch_sim_optimizer_type, + ase_filter_class=ase_filter_class, + checkpoints=checkpoints, + force_tol=force_tol, + tolerances=tolerances, + test_id_prefix=test_id_prefix, ) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index b16ab750..a4bbfa3c 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -514,8 +514,7 @@ def fire( f_alpha (float): Factor for mixing parameter decrease max_step (float): Maximum distance an atom can move per iteration (default value is 0.2). Only used when md_flavor='ase_fire'. - md_flavor (MdFlavor): Optimization flavor, either "vv_fire" or "ase_fire". - Default is "ase_fire". + md_flavor ("vv_fire" | "ase_fire"): Optimization flavor. Default is "ase_fire". Returns: tuple[Callable, Callable]: @@ -736,8 +735,7 @@ def unit_cell_fire( constant_volume (bool): Whether to maintain constant volume during optimization scalar_pressure (float): Applied external pressure in GPa max_step (float): Maximum allowed step size for ase_fire - md_flavor (MdFlavor): Optimization flavor, either "vv_fire" or "ase_fire". - Default is "ase_fire". + md_flavor ("vv_fire" | "ase_fire"): Optimization flavor. Default is "ase_fire". Returns: tuple: A pair of functions: @@ -1023,8 +1021,7 @@ def frechet_cell_fire( constant_volume (bool): Whether to maintain constant volume during optimization scalar_pressure (float): Applied external pressure in GPa max_step (float): Maximum allowed step size for ase_fire - md_flavor (MdFlavor): Optimization flavor, either "vv_fire" or "ase_fire". - Default is "ase_fire". + md_flavor ("vv_fire" | "ase_fire"): Optimization flavor. Default is "ase_fire". Returns: tuple: A pair of functions: @@ -1534,6 +1531,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 mask_atom_max_step, max_step * dr_atom / (dr_norm_atom + eps), dr_atom ) + old_row_vector_cell: torch.Tensor | None = None if is_cell_optimization: assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) # Cell clamp to max_step (Frobenius norm) @@ -1545,7 +1543,12 @@ def _ase_fire_step( # noqa: C901, PLR0915 dr_cell, ) - # 7. Position / cell update + # 7. Position / cell update + # Store old cell for scaling atoms + # Ensure old_row_vector_cell is cloned before any modification to state.cell or + # state.row_vector_cell + old_row_vector_cell = state.row_vector_cell.clone() + state.positions = state.positions + dr_atom # F_new stores F_new for Frechet's ucf_cell_grad if needed @@ -1580,8 +1583,32 @@ def _ase_fire_step( # noqa: C901, PLR0915 F_new_scaled = current_F_scaled + dr_cell state.cell_positions = F_new_scaled # track the scaled deformation gradient F_new = F_new_scaled / (cell_factor_exp_mult + eps) # Division by (N,3,1) - new_cell = torch.bmm(state.reference_cell, F_new.transpose(-2, -1)) - state.cell = new_cell + # When state.cell is set, state.row_vector_cell is auto-updated + new_cell_column_vectors = torch.bmm( + state.reference_cell, F_new.transpose(-2, -1) + ) + state.cell = new_cell_column_vectors + + # Scale atomic positions according to cell change (mimicking scale_atoms=True) + if is_cell_optimization and old_row_vector_cell is not None: + current_new_row_vector_cell = state.row_vector_cell # This is A_new after update + + inv_old_cell_batch = torch.linalg.inv(old_row_vector_cell) + # Transform matrix T such that A_new = A_old @ T (for row vectors A) + # This means cartesian positions P_new_row = P_old_row @ T + transform_matrix_batch = torch.bmm( + inv_old_cell_batch, current_new_row_vector_cell + ) # Shape [N_batch, 3, 3] + + # Shape: [N_atoms, 3, 3] + atom_specific_transform = transform_matrix_batch[state.batch] + + # state.positions is [N_atoms, 3]. Unsqueeze to [N_atoms, 1, 3] for bmm + # Result of bmm will be [N_atoms, 1, 3], then squeeze + scaled_positions = torch.bmm( + state.positions.unsqueeze(1), atom_specific_transform + ).squeeze(1) + state.positions = scaled_positions # 8. Force / stress refresh & new cell forces results = model(state) @@ -1593,11 +1620,10 @@ def _ase_fire_step( # noqa: C901, PLR0915 state.stress = results["stress"] volumes = torch.linalg.det(state.cell).view(-1, 1, 1) if torch.any(volumes <= 0): - bad_idx = torch.where(volumes <= 0)[0] + bad_indices = torch.where(volumes <= 0)[0].tolist() print( f"WARNING: Non-positive volume(s) detected during _ase_fire_step: " - f"{volumes[bad_idx].tolist()} at indices {bad_idx.tolist()} " - f"(is_frechet={is_frechet})" + f"{volumes[bad_indices].tolist()} at {bad_indices=} ({is_frechet=})" ) # volumes = torch.clamp(volumes, min=eps) # Optional: for stability