-
Couldn't load subscription status.
- Fork 58
InitVar dataclass initialization (and subclass checks) #230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
2d77db5
we can properly identify variables that are tensor | None
curtischong 6767059
simstate properly handles system_idx_init now
curtischong a1fc8b5
properly set the system idx field init = false
curtischong 80cc62a
ran into problem with using the system_index hack
curtischong 5cf6b76
init subclass runs properly
curtischong fcb9e7b
revert setting velocities to | None
curtischong fa5c697
cleanup
curtischong 904b268
more cleanup
curtischong afd7aaa
fix types of firestates
curtischong e5ca217
document system_idx better
curtischong 456ad46
cleanup _construct_state
curtischong 5efbb7c
rm state_class
curtischong File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,14 +6,16 @@ | |
|
|
||
| import copy | ||
| import importlib | ||
| import inspect | ||
| import typing | ||
| import warnings | ||
| from dataclasses import dataclass, field | ||
| from dataclasses import InitVar, dataclass, field | ||
| from typing import TYPE_CHECKING, Literal, Self | ||
|
|
||
| import torch | ||
|
|
||
| import torch_sim as ts | ||
| from torch_sim.typing import StateLike | ||
| from torch_sim.typing import SimStateVar, StateLike | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -47,9 +49,11 @@ class SimState: | |
| used by ASE. | ||
| pbc (bool): Boolean indicating whether to use periodic boundary conditions | ||
| atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) | ||
| system_idx (torch.Tensor, optional): Maps each atom index to its system index. | ||
| Has shape (n_atoms,), defaults to None, must be unique consecutive | ||
| integers starting from 0 | ||
| system_idx (torch.Tensor): Maps each atom index to its system index. | ||
| Has shape (n_atoms,), must be unique consecutive integers starting from 0. | ||
| This attribute is initialized by the init_system_idx parameter in the | ||
| constructor. If init_system_idx is not provided, it is initialized to | ||
| zeros. | ||
|
|
||
| Properties: | ||
| wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary | ||
|
|
@@ -81,9 +85,10 @@ class SimState: | |
| cell: torch.Tensor | ||
| pbc: bool # TODO: do all calculators support mixed pbc? | ||
| atomic_numbers: torch.Tensor | ||
| system_idx: torch.Tensor | None = field(default=None, kw_only=True) | ||
| system_idx: torch.Tensor = field(init=False) | ||
| init_system_idx: InitVar[torch.Tensor | None] | ||
|
|
||
| def __post_init__(self) -> None: | ||
| def __post_init__(self, init_system_idx: torch.Tensor | None) -> None: | ||
| """Validate and process the state after initialization.""" | ||
| # data validation and fill system_idx | ||
| # should make pbc a tensor here | ||
|
|
@@ -107,24 +112,25 @@ def __post_init__(self) -> None: | |
| f"masses {shapes[1]}, atomic_numbers {shapes[2]}" | ||
| ) | ||
|
|
||
| if self.cell.ndim != 3 and self.system_idx is None: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I moved these checks down since it depends on self.system_idx (also to bundle it with the self.cell.shape[0] check) |
||
| self.cell = self.cell.unsqueeze(0) | ||
|
|
||
| if self.cell.shape[-2:] != (3, 3): | ||
| raise ValueError("Cell must have shape (n_systems, 3, 3)") | ||
|
|
||
| if self.system_idx is None: | ||
| if init_system_idx is None: | ||
| self.system_idx = torch.zeros( | ||
| self.n_atoms, device=self.device, dtype=torch.int64 | ||
| ) | ||
| else: | ||
| self.system_idx = init_system_idx | ||
| # assert that system indices are unique consecutive integers | ||
| # TODO(curtis): I feel like this logic is not reliable. | ||
| # I'll come up with something better later. | ||
| _, counts = torch.unique_consecutive(self.system_idx, return_counts=True) | ||
| if not torch.all(counts == torch.bincount(self.system_idx)): | ||
| raise ValueError("System indices must be unique consecutive integers") | ||
|
|
||
| if self.cell.ndim != 3 and self.system_idx is None: | ||
| self.cell = self.cell.unsqueeze(0) | ||
|
|
||
| if self.cell.shape[-2:] != (3, 3): | ||
| raise ValueError("Cell must have shape (n_systems, 3, 3)") | ||
|
|
||
| if self.cell.shape[0] != self.n_systems: | ||
| raise ValueError( | ||
| f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}" | ||
|
|
@@ -272,7 +278,7 @@ def clone(self) -> Self: | |
| else: | ||
| attrs[attr_name] = copy.deepcopy(attr_value) | ||
|
|
||
| return self.__class__(**attrs) | ||
| return _construct_state(self, attrs) | ||
|
|
||
| def to_atoms(self) -> list["Atoms"]: | ||
| """Convert the SimState to a list of ASE Atoms objects. | ||
|
|
@@ -372,6 +378,37 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> | |
|
|
||
| return _slice_state(self, system_indices) | ||
|
|
||
| def __init_subclass__(cls, **kwargs) -> None: | ||
| """Enforce that all subclasses have valid InitVar fields.""" | ||
| type_hints = typing.get_type_hints(cls) | ||
|
|
||
| # Validate InitVar fields | ||
| for attr_name, attr_typehint in cls.__annotations__.items(): | ||
| # 1) validate InitVar fields | ||
| if type(attr_typehint) is InitVar: | ||
| # make sure its prefix is "init_" | ||
| if not attr_name.startswith("init_"): | ||
| raise TypeError( | ||
| f"Attribute '{attr_name}' in class '{cls.__name__}' is not " | ||
| "allowed to be an InitVar. It must be prefixed with 'init_'" | ||
| ) | ||
| # make sure there is a corresponding non-InitVar field | ||
| non_init_attr_name = attr_name.removeprefix("init_") | ||
| if non_init_attr_name not in type_hints: | ||
| raise TypeError( | ||
| f"Attribute '{attr_name}' in class '{cls.__name__}' is not " | ||
| "allowed to be an InitVar. It must have a corresponding " | ||
| f"non-InitVar field {non_init_attr_name}" | ||
| ) | ||
|
|
||
| # 2) forbid non init vars to have a "init_" prefix | ||
| elif attr_name.startswith("init_"): | ||
| raise TypeError( | ||
| f"Attribute '{attr_name}' in class '{cls.__name__}' is not " | ||
| "allowed to have an 'init_' prefix as it's a non-InitVar field." | ||
| ) | ||
| super().__init_subclass__(**kwargs) | ||
|
|
||
|
|
||
| class DeformGradMixin: | ||
| """Mixin for states that support deformation gradients.""" | ||
|
|
@@ -490,7 +527,7 @@ def state_to_device( | |
| attrs["masses"] = attrs["masses"].to(dtype=dtype) | ||
| attrs["cell"] = attrs["cell"].to(dtype=dtype) | ||
| attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) | ||
| return type(state)(**attrs) | ||
| return _construct_state(state, attrs) | ||
|
|
||
|
|
||
| def infer_property_scope( | ||
|
|
@@ -719,11 +756,44 @@ def _split_state( | |
| # Add the global attributes | ||
| **attrs["global"], | ||
| } | ||
| states.append(type(state)(**system_attrs)) | ||
| states.append(_construct_state(state, system_attrs)) | ||
|
|
||
| return states | ||
|
|
||
|
|
||
| def _construct_state( | ||
| old_state: SimStateVar, | ||
| new_state_attrs: dict[str, typing.Any], | ||
| ) -> SimStateVar: | ||
| """Construct a new state of the same class as the old state with the specified | ||
| attrs. | ||
|
|
||
| Args: | ||
| old_state (SimStateVar): We will construct a new state of the same class as this | ||
| one | ||
| new_state_attrs (dict[str, typing.Any]): The attributes to use to construct | ||
| the new state | ||
|
|
||
| Returns: | ||
| SimStateVar: A new state of the same class as the old state with the specified | ||
| attributes | ||
| """ | ||
| # 1) process the attrs so they are the init params | ||
| processed_params = {} | ||
| for param in inspect.signature(old_state.__class__).parameters: | ||
| if param.startswith("init_"): | ||
| # this is an InitVar field | ||
| # we need to rename the corresponding field in system_attrs to have | ||
| # an "init_" prefix | ||
| non_init_attr_name = param.removeprefix("init_") | ||
| processed_params[param] = new_state_attrs[non_init_attr_name] | ||
| else: | ||
| processed_params[param] = new_state_attrs[param] | ||
|
|
||
| # 2) construct the new state | ||
| return type(old_state)(**processed_params) | ||
|
|
||
|
|
||
| def _pop_states( | ||
| state: SimState, | ||
| pop_indices: list[int] | torch.Tensor, | ||
|
|
@@ -770,10 +840,10 @@ def _pop_states( | |
| pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_system_mask) | ||
|
|
||
| # Create the keep state | ||
| keep_state = type(state)(**keep_attrs) | ||
| keep_state = _construct_state(state, keep_attrs) | ||
|
|
||
| # Create and split the pop state | ||
| pop_state = type(state)(**pop_attrs) | ||
| pop_state = _construct_state(state, pop_attrs) | ||
| pop_states = _split_state(pop_state, ambiguous_handling) | ||
|
|
||
| return keep_state, pop_states | ||
|
|
@@ -823,7 +893,7 @@ def _slice_state( | |
| filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, system_mask) | ||
|
|
||
| # Create the sliced state | ||
| return type(state)(**filtered_attrs) | ||
| return _construct_state(state, filtered_attrs) | ||
|
|
||
|
|
||
| def concatenate_states( | ||
|
|
@@ -854,8 +924,7 @@ def concatenate_states( | |
| first_state = states[0] | ||
|
|
||
| # Ensure all states are of the same class | ||
| state_class = type(first_state) | ||
| if not all(isinstance(state, state_class) for state in states): | ||
| if not all(isinstance(state, type(first_state)) for state in states): | ||
| raise TypeError("All states must be of the same type") | ||
|
|
||
| # Use the target device or default to the first state's device | ||
|
|
@@ -912,7 +981,7 @@ def concatenate_states( | |
| concatenated["system_idx"] = torch.cat(new_system_indices) | ||
|
|
||
| # Create a new instance of the same class | ||
| return state_class(**concatenated) | ||
| return _construct_state(first_state, concatenated) | ||
|
|
||
|
|
||
| def initialize_state( | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.