Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
4157a20
fix:orb squeeze incorrect energy shape
thomasloux Sep 18, 2025
4042130
feature: add batching for nvt nosé hoover
thomasloux Sep 29, 2025
ee5ff16
feature: add tests for nvt nosé hoover
thomasloux Sep 29, 2025
7f24199
fix typo from development
thomasloux Sep 30, 2025
347a547
fix comments
thomasloux Oct 6, 2025
d7d3e78
add test for output shape if single system as input
thomasloux Oct 6, 2025
9585c7c
make ruff happy
thomasloux Oct 6, 2025
eba7657
npt nose hoover is batchable
thomasloux Oct 7, 2025
d91f94a
add single system output shape test in test_model_output_validation
thomasloux Oct 7, 2025
646ddf5
Merge branch 'TorchSim:main' into main
thomasloux Oct 8, 2025
322a16a
Merge branch 'main' into feature/batched-nvt-nose-hoover
thomasloux Oct 8, 2025
3af5c04
pass lint test
thomasloux Oct 8, 2025
a13b87d
Merge branch 'feature/batched-nvt-nose-hoover' of https://github.com/…
thomasloux Oct 8, 2025
69ee796
Merge branch 'TorchSim:main' into main
thomasloux Oct 10, 2025
746f39d
Merge remote-tracking branch 'origin/main' into feature/batched-nvt-n…
thomasloux Oct 10, 2025
e35160c
check final temperature (algorithm preservation)
thomasloux Oct 10, 2025
520fa5d
solve nose hoover for chain_length=1
thomasloux Oct 16, 2025
2b87b67
2nd part correction chain_length
thomasloux Oct 16, 2025
6e90828
Merge branch 'main' into feature/batched-nvt-nose-hoover
CompRhys Oct 18, 2025
69364e7
Merge branch 'main' into feature/batched-nvt-nose-hoover
thomasloux Oct 21, 2025
e32c269
Merge branch 'main' into feature/batched-nvt-nose-hoover
abhijeetgangan Oct 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions tests/models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_model_calculator_consistency(
return test_model_calculator_consistency


def make_validate_model_outputs_test(
def make_validate_model_outputs_test( # noqa: PLR0915
model_fixture_name: str,
device: torch.device = DEVICE,
dtype: torch.dtype = torch.float64,
Expand All @@ -135,7 +135,7 @@ def make_validate_model_outputs_test(
model_fixture_name: Name of the model fixture to validate
"""

def test_model_output_validation(request: pytest.FixtureRequest) -> None:
def test_model_output_validation(request: pytest.FixtureRequest) -> None: # noqa: PLR0915
"""Test that a model implementation follows the ModelInterface contract."""
# Get the model fixture dynamically
model: ModelInterface = request.getfixturevalue(model_fixture_name)
Expand Down Expand Up @@ -224,6 +224,15 @@ def test_model_output_validation(request: pytest.FixtureRequest) -> None:
# atol=10e-3,
# )

# Test single system output
assert fe_model_output["energy"].shape == (1,)
# forces should be shape (n_atoms, 3) for n_atoms in the system
if force_computed:
assert fe_model_output["forces"].shape == (12, 3)
# stress should be shape (1, 3, 3) for 1 system
if stress_computed:
assert fe_model_output["stress"].shape == (1, 3, 3)

# Rename the function to include the test name
test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation"
return test_model_output_validation
Loading