From 059e57d6757c19dad7458e679a84891ad3fa796f Mon Sep 17 00:00:00 2001 From: Timo Reents Date: Thu, 17 Jul 2025 16:54:24 +0200 Subject: [PATCH 1/9] Fix concatenation of states in `InFlightAutoBatcher` The `velocities` and `cell_velocities` are initialized to `None` in the `(FrechetCell)FIREState`. However, when using the `InFlightAutoBatcher` during an optimization, the current and new states are concatenated in `torch_sim.state.concatenate_states`. When trying to merge states that were already processed for a few iterations (i.e., velocities are not None anymore) and newly initialized ones, an error is raised because the code tries to merge a `Tensor` with a `None`. Here, we initialize the `(cell_)velocities` as tensors full of `nan` instead, so that one can merge already processed and newly initialized states. During the first initialization, the `fire` methods look for `nan` rows and replace them with zeros. --- tests/test_autobatching.py | 2 +- torch_sim/optimizers.py | 44 +++++++++++++++++++++++++++----------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 7be28997..73ade42e 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -490,7 +490,7 @@ def convergence_fn(state: ts.SimState) -> bool: break # run 10 steps, arbitrary number - for _ in range(10): + for _ in range(5): state = fire_update(state) convergence_tensor = convergence_fn(state) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 98a83d64..de8e839d 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -589,7 +589,9 @@ def fire_init( atomic_numbers=state.atomic_numbers.clone(), system_idx=state.system_idx.clone(), pbc=state.pbc, - velocities=None, + velocities=torch.full( + state.positions.shape, torch.nan, device=device, dtype=dtype + ), forces=forces, energy=energy, # Optimization attributes @@ -862,13 +864,17 @@ def fire_init( atomic_numbers=state.atomic_numbers.clone(), system_idx=state.system_idx.clone(), pbc=state.pbc, - velocities=None, + velocities=torch.full( + state.positions.shape, torch.nan, device=device, dtype=dtype + ), forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=torch.zeros(n_systems, 3, 3, device=device, dtype=dtype), - cell_velocities=None, + cell_velocities=torch.full( + cell_forces.shape, torch.nan, device=device, dtype=dtype + ), cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes @@ -1161,13 +1167,17 @@ def fire_init( atomic_numbers=state.atomic_numbers, system_idx=state.system_idx, pbc=state.pbc, - velocities=None, + velocities=torch.full( + state.positions.shape, torch.nan, device=device, dtype=dtype + ), forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=cell_positions, - cell_velocities=None, + cell_velocities=torch.full( + cell_forces.shape, torch.nan, device=device, dtype=dtype + ), cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes @@ -1244,15 +1254,19 @@ def _vv_fire_step( # noqa: C901, PLR0915 dtype = state.positions.dtype deform_grad_new: torch.Tensor | None = None - if state.velocities is None: - state.velocities = torch.zeros_like(state.positions) + nan_velocities = state.velocities.isnan().any(dim=1) + if nan_velocities.any(): + state.velocities[nan_velocities] = torch.zeros_like( + state.positions[nan_velocities] + ) if is_cell_optimization: if not isinstance(state, AnyFireCellState): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) - state.cell_velocities = torch.zeros( - (n_systems, 3, 3), device=device, dtype=dtype + nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) + state.cell_velocities[nan_cell_velocities] = torch.zeros_like( + state.cell_positions[nan_cell_velocities] ) alpha_start_system = torch.full( @@ -1461,16 +1475,20 @@ def _ase_fire_step( # noqa: C901, PLR0915 cur_deform_grad = None # Initialize cur_deform_grad to prevent UnboundLocalError - if state.velocities is None: - state.velocities = torch.zeros_like(state.positions) + nan_velocities = state.velocities.isnan().any(dim=1) + if nan_velocities.any(): + state.velocities[nan_velocities] = torch.zeros_like( + state.positions[nan_velocities] + ) forces = state.forces if is_cell_optimization: if not isinstance(state, AnyFireCellState): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) - state.cell_velocities = torch.zeros( - (n_systems, 3, 3), device=device, dtype=dtype + nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) + state.cell_velocities[nan_cell_velocities] = torch.zeros_like( + state.cell_positions[nan_cell_velocities] ) cur_deform_grad = state.deform_grad() else: From 3ecef41b642896bdffdbb7fccd80c50e7a2bd5ea Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Thu, 31 Jul 2025 20:35:13 -0400 Subject: [PATCH 2/9] wip code will prob delete --- examples/scripts/reproduce_err.py | 53 +++++++++++++++++++++++++++++++ tests/test_autobatching.py | 52 +++++++++++++++++++++++++++++- 2 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 examples/scripts/reproduce_err.py diff --git a/examples/scripts/reproduce_err.py b/examples/scripts/reproduce_err.py new file mode 100644 index 00000000..3a2b2b60 --- /dev/null +++ b/examples/scripts/reproduce_err.py @@ -0,0 +1,53 @@ +import torch +from ase.build import bulk +from mace.calculators.foundations_models import mace_mp + +import torch_sim as ts +from torch_sim.models.mace import MaceModel + + +si_atoms = bulk("Si", "fcc", a=3.26, cubic=True) +si_atoms.rattle(0.05) + +cu_atoms = bulk("Cu", "fcc", a=5.26, cubic=True) +cu_atoms.rattle(0.5) + +many_cu_atoms = [si_atoms] * 5 + [cu_atoms] * 20 + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +state = ts.initialize_state(many_cu_atoms, device=device, dtype=torch.float64) + +mace = mace_mp(model="small", return_raw_model=True) +mace_model = MaceModel(model=mace, device=device) + +fire_init, fire_update = ts.optimizers.fire(mace_model) +fire_state = fire_init(state) + +batcher = ts.InFlightAutoBatcher( + model=mace_model, + memory_scales_with="n_atoms", + max_memory_scaler=40, + max_iterations=10000, # Optional: maximum convergence attempts per state +) + +batcher.load_states(fire_state) + +convergence_fn = ts.generate_force_convergence_fn(5e-3, include_cell_forces=False) + +all_converged_states, convergence_tensor = [], None +while (result := batcher.next_batch(fire_state, convergence_tensor))[0] is not None: + fire_state, converged_states = result + all_converged_states.extend(converged_states) + + for _ in range(3): + fire_state = fire_update(fire_state) + + convergence_tensor = convergence_fn(fire_state, None) + print(f"Convergence tensor: {convergence_tensor}") + print(f"Convergence tensor: {batcher.current_idx}") + +all_converged_states.extend(result[1]) + +final_states = batcher.restore_original_order(all_converged_states) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 73ade42e..9e57c44c 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -490,7 +490,7 @@ def convergence_fn(state: ts.SimState) -> bool: break # run 10 steps, arbitrary number - for _ in range(5): + for _ in range(10): state = fire_update(state) convergence_tensor = convergence_fn(state) @@ -588,3 +588,53 @@ def test_in_flight_max_iterations( # Verify swap_attempts tracking for i in range(len(states)): assert batcher.swap_attempts[i] == max_attempts + + +def test_concatenation_works_with_undefined_initial_vectors( + lj_model: LennardJonesModel, +): + """Test that concatenation works with undefined initial vectors.""" + si_atoms = bulk("Si", "fcc", a=3.26, cubic=True) + si_atoms.rattle(0.05) + + cu_atoms = bulk("Cu", "fcc", a=5.26, cubic=True) + cu_atoms.rattle(0.5) + + many_cu_atoms = [si_atoms] * 5 + [cu_atoms] * 20 + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + state = ts.initialize_state(many_cu_atoms, device=device, dtype=torch.float64) + + + fire_init, fire_update = ts.optimizers.fire(mace_model) + fire_state = fire_init(state) + + batcher = ts.InFlightAutoBatcher( + model=LennardJonesModel, + memory_scales_with="n_atoms", + max_memory_scaler=40, + max_iterations=10000, # Optional: maximum convergence attempts per state + ) + + batcher.load_states(fire_state) + + convergence_fn = ts.generate_force_convergence_fn(5e-3, include_cell_forces=False) + + all_converged_states, convergence_tensor = [], None + while (result := batcher.next_batch(fire_state, convergence_tensor))[0] is not None: + fire_state, converged_states = result + all_converged_states.extend(converged_states) + + for _ in range(3): + fire_state = fire_update(fire_state) + + convergence_tensor = convergence_fn(fire_state, None) + print(f"Convergence tensor: {convergence_tensor}") + print(f"Convergence tensor: {batcher.current_idx}") + + else: + all_converged_states.extend(result[1]) + + final_states = batcher.restore_original_order(all_converged_states)" From 235bd06c3ee67979e46c1ab34b48215bc855394d Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Thu, 31 Jul 2025 19:54:08 -0700 Subject: [PATCH 3/9] we can properly identify variables that are tensor | None --- torch_sim/state.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/torch_sim/state.py b/torch_sim/state.py index ce21ef9b..1e9905fd 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -6,6 +6,7 @@ import copy import importlib +import typing import warnings from dataclasses import dataclass, field from typing import TYPE_CHECKING, Literal, Self @@ -372,6 +373,32 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> return _slice_state(self, system_indices) + def __init_subclass__(cls, **kwargs) -> None: + """Enforce that all derived states cannot have tensor attributes that can also be + None. This is because torch.concatenate cannot concat between a tensor and a None. + See https://github.com/Radical-AI/torch-sim/pull/219 for more details. + """ + # We need to use get_type_hints to correctly inspect the types + type_hints = typing.get_type_hints(cls) + for attr_name, attr_type in type_hints.items(): + origin = typing.get_origin(attr_type) + + is_union = origin is typing.Union + if not is_union and origin is not None: + # For Python 3.10+ `|` syntax, origin is types.UnionType + # We check by name to be robust against module reloading/patching issues + is_union = origin.__module__ == "types" and origin.__name__ == "UnionType" + if is_union: + args = typing.get_args(attr_type) + if torch.Tensor in args and type(None) in args: + raise TypeError( + f"Attribute '{attr_name}' in class '{cls.__name__}' is not " + "allowed to be of type 'torch.Tensor | None'. " + "Optional tensor attributes are disallowed in SimState " + "subclasses to prevent concatenation errors." + ) + super().__init_subclass__(**kwargs) + class DeformGradMixin: """Mixin for states that support deformation gradients.""" From d8f15e079d9f6365af80d7ce8284161b352c5350 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Thu, 31 Jul 2025 20:07:24 -0700 Subject: [PATCH 4/9] simstate properly handles system_idx_init now --- torch_sim/state.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 1e9905fd..e88bef02 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -8,7 +8,7 @@ import importlib import typing import warnings -from dataclasses import dataclass, field +from dataclasses import InitVar, dataclass from typing import TYPE_CHECKING, Literal, Self import torch @@ -82,9 +82,10 @@ class SimState: cell: torch.Tensor pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor - system_idx: torch.Tensor | None = field(default=None, kw_only=True) + system_idx: torch.Tensor + system_idx_init: InitVar[torch.Tensor | None] - def __post_init__(self) -> None: + def __post_init__(self, system_idx_init: torch.Tensor | None) -> None: """Validate and process the state after initialization.""" # data validation and fill system_idx # should make pbc a tensor here @@ -108,11 +109,9 @@ def __post_init__(self) -> None: f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) - if self.cell.ndim != 3 and self.system_idx is None: - self.cell = self.cell.unsqueeze(0) - - if self.cell.shape[-2:] != (3, 3): - raise ValueError("Cell must have shape (n_systems, 3, 3)") + if system_idx_init is not None and self.system_idx is None: + # we check if system_idx is none to prevent overriding system_idx^ + self.system_idx = system_idx_init if self.system_idx is None: self.system_idx = torch.zeros( @@ -126,6 +125,12 @@ def __post_init__(self) -> None: if not torch.all(counts == torch.bincount(self.system_idx)): raise ValueError("System indices must be unique consecutive integers") + if self.cell.ndim != 3 and self.system_idx is None: + self.cell = self.cell.unsqueeze(0) + + if self.cell.shape[-2:] != (3, 3): + raise ValueError("Cell must have shape (n_systems, 3, 3)") + if self.cell.shape[0] != self.n_systems: raise ValueError( f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}" From ce77a81adba69318afedfdc82a2d7d5d064d6d68 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Thu, 31 Jul 2025 20:10:50 -0700 Subject: [PATCH 5/9] properly set the system idx field init = false --- torch_sim/io.py | 2 +- torch_sim/optimizers.py | 4 ++-- torch_sim/state.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_sim/io.py b/torch_sim/io.py index 8f61bd49..3ead1849 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -241,7 +241,7 @@ def atoms_to_state( cell=cell, pbc=all(atoms_list[0].pbc), atomic_numbers=atomic_numbers, - system_idx=system_idx, + system_idx_init=system_idx, ) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index de8e839d..40bf78eb 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -474,7 +474,7 @@ class FireState(SimState): # Required attributes not in SimState forces: torch.Tensor energy: torch.Tensor - velocities: torch.Tensor | None + velocities: torch.Tensor # FIRE algorithm parameters dt: torch.Tensor @@ -971,7 +971,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): # Cell attributes cell_positions: torch.Tensor - cell_velocities: torch.Tensor | None + cell_velocities: torch.Tensor cell_forces: torch.Tensor cell_masses: torch.Tensor diff --git a/torch_sim/state.py b/torch_sim/state.py index e88bef02..a281722f 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -8,7 +8,7 @@ import importlib import typing import warnings -from dataclasses import InitVar, dataclass +from dataclasses import InitVar, dataclass, field from typing import TYPE_CHECKING, Literal, Self import torch @@ -82,7 +82,7 @@ class SimState: cell: torch.Tensor pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor - system_idx: torch.Tensor + system_idx: torch.Tensor = field(init=False) system_idx_init: InitVar[torch.Tensor | None] def __post_init__(self, system_idx_init: torch.Tensor | None) -> None: From c30ffa32328d9f40daf147de5546aa965aa2331b Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Thu, 31 Jul 2025 20:27:00 -0700 Subject: [PATCH 6/9] ran into problem with using the system_index hack --- examples/scripts/reproduce_err.py | 3 ++- torch_sim/io.py | 2 +- torch_sim/optimizers.py | 2 +- torch_sim/state.py | 11 ++++------- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/scripts/reproduce_err.py b/examples/scripts/reproduce_err.py index 3a2b2b60..2845f355 100644 --- a/examples/scripts/reproduce_err.py +++ b/examples/scripts/reproduce_err.py @@ -1,4 +1,4 @@ -import torch +import torch # noqa: D100, INP001 from ase.build import bulk from mace.calculators.foundations_models import mace_mp @@ -18,6 +18,7 @@ print(f"Using device: {device}") state = ts.initialize_state(many_cu_atoms, device=device, dtype=torch.float64) +print("state", state) mace = mace_mp(model="small", return_raw_model=True) mace_model = MaceModel(model=mace, device=device) diff --git a/torch_sim/io.py b/torch_sim/io.py index 3ead1849..350d2993 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -241,7 +241,7 @@ def atoms_to_state( cell=cell, pbc=all(atoms_list[0].pbc), atomic_numbers=atomic_numbers, - system_idx_init=system_idx, + system_index=system_idx, ) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 40bf78eb..36bd2d54 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -587,7 +587,7 @@ def fire_init( masses=state.masses.clone(), cell=state.cell.clone(), atomic_numbers=state.atomic_numbers.clone(), - system_idx=state.system_idx.clone(), + system_index=state.system_idx.clone(), pbc=state.pbc, velocities=torch.full( state.positions.shape, torch.nan, device=device, dtype=dtype diff --git a/torch_sim/state.py b/torch_sim/state.py index a281722f..5722eecd 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -83,9 +83,9 @@ class SimState: pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor system_idx: torch.Tensor = field(init=False) - system_idx_init: InitVar[torch.Tensor | None] + system_index: InitVar[torch.Tensor | None] - def __post_init__(self, system_idx_init: torch.Tensor | None) -> None: + def __post_init__(self, system_index: torch.Tensor | None) -> None: """Validate and process the state after initialization.""" # data validation and fill system_idx # should make pbc a tensor here @@ -109,15 +109,12 @@ def __post_init__(self, system_idx_init: torch.Tensor | None) -> None: f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) - if system_idx_init is not None and self.system_idx is None: - # we check if system_idx is none to prevent overriding system_idx^ - self.system_idx = system_idx_init - - if self.system_idx is None: + if system_index is None: self.system_idx = torch.zeros( self.n_atoms, device=self.device, dtype=torch.int64 ) else: + self.system_idx = system_index # assert that system indices are unique consecutive integers # TODO(curtis): I feel like this logic is not reliable. # I'll come up with something better later. From f7a6e4d0e9fda44a9e76245e3d939765320a8f99 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 12:14:14 -0700 Subject: [PATCH 7/9] init subclass runs properly --- torch_sim/io.py | 2 +- torch_sim/optimizers.py | 2 +- torch_sim/state.py | 86 +++++++++++++++++++++++++++++++++-------- 3 files changed, 72 insertions(+), 18 deletions(-) diff --git a/torch_sim/io.py b/torch_sim/io.py index 350d2993..bdfa4067 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -241,7 +241,7 @@ def atoms_to_state( cell=cell, pbc=all(atoms_list[0].pbc), atomic_numbers=atomic_numbers, - system_index=system_idx, + init_system_idx=system_idx, ) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 36bd2d54..91462f3f 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -587,7 +587,7 @@ def fire_init( masses=state.masses.clone(), cell=state.cell.clone(), atomic_numbers=state.atomic_numbers.clone(), - system_index=state.system_idx.clone(), + init_system_idx=state.system_idx.clone(), pbc=state.pbc, velocities=torch.full( state.positions.shape, torch.nan, device=device, dtype=dtype diff --git a/torch_sim/state.py b/torch_sim/state.py index 5722eecd..67b40e12 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -6,10 +6,11 @@ import copy import importlib +import inspect import typing import warnings from dataclasses import InitVar, dataclass, field -from typing import TYPE_CHECKING, Literal, Self +from typing import TYPE_CHECKING, Literal, Self, TypeVar import torch @@ -83,9 +84,9 @@ class SimState: pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor system_idx: torch.Tensor = field(init=False) - system_index: InitVar[torch.Tensor | None] + init_system_idx: InitVar[torch.Tensor | None] - def __post_init__(self, system_index: torch.Tensor | None) -> None: + def __post_init__(self, init_system_idx: torch.Tensor | None) -> None: """Validate and process the state after initialization.""" # data validation and fill system_idx # should make pbc a tensor here @@ -109,12 +110,12 @@ def __post_init__(self, system_index: torch.Tensor | None) -> None: f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) - if system_index is None: + if init_system_idx is None: self.system_idx = torch.zeros( self.n_atoms, device=self.device, dtype=torch.int64 ) else: - self.system_idx = system_index + self.system_idx = init_system_idx # assert that system indices are unique consecutive integers # TODO(curtis): I feel like this logic is not reliable. # I'll come up with something better later. @@ -275,7 +276,7 @@ def clone(self) -> Self: else: attrs[attr_name] = copy.deepcopy(attr_value) - return self.__class__(**attrs) + return construct_state(self, attrs) def to_atoms(self) -> list["Atoms"]: """Convert the SimState to a list of ASE Atoms objects. @@ -382,8 +383,8 @@ def __init_subclass__(cls, **kwargs) -> None: """ # We need to use get_type_hints to correctly inspect the types type_hints = typing.get_type_hints(cls) - for attr_name, attr_type in type_hints.items(): - origin = typing.get_origin(attr_type) + for attr_name, attr_typehint in type_hints.items(): + origin = typing.get_origin(attr_typehint) is_union = origin is typing.Union if not is_union and origin is not None: @@ -391,14 +392,43 @@ def __init_subclass__(cls, **kwargs) -> None: # We check by name to be robust against module reloading/patching issues is_union = origin.__module__ == "types" and origin.__name__ == "UnionType" if is_union: - args = typing.get_args(attr_type) + args = typing.get_args(attr_typehint) if torch.Tensor in args and type(None) in args: raise TypeError( f"Attribute '{attr_name}' in class '{cls.__name__}' is not " "allowed to be of type 'torch.Tensor | None'. " "Optional tensor attributes are disallowed in SimState " - "subclasses to prevent concatenation errors." + "subclasses to prevent concatenation errors.\n" + "If this attribute will take on a default value in the " + "post_init method, please use an InitVar for that attribute " + "but with a prepended 'init_' to the name. (e.g. init_system_idx)" ) + + # Validate InitVar fields + for attr_name, attr_typehint in cls.__annotations__.items(): + # 1) validate InitVar fields + if type(attr_typehint) is InitVar: + # make sure its prefix is "init_" + if not attr_name.startswith("init_"): + raise TypeError( + f"Attribute '{attr_name}' in class '{cls.__name__}' is not " + "allowed to be an InitVar. It must be prefixed with 'init_'" + ) + # make sure there is a corresponding non-InitVar field + non_init_attr_name = attr_name.removeprefix("init_") + if non_init_attr_name not in type_hints: + raise TypeError( + f"Attribute '{attr_name}' in class '{cls.__name__}' is not " + "allowed to be an InitVar. It must have a corresponding " + f"non-InitVar field {non_init_attr_name}" + ) + + # 2) forbid non init vars to have a "init_" prefix + elif attr_name.startswith("init_"): + raise TypeError( + f"Attribute '{attr_name}' in class '{cls.__name__}' is not " + "allowed to have an 'init_' prefix as it's a non-InitVar field." + ) super().__init_subclass__(**kwargs) @@ -519,7 +549,7 @@ def state_to_device( attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) - return type(state)(**attrs) + return construct_state(state, attrs) def infer_property_scope( @@ -748,11 +778,35 @@ def _split_state( # Add the global attributes **attrs["global"], } - states.append(type(state)(**system_attrs)) + states.append(construct_state(state, system_attrs)) return states +SimStateT = TypeVar("SimStateT", bound=SimState) + + +def construct_state( + old_state: SimStateT, + new_state_attrs: dict[str, typing.Any], +) -> SimStateT: + """Construct a new state from an old state and new state parameters.""" + # 1) process the attrs so they are the init params + processed_params = {} + for param in inspect.signature(old_state.__class__).parameters: + if param.startswith("init_"): + # this is an InitVar field + # we need to rename the corresponding field in system_attrs to have + # an "init_" prefix + non_init_attr_name = param.removeprefix("init_") + processed_params[param] = new_state_attrs[non_init_attr_name] + else: + processed_params[param] = new_state_attrs[param] + + # 2) construct the new state + return type(old_state)(**processed_params) + + def _pop_states( state: SimState, pop_indices: list[int] | torch.Tensor, @@ -799,10 +853,10 @@ def _pop_states( pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_system_mask) # Create the keep state - keep_state = type(state)(**keep_attrs) + keep_state = construct_state(state, keep_attrs) # Create and split the pop state - pop_state = type(state)(**pop_attrs) + pop_state = construct_state(state, pop_attrs) pop_states = _split_state(pop_state, ambiguous_handling) return keep_state, pop_states @@ -852,7 +906,7 @@ def _slice_state( filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, system_mask) # Create the sliced state - return type(state)(**filtered_attrs) + return construct_state(state, filtered_attrs) def concatenate_states( @@ -941,7 +995,7 @@ def concatenate_states( concatenated["system_idx"] = torch.cat(new_system_indices) # Create a new instance of the same class - return state_class(**concatenated) + return construct_state(first_state, concatenated) def initialize_state( From 6c46fa9994fe4ddee3a1a71a10f853b6fd928f35 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 13:50:03 -0700 Subject: [PATCH 8/9] revert system idx init --- torch_sim/state.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 67b40e12..a1783575 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -9,7 +9,7 @@ import inspect import typing import warnings -from dataclasses import InitVar, dataclass, field +from dataclasses import InitVar, dataclass from typing import TYPE_CHECKING, Literal, Self, TypeVar import torch @@ -83,10 +83,9 @@ class SimState: cell: torch.Tensor pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor - system_idx: torch.Tensor = field(init=False) - init_system_idx: InitVar[torch.Tensor | None] + system_idx: torch.Tensor | None - def __post_init__(self, init_system_idx: torch.Tensor | None) -> None: + def __post_init__(self) -> None: """Validate and process the state after initialization.""" # data validation and fill system_idx # should make pbc a tensor here @@ -110,12 +109,17 @@ def __post_init__(self, init_system_idx: torch.Tensor | None) -> None: f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) - if init_system_idx is None: + if self.cell.ndim != 3 and self.system_idx is None: + self.cell = self.cell.unsqueeze(0) + + if self.cell.shape[-2:] != (3, 3): + raise ValueError("Cell must have shape (n_systems, 3, 3)") + + if self.system_idx is None: self.system_idx = torch.zeros( self.n_atoms, device=self.device, dtype=torch.int64 ) else: - self.system_idx = init_system_idx # assert that system indices are unique consecutive integers # TODO(curtis): I feel like this logic is not reliable. # I'll come up with something better later. @@ -123,12 +127,6 @@ def __post_init__(self, init_system_idx: torch.Tensor | None) -> None: if not torch.all(counts == torch.bincount(self.system_idx)): raise ValueError("System indices must be unique consecutive integers") - if self.cell.ndim != 3 and self.system_idx is None: - self.cell = self.cell.unsqueeze(0) - - if self.cell.shape[-2:] != (3, 3): - raise ValueError("Cell must have shape (n_systems, 3, 3)") - if self.cell.shape[0] != self.n_systems: raise ValueError( f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}" From e4878b95e48919ffff582bcbfa5a94bbad0782b7 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 13:51:52 -0700 Subject: [PATCH 9/9] cleanup init subclass --- torch_sim/io.py | 2 +- torch_sim/optimizers.py | 2 +- torch_sim/state.py | 34 +++++----------------------------- 3 files changed, 7 insertions(+), 31 deletions(-) diff --git a/torch_sim/io.py b/torch_sim/io.py index bdfa4067..8f61bd49 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -241,7 +241,7 @@ def atoms_to_state( cell=cell, pbc=all(atoms_list[0].pbc), atomic_numbers=atomic_numbers, - init_system_idx=system_idx, + system_idx=system_idx, ) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 91462f3f..40bf78eb 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -587,7 +587,7 @@ def fire_init( masses=state.masses.clone(), cell=state.cell.clone(), atomic_numbers=state.atomic_numbers.clone(), - init_system_idx=state.system_idx.clone(), + system_idx=state.system_idx.clone(), pbc=state.pbc, velocities=torch.full( state.positions.shape, torch.nan, device=device, dtype=dtype diff --git a/torch_sim/state.py b/torch_sim/state.py index a1783575..57d8b298 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -9,7 +9,7 @@ import inspect import typing import warnings -from dataclasses import InitVar, dataclass +from dataclasses import dataclass from typing import TYPE_CHECKING, Literal, Self, TypeVar import torch @@ -397,36 +397,12 @@ def __init_subclass__(cls, **kwargs) -> None: "allowed to be of type 'torch.Tensor | None'. " "Optional tensor attributes are disallowed in SimState " "subclasses to prevent concatenation errors.\n" - "If this attribute will take on a default value in the " - "post_init method, please use an InitVar for that attribute " - "but with a prepended 'init_' to the name. (e.g. init_system_idx)" + # TODO(curtis): fix this + # "If this attribute will take on a default value in the " + # "post_init method, please use an InitVar for that attribute " + # "but with a prepended 'init_' to the name. (e.g. system_idx)" ) - # Validate InitVar fields - for attr_name, attr_typehint in cls.__annotations__.items(): - # 1) validate InitVar fields - if type(attr_typehint) is InitVar: - # make sure its prefix is "init_" - if not attr_name.startswith("init_"): - raise TypeError( - f"Attribute '{attr_name}' in class '{cls.__name__}' is not " - "allowed to be an InitVar. It must be prefixed with 'init_'" - ) - # make sure there is a corresponding non-InitVar field - non_init_attr_name = attr_name.removeprefix("init_") - if non_init_attr_name not in type_hints: - raise TypeError( - f"Attribute '{attr_name}' in class '{cls.__name__}' is not " - "allowed to be an InitVar. It must have a corresponding " - f"non-InitVar field {non_init_attr_name}" - ) - - # 2) forbid non init vars to have a "init_" prefix - elif attr_name.startswith("init_"): - raise TypeError( - f"Attribute '{attr_name}' in class '{cls.__name__}' is not " - "allowed to have an 'init_' prefix as it's a non-InitVar field." - ) super().__init_subclass__(**kwargs)