Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest
import torch
from ase.build.bulk import bulk

import torch_sim as ts
from tests.conftest import DEVICE, DTYPE
Expand Down Expand Up @@ -778,6 +779,27 @@ def test_static_no_filenames(
assert isinstance(props[0]["potential_energy"], torch.Tensor)


def test_static_after_optimize(lj_model: LennardJonesModel):
"""Tests that we can calculate static properties after an optimize run."""
atoms = bulk("Si", "diamond", a=5.43, cubic=True)
initial_state = ts.io.atoms_to_state(
atoms, device=lj_model.device, dtype=lj_model.dtype
)

final_state = ts.optimize(
system=initial_state,
model=lj_model,
optimizer=ts.Optimizer.fire,
max_steps=100,
)

results = ts.static(
system=final_state,
model=lj_model,
)
assert results[0]["potential_energy"] == final_state.energy


def test_readme_example(lj_model: LennardJonesModel, tmp_path: Path) -> None:
# this tests the example from the readme, update as needed

Expand Down
18 changes: 11 additions & 7 deletions torch_sim/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,8 +579,8 @@ class StaticState(SimState):
forces: torch.Tensor
stress: torch.Tensor

_atom_attributes = state._atom_attributes | {"forces"} # noqa: SLF001
_system_attributes = state._system_attributes | { # noqa: SLF001
_atom_attributes = SimState._atom_attributes | {"forces"} # noqa: SLF001
Copy link
Collaborator Author

@curtischong curtischong Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like changing the atom attributes from state to SimState. But I'm okay with it for now since torchsim currently only supports calculating energies, forces, and stresses.

When we support magnetic moments (and arbitrary properties in general) we will need to rethink how to create this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why I needed to set these atom attributes to be derived from SimState is because if I left it as state, then inside the trajectory reporter, it would think that StaticState would have extra attributes (that we did not initialize when we created it) the extra attribute in particular is velocity.

So by explicitly defining the _atom_attributes to be from SimState, we only say that this StaticState has a fixed limited set of attributes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to add mag moms soon so I'd prefer we pick a solution here that will be generalizable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I need to think through that some more as well. On another note, I'd prefer to do that in a different PR so this one can be constrained to fixing just this bug

_system_attributes = SimState._system_attributes | { # noqa: SLF001
"energy",
"stress",
}
Expand All @@ -605,9 +605,13 @@ class StaticState(SimState):
)

model_outputs = model(sub_state)

sub_state = StaticState(
**vars(sub_state),
static_state = StaticState(
positions=sub_state.positions,
masses=sub_state.masses,
cell=sub_state.cell,
pbc=sub_state.pbc,
atomic_numbers=sub_state.atomic_numbers,
system_idx=sub_state.system_idx,
energy=model_outputs["energy"],
forces=(
model_outputs["forces"]
Expand All @@ -621,11 +625,11 @@ class StaticState(SimState):
),
)

props = trajectory_reporter.report(sub_state, 0, model=model)
props = trajectory_reporter.report(static_state, 0, model=model)
all_props.extend(props)

if tqdm_pbar:
tqdm_pbar.update(sub_state.n_systems)
tqdm_pbar.update(static_state.n_systems)

trajectory_reporter.finish()

Expand Down