Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,5 +252,7 @@ def test_state_round_trip(
assert sim_state.pbc == round_trip_state.pbc

if isinstance(intermediate_format[0], Atoms):
# TODO: the round trip for pmg and phonopy masses is not exact.
# TODO: masses round trip for pmg and phonopy masses is not exact
# since both use their own isotope masses based on species,
# not the ones in the state
assert torch.allclose(sim_state.masses, round_trip_state.masses)
2 changes: 1 addition & 1 deletion torch_sim/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@

# Calculate virials from stress and external pressure
# Internal stress is negative of virial tensor divided by volume
virial = -volumes * state.stress + pressure_tensor * volumes
virial = -volumes * (state.stress + pressure_tensor)

Check warning on line 606 in torch_sim/integrators.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/integrators.py#L606

Added line #L606 was not covered by tests

# Add kinetic contribution (kT * Identity)
batch_kT = kT
Expand Down
44 changes: 14 additions & 30 deletions torch_sim/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@

# Calculate virial
volumes = torch.linalg.det(state.cell).view(-1, 1, 1)
virial = -volumes * stress + pressure
virial = -volumes * (stress + pressure)

Check warning on line 318 in torch_sim/optimizers.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/optimizers.py#L318

Added line #L318 was not covered by tests

if hydrostatic_strain:
diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True)
Expand All @@ -329,12 +329,6 @@
3, device=device
).unsqueeze(0).expand(state.n_batches, -1, -1)

# Scale virial by cell_factor
virial = virial / cell_factor

# Reshape virial for cell forces
cell_forces = virial # shape: (n_batches, 3, 3)

return UnitCellGDState(
positions=state.positions,
forces=forces,
Expand All @@ -351,7 +345,7 @@
atomic_numbers=state.atomic_numbers,
batch=state.batch,
cell_positions=cell_positions,
cell_forces=cell_forces,
cell_forces=virial / cell_factor,
cell_masses=cell_masses,
)

Expand Down Expand Up @@ -420,7 +414,7 @@

# Calculate virial for cell forces
volumes = torch.linalg.det(new_row_vector_cell).view(-1, 1, 1)
virial = -volumes * state.stress + state.pressure
virial = -volumes * (state.stress + state.pressure)

Check warning on line 417 in torch_sim/optimizers.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/optimizers.py#L417

Added line #L417 was not covered by tests
if state.hydrostatic_strain:
diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True)
virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze(
Expand All @@ -432,12 +426,9 @@
3, device=device
).unsqueeze(0).expand(n_batches, -1, -1)

# Scale virial by cell_factor
virial = virial / state.cell_factor

# Update cell forces
state.cell_positions = cell_positions_new
state.cell_forces = virial
state.cell_forces = virial / state.cell_factor

Check warning on line 431 in torch_sim/optimizers.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/optimizers.py#L431

Added line #L431 was not covered by tests

return state

Expand Down Expand Up @@ -537,12 +528,7 @@
# Setup parameters
params = [dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min]
dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min = [
(
p
if isinstance(p, torch.Tensor)
else torch.tensor(p, device=device, dtype=dtype)
)
for p in params
torch.as_tensor(p, device=device, dtype=dtype) for p in params
]

def fire_init(
Expand Down Expand Up @@ -680,10 +666,10 @@
# + state.alpha * state.forces * v_norm / f_norm,
# state.velocity,
# )
batch_wise_alpha = state.alpha[state.batch].unsqueeze(-1)
atom_wise_alpha = state.alpha[state.batch].unsqueeze(-1)

Check warning on line 669 in torch_sim/optimizers.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/optimizers.py#L669

Added line #L669 was not covered by tests
state.velocities = (
1.0 - batch_wise_alpha
) * state.velocities + batch_wise_alpha * state.forces * v_norm / (f_norm + eps)
1.0 - atom_wise_alpha
) * state.velocities + atom_wise_alpha * state.forces * v_norm / (f_norm + eps)

return state

Expand Down Expand Up @@ -890,7 +876,7 @@
stress = model_output["stress"] # [n_batches, 3, 3]

volumes = torch.linalg.det(state.cell).view(-1, 1, 1)
virial = -volumes * stress + pressure
virial = -volumes * (stress + pressure) # P is P_ext * I

Check warning on line 879 in torch_sim/optimizers.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/optimizers.py#L879

Added line #L879 was not covered by tests

if hydrostatic_strain:
diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True)
Expand All @@ -904,8 +890,7 @@
3, device=device
).unsqueeze(0).expand(n_batches, -1, -1)

virial = virial / cell_factor
cell_forces = virial
cell_forces = virial / cell_factor

Check warning on line 893 in torch_sim/optimizers.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/optimizers.py#L893

Added line #L893 was not covered by tests

# Sum masses per batch using segment_reduce
# TODO (AG): check this
Expand Down Expand Up @@ -1022,7 +1007,7 @@
state.stress = stress
# Calculate virial
volumes = torch.linalg.det(new_cell).view(-1, 1, 1)
virial = -volumes * stress + state.pressure
virial = -volumes * (stress + state.pressure)

Check warning on line 1010 in torch_sim/optimizers.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/optimizers.py#L1010

Added line #L1010 was not covered by tests
if state.hydrostatic_strain:
diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True)
virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze(
Expand All @@ -1034,8 +1019,7 @@
3, device=device
).unsqueeze(0).expand(n_batches, -1, -1)

virial = virial / state.cell_factor
state.cell_forces = virial
state.cell_forces = virial / state.cell_factor

Check warning on line 1022 in torch_sim/optimizers.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/optimizers.py#L1022

Added line #L1022 was not covered by tests

# Velocity Verlet first half step (v += 0.5*a*dt)
state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1)
Expand Down Expand Up @@ -1318,7 +1302,7 @@

# Calculate virial for cell forces
volumes = torch.linalg.det(state.cell).view(-1, 1, 1)
virial = -volumes * stress + pressure
virial = -volumes * (stress + pressure) # P is P_ext * I

Check warning on line 1305 in torch_sim/optimizers.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/optimizers.py#L1305

Added line #L1305 was not covered by tests

if hydrostatic_strain:
diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True)
Expand Down Expand Up @@ -1466,7 +1450,7 @@

# Calculate virial
volumes = torch.linalg.det(state.cell).view(-1, 1, 1)
virial = -volumes * stress + state.pressure
virial = -volumes * (stress + state.pressure) # P is P_ext * I

Check warning on line 1453 in torch_sim/optimizers.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/optimizers.py#L1453

Added line #L1453 was not covered by tests
if state.hydrostatic_strain:
diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True)
virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze(
Expand Down
16 changes: 6 additions & 10 deletions torch_sim/unbatched/unbatched_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def fire_init(

# Calculate virial
volume = torch.linalg.det(state.cell).view(1, 1)
virial = -volume * stress + pressure
virial = -volume * (stress + pressure)

if hydrostatic_strain:
diag_mean = torch.diagonal(virial).mean().view(1, 1)
Expand All @@ -658,9 +658,6 @@ def fire_init(
diag_mean = torch.diagonal(virial).mean().view(1, 1)
virial = virial - diag_mean * torch.eye(3, device=device)

virial = virial / cell_factor
cell_forces = virial

# Create cell masses
cell_masses = torch.full((3,), state.masses.sum(), device=device, dtype=dtype)

Expand All @@ -684,7 +681,7 @@ def fire_init(
atomic_numbers=atomic_numbers,
cell_positions=cell_positions,
cell_velocities=torch.zeros_like(cell_positions),
cell_forces=cell_forces,
cell_forces=virial / cell_factor,
cell_masses=cell_masses,
)

Expand Down Expand Up @@ -742,7 +739,7 @@ def fire_step( # noqa: PLR0915

# Calculate virial for cell forces
volume = torch.linalg.det(new_row_vector_cell).view(1, 1)
virial = -volume * stress + state.pressure
virial = -volume * (stress + state.pressure)

if state.hydrostatic_strain:
diag_mean = torch.diagonal(virial).mean().view(1, 1)
Expand All @@ -752,8 +749,7 @@ def fire_step( # noqa: PLR0915
diag_mean = torch.diagonal(virial).mean().view(1, 1)
virial = virial - diag_mean * torch.eye(3, device=device)

virial = virial / state.cell_factor
state.cell_forces = virial
state.cell_forces = virial / state.cell_factor

# Velocity Verlet second half step
state.velocities += 0.5 * state.dt * state.forces / state.masses.unsqueeze(-1)
Expand Down Expand Up @@ -964,7 +960,7 @@ def fire_init(

# Calculate virial
volume = torch.linalg.det(state.cell).view(1, 1)
virial = -volume * stress + pressure
virial = -volume * (stress + pressure)

if hydrostatic_strain:
diag_mean = torch.diagonal(virial).mean().view(1, 1)
Expand Down Expand Up @@ -1060,7 +1056,7 @@ def fire_step( # noqa: PLR0915

# Calculate virial for cell forces
volume = torch.linalg.det(state.cell).view(1, 1)
virial = -volume * stress + state.pressure
virial = -volume * (stress + state.pressure)

if state.hydrostatic_strain:
diag_mean = torch.diagonal(virial).mean().view(1, 1)
Expand Down