diff --git a/tests/models/test_pair_potential.py b/tests/models/test_pair_potential.py index 355d81715..ff103b48e 100644 --- a/tests/models/test_pair_potential.py +++ b/tests/models/test_pair_potential.py @@ -146,7 +146,7 @@ def test_half_list_matches_full(si_double_sim_state: ts.SimState, key: str) -> N model_half = PairPotentialModel(**common, reduce_to_half_list=True) out_full = model_full(si_double_sim_state) out_half = model_half(si_double_sim_state) - torch.testing.assert_close(out_half[key], out_full[key], rtol=1e-10, atol=1e-14) + torch.testing.assert_close(out_half[key], out_full[key], rtol=1e-6, atol=1e-7) @pytest.mark.parametrize("potential", ["bmhtf", "morse", "soft_sphere"]) @@ -199,7 +199,9 @@ def force_fn(dr: torch.Tensor, zi: torch.Tensor, zj: torch.Tensor) -> torch.Tens out_pp = model_pp(sim_state) out_pf = model_pf(sim_state) - assert (out_pp["forces"] != 0.0).all() + assert (out_pp["forces"] != torch.zeros_like(out_pp["forces"])).any(), ( + "All force components are exactly zero - no gradient propagated" + ) for key in ("forces", "stress", "stresses"): torch.testing.assert_close(out_pp[key], out_pf[key], rtol=1e-5, atol=1e-6)