diff --git a/tests/test_runners.py b/tests/test_runners.py index b417e581..fa7cf522 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -4,6 +4,7 @@ import numpy as np import pytest import torch +from ase.build.bulk import bulk import torch_sim as ts from tests.conftest import DEVICE, DTYPE @@ -778,6 +779,27 @@ def test_static_no_filenames( assert isinstance(props[0]["potential_energy"], torch.Tensor) +def test_static_after_optimize(lj_model: LennardJonesModel): + """Tests that we can calculate static properties after an optimize run.""" + atoms = bulk("Si", "diamond", a=5.43, cubic=True) + initial_state = ts.io.atoms_to_state( + atoms, device=lj_model.device, dtype=lj_model.dtype + ) + + final_state = ts.optimize( + system=initial_state, + model=lj_model, + optimizer=ts.Optimizer.fire, + max_steps=100, + ) + + results = ts.static( + system=final_state, + model=lj_model, + ) + assert results[0]["potential_energy"] == final_state.energy + + def test_readme_example(lj_model: LennardJonesModel, tmp_path: Path) -> None: # this tests the example from the readme, update as needed diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 4ac6f3be..5cad1fe3 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -579,8 +579,8 @@ class StaticState(SimState): forces: torch.Tensor stress: torch.Tensor - _atom_attributes = state._atom_attributes | {"forces"} # noqa: SLF001 - _system_attributes = state._system_attributes | { # noqa: SLF001 + _atom_attributes = SimState._atom_attributes | {"forces"} # noqa: SLF001 + _system_attributes = SimState._system_attributes | { # noqa: SLF001 "energy", "stress", } @@ -605,9 +605,13 @@ class StaticState(SimState): ) model_outputs = model(sub_state) - - sub_state = StaticState( - **vars(sub_state), + static_state = StaticState( + positions=sub_state.positions, + masses=sub_state.masses, + cell=sub_state.cell, + pbc=sub_state.pbc, + atomic_numbers=sub_state.atomic_numbers, + system_idx=sub_state.system_idx, energy=model_outputs["energy"], forces=( model_outputs["forces"] @@ -621,11 +625,11 @@ class StaticState(SimState): ), ) - props = trajectory_reporter.report(sub_state, 0, model=model) + props = trajectory_reporter.report(static_state, 0, model=model) all_props.extend(props) if tqdm_pbar: - tqdm_pbar.update(sub_state.n_systems) + tqdm_pbar.update(static_state.n_systems) trajectory_reporter.finish()