Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions examples/scripts/reproduce_err.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch # noqa: D100, INP001
from ase.build import bulk
from mace.calculators.foundations_models import mace_mp

import torch_sim as ts
from torch_sim.models.mace import MaceModel


si_atoms = bulk("Si", "fcc", a=3.26, cubic=True)
si_atoms.rattle(0.05)

cu_atoms = bulk("Cu", "fcc", a=5.26, cubic=True)
cu_atoms.rattle(0.5)

many_cu_atoms = [si_atoms] * 5 + [cu_atoms] * 20

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

state = ts.initialize_state(many_cu_atoms, device=device, dtype=torch.float64)
print("state", state)

mace = mace_mp(model="small", return_raw_model=True)
mace_model = MaceModel(model=mace, device=device)

fire_init, fire_update = ts.optimizers.fire(mace_model)
fire_state = fire_init(state)

batcher = ts.InFlightAutoBatcher(
model=mace_model,
memory_scales_with="n_atoms",
max_memory_scaler=40,
max_iterations=10000, # Optional: maximum convergence attempts per state
)

batcher.load_states(fire_state)

convergence_fn = ts.generate_force_convergence_fn(5e-3, include_cell_forces=False)

all_converged_states, convergence_tensor = [], None
while (result := batcher.next_batch(fire_state, convergence_tensor))[0] is not None:
fire_state, converged_states = result
all_converged_states.extend(converged_states)

for _ in range(3):
fire_state = fire_update(fire_state)

convergence_tensor = convergence_fn(fire_state, None)
print(f"Convergence tensor: {convergence_tensor}")
print(f"Convergence tensor: {batcher.current_idx}")

all_converged_states.extend(result[1])

final_states = batcher.restore_original_order(all_converged_states)
50 changes: 50 additions & 0 deletions tests/test_autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,3 +588,53 @@ def test_in_flight_max_iterations(
# Verify swap_attempts tracking
for i in range(len(states)):
assert batcher.swap_attempts[i] == max_attempts


def test_concatenation_works_with_undefined_initial_vectors(
lj_model: LennardJonesModel,
):
"""Test that concatenation works with undefined initial vectors."""
si_atoms = bulk("Si", "fcc", a=3.26, cubic=True)
si_atoms.rattle(0.05)

cu_atoms = bulk("Cu", "fcc", a=5.26, cubic=True)
cu_atoms.rattle(0.5)

many_cu_atoms = [si_atoms] * 5 + [cu_atoms] * 20

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

state = ts.initialize_state(many_cu_atoms, device=device, dtype=torch.float64)


fire_init, fire_update = ts.optimizers.fire(mace_model)
fire_state = fire_init(state)

batcher = ts.InFlightAutoBatcher(
model=LennardJonesModel,
memory_scales_with="n_atoms",
max_memory_scaler=40,
max_iterations=10000, # Optional: maximum convergence attempts per state
)

batcher.load_states(fire_state)

convergence_fn = ts.generate_force_convergence_fn(5e-3, include_cell_forces=False)

all_converged_states, convergence_tensor = [], None
while (result := batcher.next_batch(fire_state, convergence_tensor))[0] is not None:
fire_state, converged_states = result
all_converged_states.extend(converged_states)

for _ in range(3):
fire_state = fire_update(fire_state)

convergence_tensor = convergence_fn(fire_state, None)
print(f"Convergence tensor: {convergence_tensor}")
print(f"Convergence tensor: {batcher.current_idx}")

else:
all_converged_states.extend(result[1])

final_states = batcher.restore_original_order(all_converged_states)"
48 changes: 33 additions & 15 deletions torch_sim/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ class FireState(SimState):
# Required attributes not in SimState
forces: torch.Tensor
energy: torch.Tensor
velocities: torch.Tensor | None
velocities: torch.Tensor

# FIRE algorithm parameters
dt: torch.Tensor
Expand Down Expand Up @@ -589,7 +589,9 @@ def fire_init(
atomic_numbers=state.atomic_numbers.clone(),
system_idx=state.system_idx.clone(),
pbc=state.pbc,
velocities=None,
velocities=torch.full(
state.positions.shape, torch.nan, device=device, dtype=dtype
),
forces=forces,
energy=energy,
# Optimization attributes
Expand Down Expand Up @@ -862,13 +864,17 @@ def fire_init(
atomic_numbers=state.atomic_numbers.clone(),
system_idx=state.system_idx.clone(),
pbc=state.pbc,
velocities=None,
velocities=torch.full(
state.positions.shape, torch.nan, device=device, dtype=dtype
),
forces=forces,
energy=energy,
stress=stress,
# Cell attributes
cell_positions=torch.zeros(n_systems, 3, 3, device=device, dtype=dtype),
cell_velocities=None,
cell_velocities=torch.full(
cell_forces.shape, torch.nan, device=device, dtype=dtype
),
cell_forces=cell_forces,
cell_masses=cell_masses,
# Optimization attributes
Expand Down Expand Up @@ -965,7 +971,7 @@ class FrechetCellFIREState(SimState, DeformGradMixin):

# Cell attributes
cell_positions: torch.Tensor
cell_velocities: torch.Tensor | None
cell_velocities: torch.Tensor
cell_forces: torch.Tensor
cell_masses: torch.Tensor

Expand Down Expand Up @@ -1161,13 +1167,17 @@ def fire_init(
atomic_numbers=state.atomic_numbers,
system_idx=state.system_idx,
pbc=state.pbc,
velocities=None,
velocities=torch.full(
state.positions.shape, torch.nan, device=device, dtype=dtype
),
forces=forces,
energy=energy,
stress=stress,
# Cell attributes
cell_positions=cell_positions,
cell_velocities=None,
cell_velocities=torch.full(
cell_forces.shape, torch.nan, device=device, dtype=dtype
),
cell_forces=cell_forces,
cell_masses=cell_masses,
# Optimization attributes
Expand Down Expand Up @@ -1244,15 +1254,19 @@ def _vv_fire_step( # noqa: C901, PLR0915
dtype = state.positions.dtype
deform_grad_new: torch.Tensor | None = None

if state.velocities is None:
state.velocities = torch.zeros_like(state.positions)
nan_velocities = state.velocities.isnan().any(dim=1)
if nan_velocities.any():
state.velocities[nan_velocities] = torch.zeros_like(
state.positions[nan_velocities]
)
if is_cell_optimization:
if not isinstance(state, AnyFireCellState):
raise ValueError(
f"Cell optimization requires one of {get_args(AnyFireCellState)}."
)
state.cell_velocities = torch.zeros(
(n_systems, 3, 3), device=device, dtype=dtype
nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2))
state.cell_velocities[nan_cell_velocities] = torch.zeros_like(
state.cell_positions[nan_cell_velocities]
)

alpha_start_system = torch.full(
Expand Down Expand Up @@ -1461,16 +1475,20 @@ def _ase_fire_step( # noqa: C901, PLR0915

cur_deform_grad = None # Initialize cur_deform_grad to prevent UnboundLocalError

if state.velocities is None:
state.velocities = torch.zeros_like(state.positions)
nan_velocities = state.velocities.isnan().any(dim=1)
if nan_velocities.any():
state.velocities[nan_velocities] = torch.zeros_like(
state.positions[nan_velocities]
)
forces = state.forces
if is_cell_optimization:
if not isinstance(state, AnyFireCellState):
raise ValueError(
f"Cell optimization requires one of {get_args(AnyFireCellState)}."
)
state.cell_velocities = torch.zeros(
(n_systems, 3, 3), device=device, dtype=dtype
nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2))
state.cell_velocities[nan_cell_velocities] = torch.zeros_like(
state.cell_positions[nan_cell_velocities]
)
cur_deform_grad = state.deform_grad()
else:
Expand Down
77 changes: 67 additions & 10 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

import copy
import importlib
import inspect
import typing
import warnings
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Literal, Self
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, Self, TypeVar

import torch

Expand Down Expand Up @@ -81,7 +83,7 @@ class SimState:
cell: torch.Tensor
pbc: bool # TODO: do all calculators support mixed pbc?
atomic_numbers: torch.Tensor
system_idx: torch.Tensor | None = field(default=None, kw_only=True)
system_idx: torch.Tensor | None

def __post_init__(self) -> None:
"""Validate and process the state after initialization."""
Expand Down Expand Up @@ -272,7 +274,7 @@ def clone(self) -> Self:
else:
attrs[attr_name] = copy.deepcopy(attr_value)

return self.__class__(**attrs)
return construct_state(self, attrs)

def to_atoms(self) -> list["Atoms"]:
"""Convert the SimState to a list of ASE Atoms objects.
Expand Down Expand Up @@ -372,6 +374,37 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) ->

return _slice_state(self, system_indices)

def __init_subclass__(cls, **kwargs) -> None:
"""Enforce that all derived states cannot have tensor attributes that can also be
None. This is because torch.concatenate cannot concat between a tensor and a None.
See https://github.com/Radical-AI/torch-sim/pull/219 for more details.
"""
# We need to use get_type_hints to correctly inspect the types
type_hints = typing.get_type_hints(cls)
for attr_name, attr_typehint in type_hints.items():
origin = typing.get_origin(attr_typehint)

is_union = origin is typing.Union
if not is_union and origin is not None:
# For Python 3.10+ `|` syntax, origin is types.UnionType
# We check by name to be robust against module reloading/patching issues
is_union = origin.__module__ == "types" and origin.__name__ == "UnionType"
if is_union:
args = typing.get_args(attr_typehint)
if torch.Tensor in args and type(None) in args:
raise TypeError(
f"Attribute '{attr_name}' in class '{cls.__name__}' is not "
"allowed to be of type 'torch.Tensor | None'. "
"Optional tensor attributes are disallowed in SimState "
"subclasses to prevent concatenation errors.\n"
# TODO(curtis): fix this
# "If this attribute will take on a default value in the "
# "post_init method, please use an InitVar for that attribute "
# "but with a prepended 'init_' to the name. (e.g. system_idx)"
)

super().__init_subclass__(**kwargs)


class DeformGradMixin:
"""Mixin for states that support deformation gradients."""
Expand Down Expand Up @@ -490,7 +523,7 @@ def state_to_device(
attrs["masses"] = attrs["masses"].to(dtype=dtype)
attrs["cell"] = attrs["cell"].to(dtype=dtype)
attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int)
return type(state)(**attrs)
return construct_state(state, attrs)


def infer_property_scope(
Expand Down Expand Up @@ -719,11 +752,35 @@ def _split_state(
# Add the global attributes
**attrs["global"],
}
states.append(type(state)(**system_attrs))
states.append(construct_state(state, system_attrs))

return states


SimStateT = TypeVar("SimStateT", bound=SimState)


def construct_state(
old_state: SimStateT,
new_state_attrs: dict[str, typing.Any],
) -> SimStateT:
"""Construct a new state from an old state and new state parameters."""
# 1) process the attrs so they are the init params
processed_params = {}
for param in inspect.signature(old_state.__class__).parameters:
if param.startswith("init_"):
# this is an InitVar field
# we need to rename the corresponding field in system_attrs to have
# an "init_" prefix
non_init_attr_name = param.removeprefix("init_")
processed_params[param] = new_state_attrs[non_init_attr_name]
else:
processed_params[param] = new_state_attrs[param]

# 2) construct the new state
return type(old_state)(**processed_params)


def _pop_states(
state: SimState,
pop_indices: list[int] | torch.Tensor,
Expand Down Expand Up @@ -770,10 +827,10 @@ def _pop_states(
pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_system_mask)

# Create the keep state
keep_state = type(state)(**keep_attrs)
keep_state = construct_state(state, keep_attrs)

# Create and split the pop state
pop_state = type(state)(**pop_attrs)
pop_state = construct_state(state, pop_attrs)
pop_states = _split_state(pop_state, ambiguous_handling)

return keep_state, pop_states
Expand Down Expand Up @@ -823,7 +880,7 @@ def _slice_state(
filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, system_mask)

# Create the sliced state
return type(state)(**filtered_attrs)
return construct_state(state, filtered_attrs)


def concatenate_states(
Expand Down Expand Up @@ -912,7 +969,7 @@ def concatenate_states(
concatenated["system_idx"] = torch.cat(new_system_indices)

# Create a new instance of the same class
return state_class(**concatenated)
return construct_state(first_state, concatenated)


def initialize_state(
Expand Down
Loading