Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 4 additions & 38 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from dataclasses import asdict
from pathlib import Path
from typing import Any

import pytest
Expand All @@ -13,9 +11,7 @@
from torch_sim.io import atoms_to_state
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.state import SimState, concatenate_states
from torch_sim.trajectory import TrajectoryReporter
from torch_sim.unbatched.models.lennard_jones import UnbatchedLennardJonesModel
from torch_sim.unbatched.unbatched_integrators import nve


@pytest.fixture
Expand Down Expand Up @@ -203,8 +199,10 @@ def rattled_sio2_sim_state(sio2_sim_state: SimState) -> SimState:
try:
# Temporarily set a fixed seed
torch.manual_seed(3)
weibull = torch.distributions.weibull.Weibull(scale=0.5, concentration=1.0)
shifts = weibull.sample((sim_state.n_atoms, 3))
weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1)
rnd = torch.randn_like(sim_state.positions)
rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True)
shifts = weibull.sample(rnd.shape) * rnd
sim_state.positions = sim_state.positions + shifts
finally:
# Restore the original RNG state
Expand Down Expand Up @@ -293,35 +291,3 @@ def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel:
compute_stress=True,
cutoff=2.5 * 3.405,
)


@pytest.fixture
def torchsim_trajectory(
si_sim_state: SimState,
lj_model: Any,
tmp_path: Path,
device: torch.device,
dtype: torch.dtype,
):
"""Test NVE integration conserves energy."""
# Initialize integrator
kT = torch.tensor(300.0, device=device, dtype=dtype) # Temperature in K
dt = torch.tensor(0.001, device=device, dtype=dtype) # Small timestep for stability

state, update_fn = nve(
**asdict(si_sim_state),
model=lj_model,
dt=dt,
kT=kT,
)

reporter = TrajectoryReporter(tmp_path / "test.hdf5", state_frequency=1)

# Run several steps
for step in range(10):
state = update_fn(state, dt)
reporter.report(state, step)

yield reporter.trajectory

reporter.close()
Loading