diff --git a/torch_sim/io.py b/torch_sim/io.py index 8f61bd49..bdfa4067 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -241,7 +241,7 @@ def atoms_to_state( cell=cell, pbc=all(atoms_list[0].pbc), atomic_numbers=atomic_numbers, - system_idx=system_idx, + init_system_idx=system_idx, ) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 98a83d64..f8877764 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -587,7 +587,7 @@ def fire_init( masses=state.masses.clone(), cell=state.cell.clone(), atomic_numbers=state.atomic_numbers.clone(), - system_idx=state.system_idx.clone(), + init_system_idx=state.system_idx.clone(), pbc=state.pbc, velocities=None, forces=forces, diff --git a/torch_sim/state.py b/torch_sim/state.py index ce21ef9b..c0bdfbf8 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -6,14 +6,16 @@ import copy import importlib +import inspect +import typing import warnings -from dataclasses import dataclass, field +from dataclasses import InitVar, dataclass, field from typing import TYPE_CHECKING, Literal, Self import torch import torch_sim as ts -from torch_sim.typing import StateLike +from torch_sim.typing import SimStateVar, StateLike if TYPE_CHECKING: @@ -47,9 +49,11 @@ class SimState: used by ASE. pbc (bool): Boolean indicating whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) - system_idx (torch.Tensor, optional): Maps each atom index to its system index. - Has shape (n_atoms,), defaults to None, must be unique consecutive - integers starting from 0 + system_idx (torch.Tensor): Maps each atom index to its system index. + Has shape (n_atoms,), must be unique consecutive integers starting from 0. + This attribute is initialized by the init_system_idx parameter in the + constructor. If init_system_idx is not provided, it is initialized to + zeros. Properties: wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary @@ -81,9 +85,10 @@ class SimState: cell: torch.Tensor pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor - system_idx: torch.Tensor | None = field(default=None, kw_only=True) + system_idx: torch.Tensor = field(init=False) + init_system_idx: InitVar[torch.Tensor | None] - def __post_init__(self) -> None: + def __post_init__(self, init_system_idx: torch.Tensor | None) -> None: """Validate and process the state after initialization.""" # data validation and fill system_idx # should make pbc a tensor here @@ -107,17 +112,12 @@ def __post_init__(self) -> None: f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) - if self.cell.ndim != 3 and self.system_idx is None: - self.cell = self.cell.unsqueeze(0) - - if self.cell.shape[-2:] != (3, 3): - raise ValueError("Cell must have shape (n_systems, 3, 3)") - - if self.system_idx is None: + if init_system_idx is None: self.system_idx = torch.zeros( self.n_atoms, device=self.device, dtype=torch.int64 ) else: + self.system_idx = init_system_idx # assert that system indices are unique consecutive integers # TODO(curtis): I feel like this logic is not reliable. # I'll come up with something better later. @@ -125,6 +125,12 @@ def __post_init__(self) -> None: if not torch.all(counts == torch.bincount(self.system_idx)): raise ValueError("System indices must be unique consecutive integers") + if self.cell.ndim != 3 and self.system_idx is None: + self.cell = self.cell.unsqueeze(0) + + if self.cell.shape[-2:] != (3, 3): + raise ValueError("Cell must have shape (n_systems, 3, 3)") + if self.cell.shape[0] != self.n_systems: raise ValueError( f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}" @@ -272,7 +278,7 @@ def clone(self) -> Self: else: attrs[attr_name] = copy.deepcopy(attr_value) - return self.__class__(**attrs) + return _construct_state(self, attrs) def to_atoms(self) -> list["Atoms"]: """Convert the SimState to a list of ASE Atoms objects. @@ -372,6 +378,37 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> return _slice_state(self, system_indices) + def __init_subclass__(cls, **kwargs) -> None: + """Enforce that all subclasses have valid InitVar fields.""" + type_hints = typing.get_type_hints(cls) + + # Validate InitVar fields + for attr_name, attr_typehint in cls.__annotations__.items(): + # 1) validate InitVar fields + if type(attr_typehint) is InitVar: + # make sure its prefix is "init_" + if not attr_name.startswith("init_"): + raise TypeError( + f"Attribute '{attr_name}' in class '{cls.__name__}' is not " + "allowed to be an InitVar. It must be prefixed with 'init_'" + ) + # make sure there is a corresponding non-InitVar field + non_init_attr_name = attr_name.removeprefix("init_") + if non_init_attr_name not in type_hints: + raise TypeError( + f"Attribute '{attr_name}' in class '{cls.__name__}' is not " + "allowed to be an InitVar. It must have a corresponding " + f"non-InitVar field {non_init_attr_name}" + ) + + # 2) forbid non init vars to have a "init_" prefix + elif attr_name.startswith("init_"): + raise TypeError( + f"Attribute '{attr_name}' in class '{cls.__name__}' is not " + "allowed to have an 'init_' prefix as it's a non-InitVar field." + ) + super().__init_subclass__(**kwargs) + class DeformGradMixin: """Mixin for states that support deformation gradients.""" @@ -490,7 +527,7 @@ def state_to_device( attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) - return type(state)(**attrs) + return _construct_state(state, attrs) def infer_property_scope( @@ -719,11 +756,44 @@ def _split_state( # Add the global attributes **attrs["global"], } - states.append(type(state)(**system_attrs)) + states.append(_construct_state(state, system_attrs)) return states +def _construct_state( + old_state: SimStateVar, + new_state_attrs: dict[str, typing.Any], +) -> SimStateVar: + """Construct a new state of the same class as the old state with the specified + attrs. + + Args: + old_state (SimStateVar): We will construct a new state of the same class as this + one + new_state_attrs (dict[str, typing.Any]): The attributes to use to construct + the new state + + Returns: + SimStateVar: A new state of the same class as the old state with the specified + attributes + """ + # 1) process the attrs so they are the init params + processed_params = {} + for param in inspect.signature(old_state.__class__).parameters: + if param.startswith("init_"): + # this is an InitVar field + # we need to rename the corresponding field in system_attrs to have + # an "init_" prefix + non_init_attr_name = param.removeprefix("init_") + processed_params[param] = new_state_attrs[non_init_attr_name] + else: + processed_params[param] = new_state_attrs[param] + + # 2) construct the new state + return type(old_state)(**processed_params) + + def _pop_states( state: SimState, pop_indices: list[int] | torch.Tensor, @@ -770,10 +840,10 @@ def _pop_states( pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_system_mask) # Create the keep state - keep_state = type(state)(**keep_attrs) + keep_state = _construct_state(state, keep_attrs) # Create and split the pop state - pop_state = type(state)(**pop_attrs) + pop_state = _construct_state(state, pop_attrs) pop_states = _split_state(pop_state, ambiguous_handling) return keep_state, pop_states @@ -823,7 +893,7 @@ def _slice_state( filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, system_mask) # Create the sliced state - return type(state)(**filtered_attrs) + return _construct_state(state, filtered_attrs) def concatenate_states( @@ -854,8 +924,7 @@ def concatenate_states( first_state = states[0] # Ensure all states are of the same class - state_class = type(first_state) - if not all(isinstance(state, state_class) for state in states): + if not all(isinstance(state, type(first_state)) for state in states): raise TypeError("All states must be of the same type") # Use the target device or default to the first state's device @@ -912,7 +981,7 @@ def concatenate_states( concatenated["system_idx"] = torch.cat(new_system_indices) # Create a new instance of the same class - return state_class(**concatenated) + return _construct_state(first_state, concatenated) def initialize_state(