From 5356d0bb018aa9ede7622d5a787d109a184aeb17 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 22 May 2025 16:44:19 -0400 Subject: [PATCH 1/7] fea: use batched vdot --------- Co-authored-by: Janosh Riebesell --- torch_sim/math.py | 25 +++++++++++++++++++++ torch_sim/optimizers.py | 48 ++++++++++++++++++++++++++++++----------- 2 files changed, 61 insertions(+), 12 deletions(-) diff --git a/torch_sim/math.py b/torch_sim/math.py index fd9182d2..0adfbe0a 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -987,3 +987,28 @@ def matrix_log_33( print(msg) # Fall back to scipy implementation return matrix_log_scipy(matrix).to(sim_dtype) + + +def batched_vdot( + x: torch.Tensor, y: torch.Tensor, batch_indices: torch.Tensor +) -> torch.Tensor: + """Computes batched vdot (sum of element-wise product) for groups of vectors. + If is_sum_sq is True, computes sum of x_i * x_i (squared norm components). + + Args: + x: Tensor of shape [N_total_entities, D] (e.g., forces, velocities). + y: Tensor of shape [N_total_entities, D]. Ignored if is_sum_sq is True. + batch_indices: Tensor of shape [N_total_entities] indicating batch membership. + + Returns: + Tensor: shape [n_batches] where each element is the sum(x_i * y_i) + (or sum(x_i * x_i) if is_sum_sq) for entities belonging to that batch, + summed over all components D and all entities in the batch. + """ + if x.ndim != 2 or batch_indices.ndim != 1 or x.shape[0] != batch_indices.shape[0]: + raise ValueError(f"Invalid input shapes: {x.shape=}, {batch_indices.shape=}") + + output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) + output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1)) + + return output diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 0d94b4a5..1a58b4f6 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1481,14 +1481,28 @@ def _ase_fire_step( # noqa: C901, PLR0915 # 3. Velocity mixing BEFORE acceleration (ASE ordering) # Atoms - v_norm_atom = torch.norm(state.velocities, dim=1, keepdim=True) - f_norm_atom = torch.norm(state.forces, dim=1, keepdim=True) - f_unit_atom = state.forces / (f_norm_atom + eps) - alpha_atom = state.alpha[state.batch].unsqueeze(-1) - pos_mask_atom = pos_mask_batch[state.batch].unsqueeze(-1) - v_new_atom = ( - 1.0 - alpha_atom - ) * state.velocities + alpha_atom * f_unit_atom * v_norm_atom + # print(f"{state.velocities.shape=}") + v_sum_sq_batch = tsm.batched_vdot(state.velocities, state.velocities, state.batch) + # sum_sq per batch, shape [n_batches] + f_sum_sq_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) + # sum_sq per batch, shape [n_batches] + + # Expand to per-atom for applying to vectors + # These are sqrt(sum ||v_i||^2)_batch and sqrt(sum ||f_i||^2)_batch + # Effectively |V|_batch and |F|_batch for the mixing formula + sqrt_v_sum_sq_batch_expanded = torch.sqrt(v_sum_sq_batch[state.batch].unsqueeze(-1)) + sqrt_f_sum_sq_batch_expanded = torch.sqrt(f_sum_sq_batch[state.batch].unsqueeze(-1)) + + alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha + pos_mask_atom = pos_mask_batch[state.batch].unsqueeze(-1) # per-atom mask + + # ASE formula: v_new = (1-a)*v + a * (f / |F|_batch) * |V|_batch + # = (1-a)*v + a * f * (|V|_batch / |F|_batch) + mixing_term_atom = state.forces * ( + sqrt_v_sum_sq_batch_expanded / (sqrt_f_sum_sq_batch_expanded + eps) + ) + + v_new_atom = (1.0 - alpha_atom) * state.velocities + alpha_atom * mixing_term_atom state.velocities = torch.where( pos_mask_atom, v_new_atom, torch.zeros_like(state.velocities) ) @@ -1524,12 +1538,22 @@ def _ase_fire_step( # noqa: C901, PLR0915 dr_cell = cell_dt * state.cell_velocities # 6. Clamp to max_step - dr_norm_atom = torch.norm(dr_atom, dim=1, keepdim=True) - mask_atom_max_step = dr_norm_atom > max_step - dr_atom = torch.where( - mask_atom_max_step, max_step * dr_atom / (dr_norm_atom + eps), dr_atom + dr_atom_sum_sq_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) + norm_dr_atom_per_batch = torch.sqrt(dr_atom_sum_sq_batch) # shape [n_batches] + + mask_clamp_batch = norm_dr_atom_per_batch > max_step # shape [n_batches] + + scaling_factor_batch = torch.ones_like(norm_dr_atom_per_batch) + safe_norm_for_clamped_batches = norm_dr_atom_per_batch[mask_clamp_batch] + scaling_factor_batch[mask_clamp_batch] = max_step / ( + safe_norm_for_clamped_batches + eps ) + # shape [N_atoms, 1] + atom_wise_scaling_factor = scaling_factor_batch[state.batch].unsqueeze(-1) + + dr_atom = dr_atom * atom_wise_scaling_factor + old_row_vector_cell: torch.Tensor | None = None if is_cell_optimization: assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) From fa0830aaef516fdbc7ddc383116272386de751c0 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 22 May 2025 20:27:51 -0400 Subject: [PATCH 2/7] clean: remove ai slop --- torch_sim/optimizers.py | 221 ++++++++++++++++------------------------ 1 file changed, 90 insertions(+), 131 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 1a58b4f6..a2fa19e6 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1447,85 +1447,81 @@ def _ase_fire_step( # noqa: C901, PLR0915 device, dtype = state.positions.device, state.positions.dtype n_batches = state.n_batches - # Setup batch-wise alpha_start for potential reset - # alpha_start is a 0-dim tensor from the factory - alpha_start_batch = torch.full( - (n_batches,), alpha_start.item(), device=device, dtype=dtype - ) + if state.velocities is None: + state.velocities = torch.zeros_like(state.positions) + if is_cell_optimization: + state.cell_velocities = torch.zeros( + (n_batches, 3, 3), device=device, dtype=dtype + ) + else: + alpha_start_batch = torch.full( + (n_batches,), alpha_start.item(), device=device, dtype=dtype + ) - # 1. Current power (F·v) per batch (atoms + cell) - atomic_power = (state.forces * state.velocities).sum(dim=1) - batch_power = torch.zeros(n_batches, device=device, dtype=dtype) - batch_power.scatter_add_(0, state.batch, atomic_power) + # 1. Current power (F·v) per batch (atoms + cell) + batch_power = tsm.batched_vdot(state.forces, state.velocities, state.batch) - if is_cell_optimization: - valid_states = (UnitCellFireState, FrechetCellFIREState) - assert isinstance(state, valid_states), ( - f"Cell optimization requires one of {valid_states}." + if is_cell_optimization: + valid_states = (UnitCellFireState, FrechetCellFIREState) + assert isinstance(state, valid_states), ( + f"Cell optimization requires one of {valid_states}." + ) + batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + + # 2. Update dt, alpha, n_pos + pos_mask_batch = batch_power > 0.0 + neg_mask_batch = ~pos_mask_batch + + state.n_pos[pos_mask_batch] += 1 + inc_mask = (state.n_pos > n_min) & pos_mask_batch + state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) + state.alpha[inc_mask] *= f_alpha + + state.dt[neg_mask_batch] *= f_dec + state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] + state.n_pos[neg_mask_batch] = 0 + + # 3. Velocity mixing BEFORE acceleration (ASE ordering) + v_scaling_batch = tsm.batched_vdot( + state.velocities, state.velocities, state.batch ) - cell_power = (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) - batch_power += cell_power + f_scaling_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) - # 2. Update dt, alpha, n_pos - pos_mask_batch = batch_power > 0.0 - neg_mask_batch = ~pos_mask_batch - - state.n_pos[pos_mask_batch] += 1 - inc_mask = (state.n_pos > n_min) & pos_mask_batch - state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) - state.alpha[inc_mask] *= f_alpha - - state.dt[neg_mask_batch] *= f_dec - state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] - state.n_pos[neg_mask_batch] = 0 - - # 3. Velocity mixing BEFORE acceleration (ASE ordering) - # Atoms - # print(f"{state.velocities.shape=}") - v_sum_sq_batch = tsm.batched_vdot(state.velocities, state.velocities, state.batch) - # sum_sq per batch, shape [n_batches] - f_sum_sq_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) - # sum_sq per batch, shape [n_batches] - - # Expand to per-atom for applying to vectors - # These are sqrt(sum ||v_i||^2)_batch and sqrt(sum ||f_i||^2)_batch - # Effectively |V|_batch and |F|_batch for the mixing formula - sqrt_v_sum_sq_batch_expanded = torch.sqrt(v_sum_sq_batch[state.batch].unsqueeze(-1)) - sqrt_f_sum_sq_batch_expanded = torch.sqrt(f_sum_sq_batch[state.batch].unsqueeze(-1)) - - alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha - pos_mask_atom = pos_mask_batch[state.batch].unsqueeze(-1) # per-atom mask - - # ASE formula: v_new = (1-a)*v + a * (f / |F|_batch) * |V|_batch - # = (1-a)*v + a * f * (|V|_batch / |F|_batch) - mixing_term_atom = state.forces * ( - sqrt_v_sum_sq_batch_expanded / (sqrt_f_sum_sq_batch_expanded + eps) - ) + if is_cell_optimization: + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) + v_scaling_batch += ( + state.cell_velocities.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) + ) + f_scaling_batch += ( + state.cell_forces.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) + ) - v_new_atom = (1.0 - alpha_atom) * state.velocities + alpha_atom * mixing_term_atom - state.velocities = torch.where( - pos_mask_atom, v_new_atom, torch.zeros_like(state.velocities) - ) + v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) + f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) + v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - # Cell velocity mixing - cv_norm = torch.norm(state.cell_velocities, dim=(1, 2), keepdim=True) - cf_norm = torch.norm(state.cell_forces, dim=(1, 2), keepdim=True) - cf_unit = state.cell_forces / (cf_norm + eps) - alpha_cell_bc = state.alpha.view(-1, 1, 1) - pos_mask_cell_bc = pos_mask_batch.view(-1, 1, 1) - v_new_cell = ( - 1.0 - alpha_cell_bc - ) * state.cell_velocities + alpha_cell_bc * cf_unit * cv_norm - state.cell_velocities = torch.where( - pos_mask_cell_bc, v_new_cell, torch.zeros_like(state.cell_velocities) + alpha_cell_bc = state.alpha.view(-1, 1, 1) + state.cell_velocities = torch.where( + pos_mask_batch.view(-1, 1, 1), + (1.0 - alpha_cell_bc) * state.cell_velocities + + alpha_cell_bc * v_mixing_cell, + torch.zeros_like(state.cell_velocities), + ) + + v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) + v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) + + alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha + state.velocities = torch.where( + pos_mask_batch[state.batch].unsqueeze(-1), + (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, + torch.zeros_like(state.velocities), ) # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) atom_dt = state.dt[state.batch].unsqueeze(-1) state.velocities += atom_dt * state.forces - if is_cell_optimization: assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) cell_dt = state.dt.view(-1, 1, 1) @@ -1537,103 +1533,71 @@ def _ase_fire_step( # noqa: C901, PLR0915 assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) dr_cell = cell_dt * state.cell_velocities - # 6. Clamp to max_step - dr_atom_sum_sq_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) - norm_dr_atom_per_batch = torch.sqrt(dr_atom_sum_sq_batch) # shape [n_batches] - - mask_clamp_batch = norm_dr_atom_per_batch > max_step # shape [n_batches] - - scaling_factor_batch = torch.ones_like(norm_dr_atom_per_batch) - safe_norm_for_clamped_batches = norm_dr_atom_per_batch[mask_clamp_batch] - scaling_factor_batch[mask_clamp_batch] = max_step / ( - safe_norm_for_clamped_batches + eps - ) - - # shape [N_atoms, 1] - atom_wise_scaling_factor = scaling_factor_batch[state.batch].unsqueeze(-1) - - dr_atom = dr_atom * atom_wise_scaling_factor + # 6. Position / cell update + dr_scaling_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) old_row_vector_cell: torch.Tensor | None = None if is_cell_optimization: assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - # Cell clamp to max_step (Frobenius norm) - dr_cell_norm_fro = torch.norm(dr_cell.view(n_batches, -1), dim=1, keepdim=True) - mask_cell_max_step = dr_cell_norm_fro.view(n_batches, 1, 1) > max_step + dr_scaling_batch += dr_cell.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) + dr_scaling_cell = torch.sqrt(dr_scaling_batch.view(n_batches, 1, 1)) dr_cell = torch.where( - mask_cell_max_step, - max_step * dr_cell / (dr_cell_norm_fro.view(n_batches, 1, 1) + eps), + dr_scaling_cell > max_step, + max_step * dr_cell / (dr_scaling_cell + eps), dr_cell, ) - # 7. Position / cell update - # Store old cell for scaling atoms - # Ensure old_row_vector_cell is cloned before any modification to state.cell or - # state.row_vector_cell + # save the old cell to allow rescaling of the positions after cell update old_row_vector_cell = state.row_vector_cell.clone() + dr_scaling_atom = torch.sqrt(dr_scaling_batch[state.batch]) + dr_atom = torch.where( + dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom + ) state.positions = state.positions + dr_atom - # F_new stores F_new for Frechet's ucf_cell_grad if needed - F_new: torch.Tensor | None = None - # logm_F_new stores logm_F_new for Frechet's cell_forces recalc if needed - logm_F_new: torch.Tensor | None = None - if is_cell_optimization: + F_new: torch.Tensor | None = None + logm_F_new: torch.Tensor | None = None + assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) if is_frechet: assert isinstance(state, FrechetCellFIREState) - # Frechet cell update logic new_logm_F_scaled = state.cell_positions + dr_cell state.cell_positions = new_logm_F_scaled - # cell_factor is (N,1,1) logm_F_new = new_logm_F_scaled / (state.cell_factor + eps) F_new = torch.matrix_exp(logm_F_new) new_row_vector_cell = torch.bmm( state.reference_row_vector_cell, F_new.transpose(-2, -1) ) state.row_vector_cell = new_row_vector_cell - else: # UnitCellFire + else: assert isinstance(state, UnitCellFireState) - # Unit cell update logic F_current = state.deform_grad() - # state.cell_factor is (N,1,1), F_current is (N,3,3) - # cell_factor_exp for element-wise F_current * cell_factor_exp should be - # (N,3,3) or broadcast from (N,1,1) or (N,3,1) cell_factor_exp_mult = state.cell_factor.expand(n_batches, 3, 1) current_F_scaled = F_current * cell_factor_exp_mult F_new_scaled = current_F_scaled + dr_cell - state.cell_positions = F_new_scaled # track the scaled deformation gradient - F_new = F_new_scaled / (cell_factor_exp_mult + eps) # Division by (N,3,1) - # When state.cell is set, state.row_vector_cell is auto-updated + state.cell_positions = F_new_scaled + F_new = F_new_scaled / (cell_factor_exp_mult + eps) new_cell_column_vectors = torch.bmm( state.reference_cell, F_new.transpose(-2, -1) ) state.cell = new_cell_column_vectors - # Scale atomic positions according to cell change (mimicking scale_atoms=True) - if is_cell_optimization and old_row_vector_cell is not None: - current_new_row_vector_cell = state.row_vector_cell # This is A_new after update - + # rescale the positions after cell update + current_new_row_vector_cell = state.row_vector_cell inv_old_cell_batch = torch.linalg.inv(old_row_vector_cell) - # Transform matrix T such that A_new = A_old @ T (for row vectors A) - # This means cartesian positions P_new_row = P_old_row @ T transform_matrix_batch = torch.bmm( inv_old_cell_batch, current_new_row_vector_cell - ) # Shape [N_batch, 3, 3] - - # Shape: [N_atoms, 3, 3] + ) atom_specific_transform = transform_matrix_batch[state.batch] - - # state.positions is [N_atoms, 3]. Unsqueeze to [N_atoms, 1, 3] for bmm - # Result of bmm will be [N_atoms, 1, 3], then squeeze scaled_positions = torch.bmm( state.positions.unsqueeze(1), atom_specific_transform ).squeeze(1) state.positions = scaled_positions - # 8. Force / stress refresh & new cell forces + # 7. Force / stress refresh & new cell forces results = model(state) state.forces = results["forces"] state.energy = results["energy"] @@ -1648,7 +1612,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 f"WARNING: Non-positive volume(s) detected during _ase_fire_step: " f"{volumes[bad_indices].tolist()} at {bad_indices=} ({is_frechet=})" ) - # volumes = torch.clamp(volumes, min=eps) # Optional: for stability virial = -volumes * (state.stress + state.pressure) @@ -1657,7 +1620,8 @@ def _ase_fire_step( # noqa: C901, PLR0915 virial = diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype ).unsqueeze(0).expand(n_batches, -1, -1) - if state.constant_volume: # Can be true even if hydrostatic_strain is false + + if state.constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype @@ -1671,7 +1635,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 assert logm_F_new is not None, ( "logm_F_new should be defined for Frechet cell force calculation" ) - # Frechet cell force recalculation ucf_cell_grad = torch.bmm( virial, torch.linalg.inv(torch.transpose(F_new, 1, 2)) ) @@ -1683,7 +1646,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 new_cell_forces_log_space = torch.zeros_like(state.cell_forces) for b_idx in range(n_batches): - # logm_F_new[b_idx] is the current point in log-space expm_derivs = torch.stack( [ tsm.expm_frechet(logm_F_new[b_idx], direction, compute_expm=False) @@ -1694,12 +1656,9 @@ def _ase_fire_step( # noqa: C901, PLR0915 expm_derivs * ucf_cell_grad[b_idx].unsqueeze(0), dim=(1, 2) ) new_cell_forces_log_space[b_idx] = forces_flat.reshape(3, 3) - state.cell_forces = new_cell_forces_log_space / ( - state.cell_factor + eps - ) # cell_factor is (N,1,1) - else: # UnitCellFire + state.cell_forces = new_cell_forces_log_space / (state.cell_factor + eps) + else: assert isinstance(state, UnitCellFireState) - # Unit cell force recalculation - state.cell_forces = virial / state.cell_factor # cell_factor is (N,1,1) + state.cell_forces = virial / state.cell_factor return state From 53f6839bad850f6ce08da8c14978f70b77ca1daf Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 22 May 2025 21:59:06 -0400 Subject: [PATCH 3/7] clean: further attempts to clean but still not matching PR --- torch_sim/optimizers.py | 41 ++++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index a2fa19e6..1bb3f1bb 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1164,13 +1164,13 @@ def fire_init( batch=state.batch, pbc=state.pbc, # New attributes - velocities=torch.zeros_like(state.positions), + velocities=None, forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=cell_positions, - cell_velocities=torch.zeros((n_batches, 3, 3), device=device, dtype=dtype), + cell_velocities=None, cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes @@ -1449,6 +1449,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 if state.velocities is None: state.velocities = torch.zeros_like(state.positions) + forces = state.forces if is_cell_optimization: state.cell_velocities = torch.zeros( (n_batches, 3, 3), device=device, dtype=dtype @@ -1458,8 +1459,15 @@ def _ase_fire_step( # noqa: C901, PLR0915 (n_batches,), alpha_start.item(), device=device, dtype=dtype ) + if is_cell_optimization: + forces = torch.bmm( + state.forces.unsqueeze(1), state.deform_grad()[state.batch] + ).squeeze(1) + else: + forces = state.forces + # 1. Current power (F·v) per batch (atoms + cell) - batch_power = tsm.batched_vdot(state.forces, state.velocities, state.batch) + batch_power = tsm.batched_vdot(forces, state.velocities, state.batch) if is_cell_optimization: valid_states = (UnitCellFireState, FrechetCellFIREState) @@ -1520,27 +1528,17 @@ def _ase_fire_step( # noqa: C901, PLR0915 ) # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) - atom_dt = state.dt[state.batch].unsqueeze(-1) - state.velocities += atom_dt * state.forces - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - cell_dt = state.dt.view(-1, 1, 1) - state.cell_velocities += cell_dt * state.cell_forces - - # 5. Displacements - dr_atom = atom_dt * state.velocities - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - dr_cell = cell_dt * state.cell_velocities - - # 6. Position / cell update + state.velocities += state.forces * state.dt[state.batch].unsqueeze(-1) + dr_atom = state.forces * state.dt[state.batch].unsqueeze(-1) dr_scaling_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) - old_row_vector_cell: torch.Tensor | None = None if is_cell_optimization: assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) + state.cell_velocities += state.cell_forces * state.dt.view(-1, 1, 1) + dr_cell = state.cell_velocities * state.dt.view(-1, 1, 1) + dr_scaling_batch += dr_cell.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) - dr_scaling_cell = torch.sqrt(dr_scaling_batch.view(n_batches, 1, 1)) + dr_scaling_cell = torch.sqrt(dr_scaling_batch).view(n_batches, 1, 1) dr_cell = torch.where( dr_scaling_cell > max_step, max_step * dr_cell / (dr_scaling_cell + eps), @@ -1550,16 +1548,13 @@ def _ase_fire_step( # noqa: C901, PLR0915 # save the old cell to allow rescaling of the positions after cell update old_row_vector_cell = state.row_vector_cell.clone() - dr_scaling_atom = torch.sqrt(dr_scaling_batch[state.batch]) + dr_scaling_atom = torch.sqrt(dr_scaling_batch)[state.batch] dr_atom = torch.where( dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom ) state.positions = state.positions + dr_atom if is_cell_optimization: - F_new: torch.Tensor | None = None - logm_F_new: torch.Tensor | None = None - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) if is_frechet: assert isinstance(state, FrechetCellFIREState) From 75ca9a55e47a78d3123407e25ea1d65eae2caa7a Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 23 May 2025 10:45:42 -0400 Subject: [PATCH 4/7] fix: dr is vdt rather than fdt --- torch_sim/optimizers.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 1bb3f1bb..22d10ecd 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1406,8 +1406,11 @@ def _vv_fire_step( # noqa: C901, PLR0915 return state +VALID_FIRE_CELL_STATES = (UnitCellFireState, FrechetCellFIREState) + + def _ase_fire_step( # noqa: C901, PLR0915 - state: FireState | UnitCellFireState | FrechetCellFIREState, + state: FireState | VALID_FIRE_CELL_STATES, model: torch.nn.Module, *, dt_max: torch.Tensor, @@ -1451,6 +1454,10 @@ def _ase_fire_step( # noqa: C901, PLR0915 state.velocities = torch.zeros_like(state.positions) forces = state.forces if is_cell_optimization: + if not isinstance(state, VALID_FIRE_CELL_STATES): + raise ValueError( + "Cell optimization requires one of {VALID_FIRE_CELL_STATES}." + ) state.cell_velocities = torch.zeros( (n_batches, 3, 3), device=device, dtype=dtype ) @@ -1470,10 +1477,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 batch_power = tsm.batched_vdot(forces, state.velocities, state.batch) if is_cell_optimization: - valid_states = (UnitCellFireState, FrechetCellFIREState) - assert isinstance(state, valid_states), ( - f"Cell optimization requires one of {valid_states}." - ) batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) # 2. Update dt, alpha, n_pos @@ -1496,7 +1499,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 f_scaling_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) v_scaling_batch += ( state.cell_velocities.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) ) @@ -1529,11 +1531,10 @@ def _ase_fire_step( # noqa: C901, PLR0915 # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) state.velocities += state.forces * state.dt[state.batch].unsqueeze(-1) - dr_atom = state.forces * state.dt[state.batch].unsqueeze(-1) + dr_atom = state.velocities * state.dt[state.batch].unsqueeze(-1) dr_scaling_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.cell_velocities += state.cell_forces * state.dt.view(-1, 1, 1) dr_cell = state.cell_velocities * state.dt.view(-1, 1, 1) @@ -1555,7 +1556,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 state.positions = state.positions + dr_atom if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) if is_frechet: assert isinstance(state, FrechetCellFIREState) new_logm_F_scaled = state.cell_positions + dr_cell @@ -1598,7 +1598,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 state.energy = results["energy"] if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.stress = results["stress"] volumes = torch.linalg.det(state.cell).view(-1, 1, 1) if torch.any(volumes <= 0): From 971a4f5737100da6adcca92618324110f579acfd Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 23 May 2025 10:57:10 -0400 Subject: [PATCH 5/7] typing: fix typing issue --- torch_sim/optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 22d10ecd..c24ea85b 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1410,7 +1410,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 def _ase_fire_step( # noqa: C901, PLR0915 - state: FireState | VALID_FIRE_CELL_STATES, + state: FireState | UnitCellFireState | FrechetCellFIREState, model: torch.nn.Module, *, dt_max: torch.Tensor, From cc94facb1c164c8def7823a464cc7711fa956d97 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 23 May 2025 13:45:30 -0400 Subject: [PATCH 6/7] wip: still not sure where the difference is now --- torch_sim/math.py | 2 +- torch_sim/optimizers.py | 150 +++++++++++++++++++++------------------- 2 files changed, 78 insertions(+), 74 deletions(-) diff --git a/torch_sim/math.py b/torch_sim/math.py index 0adfbe0a..ff78757d 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -997,7 +997,7 @@ def batched_vdot( Args: x: Tensor of shape [N_total_entities, D] (e.g., forces, velocities). - y: Tensor of shape [N_total_entities, D]. Ignored if is_sum_sq is True. + y: Tensor of shape [N_total_entities, D]. batch_indices: Tensor of shape [N_total_entities] indicating batch membership. Returns: diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index c24ea85b..bdfef729 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -270,6 +270,8 @@ def gd_init( if not isinstance(state, SimState): state = SimState(**state) + n_batches = state.n_batches + # Setup cell_factor if cell_factor is None: # Count atoms per batch @@ -283,7 +285,7 @@ def gd_init( ) # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(-1, 1, 1) + cell_factor = cell_factor.view(n_batches, 1, 1) scalar_pressure = torch.full( (state.n_batches, 1, 1), scalar_pressure, device=device, dtype=dtype @@ -316,7 +318,7 @@ def gd_init( ) # shape: (n_batches, 3, 3) # Calculate virial - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (stress + pressure) if hydrostatic_strain: @@ -391,7 +393,7 @@ def gd_step( # Get per-atom and per-cell learning rates atom_wise_lr = positions_lr[state.batch].unsqueeze(-1) - cell_wise_lr = cell_lr.view(-1, 1, 1) # shape: (n_batches, 1, 1) + cell_wise_lr = cell_lr.view(n_batches, 1, 1) # shape: (n_batches, 1, 1) # Update atomic and cell positions atomic_positions_new = state.positions + atom_wise_lr * state.forces @@ -415,7 +417,7 @@ def gd_step( state.stress = model_output["stress"] # Calculate virial for cell forces - volumes = torch.linalg.det(new_row_vector_cell).view(-1, 1, 1) + volumes = torch.linalg.det(new_row_vector_cell).view(n_batches, 1, 1) virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) @@ -811,7 +813,7 @@ def fire_init( ) # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(-1, 1, 1) + cell_factor = cell_factor.view(n_batches, 1, 1) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) @@ -824,7 +826,7 @@ def fire_init( forces = model_output["forces"] # [n_total_atoms, 3] stress = model_output["stress"] # [n_batches, 3, 3] - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: @@ -1097,7 +1099,7 @@ def fire_init( ) # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(-1, 1, 1) + cell_factor = cell_factor.view(n_batches, 1, 1) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) @@ -1121,7 +1123,7 @@ def fire_init( cell_positions = torch.zeros((n_batches, 3, 3), device=device, dtype=dtype) # Calculate virial for cell forces - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: @@ -1202,8 +1204,11 @@ def fire_init( return fire_init, functools.partial(step_func, **step_func_kwargs) +VALID_FIRE_CELL_STATES = UnitCellFireState | FrechetCellFIREState + + def _vv_fire_step( # noqa: C901, PLR0915 - state: FireState | UnitCellFireState | FrechetCellFIREState, + state: FireState | VALID_FIRE_CELL_STATES, model: torch.nn.Module, *, dt_max: torch.Tensor, @@ -1215,7 +1220,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 eps: float, is_cell_optimization: bool = False, is_frechet: bool = False, -) -> FireState | UnitCellFireState | FrechetCellFIREState: +) -> FireState | VALID_FIRE_CELL_STATES: """Perform one Velocity-Verlet based FIRE optimization step. Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm for @@ -1244,6 +1249,17 @@ def _vv_fire_step( # noqa: C901, PLR0915 dtype = state.positions.dtype deform_grad_new: torch.Tensor | None = None + if state.velocities is None: + state.velocities = torch.zeros_like(state.positions) + if is_cell_optimization: + if not isinstance(state, VALID_FIRE_CELL_STATES): + raise ValueError( + "Cell optimization requires one of {VALID_FIRE_CELL_STATES}." + ) + state.cell_velocities = torch.zeros( + (n_batches, 3, 3), device=device, dtype=dtype + ) + alpha_start_batch = torch.full( (n_batches,), alpha_start.item(), device=device, dtype=dtype ) @@ -1252,7 +1268,6 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) cell_wise_dt = state.dt.unsqueeze(-1).unsqueeze(-1) state.cell_velocities += ( 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) @@ -1261,7 +1276,6 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.positions = state.positions + atom_wise_dt * state.velocities if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) cell_factor_reshaped = state.cell_factor.view(n_batches, 1, 1) if is_frechet: assert isinstance(state, FrechetCellFIREState) @@ -1284,7 +1298,6 @@ def _vv_fire_step( # noqa: C901, PLR0915 else: assert isinstance(state, UnitCellFireState) cur_deform_grad = state.deform_grad() - # cell_factor is (N,1,1) cell_factor_expanded = state.cell_factor.expand(n_batches, 3, 1) current_cell_positions_scaled = ( cur_deform_grad.view(n_batches, 3, 3) * cell_factor_expanded @@ -1305,9 +1318,8 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.energy = results["energy"] if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.stress = results["stress"] - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: @@ -1351,66 +1363,62 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.cell_velocities += ( 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) ) - atomic_power = (state.forces * state.velocities).sum(dim=1) - atomic_power_per_batch = torch.zeros( - n_batches, device=device, dtype=atomic_power.dtype - ) - atomic_power_per_batch.scatter_add_(dim=0, index=state.batch, src=atomic_power) - batch_power = atomic_power_per_batch + batch_power = tsm.batched_vdot(state.forces, state.velocities, state.batch) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - cell_power = (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) - batch_power += cell_power - - for batch_idx in range(n_batches): - if batch_power[batch_idx] > 0: - state.n_pos[batch_idx] += 1 - if state.n_pos[batch_idx] > n_min: - state.dt[batch_idx] = torch.minimum(state.dt[batch_idx] * f_inc, dt_max) - state.alpha[batch_idx] = state.alpha[batch_idx] * f_alpha - else: - state.n_pos[batch_idx] = 0 - state.dt[batch_idx] = state.dt[batch_idx] * f_dec - state.alpha[batch_idx] = alpha_start_batch[batch_idx] - state.velocities[state.batch == batch_idx] = 0 - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - state.cell_velocities[batch_idx] = 0 - - v_norm = torch.norm(state.velocities, dim=1, keepdim=True) - f_norm = torch.norm(state.forces, dim=1, keepdim=True) - atom_wise_alpha = state.alpha[state.batch].unsqueeze(-1) - state.velocities = (1.0 - atom_wise_alpha) * state.velocities + ( - atom_wise_alpha * state.forces * v_norm / (f_norm + eps) - ) + batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + + # 2. Update dt, alpha, n_pos + pos_mask_batch = batch_power > 0.0 + neg_mask_batch = ~pos_mask_batch + + state.n_pos[pos_mask_batch] += 1 + inc_mask = (state.n_pos > n_min) & pos_mask_batch + state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) + state.alpha[inc_mask] *= f_alpha + + state.dt[neg_mask_batch] *= f_dec + state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] + state.n_pos[neg_mask_batch] = 0 + + v_scaling_batch = tsm.batched_vdot(state.velocities, state.velocities, state.batch) + f_scaling_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - cell_v_norm = torch.norm(state.cell_velocities, dim=(1, 2), keepdim=True) - cell_f_norm = torch.norm(state.cell_forces, dim=(1, 2), keepdim=True) - cell_wise_alpha = state.alpha.unsqueeze(-1).unsqueeze(-1) - cell_mask = (cell_f_norm > eps).expand_as(state.cell_velocities) + v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) + f_scaling_batch += state.cell_forces.pow(2).sum(dim=(1, 2)) + + v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) + f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) + v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell + + alpha_cell_bc = state.alpha.view(n_batches, 1, 1) state.cell_velocities = torch.where( - cell_mask, - (1.0 - cell_wise_alpha) * state.cell_velocities - + cell_wise_alpha * state.cell_forces * cell_v_norm / (cell_f_norm + eps), - state.cell_velocities, + pos_mask_batch.view(n_batches, 1, 1), + (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, + torch.zeros_like(state.cell_velocities), ) - return state + v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) + v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) + alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha + state.velocities = torch.where( + pos_mask_batch[state.batch].unsqueeze(-1), + (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, + torch.zeros_like(state.velocities), + ) -VALID_FIRE_CELL_STATES = (UnitCellFireState, FrechetCellFIREState) + return state def _ase_fire_step( # noqa: C901, PLR0915 - state: FireState | UnitCellFireState | FrechetCellFIREState, + state: FireState | VALID_FIRE_CELL_STATES, model: torch.nn.Module, *, dt_max: torch.Tensor, @@ -1423,7 +1431,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 eps: float, is_cell_optimization: bool = False, is_frechet: bool = False, -) -> FireState | UnitCellFireState | FrechetCellFIREState: +) -> FireState | VALID_FIRE_CELL_STATES: """Perform one ASE-style FIRE optimization step. Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm @@ -1499,20 +1507,16 @@ def _ase_fire_step( # noqa: C901, PLR0915 f_scaling_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) if is_cell_optimization: - v_scaling_batch += ( - state.cell_velocities.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) - ) - f_scaling_batch += ( - state.cell_forces.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) - ) + v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) + f_scaling_batch += state.cell_forces.pow(2).sum(dim=(1, 2)) v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell - alpha_cell_bc = state.alpha.view(-1, 1, 1) + alpha_cell_bc = state.alpha.view(n_batches, 1, 1) state.cell_velocities = torch.where( - pos_mask_batch.view(-1, 1, 1), + pos_mask_batch.view(n_batches, 1, 1), (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, torch.zeros_like(state.cell_velocities), @@ -1535,10 +1539,10 @@ def _ase_fire_step( # noqa: C901, PLR0915 dr_scaling_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) if is_cell_optimization: - state.cell_velocities += state.cell_forces * state.dt.view(-1, 1, 1) - dr_cell = state.cell_velocities * state.dt.view(-1, 1, 1) + state.cell_velocities += state.cell_forces * state.dt.view(n_batches, 1, 1) + dr_cell = state.cell_velocities * state.dt.view(n_batches, 1, 1) - dr_scaling_batch += dr_cell.pow(2).sum(dim=(1, 2), keepdim=True).squeeze(-1) + dr_scaling_batch += dr_cell.pow(2).sum(dim=(1, 2)) dr_scaling_cell = torch.sqrt(dr_scaling_batch).view(n_batches, 1, 1) dr_cell = torch.where( dr_scaling_cell > max_step, @@ -1549,7 +1553,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 # save the old cell to allow rescaling of the positions after cell update old_row_vector_cell = state.row_vector_cell.clone() - dr_scaling_atom = torch.sqrt(dr_scaling_batch)[state.batch] + dr_scaling_atom = torch.sqrt(dr_scaling_batch)[state.batch].unsqueeze(-1) dr_atom = torch.where( dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom ) @@ -1599,7 +1603,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 if is_cell_optimization: state.stress = results["stress"] - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) if torch.any(volumes <= 0): bad_indices = torch.where(volumes <= 0)[0].tolist() print( From 4a494e6de2f7f52bed22de386635bc1e17e9667b Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 26 May 2025 20:15:53 -0400 Subject: [PATCH 7/7] update forces per comment --- torch_sim/optimizers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index bdfef729..108b87a8 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1504,7 +1504,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 v_scaling_batch = tsm.batched_vdot( state.velocities, state.velocities, state.batch ) - f_scaling_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) + f_scaling_batch = tsm.batched_vdot(forces, forces, state.batch) if is_cell_optimization: v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) @@ -1524,7 +1524,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) - v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) + v_mixing_atom = forces * (v_scaling_atom / (f_scaling_atom + eps)) alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha state.velocities = torch.where( @@ -1534,7 +1534,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 ) # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) - state.velocities += state.forces * state.dt[state.batch].unsqueeze(-1) + state.velocities += forces * state.dt[state.batch].unsqueeze(-1) dr_atom = state.velocities * state.dt[state.batch].unsqueeze(-1) dr_scaling_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch)