Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/user/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ Learn more in [Understanding Reporting](../tutorials/reporting_tutorial.ipynb)

## High-level vs Low-Level

Under the hood, TorchSim takes a modular functional approach to atomistic simulation. Each integrator or optimizer function, such as `nvt_langevin,` takes in a model and parameters and returns `init` and `update` functions that act on a unique `State.` The state inherits from `SimState` and tracks the fixed and fluctuating parameters of the simulation, such as the `momenta` for NVT or the timestep for FIRE. The runner functions take this basic structure and wrap it in a convenient interface with autobatching and reporting.
Under the hood, TorchSim takes a modular functional approach to atomistic simulation. Each integrator or optimizer has associated `init` and `update` functions that initialize and update a unique `State.` The state inherits from `SimState` and tracks the fixed and fluctuating parameters of the simulation, such as the `momenta` for NVT or the timestep for FIRE. The runner functions take this basic structure and wrap it in a convenient interface with autobatching and reporting.

Learn more in [Fundamentals of TorchSim](../tutorials/low_level_tutorial.ipynb) and [Implementing New Methods](../tutorials/hybrid_swap_tutorial.ipynb)
23 changes: 11 additions & 12 deletions examples/tutorials/low_level_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,21 @@
"""
## Optimizers and Integrators

All optimizers and integrators share a similar interface. They accept a model and
return two functions: `init_fn` and step_fn`. The `init_fn` function returns the
initialized optimizer-specific state, while the step_fn` function updates the
simulation state.
All optimizers and integrators have an associated `init_fn` and a `step_fn`.
The `init_fn` function returns the initialized optimizer-specific state,
while the `step_fn` function updates the simulation state. The formal pairings
are stored in the `ts.INTEGRATOR_REGISTRY` and `ts.OPTIM_REGISTRY` dictionaries.

### Unit Cell Fire

We will walk through the fire optimizer with unit cell filter as an example.
"""

# %%
state = ts.fire_init(state=state, model=model, cell_filter=ts.CellFilter.unit)

# add a little noise so we have something to relax
state.positions = state.positions + torch.randn_like(state.positions) * 0.05


# %% [markdown]
Expand All @@ -139,10 +143,6 @@
"""

# %%
state = ts.fire_init(state=state, model=model, cell_filter=ts.CellFilter.unit)

# add a little noise so we have something to relax
state.positions = state.positions + torch.randn_like(state.positions) * 0.05

for step in range(20):
state = ts.fire_step(state=state, model=model)
Expand All @@ -151,10 +151,9 @@

# %% [markdown]
"""
You can set the optimizer-specific arguments in the `optimize` function
optimizer=ts.Optimizer.fire, cell_filter=ts.CellFilter.unit. Fixed
parameters can usually be passed to the `init_fn` and parameters that vary over
the course of the simulation can be passed to the step_fn`.
Fixed parameters can usually be passed to the `init_fn` and parameters that vary over
the course of the simulation can be passed to the `step_fn`. In the `optimize`
function, you set these with the `init_kwargs` and `optimizer_kwargs` arguments.
"""

# %%
Expand Down
1 change: 1 addition & 0 deletions torch_sim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from torch_sim.quantities import (
calc_kinetic_energy,
calc_kT,
calc_temperature,
get_pressure,
system_wise_max_force,
)
Expand Down
2 changes: 1 addition & 1 deletion torch_sim/quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def calc_temperature(
units (object): Units to return the temperature in

Returns:
torch.Tensor: Temperature value in specified units
torch.Tensor: Temperature value in specified units (default, K)
"""
kT = calc_kT(
masses=masses,
Expand Down
14 changes: 8 additions & 6 deletions torch_sim/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
converting between different atomistic representations and handling simulation state.
"""

import copy
import warnings
from collections.abc import Callable
from dataclasses import dataclass
Expand All @@ -29,7 +30,6 @@
def _configure_reporter(
trajectory_reporter: TrajectoryReporter | dict,
*,
state_kwargs: dict | None = None,
properties: list[str] | None = None,
prop_frequency: int = 10,
state_frequency: int = 100,
Expand All @@ -53,12 +53,12 @@ def _configure_reporter(
}

# ordering is important to ensure we can override defaults
trajectory_reporter = copy.deepcopy(trajectory_reporter)
return TrajectoryReporter(
prop_calculators=trajectory_reporter.pop(
"prop_calculators", {prop_frequency: prop_calculators}
),
state_frequency=trajectory_reporter.pop("state_frequency", state_frequency),
state_kwargs=state_kwargs or {},
**trajectory_reporter,
)

Expand Down Expand Up @@ -561,13 +561,15 @@ def static(
properties.append("forces")
if model.compute_stress:
properties.append("stress")
trajectory_reporter = _configure_reporter(
trajectory_reporter or dict(filenames=None),
state_kwargs={
if isinstance(trajectory_reporter, dict):
trajectory_reporter = copy.deepcopy(trajectory_reporter)
trajectory_reporter["state_kwargs"] = {
"variable_atomic_numbers": True,
"variable_masses": True,
"save_forces": model.compute_forces,
},
}
trajectory_reporter = _configure_reporter(
trajectory_reporter or dict(filenames=None),
properties=properties,
)

Expand Down
6 changes: 0 additions & 6 deletions torch_sim/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class TrajectoryReporter:
trajectories (list): TorchSimTrajectory instances
filenames (list): Trajectory file paths
array_registry (dict): Map of array names to (shape, dtype) tuples
shape_warned (bool): Whether a shape warning has been issued

Examples:
>>> reporter = TrajectoryReporter(
Expand All @@ -96,7 +95,6 @@ class TrajectoryReporter:
prop_calculators: dict[int, dict[str, Callable]]
state_kwargs: dict[str, Any]
metadata: dict[str, str] | None
shape_warned: bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is technically a breaking change. so I'll label it as breaking

trajectories: list["TorchSimTrajectory"]
filenames: list[str | pathlib.Path] | None

Expand Down Expand Up @@ -140,7 +138,6 @@ def __init__(

self.prop_calculators = prop_calculators or {}
self.state_kwargs = state_kwargs or {}
self.shape_warned = False
self.metadata = metadata

self.trajectories = []
Expand Down Expand Up @@ -258,9 +255,6 @@ def report(
all_props: list[dict[str, torch.Tensor]] = []
# Process each system separately
for idx, substate in enumerate(split_states):
# Slice the state once to get only the data for this system
self.shape_warned = True

# Write state to trajectory if it's time
if (
self.state_frequency
Expand Down
Loading