Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ coverage.xml

# env
uv.lock

.vscode/
4 changes: 2 additions & 2 deletions examples/tutorials/state_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
f"Multi-state has {multi_state.n_atoms} total atoms across {multi_state.n_batches} batches"
)

# we can see how the shapes of batchwise, atomwise, and global properties change
# we can see how the shapes of batchwise, atomwise, and per-graph properties change
print(f"Positions shape: {multi_state.positions.shape}")
print(f"Cell shape: {multi_state.cell.shape}")
print(f"PBC: {multi_state.pbc}")
Expand Down Expand Up @@ -258,7 +258,7 @@

print("MDState properties:")
scope = infer_property_scope(md_state)
print("Global properties:", scope["global"])
print("Per-graph properties:", scope["per_graph"])
print("Per-atom properties:", scope["per_atom"])
print("Per-batch properties:", scope["per_batch"])

Expand Down
4 changes: 2 additions & 2 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
def test_infer_sim_state_property_scope(si_sim_state: SimState) -> None:
"""Test inference of property scope."""
scope = infer_property_scope(si_sim_state)
assert set(scope["global"]) == {"pbc"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here it is global because things like the PBC cannot be different for things in a batched state. per_batch is the misleading one here as that it the per_graph level of abstraction. If pylance doesn't like global I would agree that changing it is a good thing to do but I would think that a different name, for example shared, might be better. I would also consider renaming per_batch as I don't think that per_batch is clear that it's per substate within a batched simstate.

assert set(scope["per_graph"]) == {"pbc"}
assert set(scope["per_atom"]) == {"positions", "masses", "atomic_numbers", "batch"}
assert set(scope["per_batch"]) == {"cell"}

Expand All @@ -40,7 +40,7 @@ def test_infer_md_state_property_scope(si_sim_state: SimState) -> None:
energy=torch.zeros((1,)),
)
scope = infer_property_scope(state)
assert set(scope["global"]) == {"pbc"}
assert set(scope["per_graph"]) == {"pbc"}
assert set(scope["per_atom"]) == {
"positions",
"masses",
Expand Down
2 changes: 1 addition & 1 deletion torch_sim/models/lennard_jones.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]:

# we always return tensors
# per atom properties are returned as (atoms, ...) tensors
# global properties are returned as shape (..., n) tensors
# per-graph properties are returned as shape (..., n) tensors
results = {}
for key in ("stress", "energy"):
if key in properties:
Expand Down
2 changes: 1 addition & 1 deletion torch_sim/models/morse.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]:

# we always return tensors
# per atom properties are returned as (atoms, ...) tensors
# global properties are returned as shape (..., n) tensors
# per-graph properties are returned as shape (..., n) tensors
results = {}
for key in ("stress", "energy"):
if key in properties:
Expand Down
72 changes: 42 additions & 30 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import importlib
import warnings
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Literal, Self
from typing import TYPE_CHECKING, Literal, Self, TypedDict, cast

import torch

Expand Down Expand Up @@ -152,7 +152,7 @@
return torch.unique(self.batch).shape[0]

@property
def volume(self) -> torch.Tensor:
def volume(self) -> torch.Tensor | None:
"""Volume of the system."""
return torch.det(self.cell) if self.pbc else None

Expand Down Expand Up @@ -226,7 +226,7 @@
"""
return ts.io.state_to_phonopy(self)

def split(self) -> list[Self]:
def split(self) -> list["SimState"]:
"""Split the SimState into a list of single-batch SimStates.

Divides the current state into separate states, each containing a single batch,
Expand All @@ -237,7 +237,9 @@
"""
return _split_state(self)

def pop(self, batch_indices: int | list[int] | slice | torch.Tensor) -> list[Self]:
def pop(
self, batch_indices: int | list[int] | slice | torch.Tensor
) -> list["SimState"]:
"""Pop off states with the specified batch indices.

This method modifies the original state object by removing the specified
Expand Down Expand Up @@ -268,7 +270,7 @@

def to(
self, device: torch.device | None = None, dtype: torch.dtype | None = None
) -> Self:
) -> "SimState":
"""Convert the SimState to a new device and/or data type.

Args:
Expand All @@ -282,7 +284,9 @@
"""
return state_to_device(self, device, dtype)

def __getitem__(self, batch_indices: int | list[int] | slice | torch.Tensor) -> Self:
def __getitem__(
self, batch_indices: int | list[int] | slice | torch.Tensor
) -> "SimState":
"""Enable standard Python indexing syntax for slicing batches.

Args:
Expand Down Expand Up @@ -387,7 +391,7 @@

def state_to_device(
state: SimState, device: torch.device | None = None, dtype: torch.dtype | None = None
) -> Self:
) -> SimState:
"""Convert the SimState to a new device and dtype.

Creates a new SimState with all tensors moved to the specified device and
Expand Down Expand Up @@ -419,10 +423,18 @@
return type(state)(**attrs)


class Scope(TypedDict):
"""Dictionary mapping each scope category to a list of property names"""

per_graph: list[str]
per_atom: list[str]
per_batch: list[str]


def infer_property_scope(
state: SimState,
ambiguous_handling: Literal["error", "globalize", "globalize_warn"] = "error",
) -> dict[Literal["global", "per_atom", "per_batch"], list[str]]:
) -> Scope:
"""Infer whether a property is global, per-atom, or per-batch.

Analyzes the shapes of tensor attributes to determine their scope within
Expand All @@ -437,14 +449,13 @@
- "globalize_warn": Treat ambiguous properties as global with a warning

Returns:
dict[Literal["global", "per_atom", "per_batch"], list[str]]: Dictionary mapping
each scope category to a list of property names
Scope: Dictionary mapping each scope category to a list of property names

Raises:
ValueError: If n_atoms equals n_batches (making scope inference ambiguous) or
if ambiguous_handling="error" and an ambiguous property is encountered
"""
# TODO: this cannot effectively resolve global properties with
# TODO: this cannot effectively resolve per-graph properties with
# length of n_atoms or n_batches, they will be classified incorrectly,
# no clear fix

Expand All @@ -454,25 +465,25 @@
"which means shapes cannot be inferred unambiguously."
)

scope = {
"global": [],
scope: Scope = {
"per_graph": [],
"per_atom": [],
"per_batch": [],
}

# Iterate through all attributes
for attr_name, attr_value in vars(state).items():
# Handle scalar values (global properties)
# Handle scalar values (per-graph properties)
if not isinstance(attr_value, torch.Tensor):
scope["global"].append(attr_name)
scope["per_graph"].append(attr_name)
continue

# Handle tensor properties based on shape
shape = attr_value.shape

# Empty tensor case
if len(shape) == 0:
scope["global"].append(attr_name)
scope["per_graph"].append(attr_name)

Check warning on line 486 in torch_sim/state.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/state.py#L486

Added line #L486 was not covered by tests
# Vector/matrix with first dimension matching number of atoms
elif shape[0] == state.n_atoms:
scope["per_atom"].append(attr_name)
Expand All @@ -487,7 +498,7 @@
f"{state.n_batches} (per-batch), or a scalar (global)."
)
elif ambiguous_handling in ("globalize", "globalize_warn"):
scope["global"].append(attr_name)
scope["per_graph"].append(attr_name)

Check warning on line 501 in torch_sim/state.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/state.py#L501

Added line #L501 was not covered by tests

if ambiguous_handling == "globalize_warn":
warnings.warn(
Expand Down Expand Up @@ -519,11 +530,11 @@
"""
scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling)

attrs = {"global": {}, "per_atom": {}, "per_batch": {}}
attrs: dict[str, dict] = {"per_graph": {}, "per_atom": {}, "per_batch": {}}

# Process global properties
for attr_name in scope["global"]:
attrs["global"][attr_name] = getattr(state, attr_name)
# Process per-graph properties
for attr_name in scope["per_graph"]:
attrs["per_graph"][attr_name] = getattr(state, attr_name)

# Process per-atom properties
for attr_name in scope["per_atom"]:
Expand All @@ -537,7 +548,7 @@


def _filter_attrs_by_mask(
attrs: dict[str, dict],
attrs: Scope,
atom_mask: torch.Tensor,
batch_mask: torch.Tensor,
) -> dict:
Expand All @@ -546,7 +557,7 @@
Selects subsets of attributes based on boolean masks for atoms and batches.

Args:
attrs (dict[str, dict]): Dictionary with 'global', 'per_atom', and 'per_batch'
attrs (Scope): Dictionary with 'per_graph', 'per_atom', and 'per_batch'
attributes
atom_mask (torch.Tensor): Boolean mask for atoms to include with shape
(n_atoms,)
Expand All @@ -559,7 +570,7 @@
filtered_attrs = {}

# Copy global attributes directly
filtered_attrs.update(attrs["global"])
filtered_attrs.update(attrs["per_graph"])

# Filter per-atom attributes
for attr_name, attr_value in attrs["per_atom"].items():
Expand Down Expand Up @@ -638,7 +649,7 @@
# Add the split per-batch attributes
**{attr_name: split_per_batch[attr_name][i] for attr_name in split_per_batch},
# Add the global attributes
**attrs["global"],
**attrs["per_graph"],
}
states.append(type(state)(**batch_attrs))

Expand Down Expand Up @@ -753,7 +764,7 @@
"""Concatenate a list of SimStates into a single SimState.

Combines multiple states into a single state with multiple batches.
Global properties are taken from the first state, and per-atom and per-batch
per-graph properties are taken from the first state, and per-atom and per-batch
properties are concatenated.

Args:
Expand Down Expand Up @@ -783,14 +794,14 @@
target_device = device or first_state.device

# Get property scopes from the first state to identify
# global/per-atom/per-batch properties
# per-graph/per-atom/per-batch properties
first_scope = infer_property_scope(first_state)
global_props = set(first_scope["global"])
per_graph_props = set(first_scope["per_graph"])
per_atom_props = set(first_scope["per_atom"])
per_batch_props = set(first_scope["per_batch"])

# Initialize result with global properties from first state
concatenated = {prop: getattr(first_state, prop) for prop in global_props}
# Initialize result with per-graph properties from first state
concatenated = {prop: getattr(first_state, prop) for prop in per_graph_props}

# Pre-allocate lists for tensors to concatenate
per_atom_tensors = {prop: [] for prop in per_atom_props}
Expand Down Expand Up @@ -864,6 +875,7 @@
return state_to_device(system, device, dtype)

if isinstance(system, list) and all(isinstance(s, SimState) for s in system):
system = cast(list[SimState], system)
if not all(state.n_batches == 1 for state in system):
raise ValueError(
"When providing a list of states, to the initialize_state function, "
Expand Down
Loading