diff --git a/examples/scripts/reproduce_err.py b/examples/scripts/reproduce_err.py new file mode 100644 index 00000000..2845f355 --- /dev/null +++ b/examples/scripts/reproduce_err.py @@ -0,0 +1,54 @@ +import torch # noqa: D100, INP001 +from ase.build import bulk +from mace.calculators.foundations_models import mace_mp + +import torch_sim as ts +from torch_sim.models.mace import MaceModel + + +si_atoms = bulk("Si", "fcc", a=3.26, cubic=True) +si_atoms.rattle(0.05) + +cu_atoms = bulk("Cu", "fcc", a=5.26, cubic=True) +cu_atoms.rattle(0.5) + +many_cu_atoms = [si_atoms] * 5 + [cu_atoms] * 20 + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +state = ts.initialize_state(many_cu_atoms, device=device, dtype=torch.float64) +print("state", state) + +mace = mace_mp(model="small", return_raw_model=True) +mace_model = MaceModel(model=mace, device=device) + +fire_init, fire_update = ts.optimizers.fire(mace_model) +fire_state = fire_init(state) + +batcher = ts.InFlightAutoBatcher( + model=mace_model, + memory_scales_with="n_atoms", + max_memory_scaler=40, + max_iterations=10000, # Optional: maximum convergence attempts per state +) + +batcher.load_states(fire_state) + +convergence_fn = ts.generate_force_convergence_fn(5e-3, include_cell_forces=False) + +all_converged_states, convergence_tensor = [], None +while (result := batcher.next_batch(fire_state, convergence_tensor))[0] is not None: + fire_state, converged_states = result + all_converged_states.extend(converged_states) + + for _ in range(3): + fire_state = fire_update(fire_state) + + convergence_tensor = convergence_fn(fire_state, None) + print(f"Convergence tensor: {convergence_tensor}") + print(f"Convergence tensor: {batcher.current_idx}") + +all_converged_states.extend(result[1]) + +final_states = batcher.restore_original_order(all_converged_states) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 7be28997..9e57c44c 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -588,3 +588,53 @@ def test_in_flight_max_iterations( # Verify swap_attempts tracking for i in range(len(states)): assert batcher.swap_attempts[i] == max_attempts + + +def test_concatenation_works_with_undefined_initial_vectors( + lj_model: LennardJonesModel, +): + """Test that concatenation works with undefined initial vectors.""" + si_atoms = bulk("Si", "fcc", a=3.26, cubic=True) + si_atoms.rattle(0.05) + + cu_atoms = bulk("Cu", "fcc", a=5.26, cubic=True) + cu_atoms.rattle(0.5) + + many_cu_atoms = [si_atoms] * 5 + [cu_atoms] * 20 + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + state = ts.initialize_state(many_cu_atoms, device=device, dtype=torch.float64) + + + fire_init, fire_update = ts.optimizers.fire(mace_model) + fire_state = fire_init(state) + + batcher = ts.InFlightAutoBatcher( + model=LennardJonesModel, + memory_scales_with="n_atoms", + max_memory_scaler=40, + max_iterations=10000, # Optional: maximum convergence attempts per state + ) + + batcher.load_states(fire_state) + + convergence_fn = ts.generate_force_convergence_fn(5e-3, include_cell_forces=False) + + all_converged_states, convergence_tensor = [], None + while (result := batcher.next_batch(fire_state, convergence_tensor))[0] is not None: + fire_state, converged_states = result + all_converged_states.extend(converged_states) + + for _ in range(3): + fire_state = fire_update(fire_state) + + convergence_tensor = convergence_fn(fire_state, None) + print(f"Convergence tensor: {convergence_tensor}") + print(f"Convergence tensor: {batcher.current_idx}") + + else: + all_converged_states.extend(result[1]) + + final_states = batcher.restore_original_order(all_converged_states)" diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 98a83d64..40bf78eb 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -474,7 +474,7 @@ class FireState(SimState): # Required attributes not in SimState forces: torch.Tensor energy: torch.Tensor - velocities: torch.Tensor | None + velocities: torch.Tensor # FIRE algorithm parameters dt: torch.Tensor @@ -589,7 +589,9 @@ def fire_init( atomic_numbers=state.atomic_numbers.clone(), system_idx=state.system_idx.clone(), pbc=state.pbc, - velocities=None, + velocities=torch.full( + state.positions.shape, torch.nan, device=device, dtype=dtype + ), forces=forces, energy=energy, # Optimization attributes @@ -862,13 +864,17 @@ def fire_init( atomic_numbers=state.atomic_numbers.clone(), system_idx=state.system_idx.clone(), pbc=state.pbc, - velocities=None, + velocities=torch.full( + state.positions.shape, torch.nan, device=device, dtype=dtype + ), forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=torch.zeros(n_systems, 3, 3, device=device, dtype=dtype), - cell_velocities=None, + cell_velocities=torch.full( + cell_forces.shape, torch.nan, device=device, dtype=dtype + ), cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes @@ -965,7 +971,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin): # Cell attributes cell_positions: torch.Tensor - cell_velocities: torch.Tensor | None + cell_velocities: torch.Tensor cell_forces: torch.Tensor cell_masses: torch.Tensor @@ -1161,13 +1167,17 @@ def fire_init( atomic_numbers=state.atomic_numbers, system_idx=state.system_idx, pbc=state.pbc, - velocities=None, + velocities=torch.full( + state.positions.shape, torch.nan, device=device, dtype=dtype + ), forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=cell_positions, - cell_velocities=None, + cell_velocities=torch.full( + cell_forces.shape, torch.nan, device=device, dtype=dtype + ), cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes @@ -1244,15 +1254,19 @@ def _vv_fire_step( # noqa: C901, PLR0915 dtype = state.positions.dtype deform_grad_new: torch.Tensor | None = None - if state.velocities is None: - state.velocities = torch.zeros_like(state.positions) + nan_velocities = state.velocities.isnan().any(dim=1) + if nan_velocities.any(): + state.velocities[nan_velocities] = torch.zeros_like( + state.positions[nan_velocities] + ) if is_cell_optimization: if not isinstance(state, AnyFireCellState): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) - state.cell_velocities = torch.zeros( - (n_systems, 3, 3), device=device, dtype=dtype + nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) + state.cell_velocities[nan_cell_velocities] = torch.zeros_like( + state.cell_positions[nan_cell_velocities] ) alpha_start_system = torch.full( @@ -1461,16 +1475,20 @@ def _ase_fire_step( # noqa: C901, PLR0915 cur_deform_grad = None # Initialize cur_deform_grad to prevent UnboundLocalError - if state.velocities is None: - state.velocities = torch.zeros_like(state.positions) + nan_velocities = state.velocities.isnan().any(dim=1) + if nan_velocities.any(): + state.velocities[nan_velocities] = torch.zeros_like( + state.positions[nan_velocities] + ) forces = state.forces if is_cell_optimization: if not isinstance(state, AnyFireCellState): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) - state.cell_velocities = torch.zeros( - (n_systems, 3, 3), device=device, dtype=dtype + nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) + state.cell_velocities[nan_cell_velocities] = torch.zeros_like( + state.cell_positions[nan_cell_velocities] ) cur_deform_grad = state.deform_grad() else: diff --git a/torch_sim/state.py b/torch_sim/state.py index ce21ef9b..57d8b298 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -6,9 +6,11 @@ import copy import importlib +import inspect +import typing import warnings -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Literal, Self +from dataclasses import dataclass +from typing import TYPE_CHECKING, Literal, Self, TypeVar import torch @@ -81,7 +83,7 @@ class SimState: cell: torch.Tensor pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor - system_idx: torch.Tensor | None = field(default=None, kw_only=True) + system_idx: torch.Tensor | None def __post_init__(self) -> None: """Validate and process the state after initialization.""" @@ -272,7 +274,7 @@ def clone(self) -> Self: else: attrs[attr_name] = copy.deepcopy(attr_value) - return self.__class__(**attrs) + return construct_state(self, attrs) def to_atoms(self) -> list["Atoms"]: """Convert the SimState to a list of ASE Atoms objects. @@ -372,6 +374,37 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> return _slice_state(self, system_indices) + def __init_subclass__(cls, **kwargs) -> None: + """Enforce that all derived states cannot have tensor attributes that can also be + None. This is because torch.concatenate cannot concat between a tensor and a None. + See https://github.com/Radical-AI/torch-sim/pull/219 for more details. + """ + # We need to use get_type_hints to correctly inspect the types + type_hints = typing.get_type_hints(cls) + for attr_name, attr_typehint in type_hints.items(): + origin = typing.get_origin(attr_typehint) + + is_union = origin is typing.Union + if not is_union and origin is not None: + # For Python 3.10+ `|` syntax, origin is types.UnionType + # We check by name to be robust against module reloading/patching issues + is_union = origin.__module__ == "types" and origin.__name__ == "UnionType" + if is_union: + args = typing.get_args(attr_typehint) + if torch.Tensor in args and type(None) in args: + raise TypeError( + f"Attribute '{attr_name}' in class '{cls.__name__}' is not " + "allowed to be of type 'torch.Tensor | None'. " + "Optional tensor attributes are disallowed in SimState " + "subclasses to prevent concatenation errors.\n" + # TODO(curtis): fix this + # "If this attribute will take on a default value in the " + # "post_init method, please use an InitVar for that attribute " + # "but with a prepended 'init_' to the name. (e.g. system_idx)" + ) + + super().__init_subclass__(**kwargs) + class DeformGradMixin: """Mixin for states that support deformation gradients.""" @@ -490,7 +523,7 @@ def state_to_device( attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) - return type(state)(**attrs) + return construct_state(state, attrs) def infer_property_scope( @@ -719,11 +752,35 @@ def _split_state( # Add the global attributes **attrs["global"], } - states.append(type(state)(**system_attrs)) + states.append(construct_state(state, system_attrs)) return states +SimStateT = TypeVar("SimStateT", bound=SimState) + + +def construct_state( + old_state: SimStateT, + new_state_attrs: dict[str, typing.Any], +) -> SimStateT: + """Construct a new state from an old state and new state parameters.""" + # 1) process the attrs so they are the init params + processed_params = {} + for param in inspect.signature(old_state.__class__).parameters: + if param.startswith("init_"): + # this is an InitVar field + # we need to rename the corresponding field in system_attrs to have + # an "init_" prefix + non_init_attr_name = param.removeprefix("init_") + processed_params[param] = new_state_attrs[non_init_attr_name] + else: + processed_params[param] = new_state_attrs[param] + + # 2) construct the new state + return type(old_state)(**processed_params) + + def _pop_states( state: SimState, pop_indices: list[int] | torch.Tensor, @@ -770,10 +827,10 @@ def _pop_states( pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_system_mask) # Create the keep state - keep_state = type(state)(**keep_attrs) + keep_state = construct_state(state, keep_attrs) # Create and split the pop state - pop_state = type(state)(**pop_attrs) + pop_state = construct_state(state, pop_attrs) pop_states = _split_state(pop_state, ambiguous_handling) return keep_state, pop_states @@ -823,7 +880,7 @@ def _slice_state( filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, system_mask) # Create the sliced state - return type(state)(**filtered_attrs) + return construct_state(state, filtered_attrs) def concatenate_states( @@ -912,7 +969,7 @@ def concatenate_states( concatenated["system_idx"] = torch.cat(new_system_indices) # Create a new instance of the same class - return state_class(**concatenated) + return construct_state(first_state, concatenated) def initialize_state(