diff --git a/torch_sim/math.py b/torch_sim/math.py index fd9182d2..ff78757d 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -987,3 +987,28 @@ def matrix_log_33( print(msg) # Fall back to scipy implementation return matrix_log_scipy(matrix).to(sim_dtype) + + +def batched_vdot( + x: torch.Tensor, y: torch.Tensor, batch_indices: torch.Tensor +) -> torch.Tensor: + """Computes batched vdot (sum of element-wise product) for groups of vectors. + If is_sum_sq is True, computes sum of x_i * x_i (squared norm components). + + Args: + x: Tensor of shape [N_total_entities, D] (e.g., forces, velocities). + y: Tensor of shape [N_total_entities, D]. + batch_indices: Tensor of shape [N_total_entities] indicating batch membership. + + Returns: + Tensor: shape [n_batches] where each element is the sum(x_i * y_i) + (or sum(x_i * x_i) if is_sum_sq) for entities belonging to that batch, + summed over all components D and all entities in the batch. + """ + if x.ndim != 2 or batch_indices.ndim != 1 or x.shape[0] != batch_indices.shape[0]: + raise ValueError(f"Invalid input shapes: {x.shape=}, {batch_indices.shape=}") + + output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) + output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1)) + + return output diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 0d94b4a5..108b87a8 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -270,6 +270,8 @@ def gd_init( if not isinstance(state, SimState): state = SimState(**state) + n_batches = state.n_batches + # Setup cell_factor if cell_factor is None: # Count atoms per batch @@ -283,7 +285,7 @@ def gd_init( ) # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(-1, 1, 1) + cell_factor = cell_factor.view(n_batches, 1, 1) scalar_pressure = torch.full( (state.n_batches, 1, 1), scalar_pressure, device=device, dtype=dtype @@ -316,7 +318,7 @@ def gd_init( ) # shape: (n_batches, 3, 3) # Calculate virial - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (stress + pressure) if hydrostatic_strain: @@ -391,7 +393,7 @@ def gd_step( # Get per-atom and per-cell learning rates atom_wise_lr = positions_lr[state.batch].unsqueeze(-1) - cell_wise_lr = cell_lr.view(-1, 1, 1) # shape: (n_batches, 1, 1) + cell_wise_lr = cell_lr.view(n_batches, 1, 1) # shape: (n_batches, 1, 1) # Update atomic and cell positions atomic_positions_new = state.positions + atom_wise_lr * state.forces @@ -415,7 +417,7 @@ def gd_step( state.stress = model_output["stress"] # Calculate virial for cell forces - volumes = torch.linalg.det(new_row_vector_cell).view(-1, 1, 1) + volumes = torch.linalg.det(new_row_vector_cell).view(n_batches, 1, 1) virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) @@ -811,7 +813,7 @@ def fire_init( ) # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(-1, 1, 1) + cell_factor = cell_factor.view(n_batches, 1, 1) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) @@ -824,7 +826,7 @@ def fire_init( forces = model_output["forces"] # [n_total_atoms, 3] stress = model_output["stress"] # [n_batches, 3, 3] - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: @@ -1097,7 +1099,7 @@ def fire_init( ) # Reshape to (n_batches, 1, 1) for broadcasting - cell_factor = cell_factor.view(-1, 1, 1) + cell_factor = cell_factor.view(n_batches, 1, 1) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) @@ -1121,7 +1123,7 @@ def fire_init( cell_positions = torch.zeros((n_batches, 3, 3), device=device, dtype=dtype) # Calculate virial for cell forces - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: @@ -1164,13 +1166,13 @@ def fire_init( batch=state.batch, pbc=state.pbc, # New attributes - velocities=torch.zeros_like(state.positions), + velocities=None, forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=cell_positions, - cell_velocities=torch.zeros((n_batches, 3, 3), device=device, dtype=dtype), + cell_velocities=None, cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes @@ -1202,8 +1204,11 @@ def fire_init( return fire_init, functools.partial(step_func, **step_func_kwargs) +VALID_FIRE_CELL_STATES = UnitCellFireState | FrechetCellFIREState + + def _vv_fire_step( # noqa: C901, PLR0915 - state: FireState | UnitCellFireState | FrechetCellFIREState, + state: FireState | VALID_FIRE_CELL_STATES, model: torch.nn.Module, *, dt_max: torch.Tensor, @@ -1215,7 +1220,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 eps: float, is_cell_optimization: bool = False, is_frechet: bool = False, -) -> FireState | UnitCellFireState | FrechetCellFIREState: +) -> FireState | VALID_FIRE_CELL_STATES: """Perform one Velocity-Verlet based FIRE optimization step. Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm for @@ -1244,6 +1249,17 @@ def _vv_fire_step( # noqa: C901, PLR0915 dtype = state.positions.dtype deform_grad_new: torch.Tensor | None = None + if state.velocities is None: + state.velocities = torch.zeros_like(state.positions) + if is_cell_optimization: + if not isinstance(state, VALID_FIRE_CELL_STATES): + raise ValueError( + "Cell optimization requires one of {VALID_FIRE_CELL_STATES}." + ) + state.cell_velocities = torch.zeros( + (n_batches, 3, 3), device=device, dtype=dtype + ) + alpha_start_batch = torch.full( (n_batches,), alpha_start.item(), device=device, dtype=dtype ) @@ -1252,7 +1268,6 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) cell_wise_dt = state.dt.unsqueeze(-1).unsqueeze(-1) state.cell_velocities += ( 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) @@ -1261,7 +1276,6 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.positions = state.positions + atom_wise_dt * state.velocities if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) cell_factor_reshaped = state.cell_factor.view(n_batches, 1, 1) if is_frechet: assert isinstance(state, FrechetCellFIREState) @@ -1284,7 +1298,6 @@ def _vv_fire_step( # noqa: C901, PLR0915 else: assert isinstance(state, UnitCellFireState) cur_deform_grad = state.deform_grad() - # cell_factor is (N,1,1) cell_factor_expanded = state.cell_factor.expand(n_batches, 3, 1) current_cell_positions_scaled = ( cur_deform_grad.view(n_batches, 3, 3) * cell_factor_expanded @@ -1305,9 +1318,8 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.energy = results["energy"] if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.stress = results["stress"] - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: @@ -1351,63 +1363,62 @@ def _vv_fire_step( # noqa: C901, PLR0915 state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.cell_velocities += ( 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) ) - atomic_power = (state.forces * state.velocities).sum(dim=1) - atomic_power_per_batch = torch.zeros( - n_batches, device=device, dtype=atomic_power.dtype - ) - atomic_power_per_batch.scatter_add_(dim=0, index=state.batch, src=atomic_power) - batch_power = atomic_power_per_batch + batch_power = tsm.batched_vdot(state.forces, state.velocities, state.batch) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - cell_power = (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) - batch_power += cell_power - - for batch_idx in range(n_batches): - if batch_power[batch_idx] > 0: - state.n_pos[batch_idx] += 1 - if state.n_pos[batch_idx] > n_min: - state.dt[batch_idx] = torch.minimum(state.dt[batch_idx] * f_inc, dt_max) - state.alpha[batch_idx] = state.alpha[batch_idx] * f_alpha - else: - state.n_pos[batch_idx] = 0 - state.dt[batch_idx] = state.dt[batch_idx] * f_dec - state.alpha[batch_idx] = alpha_start_batch[batch_idx] - state.velocities[state.batch == batch_idx] = 0 - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - state.cell_velocities[batch_idx] = 0 - - v_norm = torch.norm(state.velocities, dim=1, keepdim=True) - f_norm = torch.norm(state.forces, dim=1, keepdim=True) - atom_wise_alpha = state.alpha[state.batch].unsqueeze(-1) - state.velocities = (1.0 - atom_wise_alpha) * state.velocities + ( - atom_wise_alpha * state.forces * v_norm / (f_norm + eps) - ) + batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + + # 2. Update dt, alpha, n_pos + pos_mask_batch = batch_power > 0.0 + neg_mask_batch = ~pos_mask_batch + + state.n_pos[pos_mask_batch] += 1 + inc_mask = (state.n_pos > n_min) & pos_mask_batch + state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) + state.alpha[inc_mask] *= f_alpha + + state.dt[neg_mask_batch] *= f_dec + state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] + state.n_pos[neg_mask_batch] = 0 + + v_scaling_batch = tsm.batched_vdot(state.velocities, state.velocities, state.batch) + f_scaling_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - cell_v_norm = torch.norm(state.cell_velocities, dim=(1, 2), keepdim=True) - cell_f_norm = torch.norm(state.cell_forces, dim=(1, 2), keepdim=True) - cell_wise_alpha = state.alpha.unsqueeze(-1).unsqueeze(-1) - cell_mask = (cell_f_norm > eps).expand_as(state.cell_velocities) + v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) + f_scaling_batch += state.cell_forces.pow(2).sum(dim=(1, 2)) + + v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) + f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) + v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell + + alpha_cell_bc = state.alpha.view(n_batches, 1, 1) state.cell_velocities = torch.where( - cell_mask, - (1.0 - cell_wise_alpha) * state.cell_velocities - + cell_wise_alpha * state.cell_forces * cell_v_norm / (cell_f_norm + eps), - state.cell_velocities, + pos_mask_batch.view(n_batches, 1, 1), + (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, + torch.zeros_like(state.cell_velocities), ) + v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) + v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) + + alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha + state.velocities = torch.where( + pos_mask_batch[state.batch].unsqueeze(-1), + (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, + torch.zeros_like(state.velocities), + ) + return state def _ase_fire_step( # noqa: C901, PLR0915 - state: FireState | UnitCellFireState | FrechetCellFIREState, + state: FireState | VALID_FIRE_CELL_STATES, model: torch.nn.Module, *, dt_max: torch.Tensor, @@ -1420,7 +1431,7 @@ def _ase_fire_step( # noqa: C901, PLR0915 eps: float, is_cell_optimization: bool = False, is_frechet: bool = False, -) -> FireState | UnitCellFireState | FrechetCellFIREState: +) -> FireState | VALID_FIRE_CELL_STATES: """Perform one ASE-style FIRE optimization step. Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm @@ -1447,184 +1458,158 @@ def _ase_fire_step( # noqa: C901, PLR0915 device, dtype = state.positions.device, state.positions.dtype n_batches = state.n_batches - # Setup batch-wise alpha_start for potential reset - # alpha_start is a 0-dim tensor from the factory - alpha_start_batch = torch.full( - (n_batches,), alpha_start.item(), device=device, dtype=dtype - ) + if state.velocities is None: + state.velocities = torch.zeros_like(state.positions) + forces = state.forces + if is_cell_optimization: + if not isinstance(state, VALID_FIRE_CELL_STATES): + raise ValueError( + "Cell optimization requires one of {VALID_FIRE_CELL_STATES}." + ) + state.cell_velocities = torch.zeros( + (n_batches, 3, 3), device=device, dtype=dtype + ) + else: + alpha_start_batch = torch.full( + (n_batches,), alpha_start.item(), device=device, dtype=dtype + ) - # 1. Current power (F·v) per batch (atoms + cell) - atomic_power = (state.forces * state.velocities).sum(dim=1) - batch_power = torch.zeros(n_batches, device=device, dtype=dtype) - batch_power.scatter_add_(0, state.batch, atomic_power) + if is_cell_optimization: + forces = torch.bmm( + state.forces.unsqueeze(1), state.deform_grad()[state.batch] + ).squeeze(1) + else: + forces = state.forces - if is_cell_optimization: - valid_states = (UnitCellFireState, FrechetCellFIREState) - assert isinstance(state, valid_states), ( - f"Cell optimization requires one of {valid_states}." - ) - cell_power = (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) - batch_power += cell_power + # 1. Current power (F·v) per batch (atoms + cell) + batch_power = tsm.batched_vdot(forces, state.velocities, state.batch) - # 2. Update dt, alpha, n_pos - pos_mask_batch = batch_power > 0.0 - neg_mask_batch = ~pos_mask_batch + if is_cell_optimization: + batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) - state.n_pos[pos_mask_batch] += 1 - inc_mask = (state.n_pos > n_min) & pos_mask_batch - state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) - state.alpha[inc_mask] *= f_alpha + # 2. Update dt, alpha, n_pos + pos_mask_batch = batch_power > 0.0 + neg_mask_batch = ~pos_mask_batch - state.dt[neg_mask_batch] *= f_dec - state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] - state.n_pos[neg_mask_batch] = 0 + state.n_pos[pos_mask_batch] += 1 + inc_mask = (state.n_pos > n_min) & pos_mask_batch + state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) + state.alpha[inc_mask] *= f_alpha - # 3. Velocity mixing BEFORE acceleration (ASE ordering) - # Atoms - v_norm_atom = torch.norm(state.velocities, dim=1, keepdim=True) - f_norm_atom = torch.norm(state.forces, dim=1, keepdim=True) - f_unit_atom = state.forces / (f_norm_atom + eps) - alpha_atom = state.alpha[state.batch].unsqueeze(-1) - pos_mask_atom = pos_mask_batch[state.batch].unsqueeze(-1) - v_new_atom = ( - 1.0 - alpha_atom - ) * state.velocities + alpha_atom * f_unit_atom * v_norm_atom - state.velocities = torch.where( - pos_mask_atom, v_new_atom, torch.zeros_like(state.velocities) - ) + state.dt[neg_mask_batch] *= f_dec + state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] + state.n_pos[neg_mask_batch] = 0 - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - # Cell velocity mixing - cv_norm = torch.norm(state.cell_velocities, dim=(1, 2), keepdim=True) - cf_norm = torch.norm(state.cell_forces, dim=(1, 2), keepdim=True) - cf_unit = state.cell_forces / (cf_norm + eps) - alpha_cell_bc = state.alpha.view(-1, 1, 1) - pos_mask_cell_bc = pos_mask_batch.view(-1, 1, 1) - v_new_cell = ( - 1.0 - alpha_cell_bc - ) * state.cell_velocities + alpha_cell_bc * cf_unit * cv_norm - state.cell_velocities = torch.where( - pos_mask_cell_bc, v_new_cell, torch.zeros_like(state.cell_velocities) + # 3. Velocity mixing BEFORE acceleration (ASE ordering) + v_scaling_batch = tsm.batched_vdot( + state.velocities, state.velocities, state.batch ) + f_scaling_batch = tsm.batched_vdot(forces, forces, state.batch) + + if is_cell_optimization: + v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) + f_scaling_batch += state.cell_forces.pow(2).sum(dim=(1, 2)) + + v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) + f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) + v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell + + alpha_cell_bc = state.alpha.view(n_batches, 1, 1) + state.cell_velocities = torch.where( + pos_mask_batch.view(n_batches, 1, 1), + (1.0 - alpha_cell_bc) * state.cell_velocities + + alpha_cell_bc * v_mixing_cell, + torch.zeros_like(state.cell_velocities), + ) - # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) - atom_dt = state.dt[state.batch].unsqueeze(-1) - state.velocities += atom_dt * state.forces + v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) + v_mixing_atom = forces * (v_scaling_atom / (f_scaling_atom + eps)) - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - cell_dt = state.dt.view(-1, 1, 1) - state.cell_velocities += cell_dt * state.cell_forces - - # 5. Displacements - dr_atom = atom_dt * state.velocities - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - dr_cell = cell_dt * state.cell_velocities + alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha + state.velocities = torch.where( + pos_mask_batch[state.batch].unsqueeze(-1), + (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, + torch.zeros_like(state.velocities), + ) - # 6. Clamp to max_step - dr_norm_atom = torch.norm(dr_atom, dim=1, keepdim=True) - mask_atom_max_step = dr_norm_atom > max_step - dr_atom = torch.where( - mask_atom_max_step, max_step * dr_atom / (dr_norm_atom + eps), dr_atom - ) + # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) + state.velocities += forces * state.dt[state.batch].unsqueeze(-1) + dr_atom = state.velocities * state.dt[state.batch].unsqueeze(-1) + dr_scaling_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) - old_row_vector_cell: torch.Tensor | None = None if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) - # Cell clamp to max_step (Frobenius norm) - dr_cell_norm_fro = torch.norm(dr_cell.view(n_batches, -1), dim=1, keepdim=True) - mask_cell_max_step = dr_cell_norm_fro.view(n_batches, 1, 1) > max_step + state.cell_velocities += state.cell_forces * state.dt.view(n_batches, 1, 1) + dr_cell = state.cell_velocities * state.dt.view(n_batches, 1, 1) + + dr_scaling_batch += dr_cell.pow(2).sum(dim=(1, 2)) + dr_scaling_cell = torch.sqrt(dr_scaling_batch).view(n_batches, 1, 1) dr_cell = torch.where( - mask_cell_max_step, - max_step * dr_cell / (dr_cell_norm_fro.view(n_batches, 1, 1) + eps), + dr_scaling_cell > max_step, + max_step * dr_cell / (dr_scaling_cell + eps), dr_cell, ) - # 7. Position / cell update - # Store old cell for scaling atoms - # Ensure old_row_vector_cell is cloned before any modification to state.cell or - # state.row_vector_cell + # save the old cell to allow rescaling of the positions after cell update old_row_vector_cell = state.row_vector_cell.clone() + dr_scaling_atom = torch.sqrt(dr_scaling_batch)[state.batch].unsqueeze(-1) + dr_atom = torch.where( + dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom + ) state.positions = state.positions + dr_atom - # F_new stores F_new for Frechet's ucf_cell_grad if needed - F_new: torch.Tensor | None = None - # logm_F_new stores logm_F_new for Frechet's cell_forces recalc if needed - logm_F_new: torch.Tensor | None = None - if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) if is_frechet: assert isinstance(state, FrechetCellFIREState) - # Frechet cell update logic new_logm_F_scaled = state.cell_positions + dr_cell state.cell_positions = new_logm_F_scaled - # cell_factor is (N,1,1) logm_F_new = new_logm_F_scaled / (state.cell_factor + eps) F_new = torch.matrix_exp(logm_F_new) new_row_vector_cell = torch.bmm( state.reference_row_vector_cell, F_new.transpose(-2, -1) ) state.row_vector_cell = new_row_vector_cell - else: # UnitCellFire + else: assert isinstance(state, UnitCellFireState) - # Unit cell update logic F_current = state.deform_grad() - # state.cell_factor is (N,1,1), F_current is (N,3,3) - # cell_factor_exp for element-wise F_current * cell_factor_exp should be - # (N,3,3) or broadcast from (N,1,1) or (N,3,1) cell_factor_exp_mult = state.cell_factor.expand(n_batches, 3, 1) current_F_scaled = F_current * cell_factor_exp_mult F_new_scaled = current_F_scaled + dr_cell - state.cell_positions = F_new_scaled # track the scaled deformation gradient - F_new = F_new_scaled / (cell_factor_exp_mult + eps) # Division by (N,3,1) - # When state.cell is set, state.row_vector_cell is auto-updated + state.cell_positions = F_new_scaled + F_new = F_new_scaled / (cell_factor_exp_mult + eps) new_cell_column_vectors = torch.bmm( state.reference_cell, F_new.transpose(-2, -1) ) state.cell = new_cell_column_vectors - # Scale atomic positions according to cell change (mimicking scale_atoms=True) - if is_cell_optimization and old_row_vector_cell is not None: - current_new_row_vector_cell = state.row_vector_cell # This is A_new after update - + # rescale the positions after cell update + current_new_row_vector_cell = state.row_vector_cell inv_old_cell_batch = torch.linalg.inv(old_row_vector_cell) - # Transform matrix T such that A_new = A_old @ T (for row vectors A) - # This means cartesian positions P_new_row = P_old_row @ T transform_matrix_batch = torch.bmm( inv_old_cell_batch, current_new_row_vector_cell - ) # Shape [N_batch, 3, 3] - - # Shape: [N_atoms, 3, 3] + ) atom_specific_transform = transform_matrix_batch[state.batch] - - # state.positions is [N_atoms, 3]. Unsqueeze to [N_atoms, 1, 3] for bmm - # Result of bmm will be [N_atoms, 1, 3], then squeeze scaled_positions = torch.bmm( state.positions.unsqueeze(1), atom_specific_transform ).squeeze(1) state.positions = scaled_positions - # 8. Force / stress refresh & new cell forces + # 7. Force / stress refresh & new cell forces results = model(state) state.forces = results["forces"] state.energy = results["energy"] if is_cell_optimization: - assert isinstance(state, (UnitCellFireState, FrechetCellFIREState)) state.stress = results["stress"] - volumes = torch.linalg.det(state.cell).view(-1, 1, 1) + volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) if torch.any(volumes <= 0): bad_indices = torch.where(volumes <= 0)[0].tolist() print( f"WARNING: Non-positive volume(s) detected during _ase_fire_step: " f"{volumes[bad_indices].tolist()} at {bad_indices=} ({is_frechet=})" ) - # volumes = torch.clamp(volumes, min=eps) # Optional: for stability virial = -volumes * (state.stress + state.pressure) @@ -1633,7 +1618,8 @@ def _ase_fire_step( # noqa: C901, PLR0915 virial = diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype ).unsqueeze(0).expand(n_batches, -1, -1) - if state.constant_volume: # Can be true even if hydrostatic_strain is false + + if state.constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype @@ -1647,7 +1633,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 assert logm_F_new is not None, ( "logm_F_new should be defined for Frechet cell force calculation" ) - # Frechet cell force recalculation ucf_cell_grad = torch.bmm( virial, torch.linalg.inv(torch.transpose(F_new, 1, 2)) ) @@ -1659,7 +1644,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 new_cell_forces_log_space = torch.zeros_like(state.cell_forces) for b_idx in range(n_batches): - # logm_F_new[b_idx] is the current point in log-space expm_derivs = torch.stack( [ tsm.expm_frechet(logm_F_new[b_idx], direction, compute_expm=False) @@ -1670,12 +1654,9 @@ def _ase_fire_step( # noqa: C901, PLR0915 expm_derivs * ucf_cell_grad[b_idx].unsqueeze(0), dim=(1, 2) ) new_cell_forces_log_space[b_idx] = forces_flat.reshape(3, 3) - state.cell_forces = new_cell_forces_log_space / ( - state.cell_factor + eps - ) # cell_factor is (N,1,1) - else: # UnitCellFire + state.cell_forces = new_cell_forces_log_space / (state.cell_factor + eps) + else: assert isinstance(state, UnitCellFireState) - # Unit cell force recalculation - state.cell_forces = virial / state.cell_factor # cell_factor is (N,1,1) + state.cell_forces = virial / state.cell_factor return state