diff --git a/.github/workflows/link-check.yml b/.github/workflows/link-check.yml index 2b57b818..25a0aa68 100644 --- a/.github/workflows/link-check.yml +++ b/.github/workflows/link-check.yml @@ -18,5 +18,3 @@ jobs: with: # ignore ipynb links since they're generated on the fly args: --exclude-path dist --exclude '\.ipynb$' --accept 100..=103,200..=299,403,429,500 -- ./**/*.{md,py,yml,json} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index 2fc90f3b..399c5a5e 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -171,7 +171,9 @@ def run_optimization_ts( # noqa: PLR0915 convergence_steps = torch.full( (total_structures,), -1, dtype=torch.long, device=device ) - convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol) + convergence_fn = ts.generate_force_convergence_fn( + force_tol=force_tol, include_cell_forces=ts_use_frechet + ) converged_tensor_global = torch.zeros( total_structures, dtype=torch.bool, device=device ) @@ -194,7 +196,7 @@ def run_optimization_ts( # noqa: PLR0915 current_indices_list, dtype=torch.long, device=device ) - steps_this_round = 10 + steps_this_round = 1 for _ in range(steps_this_round): opt_state = update_fn_opt(opt_state) global_step += steps_this_round diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py index e52a82d1..acf5e0b9 100644 --- a/tests/test_optimizers_vs_ase.py +++ b/tests/test_optimizers_vs_ase.py @@ -1,13 +1,13 @@ -import copy from typing import TYPE_CHECKING, Any import pytest import torch from ase.filters import FrechetCellFilter, UnitCellFilter from ase.optimize import FIRE +from pymatgen.analysis.structure_matcher import StructureMatcher import torch_sim as ts -from torch_sim.io import state_to_atoms +from torch_sim.io import atoms_to_state, state_to_atoms, state_to_structures from torch_sim.models.mace import MaceModel from torch_sim.optimizers import frechet_cell_fire, unit_cell_fire @@ -16,6 +16,65 @@ from mace.calculators import MACECalculator +def _compare_ase_and_ts_states( + ts_current_system_state: ts.state.SimState, + filtered_ase_atoms_for_run: Any, + tolerances: dict[str, float], + current_test_id: str, +) -> None: + structure_matcher = StructureMatcher( + ltol=tolerances["lattice_tol"], + stol=tolerances["site_tol"], + angle_tol=tolerances["angle_tol"], + scale=False, + ) + + tensor_kwargs = { + "device": ts_current_system_state.device, + "dtype": ts_current_system_state.dtype, + } + + final_custom_energy = ts_current_system_state.energy.item() + final_custom_forces_max = ( + torch.norm(ts_current_system_state.forces, dim=-1).max().item() + ) + + # Convert torch-sim state to pymatgen Structure + ts_structure = state_to_structures(ts_current_system_state)[0] + + # Convert ASE atoms to pymatgen Structure + final_ase_atoms = filtered_ase_atoms_for_run.atoms + final_ase_energy = final_ase_atoms.get_potential_energy() + ase_forces_raw = final_ase_atoms.get_forces() + final_ase_forces_max = torch.norm( + torch.tensor(ase_forces_raw, **tensor_kwargs), dim=-1 + ).max() + ts_state = atoms_to_state(final_ase_atoms, **tensor_kwargs) + ase_structure = state_to_structures(ts_state)[0] + + # Compare energies + energy_diff = abs(final_custom_energy - final_ase_energy) + assert energy_diff < tolerances["energy"], ( + f"{current_test_id}: Final energies differ significantly: " + f"torch-sim={final_custom_energy:.6f}, ASE={final_ase_energy:.6f}, " + f"Diff={energy_diff:.2e}" + ) + + # Compare forces + force_max_diff = abs(final_custom_forces_max - final_ase_forces_max) + assert force_max_diff < tolerances["force_max"], ( + f"{current_test_id}: Max forces differ significantly: " + f"torch-sim={final_custom_forces_max:.4f}, ASE={final_ase_forces_max:.4f}, " + f"Diff={force_max_diff:.2e}" + ) + + # Compare structures using StructureMatcher + assert structure_matcher.fit(ts_structure, ase_structure), ( + f"{current_test_id}: Structures do not match according to StructureMatcher\n" + f"{ts_structure=}\n{ase_structure=}" + ) + + def _run_and_compare_optimizers( initial_sim_state_fixture: ts.state.SimState, torchsim_mace_mpa: MaceModel, @@ -32,14 +91,7 @@ def _run_and_compare_optimizers( dtype = torch.float64 device = torchsim_mace_mpa.device - ts_current_system_state = copy.deepcopy(initial_sim_state_fixture).to( - dtype=dtype, device=device - ) - ts_current_system_state.positions = ( - ts_current_system_state.positions.detach().requires_grad_() - ) - ts_current_system_state.cell = ts_current_system_state.cell.detach().requires_grad_() - ts_optimizer_state = None + ts_current_system_state = initial_sim_state_fixture.clone() optimizer_builders = { "frechet": frechet_cell_fire, @@ -54,89 +106,53 @@ def _run_and_compare_optimizers( ) ase_atoms_for_run = state_to_atoms( - copy.deepcopy(initial_sim_state_fixture).to(dtype=dtype, device=device) + initial_sim_state_fixture.clone().to(dtype=dtype, device=device) )[0] ase_atoms_for_run.calc = ase_mace_mpa filtered_ase_atoms_for_run = ase_filter_class(ase_atoms_for_run) ase_optimizer = FIRE(filtered_ase_atoms_for_run, logfile=None) last_checkpoint_step_count = 0 - convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol) + convergence_fn = ts.generate_force_convergence_fn( + force_tol=force_tol, include_cell_forces=True + ) + + results = torchsim_mace_mpa(ts_current_system_state) + ts_initial_system_state = ts_current_system_state.clone() + ts_initial_system_state.forces = results["forces"] + ts_initial_system_state.energy = results["energy"] + ase_atoms_for_run.calc.calculate(ase_atoms_for_run) + + _compare_ase_and_ts_states( + ts_initial_system_state, + filtered_ase_atoms_for_run, + tolerances, + f"{test_id_prefix} (Initial)", + ) for checkpoint_step in checkpoints: steps_for_current_segment = checkpoint_step - last_checkpoint_step_count if steps_for_current_segment > 0: - # Ensure requires_grad is set for the input to ts.optimize - # ts.optimize is expected to return a state suitable for further optimization - # if optimizer_state is passed. - ts_current_system_state.positions = ( - ts_current_system_state.positions.detach().requires_grad_() - ) - ts_current_system_state.cell = ( - ts_current_system_state.cell.detach().requires_grad_() - ) - new_ts_state_and_optimizer_state = ts.optimize( + updated_ts_state = ts.optimize( system=ts_current_system_state, model=torchsim_mace_mpa, optimizer=optimizer_callable_for_ts_optimize, max_steps=steps_for_current_segment, convergence_fn=convergence_fn, - optimizer_state=ts_optimizer_state, + steps_between_swaps=1, ) - ts_current_system_state = new_ts_state_and_optimizer_state - ts_optimizer_state = new_ts_state_and_optimizer_state + ts_current_system_state = updated_ts_state.clone() ase_optimizer.run(fmax=force_tol, steps=steps_for_current_segment) current_test_id = f"{test_id_prefix} (Step {checkpoint_step})" - final_custom_energy = ts_current_system_state.energy.item() - final_custom_forces_max = ( - torch.norm(ts_current_system_state.forces, dim=-1).max().item() - ) - final_custom_positions = ts_current_system_state.positions.detach() - final_custom_cell = ts_current_system_state.row_vector_cell.squeeze(0).detach() - - final_ase_atoms = filtered_ase_atoms_for_run.atoms - final_ase_energy = final_ase_atoms.get_potential_energy() - ase_forces_raw = final_ase_atoms.get_forces() - final_ase_forces_max = torch.norm( - torch.tensor(ase_forces_raw, device=device, dtype=dtype), dim=-1 - ).max() - final_ase_positions = torch.tensor( - final_ase_atoms.get_positions(), device=device, dtype=dtype - ) - final_ase_cell = torch.tensor( - final_ase_atoms.get_cell(), device=device, dtype=dtype - ) - - energy_diff = abs(final_custom_energy - final_ase_energy) - assert energy_diff < tolerances["energy"], ( - f"{current_test_id}: Final energies differ significantly: " - f"torch-sim={final_custom_energy:.6f}, ASE={final_ase_energy:.6f}, " - f"Diff={energy_diff:.2e}" - ) - - avg_displacement = ( - torch.norm(final_custom_positions - final_ase_positions, dim=-1).mean().item() - ) - assert avg_displacement < tolerances["pos"], ( - f"{current_test_id}: Final positions differ ({avg_displacement=:.4f})" - ) - - cell_diff = torch.norm(final_custom_cell - final_ase_cell).item() - assert cell_diff < tolerances["cell"], ( - f"{current_test_id}: Final cell matrices differ (Frobenius norm: " - f"{cell_diff:.4f})\nTorch-sim Cell:\n{final_custom_cell}" - f"\nASE Cell:\n{final_ase_cell}" - ) - - force_max_diff = abs(final_custom_forces_max - final_ase_forces_max) - assert force_max_diff < tolerances["force_max"], ( - f"{current_test_id}: Max forces differ significantly: " - f"torch-sim={final_custom_forces_max:.4f}, ASE={final_ase_forces_max:.4f}, " - f"Diff={force_max_diff:.2e}" + _compare_ase_and_ts_states( + ts_current_system_state, + filtered_ase_atoms_for_run, + tolerances, + current_test_id, ) last_checkpoint_step_count = checkpoint_step @@ -157,47 +173,92 @@ def _run_and_compare_optimizers( "rattled_sio2_sim_state", "frechet", FrechetCellFilter, - [33, 66, 100], + [1, 33, 66, 100], 0.02, - {"energy": 1e-2, "pos": 1.5e-2, "cell": 1.8e-2, "force_max": 1.5e-1}, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, "SiO2 (Frechet)", ), ( "osn2_sim_state", "frechet", FrechetCellFilter, - [16, 33, 50], + [1, 16, 33, 50], 0.02, - {"energy": 1e-4, "pos": 1e-3, "cell": 1.8e-3, "force_max": 5e-2}, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, "OsN2 (Frechet)", ), ( "distorted_fcc_al_conventional_sim_state", "frechet", FrechetCellFilter, - [33, 66, 100], + [1, 33, 66, 100], 0.01, - {"energy": 1e-2, "pos": 5e-3, "cell": 2e-2, "force_max": 5e-2}, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 5e-1, + }, "Triclinic Al (Frechet)", ), ( "distorted_fcc_al_conventional_sim_state", "unit_cell", UnitCellFilter, - [33, 66, 100], + [1, 33, 66, 100], 0.01, - {"energy": 1e-2, "pos": 3e-2, "cell": 1e-1, "force_max": 5e-2}, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 5e-1, + }, "Triclinic Al (UnitCell)", ), ( "rattled_sio2_sim_state", "unit_cell", UnitCellFilter, - [33, 66, 100], + [1, 33, 66, 100], 0.02, - {"energy": 1.5e-2, "pos": 2.5e-2, "cell": 5e-2, "force_max": 0.25}, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, "SiO2 (UnitCell)", ), + ( + "osn2_sim_state", + "unit_cell", + UnitCellFilter, + [1, 16, 33, 50], + 0.02, + { + "energy": 1e-2, + "force_max": 5e-2, + "lattice_tol": 3e-2, + "site_tol": 3e-2, + "angle_tol": 1e-1, + }, + "OsN2 (UnitCell)", + ), ], ) def test_optimizer_vs_ase_parametrized( diff --git a/tests/test_runners.py b/tests/test_runners.py index f90cfee7..1b7c6260 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -297,7 +297,7 @@ def test_optimize_fire( # Check force convergence assert torch.all(final_state.forces < 3e-1) - assert energies.shape[0] > 10 + assert energies.shape[0] >= 10 assert energies[0] > energies[-1] assert not torch.allclose(original_state.positions, final_state.positions) @@ -327,7 +327,8 @@ def test_default_converged_fn( with TorchSimTrajectory(traj_file) as traj: energies = traj.get_array("energy") - assert energies[-3] > energies[-1] + # Check that overall energy decreases (first to last) + assert energies[0] > energies[-1] assert not torch.allclose(original_state.positions, final_state.positions) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 3b87d489..e7781d26 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -1030,7 +1030,9 @@ def next_batch( # Increment attempt counters and check for max attempts in a single loop for cur_idx, abs_idx in enumerate(self.current_idx): self.swap_attempts[abs_idx] += 1 - if self.max_attempts and (self.swap_attempts[abs_idx] >= self.max_attempts): + if self.max_attempts is not None and ( + self.swap_attempts[abs_idx] >= self.max_attempts + ): # Force convergence for states that have reached max attempts convergence_tensor[cur_idx] = torch.tensor(True) # noqa: FBT003 diff --git a/torch_sim/math.py b/torch_sim/math.py index fd9182d2..07037e0e 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -987,3 +987,36 @@ def matrix_log_33( print(msg) # Fall back to scipy implementation return matrix_log_scipy(matrix).to(sim_dtype) + + +def batched_vdot( + x: torch.Tensor, y: torch.Tensor, batch_indices: torch.Tensor +) -> torch.Tensor: + """Computes batched vdot (sum of element-wise product) for groups of vectors. + + Args: + x: Tensor of shape [N_total_entities, D] (e.g., forces, velocities). + y: Tensor of shape [N_total_entities, D]. + batch_indices: Tensor of shape [N_total_entities] indicating batch membership. + + Returns: + Tensor: shape [n_batches] where each element is the sum(x_i * y_i) + for entities belonging to that batch, + summed over all components D and all entities in the batch. + """ + if ( + x.ndim != 2 + or y.ndim != 2 + or batch_indices.ndim != 1 + or x.shape != y.shape + or x.shape[0] != batch_indices.shape[0] + ): + raise ValueError(f"Invalid input shapes: {x.shape=}, {batch_indices.shape=}") + + if batch_indices.min() < 0: + raise ValueError("batch_indices must be non-negative") + + output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) + output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1)) + + return output diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 0d94b4a5..186df609 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -270,6 +270,8 @@ def gd_init( if not isinstance(state, SimState): state = SimState(**state) + n_batches = state.n_batches + # Setup cell_factor if cell_factor is None: # Count atoms per batch @@ -283,7 +285,7 @@ def gd_init( ) # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(-1, 1, 1) + cell_factor = cell_factor.view(n_batches, 1, 1) scalar_pressure = torch.full( (state.n_batches, 1, 1), scalar_pressure, device=device, dtype=dtype @@ -316,7 +318,7 @@ def gd_init( ) # shape: (n_batches, 3, 3) # Calculate virial - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (stress + pressure) if hydrostatic_strain: @@ -391,7 +393,7 @@ def gd_step( # Get per-atom and per-cell learning rates atom_wise_lr = positions_lr[state.batch].unsqueeze(-1) - cell_wise_lr = cell_lr.view(-1, 1, 1) # shape: (n_batches, 1, 1) + cell_wise_lr = cell_lr.view(n_batches, 1, 1) # shape: (n_batches, 1, 1) # Update atomic and cell positions atomic_positions_new = state.positions + atom_wise_lr * state.forces @@ -415,7 +417,7 @@ def gd_step( state.stress = model_output["stress"] # Calculate virial for cell forces - volumes = torch.linalg.det(new_row_vector_cell).view(-1, 1, 1) + volumes = torch.linalg.det(new_row_vector_cell).view(n_batches, 1, 1) virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) @@ -474,7 +476,7 @@ class FireState(SimState): # Required attributes not in SimState forces: torch.Tensor energy: torch.Tensor - velocities: torch.Tensor + velocities: torch.Tensor | None # FIRE algorithm parameters dt: torch.Tensor @@ -589,8 +591,7 @@ def fire_init( atomic_numbers=state.atomic_numbers.clone(), batch=state.batch.clone(), pbc=state.pbc, - # New attributes - velocities=torch.zeros_like(state.positions), + velocities=None, forces=forces, energy=energy, # Optimization attributes @@ -811,7 +812,7 @@ def fire_init( ) # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(-1, 1, 1) + cell_factor = cell_factor.view(n_batches, 1, 1) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) @@ -824,7 +825,7 @@ def fire_init( forces = model_output["forces"] # [n_total_atoms, 3] stress = model_output["stress"] # [n_batches, 3, 3] - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: @@ -863,14 +864,13 @@ def fire_init( atomic_numbers=state.atomic_numbers.clone(), batch=state.batch.clone(), pbc=state.pbc, - # New attributes - velocities=torch.zeros_like(state.positions), + velocities=None, forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=torch.zeros(n_batches, 3, 3, device=device, dtype=dtype), - cell_velocities=torch.zeros(n_batches, 3, 3, device=device, dtype=dtype), + cell_velocities=None, cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes @@ -967,7 +967,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): # Cell attributes cell_positions: torch.Tensor - cell_velocities: torch.Tensor + cell_velocities: torch.Tensor | None cell_forces: torch.Tensor cell_masses: torch.Tensor @@ -1097,7 +1097,7 @@ def fire_init( ) # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(-1, 1, 1) + cell_factor = cell_factor.view(n_batches, 1, 1) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) @@ -1121,7 +1121,7 @@ def fire_init( cell_positions = torch.zeros((n_batches, 3, 3), device=device, dtype=dtype) # Calculate virial for cell forces - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: @@ -1163,14 +1163,13 @@ def fire_init( atomic_numbers=state.atomic_numbers, batch=state.batch, pbc=state.pbc, - # New attributes - velocities=torch.zeros_like(state.positions), + velocities=None, forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=cell_positions, - cell_velocities=torch.zeros((n_batches, 3, 3), device=device, dtype=dtype), + cell_velocities=None, cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes @@ -1202,8 +1201,11 @@ def fire_init( return fire_init, functools.partial(step_func, **step_func_kwargs) +AnyFireCellState = UnitCellFireState | FrechetCellFIREState + + def _vv_fire_step( # noqa: C901, PLR0915 - state: FireState | UnitCellFireState | FrechetCellFIREState, + state: FireState | AnyFireCellState, model: torch.nn.Module, *, dt_max: torch.Tensor, @@ -1215,7 +1217,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 eps: float, is_cell_optimization: bool = False, is_frechet: bool = False, -) -> FireState | UnitCellFireState | FrechetCellFIREState: +) -> FireState | AnyFireCellState: """Perform one Velocity-Verlet based FIRE optimization step. Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm for @@ -1244,6 +1246,17 @@ def _vv_fire_step( # noqa: C901, PLR0915 dtype = state.positions.dtype deform_grad_new: torch.Tensor | None = None + if state.velocities is None: + state.velocities = torch.zeros_like(state.positions) + if is_cell_optimization: + if not isinstance(state, AnyFireCellState): + raise ValueError( + f"Cell optimization requires one of {get_args(AnyFireCellState)}." + ) + state.cell_velocities = torch.zeros( + (n_batches, 3, 3), device=device, dtype=dtype + ) + alpha_start_batch = torch.full( (n_batches,), alpha_start.item(), device=device, dtype=dtype ) @@ -1252,7 +1265,6 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) cell_wise_dt = state.dt.unsqueeze(-1).unsqueeze(-1) state.cell_velocities += ( 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) @@ -1261,7 +1273,6 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.positions = state.positions + atom_wise_dt * state.velocities if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) cell_factor_reshaped = state.cell_factor.view(n_batches, 1, 1) if is_frechet: assert isinstance(state, FrechetCellFIREState) @@ -1284,7 +1295,6 @@ def _vv_fire_step( # noqa: C901, PLR0915 else: assert isinstance(state, UnitCellFireState) cur_deform_grad = state.deform_grad() - # cell_factor is (N,1,1) cell_factor_expanded = state.cell_factor.expand(n_batches, 3, 1) current_cell_positions_scaled = ( cur_deform_grad.view(n_batches, 3, 3) * cell_factor_expanded @@ -1305,9 +1315,8 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.energy = results["energy"] if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.stress = results["stress"] - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: @@ -1351,63 +1360,62 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.cell_velocities += ( 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) ) - atomic_power = (state.forces * state.velocities).sum(dim=1) - atomic_power_per_batch = torch.zeros( - n_batches, device=device, dtype=atomic_power.dtype - ) - atomic_power_per_batch.scatter_add_(dim=0, index=state.batch, src=atomic_power) - batch_power = atomic_power_per_batch + batch_power = tsm.batched_vdot(state.forces, state.velocities, state.batch) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - cell_power = (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) - batch_power += cell_power - - for batch_idx in range(n_batches): - if batch_power[batch_idx] > 0: - state.n_pos[batch_idx] += 1 - if state.n_pos[batch_idx] > n_min: - state.dt[batch_idx] = torch.minimum(state.dt[batch_idx] * f_inc, dt_max) - state.alpha[batch_idx] = state.alpha[batch_idx] * f_alpha - else: - state.n_pos[batch_idx] = 0 - state.dt[batch_idx] = state.dt[batch_idx] * f_dec - state.alpha[batch_idx] = alpha_start_batch[batch_idx] - state.velocities[state.batch == batch_idx] = 0 - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - state.cell_velocities[batch_idx] = 0 - - v_norm = torch.norm(state.velocities, dim=1, keepdim=True) - f_norm = torch.norm(state.forces, dim=1, keepdim=True) - atom_wise_alpha = state.alpha[state.batch].unsqueeze(-1) - state.velocities = (1.0 - atom_wise_alpha) * state.velocities + ( - atom_wise_alpha * state.forces * v_norm / (f_norm + eps) - ) + batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + + # 2. Update dt, alpha, n_pos + pos_mask_batch = batch_power > 0.0 + neg_mask_batch = ~pos_mask_batch + + state.n_pos[pos_mask_batch] += 1 + inc_mask = (state.n_pos > n_min) & pos_mask_batch + state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) + state.alpha[inc_mask] *= f_alpha + + state.dt[neg_mask_batch] *= f_dec + state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] + state.n_pos[neg_mask_batch] = 0 + + v_scaling_batch = tsm.batched_vdot(state.velocities, state.velocities, state.batch) + f_scaling_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - cell_v_norm = torch.norm(state.cell_velocities, dim=(1, 2), keepdim=True) - cell_f_norm = torch.norm(state.cell_forces, dim=(1, 2), keepdim=True) - cell_wise_alpha = state.alpha.unsqueeze(-1).unsqueeze(-1) - cell_mask = (cell_f_norm > eps).expand_as(state.cell_velocities) + v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) + f_scaling_batch += state.cell_forces.pow(2).sum(dim=(1, 2)) + + v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) + f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) + v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell + + alpha_cell_bc = state.alpha.view(n_batches, 1, 1) state.cell_velocities = torch.where( - cell_mask, - (1.0 - cell_wise_alpha) * state.cell_velocities - + cell_wise_alpha * state.cell_forces * cell_v_norm / (cell_f_norm + eps), - state.cell_velocities, + pos_mask_batch.view(n_batches, 1, 1), + (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, + torch.zeros_like(state.cell_velocities), ) + v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) + v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) + + alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha + state.velocities = torch.where( + pos_mask_batch[state.batch].unsqueeze(-1), + (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, + torch.zeros_like(state.velocities), + ) + return state def _ase_fire_step( # noqa: C901, PLR0915 - state: FireState | UnitCellFireState | FrechetCellFIREState, + state: FireState | AnyFireCellState, model: torch.nn.Module, *, dt_max: torch.Tensor, @@ -1420,7 +1428,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 eps: float, is_cell_optimization: bool = False, is_frechet: bool = False, -) -> FireState | UnitCellFireState | FrechetCellFIREState: +) -> FireState | AnyFireCellState: """Perform one ASE-style FIRE optimization step. Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm @@ -1447,184 +1455,158 @@ def _ase_fire_step( # noqa: C901, PLR0915 device, dtype = state.positions.device, state.positions.dtype n_batches = state.n_batches - # Setup batch-wise alpha_start for potential reset - # alpha_start is a 0-dim tensor from the factory - alpha_start_batch = torch.full( - (n_batches,), alpha_start.item(), device=device, dtype=dtype - ) + if state.velocities is None: + state.velocities = torch.zeros_like(state.positions) + forces = state.forces + if is_cell_optimization: + if not isinstance(state, AnyFireCellState): + raise ValueError( + f"Cell optimization requires one of {get_args(AnyFireCellState)}." + ) + state.cell_velocities = torch.zeros( + (n_batches, 3, 3), device=device, dtype=dtype + ) + cur_deform_grad = state.deform_grad() + else: + alpha_start_batch = torch.full( + (n_batches,), alpha_start.item(), device=device, dtype=dtype + ) + + if is_cell_optimization: + cur_deform_grad = state.deform_grad() + forces = torch.bmm( + state.forces.unsqueeze(1), cur_deform_grad[state.batch] + ).squeeze(1) + else: + forces = state.forces - # 1. Current power (F·v) per batch (atoms + cell) - atomic_power = (state.forces * state.velocities).sum(dim=1) - batch_power = torch.zeros(n_batches, device=device, dtype=dtype) - batch_power.scatter_add_(0, state.batch, atomic_power) + # 1. Current power (F·v) per batch (atoms + cell) + batch_power = tsm.batched_vdot(forces, state.velocities, state.batch) - if is_cell_optimization: - valid_states = (UnitCellFireState, FrechetCellFIREState) - assert isinstance(state, valid_states), ( - f"Cell optimization requires one of {valid_states}." - ) - cell_power = (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) - batch_power += cell_power + if is_cell_optimization: + batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) - # 2. Update dt, alpha, n_pos - pos_mask_batch = batch_power > 0.0 - neg_mask_batch = ~pos_mask_batch + # 2. Update dt, alpha, n_pos + pos_mask_batch = batch_power > 0.0 + neg_mask_batch = ~pos_mask_batch - state.n_pos[pos_mask_batch] += 1 - inc_mask = (state.n_pos > n_min) & pos_mask_batch - state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) - state.alpha[inc_mask] *= f_alpha + inc_mask = (state.n_pos > n_min) & pos_mask_batch + state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) + state.alpha[inc_mask] *= f_alpha + state.n_pos[pos_mask_batch] += 1 - state.dt[neg_mask_batch] *= f_dec - state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] - state.n_pos[neg_mask_batch] = 0 + state.dt[neg_mask_batch] *= f_dec + state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] + state.n_pos[neg_mask_batch] = 0 - # 3. Velocity mixing BEFORE acceleration (ASE ordering) - # Atoms - v_norm_atom = torch.norm(state.velocities, dim=1, keepdim=True) - f_norm_atom = torch.norm(state.forces, dim=1, keepdim=True) - f_unit_atom = state.forces / (f_norm_atom + eps) - alpha_atom = state.alpha[state.batch].unsqueeze(-1) - pos_mask_atom = pos_mask_batch[state.batch].unsqueeze(-1) - v_new_atom = ( - 1.0 - alpha_atom - ) * state.velocities + alpha_atom * f_unit_atom * v_norm_atom - state.velocities = torch.where( - pos_mask_atom, v_new_atom, torch.zeros_like(state.velocities) - ) + # 3. Velocity mixing BEFORE acceleration (ASE ordering) + v_scaling_batch = tsm.batched_vdot( + state.velocities, state.velocities, state.batch + ) + f_scaling_batch = tsm.batched_vdot(forces, forces, state.batch) + + if is_cell_optimization: + v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) + f_scaling_batch += state.cell_forces.pow(2).sum(dim=(1, 2)) + + v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) + f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) + v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell + + alpha_cell_bc = state.alpha.view(n_batches, 1, 1) + state.cell_velocities = torch.where( + pos_mask_batch.view(n_batches, 1, 1), + (1.0 - alpha_cell_bc) * state.cell_velocities + + alpha_cell_bc * v_mixing_cell, + torch.zeros_like(state.cell_velocities), + ) - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - # Cell velocity mixing - cv_norm = torch.norm(state.cell_velocities, dim=(1, 2), keepdim=True) - cf_norm = torch.norm(state.cell_forces, dim=(1, 2), keepdim=True) - cf_unit = state.cell_forces / (cf_norm + eps) - alpha_cell_bc = state.alpha.view(-1, 1, 1) - pos_mask_cell_bc = pos_mask_batch.view(-1, 1, 1) - v_new_cell = ( - 1.0 - alpha_cell_bc - ) * state.cell_velocities + alpha_cell_bc * cf_unit * cv_norm - state.cell_velocities = torch.where( - pos_mask_cell_bc, v_new_cell, torch.zeros_like(state.cell_velocities) + v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) + v_mixing_atom = forces * (v_scaling_atom / (f_scaling_atom + eps)) + + alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha + state.velocities = torch.where( + pos_mask_batch[state.batch].unsqueeze(-1), + (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, + torch.zeros_like(state.velocities), ) # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) - atom_dt = state.dt[state.batch].unsqueeze(-1) - state.velocities += atom_dt * state.forces - - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - cell_dt = state.dt.view(-1, 1, 1) - state.cell_velocities += cell_dt * state.cell_forces + state.velocities += forces * state.dt[state.batch].unsqueeze(-1) + dr_atom = state.velocities * state.dt[state.batch].unsqueeze(-1) + dr_scaling_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) - # 5. Displacements - dr_atom = atom_dt * state.velocities if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - dr_cell = cell_dt * state.cell_velocities - - # 6. Clamp to max_step - dr_norm_atom = torch.norm(dr_atom, dim=1, keepdim=True) - mask_atom_max_step = dr_norm_atom > max_step - dr_atom = torch.where( - mask_atom_max_step, max_step * dr_atom / (dr_norm_atom + eps), dr_atom - ) + state.cell_velocities += state.cell_forces * state.dt.view(n_batches, 1, 1) + dr_cell = state.cell_velocities * state.dt.view(n_batches, 1, 1) - old_row_vector_cell: torch.Tensor | None = None - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - # Cell clamp to max_step (Frobenius norm) - dr_cell_norm_fro = torch.norm(dr_cell.view(n_batches, -1), dim=1, keepdim=True) - mask_cell_max_step = dr_cell_norm_fro.view(n_batches, 1, 1) > max_step + dr_scaling_batch += dr_cell.pow(2).sum(dim=(1, 2)) + dr_scaling_cell = torch.sqrt(dr_scaling_batch).view(n_batches, 1, 1) dr_cell = torch.where( - mask_cell_max_step, - max_step * dr_cell / (dr_cell_norm_fro.view(n_batches, 1, 1) + eps), + dr_scaling_cell > max_step, + max_step * dr_cell / (dr_scaling_cell + eps), dr_cell, ) - # 7. Position / cell update - # Store old cell for scaling atoms - # Ensure old_row_vector_cell is cloned before any modification to state.cell or - # state.row_vector_cell - old_row_vector_cell = state.row_vector_cell.clone() - - state.positions = state.positions + dr_atom + dr_scaling_atom = torch.sqrt(dr_scaling_batch)[state.batch].unsqueeze(-1) - # F_new stores F_new for Frechet's ucf_cell_grad if needed - F_new: torch.Tensor | None = None - # logm_F_new stores logm_F_new for Frechet's cell_forces recalc if needed - logm_F_new: torch.Tensor | None = None + dr_atom = torch.where( + dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom + ) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) + state.positions = ( + torch.linalg.solve( + cur_deform_grad[state.batch], state.positions.unsqueeze(-1) + ).squeeze(-1) + + dr_atom + ) + if is_frechet: assert isinstance(state, FrechetCellFIREState) - # Frechet cell update logic new_logm_F_scaled = state.cell_positions + dr_cell state.cell_positions = new_logm_F_scaled - # cell_factor is (N,1,1) logm_F_new = new_logm_F_scaled / (state.cell_factor + eps) F_new = torch.matrix_exp(logm_F_new) new_row_vector_cell = torch.bmm( state.reference_row_vector_cell, F_new.transpose(-2, -1) ) state.row_vector_cell = new_row_vector_cell - else: # UnitCellFire + else: assert isinstance(state, UnitCellFireState) - # Unit cell update logic F_current = state.deform_grad() - # state.cell_factor is (N,1,1), F_current is (N,3,3) - # cell_factor_exp for element-wise F_current * cell_factor_exp should be - # (N,3,3) or broadcast from (N,1,1) or (N,3,1) cell_factor_exp_mult = state.cell_factor.expand(n_batches, 3, 1) current_F_scaled = F_current * cell_factor_exp_mult F_new_scaled = current_F_scaled + dr_cell - state.cell_positions = F_new_scaled # track the scaled deformation gradient - F_new = F_new_scaled / (cell_factor_exp_mult + eps) # Division by (N,3,1) - # When state.cell is set, state.row_vector_cell is auto-updated - new_cell_column_vectors = torch.bmm( - state.reference_cell, F_new.transpose(-2, -1) + state.cell_positions = F_new_scaled + F_new = F_new_scaled / (cell_factor_exp_mult + eps) + new_row_vector_cell = torch.bmm( + state.reference_row_vector_cell, F_new.transpose(-2, -1) ) - state.cell = new_cell_column_vectors - - # Scale atomic positions according to cell change (mimicking scale_atoms=True) - if is_cell_optimization and old_row_vector_cell is not None: - current_new_row_vector_cell = state.row_vector_cell # This is A_new after update - - inv_old_cell_batch = torch.linalg.inv(old_row_vector_cell) - # Transform matrix T such that A_new = A_old @ T (for row vectors A) - # This means cartesian positions P_new_row = P_old_row @ T - transform_matrix_batch = torch.bmm( - inv_old_cell_batch, current_new_row_vector_cell - ) # Shape [N_batch, 3, 3] - - # Shape: [N_atoms, 3, 3] - atom_specific_transform = transform_matrix_batch[state.batch] - - # state.positions is [N_atoms, 3]. Unsqueeze to [N_atoms, 1, 3] for bmm - # Result of bmm will be [N_atoms, 1, 3], then squeeze - scaled_positions = torch.bmm( - state.positions.unsqueeze(1), atom_specific_transform + state.row_vector_cell = new_row_vector_cell + + state.positions = torch.bmm( + state.positions.unsqueeze(1), F_new[state.batch].transpose(-2, -1) ).squeeze(1) - state.positions = scaled_positions + else: + state.positions = state.positions + dr_atom - # 8. Force / stress refresh & new cell forces + # 7. Force / stress refresh & new cell forces results = model(state) state.forces = results["forces"] state.energy = results["energy"] if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.stress = results["stress"] - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) if torch.any(volumes <= 0): bad_indices = torch.where(volumes <= 0)[0].tolist() print( f"WARNING: Non-positive volume(s) detected during _ase_fire_step: " f"{volumes[bad_indices].tolist()} at {bad_indices=} ({is_frechet=})" ) - # volumes = torch.clamp(volumes, min=eps) # Optional: for stability virial = -volumes * (state.stress + state.pressure) @@ -1633,7 +1615,8 @@ def _ase_fire_step( # noqa: C901, PLR0915 virial = diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype ).unsqueeze(0).expand(n_batches, -1, -1) - if state.constant_volume: # Can be true even if hydrostatic_strain is false + + if state.constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype @@ -1647,7 +1630,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 assert logm_F_new is not None, ( "logm_F_new should be defined for Frechet cell force calculation" ) - # Frechet cell force recalculation ucf_cell_grad = torch.bmm( virial, torch.linalg.inv(torch.transpose(F_new, 1, 2)) ) @@ -1659,7 +1641,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 new_cell_forces_log_space = torch.zeros_like(state.cell_forces) for b_idx in range(n_batches): - # logm_F_new[b_idx] is the current point in log-space expm_derivs = torch.stack( [ tsm.expm_frechet(logm_F_new[b_idx], direction, compute_expm=False) @@ -1670,12 +1651,9 @@ def _ase_fire_step( # noqa: C901, PLR0915 expm_derivs * ucf_cell_grad[b_idx].unsqueeze(0), dim=(1, 2) ) new_cell_forces_log_space[b_idx] = forces_flat.reshape(3, 3) - state.cell_forces = new_cell_forces_log_space / ( - state.cell_factor + eps - ) # cell_factor is (N,1,1) - else: # UnitCellFire + state.cell_forces = new_cell_forces_log_space / (state.cell_factor + eps) + else: assert isinstance(state, UnitCellFireState) - # Unit cell force recalculation - state.cell_forces = virial / state.cell_factor # cell_factor is (N,1,1) + state.cell_forces = virial / state.cell_factor return state diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 737312b6..5925872b 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -16,6 +16,12 @@ from torch_sim.autobatching import BinningAutoBatcher, InFlightAutoBatcher from torch_sim.models.interface import ModelInterface +from torch_sim.optimizers import ( + FireState, + FrechetCellFIREState, + UnitCellFireState, + UnitCellGDState, +) from torch_sim.quantities import batchwise_max_force, calc_kinetic_energy, calc_kT from torch_sim.state import SimState, concatenate_states, initialize_state from torch_sim.trajectory import TrajectoryReporter @@ -278,12 +284,16 @@ def _chunked_apply( return concatenate_states(ordered_states) -def generate_force_convergence_fn(force_tol: float = 1e-1) -> Callable: +def generate_force_convergence_fn( + force_tol: float = 1e-1, *, include_cell_forces: bool = False +) -> Callable: """Generate a force-based convergence function for the convergence_fn argument of the optimize function. Args: force_tol (float): Force tolerance for convergence + include_cell_forces (bool): Whether to include the `cell_forces` in + the convergence check. Returns: Convergence function that takes a state and last energy and @@ -295,7 +305,16 @@ def convergence_fn( last_energy: torch.Tensor | None = None, # noqa: ARG001 ) -> bool: """Check if the system has converged.""" - return batchwise_max_force(state) < force_tol + force_conv = batchwise_max_force(state) < force_tol + + if include_cell_forces: + if (cell_forces := getattr(state, "cell_forces", None)) is None: + raise ValueError("cell_forces not found in state") + cell_forces_norm, _ = cell_forces.norm(dim=2).max(dim=1) + cell_force_conv = cell_forces_norm < force_tol + return force_conv & cell_force_conv + + return force_conv return convergence_fn @@ -378,13 +397,17 @@ def optimize( # noqa: C901 autobatcher = _configure_in_flight_autobatcher( model, state, autobatcher, max_attempts ) - state = _chunked_apply( - init_fn, - state, - model, - max_memory_scaler=autobatcher.max_memory_scaler, - memory_scales_with=autobatcher.memory_scales_with, - ) + + if not isinstance( + state, (FireState, UnitCellFireState, UnitCellGDState, FrechetCellFIREState) + ): + state = _chunked_apply( + init_fn, + state, + model, + max_memory_scaler=autobatcher.max_memory_scaler, + memory_scales_with=autobatcher.memory_scales_with, + ) autobatcher.load_states(state) trajectory_reporter = _configure_reporter( trajectory_reporter, @@ -514,8 +537,8 @@ class StaticState(type(state)): pbar_kwargs.setdefault("disable", None) tqdm_pbar = tqdm(total=state.n_batches, **pbar_kwargs) - for substate, batch_indices in batch_iterator: - print(substate.atomic_numbers) + for sub_state, batch_indices in batch_iterator: + print(sub_state.atomic_numbers) # set up trajectory reporters if autobatcher and trajectory_reporter and og_filenames is not None: # we must remake the trajectory reporter for each batch @@ -523,20 +546,20 @@ class StaticState(type(state)): filenames=[og_filenames[idx] for idx in batch_indices] ) - model_outputs = model(substate) + model_outputs = model(sub_state) - substate = StaticState( - **vars(substate), + sub_state = StaticState( + **vars(sub_state), energy=model_outputs["energy"], forces=model_outputs["forces"] if model.compute_forces else None, stress=model_outputs["stress"] if model.compute_stress else None, ) - props = trajectory_reporter.report(substate, 0, model=model) + props = trajectory_reporter.report(sub_state, 0, model=model) all_props.extend(props) if tqdm_pbar: - tqdm_pbar.update(substate.n_batches) + tqdm_pbar.update(sub_state.n_batches) trajectory_reporter.finish()