Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tests/models/test_pair_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_half_list_matches_full(si_double_sim_state: ts.SimState, key: str) -> N
model_half = PairPotentialModel(**common, reduce_to_half_list=True)
out_full = model_full(si_double_sim_state)
out_half = model_half(si_double_sim_state)
torch.testing.assert_close(out_half[key], out_full[key], rtol=1e-10, atol=1e-14)
torch.testing.assert_close(out_half[key], out_full[key], rtol=1e-6, atol=1e-7)


@pytest.mark.parametrize("potential", ["bmhtf", "morse", "soft_sphere"])
Expand Down Expand Up @@ -199,7 +199,9 @@ def force_fn(dr: torch.Tensor, zi: torch.Tensor, zj: torch.Tensor) -> torch.Tens
out_pp = model_pp(sim_state)
out_pf = model_pf(sim_state)

assert (out_pp["forces"] != 0.0).all()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!= on float value is not great, could you replace it by >=tol with tol=1e-6 for instance

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a check for exact zero's which would be a logic failure as opposed to numerically close to zero which would be fine. It being brittle is okay the problem was that seemingly we got one force that evaluated as equal to zero in the latest torch

Copy link
Copy Markdown
Collaborator

@thomasloux thomasloux Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok then could you just add a short comment about that, because it's really unusual to do that, good for me otherwise

assert (out_pp["forces"] != torch.zeros_like(out_pp["forces"])).any(), (
"All force components are exactly zero - no gradient propagated"
)

for key in ("forces", "stress", "stresses"):
torch.testing.assert_close(out_pp[key], out_pf[key], rtol=1e-5, atol=1e-6)
Expand Down
Loading