Skip to content
2 changes: 1 addition & 1 deletion torch_sim/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def atoms_to_state(
cell=cell,
pbc=all(atoms_list[0].pbc),
atomic_numbers=atomic_numbers,
system_idx=system_idx,
init_system_idx=system_idx,
)


Expand Down
2 changes: 1 addition & 1 deletion torch_sim/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def fire_init(
masses=state.masses.clone(),
cell=state.cell.clone(),
atomic_numbers=state.atomic_numbers.clone(),
system_idx=state.system_idx.clone(),
init_system_idx=state.system_idx.clone(),
pbc=state.pbc,
velocities=None,
forces=forces,
Expand Down
115 changes: 92 additions & 23 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@

import copy
import importlib
import inspect
import typing
import warnings
from dataclasses import dataclass, field
from dataclasses import InitVar, dataclass, field
from typing import TYPE_CHECKING, Literal, Self

import torch

import torch_sim as ts
from torch_sim.typing import StateLike
from torch_sim.typing import SimStateVar, StateLike


if TYPE_CHECKING:
Expand Down Expand Up @@ -47,9 +49,11 @@ class SimState:
used by ASE.
pbc (bool): Boolean indicating whether to use periodic boundary conditions
atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,)
system_idx (torch.Tensor, optional): Maps each atom index to its system index.
Has shape (n_atoms,), defaults to None, must be unique consecutive
integers starting from 0
system_idx (torch.Tensor): Maps each atom index to its system index.
Has shape (n_atoms,), must be unique consecutive integers starting from 0.
This attribute is initialized by the init_system_idx parameter in the
constructor. If init_system_idx is not provided, it is initialized to
zeros.

Properties:
wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary
Expand Down Expand Up @@ -81,9 +85,10 @@ class SimState:
cell: torch.Tensor
pbc: bool # TODO: do all calculators support mixed pbc?
atomic_numbers: torch.Tensor
system_idx: torch.Tensor | None = field(default=None, kw_only=True)
system_idx: torch.Tensor = field(init=False)
init_system_idx: InitVar[torch.Tensor | None]

def __post_init__(self) -> None:
def __post_init__(self, init_system_idx: torch.Tensor | None) -> None:
"""Validate and process the state after initialization."""
# data validation and fill system_idx
# should make pbc a tensor here
Expand All @@ -107,24 +112,25 @@ def __post_init__(self) -> None:
f"masses {shapes[1]}, atomic_numbers {shapes[2]}"
)

if self.cell.ndim != 3 and self.system_idx is None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved these checks down since it depends on self.system_idx (also to bundle it with the self.cell.shape[0] check)

self.cell = self.cell.unsqueeze(0)

if self.cell.shape[-2:] != (3, 3):
raise ValueError("Cell must have shape (n_systems, 3, 3)")

if self.system_idx is None:
if init_system_idx is None:
self.system_idx = torch.zeros(
self.n_atoms, device=self.device, dtype=torch.int64
)
else:
self.system_idx = init_system_idx
# assert that system indices are unique consecutive integers
# TODO(curtis): I feel like this logic is not reliable.
# I'll come up with something better later.
_, counts = torch.unique_consecutive(self.system_idx, return_counts=True)
if not torch.all(counts == torch.bincount(self.system_idx)):
raise ValueError("System indices must be unique consecutive integers")

if self.cell.ndim != 3 and self.system_idx is None:
self.cell = self.cell.unsqueeze(0)

if self.cell.shape[-2:] != (3, 3):
raise ValueError("Cell must have shape (n_systems, 3, 3)")

if self.cell.shape[0] != self.n_systems:
raise ValueError(
f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}"
Expand Down Expand Up @@ -272,7 +278,7 @@ def clone(self) -> Self:
else:
attrs[attr_name] = copy.deepcopy(attr_value)

return self.__class__(**attrs)
return _construct_state(self, attrs)

def to_atoms(self) -> list["Atoms"]:
"""Convert the SimState to a list of ASE Atoms objects.
Expand Down Expand Up @@ -372,6 +378,37 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) ->

return _slice_state(self, system_indices)

def __init_subclass__(cls, **kwargs) -> None:
"""Enforce that all subclasses have valid InitVar fields."""
type_hints = typing.get_type_hints(cls)

# Validate InitVar fields
for attr_name, attr_typehint in cls.__annotations__.items():
# 1) validate InitVar fields
if type(attr_typehint) is InitVar:
# make sure its prefix is "init_"
if not attr_name.startswith("init_"):
raise TypeError(
f"Attribute '{attr_name}' in class '{cls.__name__}' is not "
"allowed to be an InitVar. It must be prefixed with 'init_'"
)
# make sure there is a corresponding non-InitVar field
non_init_attr_name = attr_name.removeprefix("init_")
if non_init_attr_name not in type_hints:
raise TypeError(
f"Attribute '{attr_name}' in class '{cls.__name__}' is not "
"allowed to be an InitVar. It must have a corresponding "
f"non-InitVar field {non_init_attr_name}"
)

# 2) forbid non init vars to have a "init_" prefix
elif attr_name.startswith("init_"):
raise TypeError(
f"Attribute '{attr_name}' in class '{cls.__name__}' is not "
"allowed to have an 'init_' prefix as it's a non-InitVar field."
)
super().__init_subclass__(**kwargs)


class DeformGradMixin:
"""Mixin for states that support deformation gradients."""
Expand Down Expand Up @@ -490,7 +527,7 @@ def state_to_device(
attrs["masses"] = attrs["masses"].to(dtype=dtype)
attrs["cell"] = attrs["cell"].to(dtype=dtype)
attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int)
return type(state)(**attrs)
return _construct_state(state, attrs)


def infer_property_scope(
Expand Down Expand Up @@ -719,11 +756,44 @@ def _split_state(
# Add the global attributes
**attrs["global"],
}
states.append(type(state)(**system_attrs))
states.append(_construct_state(state, system_attrs))

return states


def _construct_state(
old_state: SimStateVar,
new_state_attrs: dict[str, typing.Any],
) -> SimStateVar:
"""Construct a new state of the same class as the old state with the specified
attrs.

Args:
old_state (SimStateVar): We will construct a new state of the same class as this
one
new_state_attrs (dict[str, typing.Any]): The attributes to use to construct
the new state

Returns:
SimStateVar: A new state of the same class as the old state with the specified
attributes
"""
# 1) process the attrs so they are the init params
processed_params = {}
for param in inspect.signature(old_state.__class__).parameters:
if param.startswith("init_"):
# this is an InitVar field
# we need to rename the corresponding field in system_attrs to have
# an "init_" prefix
non_init_attr_name = param.removeprefix("init_")
processed_params[param] = new_state_attrs[non_init_attr_name]
else:
processed_params[param] = new_state_attrs[param]

# 2) construct the new state
return type(old_state)(**processed_params)


def _pop_states(
state: SimState,
pop_indices: list[int] | torch.Tensor,
Expand Down Expand Up @@ -770,10 +840,10 @@ def _pop_states(
pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_system_mask)

# Create the keep state
keep_state = type(state)(**keep_attrs)
keep_state = _construct_state(state, keep_attrs)

# Create and split the pop state
pop_state = type(state)(**pop_attrs)
pop_state = _construct_state(state, pop_attrs)
pop_states = _split_state(pop_state, ambiguous_handling)

return keep_state, pop_states
Expand Down Expand Up @@ -823,7 +893,7 @@ def _slice_state(
filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, system_mask)

# Create the sliced state
return type(state)(**filtered_attrs)
return _construct_state(state, filtered_attrs)


def concatenate_states(
Expand Down Expand Up @@ -854,8 +924,7 @@ def concatenate_states(
first_state = states[0]

# Ensure all states are of the same class
state_class = type(first_state)
if not all(isinstance(state, state_class) for state in states):
if not all(isinstance(state, type(first_state)) for state in states):
raise TypeError("All states must be of the same type")

# Use the target device or default to the first state's device
Expand Down Expand Up @@ -912,7 +981,7 @@ def concatenate_states(
concatenated["system_idx"] = torch.cat(new_system_indices)

# Create a new instance of the same class
return state_class(**concatenated)
return _construct_state(first_state, concatenated)


def initialize_state(
Expand Down
Loading