From 564547fcb25a7406940a9c1210e74a5853689ca0 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Thu, 17 Apr 2025 15:34:03 -0400 Subject: [PATCH 1/3] fix virial calculations in optimizers and integrators --- torch_sim/integrators.py | 2 +- torch_sim/optimizers.py | 25 +++++++++------------ torch_sim/unbatched/unbatched_optimizers.py | 8 +++---- 3 files changed, 15 insertions(+), 20 deletions(-) diff --git a/torch_sim/integrators.py b/torch_sim/integrators.py index 8bca00aa..f2c872fe 100644 --- a/torch_sim/integrators.py +++ b/torch_sim/integrators.py @@ -603,7 +603,7 @@ def _compute_cell_force( # Calculate virials from stress and external pressure # Internal stress is negative of virial tensor divided by volume - virial = -volumes * state.stress + pressure_tensor * volumes + virial = -volumes * (state.stress + pressure_tensor) # Add kinetic contribution (kT * Identity) batch_kT = kT diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index b503e7b5..543cce42 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -315,7 +315,7 @@ def gd_init( # Calculate virial volumes = torch.linalg.det(state.cell).view(-1, 1, 1) - virial = -volumes * stress + pressure + virial = -volumes * (stress + pressure) if hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) @@ -420,7 +420,7 @@ def gd_step( # Calculate virial for cell forces volumes = torch.linalg.det(new_row_vector_cell).view(-1, 1, 1) - virial = -volumes * state.stress + state.pressure + virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( @@ -537,12 +537,7 @@ def fire( # Setup parameters params = [dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min] dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min = [ - ( - p - if isinstance(p, torch.Tensor) - else torch.tensor(p, device=device, dtype=dtype) - ) - for p in params + torch.as_tensor(p, device=device, dtype=dtype) for p in params ] def fire_init( @@ -680,10 +675,10 @@ def fire_step( # + state.alpha * state.forces * v_norm / f_norm, # state.velocity, # ) - batch_wise_alpha = state.alpha[state.batch].unsqueeze(-1) + atom_wise_alpha = state.alpha[state.batch].unsqueeze(-1) state.velocities = ( - 1.0 - batch_wise_alpha - ) * state.velocities + batch_wise_alpha * state.forces * v_norm / (f_norm + eps) + 1.0 - atom_wise_alpha + ) * state.velocities + atom_wise_alpha * state.forces * v_norm / (f_norm + eps) return state @@ -890,7 +885,7 @@ def fire_init( stress = model_output["stress"] # [n_batches, 3, 3] volumes = torch.linalg.det(state.cell).view(-1, 1, 1) - virial = -volumes * stress + pressure + virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) @@ -1022,7 +1017,7 @@ def fire_step( # noqa: PLR0915 state.stress = stress # Calculate virial volumes = torch.linalg.det(new_cell).view(-1, 1, 1) - virial = -volumes * stress + state.pressure + virial = -volumes * (stress + state.pressure) if state.hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( @@ -1318,7 +1313,7 @@ def fire_init( # Calculate virial for cell forces volumes = torch.linalg.det(state.cell).view(-1, 1, 1) - virial = -volumes * stress + pressure + virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) @@ -1466,7 +1461,7 @@ def fire_step( # noqa: PLR0915 # Calculate virial volumes = torch.linalg.det(state.cell).view(-1, 1, 1) - virial = -volumes * stress + state.pressure + virial = -volumes * (stress + state.pressure) # P is P_ext * I if state.hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( diff --git a/torch_sim/unbatched/unbatched_optimizers.py b/torch_sim/unbatched/unbatched_optimizers.py index 5648fbfd..af6689ce 100644 --- a/torch_sim/unbatched/unbatched_optimizers.py +++ b/torch_sim/unbatched/unbatched_optimizers.py @@ -648,7 +648,7 @@ def fire_init( # Calculate virial volume = torch.linalg.det(state.cell).view(1, 1) - virial = -volume * stress + pressure + virial = -volume * (stress + pressure) if hydrostatic_strain: diag_mean = torch.diagonal(virial).mean().view(1, 1) @@ -742,7 +742,7 @@ def fire_step( # noqa: PLR0915 # Calculate virial for cell forces volume = torch.linalg.det(new_row_vector_cell).view(1, 1) - virial = -volume * stress + state.pressure + virial = -volume * (stress + state.pressure) if state.hydrostatic_strain: diag_mean = torch.diagonal(virial).mean().view(1, 1) @@ -964,7 +964,7 @@ def fire_init( # Calculate virial volume = torch.linalg.det(state.cell).view(1, 1) - virial = -volume * stress + pressure + virial = -volume * (stress + pressure) if hydrostatic_strain: diag_mean = torch.diagonal(virial).mean().view(1, 1) @@ -1060,7 +1060,7 @@ def fire_step( # noqa: PLR0915 # Calculate virial for cell forces volume = torch.linalg.det(state.cell).view(1, 1) - virial = -volume * stress + state.pressure + virial = -volume * (stress + state.pressure) if state.hydrostatic_strain: diag_mean = torch.diagonal(virial).mean().view(1, 1) From 40d0654b738ab661ef206b4dd2c5480ffb03035e Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Thu, 17 Apr 2025 15:37:52 -0400 Subject: [PATCH 2/3] refactor cell_forces in optimizers.py --- torch_sim/optimizers.py | 19 ++++--------------- torch_sim/unbatched/unbatched_optimizers.py | 8 ++------ 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 543cce42..e7b8cb47 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -329,12 +329,6 @@ def gd_init( 3, device=device ).unsqueeze(0).expand(state.n_batches, -1, -1) - # Scale virial by cell_factor - virial = virial / cell_factor - - # Reshape virial for cell forces - cell_forces = virial # shape: (n_batches, 3, 3) - return UnitCellGDState( positions=state.positions, forces=forces, @@ -351,7 +345,7 @@ def gd_init( atomic_numbers=state.atomic_numbers, batch=state.batch, cell_positions=cell_positions, - cell_forces=cell_forces, + cell_forces=virial / cell_factor, cell_masses=cell_masses, ) @@ -432,12 +426,9 @@ def gd_step( 3, device=device ).unsqueeze(0).expand(n_batches, -1, -1) - # Scale virial by cell_factor - virial = virial / state.cell_factor - # Update cell forces state.cell_positions = cell_positions_new - state.cell_forces = virial + state.cell_forces = virial / state.cell_factor return state @@ -899,8 +890,7 @@ def fire_init( 3, device=device ).unsqueeze(0).expand(n_batches, -1, -1) - virial = virial / cell_factor - cell_forces = virial + cell_forces = virial / cell_factor # Sum masses per batch using segment_reduce # TODO (AG): check this @@ -1029,8 +1019,7 @@ def fire_step( # noqa: PLR0915 3, device=device ).unsqueeze(0).expand(n_batches, -1, -1) - virial = virial / state.cell_factor - state.cell_forces = virial + state.cell_forces = virial / state.cell_factor # Velocity Verlet first half step (v += 0.5*a*dt) state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) diff --git a/torch_sim/unbatched/unbatched_optimizers.py b/torch_sim/unbatched/unbatched_optimizers.py index af6689ce..11dc7afe 100644 --- a/torch_sim/unbatched/unbatched_optimizers.py +++ b/torch_sim/unbatched/unbatched_optimizers.py @@ -658,9 +658,6 @@ def fire_init( diag_mean = torch.diagonal(virial).mean().view(1, 1) virial = virial - diag_mean * torch.eye(3, device=device) - virial = virial / cell_factor - cell_forces = virial - # Create cell masses cell_masses = torch.full((3,), state.masses.sum(), device=device, dtype=dtype) @@ -684,7 +681,7 @@ def fire_init( atomic_numbers=atomic_numbers, cell_positions=cell_positions, cell_velocities=torch.zeros_like(cell_positions), - cell_forces=cell_forces, + cell_forces=virial / cell_factor, cell_masses=cell_masses, ) @@ -752,8 +749,7 @@ def fire_step( # noqa: PLR0915 diag_mean = torch.diagonal(virial).mean().view(1, 1) virial = virial - diag_mean * torch.eye(3, device=device) - virial = virial / state.cell_factor - state.cell_forces = virial + state.cell_forces = virial / state.cell_factor # Velocity Verlet second half step state.velocities += 0.5 * state.dt * state.forces / state.masses.unsqueeze(-1) From 536657e56d95f1419ea5b70eb10848631f204036 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Thu, 17 Apr 2025 15:46:32 -0400 Subject: [PATCH 3/3] clarify test_state_round_trip not testing round trip for masses with pymatgen and phonopy --- tests/test_io.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_io.py b/tests/test_io.py index b0b27db1..350f6a55 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -252,5 +252,7 @@ def test_state_round_trip( assert sim_state.pbc == round_trip_state.pbc if isinstance(intermediate_format[0], Atoms): - # TODO: the round trip for pmg and phonopy masses is not exact. + # TODO: masses round trip for pmg and phonopy masses is not exact + # since both use their own isotope masses based on species, + # not the ones in the state assert torch.allclose(sim_state.masses, round_trip_state.masses)