diff --git a/docs/user/overview.md b/docs/user/overview.md index a7c0655d..56fdf136 100644 --- a/docs/user/overview.md +++ b/docs/user/overview.md @@ -27,6 +27,6 @@ Learn more in [Understanding Reporting](../tutorials/reporting_tutorial.ipynb) ## High-level vs Low-Level -Under the hood, TorchSim takes a modular functional approach to atomistic simulation. Each integrator or optimizer function, such as `nvt_langevin,` takes in a model and parameters and returns `init` and `update` functions that act on a unique `State.` The state inherits from `SimState` and tracks the fixed and fluctuating parameters of the simulation, such as the `momenta` for NVT or the timestep for FIRE. The runner functions take this basic structure and wrap it in a convenient interface with autobatching and reporting. +Under the hood, TorchSim takes a modular functional approach to atomistic simulation. Each integrator or optimizer has associated `init` and `update` functions that initialize and update a unique `State.` The state inherits from `SimState` and tracks the fixed and fluctuating parameters of the simulation, such as the `momenta` for NVT or the timestep for FIRE. The runner functions take this basic structure and wrap it in a convenient interface with autobatching and reporting. Learn more in [Fundamentals of TorchSim](../tutorials/low_level_tutorial.ipynb) and [Implementing New Methods](../tutorials/hybrid_swap_tutorial.ipynb) diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index 7677e468..eb8d63ed 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -117,10 +117,10 @@ """ ## Optimizers and Integrators -All optimizers and integrators share a similar interface. They accept a model and -return two functions: `init_fn` and step_fn`. The `init_fn` function returns the -initialized optimizer-specific state, while the step_fn` function updates the -simulation state. +All optimizers and integrators have an associated `init_fn` and a `step_fn`. +The `init_fn` function returns the initialized optimizer-specific state, +while the `step_fn` function updates the simulation state. The formal pairings +are stored in the `ts.INTEGRATOR_REGISTRY` and `ts.OPTIM_REGISTRY` dictionaries. ### Unit Cell Fire @@ -128,6 +128,10 @@ """ # %% +state = ts.fire_init(state=state, model=model, cell_filter=ts.CellFilter.unit) + +# add a little noise so we have something to relax +state.positions = state.positions + torch.randn_like(state.positions) * 0.05 # %% [markdown] @@ -139,10 +143,6 @@ """ # %% -state = ts.fire_init(state=state, model=model, cell_filter=ts.CellFilter.unit) - -# add a little noise so we have something to relax -state.positions = state.positions + torch.randn_like(state.positions) * 0.05 for step in range(20): state = ts.fire_step(state=state, model=model) @@ -151,10 +151,9 @@ # %% [markdown] """ -You can set the optimizer-specific arguments in the `optimize` function -optimizer=ts.Optimizer.fire, cell_filter=ts.CellFilter.unit. Fixed -parameters can usually be passed to the `init_fn` and parameters that vary over -the course of the simulation can be passed to the step_fn`. +Fixed parameters can usually be passed to the `init_fn` and parameters that vary over +the course of the simulation can be passed to the `step_fn`. In the `optimize` +function, you set these with the `init_kwargs` and `optimizer_kwargs` arguments. """ # %% diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index a5a8af76..3bc1bc92 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -66,6 +66,7 @@ from torch_sim.quantities import ( calc_kinetic_energy, calc_kT, + calc_temperature, get_pressure, system_wise_max_force, ) diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index 8ffc61ed..35f7d6f0 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -88,7 +88,7 @@ def calc_temperature( units (object): Units to return the temperature in Returns: - torch.Tensor: Temperature value in specified units + torch.Tensor: Temperature value in specified units (default, K) """ kT = calc_kT( masses=masses, diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 6e956750..4ac6f3be 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -5,6 +5,7 @@ converting between different atomistic representations and handling simulation state. """ +import copy import warnings from collections.abc import Callable from dataclasses import dataclass @@ -29,7 +30,6 @@ def _configure_reporter( trajectory_reporter: TrajectoryReporter | dict, *, - state_kwargs: dict | None = None, properties: list[str] | None = None, prop_frequency: int = 10, state_frequency: int = 100, @@ -53,12 +53,12 @@ def _configure_reporter( } # ordering is important to ensure we can override defaults + trajectory_reporter = copy.deepcopy(trajectory_reporter) return TrajectoryReporter( prop_calculators=trajectory_reporter.pop( "prop_calculators", {prop_frequency: prop_calculators} ), state_frequency=trajectory_reporter.pop("state_frequency", state_frequency), - state_kwargs=state_kwargs or {}, **trajectory_reporter, ) @@ -561,13 +561,15 @@ def static( properties.append("forces") if model.compute_stress: properties.append("stress") - trajectory_reporter = _configure_reporter( - trajectory_reporter or dict(filenames=None), - state_kwargs={ + if isinstance(trajectory_reporter, dict): + trajectory_reporter = copy.deepcopy(trajectory_reporter) + trajectory_reporter["state_kwargs"] = { "variable_atomic_numbers": True, "variable_masses": True, "save_forces": model.compute_forces, - }, + } + trajectory_reporter = _configure_reporter( + trajectory_reporter or dict(filenames=None), properties=properties, ) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index c9167d0f..3220b24b 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -76,7 +76,6 @@ class TrajectoryReporter: trajectories (list): TorchSimTrajectory instances filenames (list): Trajectory file paths array_registry (dict): Map of array names to (shape, dtype) tuples - shape_warned (bool): Whether a shape warning has been issued Examples: >>> reporter = TrajectoryReporter( @@ -96,7 +95,6 @@ class TrajectoryReporter: prop_calculators: dict[int, dict[str, Callable]] state_kwargs: dict[str, Any] metadata: dict[str, str] | None - shape_warned: bool trajectories: list["TorchSimTrajectory"] filenames: list[str | pathlib.Path] | None @@ -140,7 +138,6 @@ def __init__( self.prop_calculators = prop_calculators or {} self.state_kwargs = state_kwargs or {} - self.shape_warned = False self.metadata = metadata self.trajectories = [] @@ -258,9 +255,6 @@ def report( all_props: list[dict[str, torch.Tensor]] = [] # Process each system separately for idx, substate in enumerate(split_states): - # Slice the state once to get only the data for this system - self.shape_warned = True - # Write state to trajectory if it's time if ( self.state_frequency