From a7aa7cf535de6eeb330509721c781d14c0c6b6cd Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 17:06:20 -0700 Subject: [PATCH 1/5] more renaming --- torch_sim/runners.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index ef1dcb16..041de0ba 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -174,14 +174,14 @@ def integrate( pbar_kwargs.setdefault("disable", None) tqdm_pbar = tqdm(total=state.n_systems, **pbar_kwargs) - for state, batch_indices in batch_iterator: + for state, system_indices in batch_iterator: state = init_fn(state) # set up trajectory reporters if autobatcher and trajectory_reporter: - # we must remake the trajectory reporter for each batch + # we must remake the trajectory reporter for each system trajectory_reporter.load_new_trajectories( - filenames=[og_filenames[i] for i in batch_indices] + filenames=[og_filenames[i] for i in system_indices] ) # run the simulation @@ -278,7 +278,7 @@ def _chunked_apply( autobatcher.load_states(states) initialized_states = [] - initialized_states = [fn(batch) for batch in autobatcher] + initialized_states = [fn(system) for system in autobatcher] ordered_states = autobatcher.restore_original_order(initialized_states) return concatenate_states(ordered_states) @@ -308,7 +308,7 @@ def convergence_fn( Returns: torch.Tensor: Boolean tensor of shape (n_systems,) indicating - convergence status for each batch. + convergence status for each system. """ force_conv = batchwise_max_force(state) < force_tol @@ -344,7 +344,7 @@ def convergence_fn( Returns: torch.Tensor: Boolean tensor of shape (n_systems,) indicating - convergence status for each batch. + convergence status for each system. """ return torch.abs(state.energy - last_energy) < energy_tol @@ -437,13 +437,13 @@ def optimize( # noqa: C901 tqdm_pbar = tqdm(total=state.n_systems, **pbar_kwargs) while (result := autobatcher.next_batch(state, convergence_tensor))[0] is not None: - state, converged_states, batch_indices = result + state, converged_states, system_indices = result all_converged_states.extend(converged_states) # need to update the trajectory reporter if any states have converged if trajectory_reporter and (step == 1 or len(converged_states) > 0): trajectory_reporter.load_new_trajectories( - filenames=[og_filenames[i] for i in batch_indices] + filenames=[og_filenames[i] for i in system_indices] ) for _step in range(steps_between_swaps): @@ -487,8 +487,8 @@ def static( """Run single point calculations on a batch of systems. Unlike the other runners, this function does not return a state. Instead, it - returns a list of dictionaries, one for each batch in the input state. Each - dictionary contains the properties calculated for that batch. It will also + returns a list of dictionaries, one for each system in the input state. Each + dictionary contains the properties calculated for that system. It will also modify the state in place with the "energy", "forces", and "stress" properties if they are present in the model output. @@ -547,12 +547,12 @@ class StaticState(type(state)): pbar_kwargs.setdefault("disable", None) tqdm_pbar = tqdm(total=state.n_systems, **pbar_kwargs) - for sub_state, batch_indices in batch_iterator: + for sub_state, system_indices in batch_iterator: # set up trajectory reporters if autobatcher and trajectory_reporter and og_filenames is not None: - # we must remake the trajectory reporter for each batch + # we must remake the trajectory reporter for each system trajectory_reporter.load_new_trajectories( - filenames=[og_filenames[idx] for idx in batch_indices] + filenames=[og_filenames[idx] for idx in system_indices] ) model_outputs = model(sub_state) From f5076c89b4f786eba3e28ae979553922ba8c6c51 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 17:10:00 -0700 Subject: [PATCH 2/5] more renaming --- examples/tutorials/low_level_tutorial.py | 6 +++--- examples/tutorials/state_tutorial.py | 22 +++++++++++----------- torch_sim/quantities.py | 2 +- torch_sim/runners.py | 8 ++++---- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index 69d22e78..99c8702d 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -107,7 +107,7 @@ """ `SimState` objects can be passed directly to the model and it will compute the properties of the systems in the batch. The properties will be returned -either batchwise, like the energy, or atomwise, like the forces. +either systemwise, like the energy, or atomwise, like the forces. Note that the energy here refers to the potential energy of the system. """ @@ -116,9 +116,9 @@ model_outputs = model(state) print(f"Model outputs: {', '.join(list(model_outputs))}") -print(f"Energy is a batchwise property with shape: {model_outputs['energy'].shape}") +print(f"Energy is a systemwise property with shape: {model_outputs['energy'].shape}") print(f"Forces are an atomwise property with shape: {model_outputs['forces'].shape}") -print(f"Stress is a batchwise property with shape: {model_outputs['stress'].shape}") +print(f"Stress is a systemwise property with shape: {model_outputs['stress'].shape}") # %% [markdown] diff --git a/examples/tutorials/state_tutorial.py b/examples/tutorials/state_tutorial.py index 0bc9e341..0d3eca96 100644 --- a/examples/tutorials/state_tutorial.py +++ b/examples/tutorials/state_tutorial.py @@ -71,11 +71,11 @@ # %% [markdown] """ -SimState attributes fall into three categories: atomwise, batchwise, and global. +SimState attributes fall into three categories: atomwise, systemwise, and global. * Atomwise attributes are tensors with shape (n_atoms, ...), these are `positions`, - `masses`, `atomic_numbers`, and `batch`. Names are plural. -* Batchwise attributes are tensors with shape (n_systems, ...), this is just `cell` for + `masses`, `atomic_numbers`, and `system_idx`. Names are plural. +* Systemwise attributes are tensors with shape (n_systems, ...), this is just `cell` for the base SimState. Names are singular. * Global attributes have any other shape or type, just `pbc` here. Names are singular. @@ -112,7 +112,7 @@ f"Multi-state has {multi_state.n_atoms} total atoms across {multi_state.n_systems} systems" ) -# we can see how the shapes of batchwise, atomwise, and global properties change +# we can see how the shapes of atomwise, systemwise, and global properties change print(f"Positions shape: {multi_state.positions.shape}") print(f"Cell shape: {multi_state.cell.shape}") print(f"PBC: {multi_state.pbc}") @@ -142,7 +142,7 @@ SimState supports many convenience operations for manipulating batched states. Slicing is supported through fancy indexing, e.g. `state[[0, 1, 2]]` will return a new state -containing only the first three batches. The other operations are available through the +containing only the first three systems. The other operations are available through the `pop`, `split`, `clone`, and `to` methods. """ @@ -182,19 +182,19 @@ # %% [markdown] """ -You can extract specific batches from a batched state using Python's slicing syntax. +You can extract specific systems from a batched state using Python's slicing syntax. This is extremely useful for analyzing specific systems or for implementing complex workflows where different systems need separate processing: The slicing interface follows Python's standard indexing conventions, making it intuitive to use. Behind the scenes, TorchSim is creating a new SimState with only the -selected batches, maintaining all the necessary properties and relationships. +selected systems, maintaining all the necessary properties and relationships. Note the difference between these operations: -- `split()` returns all batches as separate states but doesn't modify the original -- `pop()` removes specified batches from the original state and returns them as +- `split()` returns all systems as separate states but doesn't modify the original +- `pop()` removes specified systems from the original state and returns them as separate states -- `__getitem__` (slicing) creates a new state with specified batches without modifying +- `__getitem__` (slicing) creates a new state with specified systems without modifying the original This flexibility allows you to structure your simulation workflows in the most @@ -203,7 +203,7 @@ ### Splitting and Popping Batches SimState provides methods to split a batched state into separate states or to remove -specific batches: +specific systems: """ # %% [markdown] diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index 5d9898c2..971b1b54 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -141,7 +141,7 @@ def get_pressure( return 1 / dim * ((2 * kinetic_energy / volume) - torch.einsum("...ii", stress)) -def batchwise_max_force(state: SimState) -> torch.Tensor: +def systemwise_max_force(state: SimState) -> torch.Tensor: """Compute the maximum force per system. Args: diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 041de0ba..3c83724c 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -22,7 +22,7 @@ UnitCellFireState, UnitCellGDState, ) -from torch_sim.quantities import batchwise_max_force, calc_kinetic_energy, calc_kT +from torch_sim.quantities import calc_kinetic_energy, calc_kT, systemwise_max_force from torch_sim.state import SimState, concatenate_states, initialize_state from torch_sim.trajectory import TrajectoryReporter from torch_sim.typing import StateLike @@ -297,7 +297,7 @@ def generate_force_convergence_fn( Returns: Convergence function that takes a state and last energy and - returns a batchwise boolean function + returns a systemwise boolean function """ def convergence_fn( @@ -310,7 +310,7 @@ def convergence_fn( torch.Tensor: Boolean tensor of shape (n_systems,) indicating convergence status for each system. """ - force_conv = batchwise_max_force(state) < force_tol + force_conv = systemwise_max_force(state) < force_tol if include_cell_forces: if (cell_forces := getattr(state, "cell_forces", None)) is None: @@ -333,7 +333,7 @@ def generate_energy_convergence_fn(energy_tol: float = 1e-3) -> Callable: Returns: Convergence function that takes a state and last energy and - returns a batchwise boolean function + returns a systemwise boolean function """ def convergence_fn( From 2c5810de5c42fe5cc280f46b4d7cc50f7a3700a8 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 17:17:11 -0700 Subject: [PATCH 3/5] more renaming --- .../7_Others/7.3_Batched_neighbor_list.py | 12 +++--- tests/test_autobatching.py | 2 +- tests/test_neighbors.py | 14 +++---- tests/test_transforms.py | 14 +++---- torch_sim/neighbors.py | 40 +++++++++---------- 5 files changed, 41 insertions(+), 41 deletions(-) diff --git a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py index 91141fd3..2b845c07 100644 --- a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py +++ b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py @@ -25,26 +25,26 @@ # Fix: Ensure pbc has the correct shape [n_systems, 3] pbc_tensor = torch.tensor([[pbc] * 3] * len(atoms_list), dtype=torch.bool) -mapping, mapping_batch, shifts_idx = torch_nl_linked_cell( +mapping, mapping_system, shifts_idx = torch_nl_linked_cell( cutoff, pos, cell, pbc_tensor, system_idx, self_interaction ) -cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, mapping_batch) +cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, mapping_system) dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts) print(mapping.shape) -print(mapping_batch.shape) +print(mapping_system.shape) print(shifts_idx.shape) print(cell_shifts.shape) print(dds.shape) -mapping_n2, mapping_batch_n2, shifts_idx_n2 = torch_nl_n2( +mapping_n2, mapping_system_n2, shifts_idx_n2 = torch_nl_n2( cutoff, pos, cell, pbc_tensor, system_idx, self_interaction ) -cell_shifts_n2 = transforms.compute_cell_shifts(cell, shifts_idx_n2, mapping_batch_n2) +cell_shifts_n2 = transforms.compute_cell_shifts(cell, shifts_idx_n2, mapping_system_n2) dds_n2 = transforms.compute_distances_with_cell_shifts(pos, mapping_n2, cell_shifts_n2) print(mapping_n2.shape) -print(mapping_batch_n2.shape) +print(mapping_system_n2.shape) print(shifts_idx_n2.shape) print(cell_shifts_n2.shape) print(dds_n2.shape) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 7be28997..30544f64 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -149,7 +149,7 @@ def test_binning_auto_batcher( # Get batches until None is returned batches = list(batcher) - # Check we got the expected number of batches + # Check we got the expected number of systems assert len(batches) == len(batcher.batched_states) # Test restore_original_order diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 205b626a..ac948acb 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -342,13 +342,13 @@ def test_torch_nl_implementations( ) # Get the neighbor list from the implementation being tested - mapping, mapping_batch, shifts_idx = nl_implementation( + mapping, mapping_system, shifts_idx = nl_implementation( cutoff, pos, row_vector_cell, pbc, batch, self_interaction ) # Calculate distances cell_shifts = transforms.compute_cell_shifts( - row_vector_cell, shifts_idx, mapping_batch + row_vector_cell, shifts_idx, mapping_system ) dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts) dds = np.sort(dds.numpy()) @@ -496,7 +496,7 @@ def test_strict_nl_edge_cases( # Test with no cell shifts mapping = torch.tensor([[0], [1]], device=device, dtype=torch.long) - batch_mapping = torch.tensor([0], device=device, dtype=torch.long) + system_mapping = torch.tensor([0], device=device, dtype=torch.long) shifts_idx = torch.zeros((1, 3), device=device, dtype=torch.long) new_mapping, new_batch, new_shifts = neighbors.strict_nl( @@ -504,14 +504,14 @@ def test_strict_nl_edge_cases( positions=pos, cell=cell, mapping=mapping, - batch_mapping=batch_mapping, + system_mapping=system_mapping, shifts_idx=shifts_idx, ) assert len(new_mapping[0]) > 0 # Should find neighbors # Test with different batch mappings mapping = torch.tensor([[0, 1], [1, 0]], device=device, dtype=torch.long) - batch_mapping = torch.tensor([0, 1], device=device, dtype=torch.long) + system_mapping = torch.tensor([0, 1], device=device, dtype=torch.long) shifts_idx = torch.zeros((2, 3), device=device, dtype=torch.long) new_mapping, new_batch, new_shifts = neighbors.strict_nl( @@ -519,7 +519,7 @@ def test_strict_nl_edge_cases( positions=pos, cell=cell, mapping=mapping, - batch_mapping=batch_mapping, + system_mapping=system_mapping, shifts_idx=shifts_idx, ) assert len(new_mapping[0]) > 0 # Should find neighbors @@ -559,7 +559,7 @@ def test_neighbor_lists_time_and_memory( system_idx = torch.zeros(n_atoms, dtype=torch.long, device=device) # Fix pbc tensor shape pbc = torch.tensor([[True, True, True]], device=device) - mapping, mapping_batch, shifts_idx = nl_fn( + mapping, mapping_system, shifts_idx = nl_fn( cutoff, pos, cell, pbc, system_idx, self_interaction=False ) else: diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 4c05e658..ca965c69 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1183,9 +1183,9 @@ def test_compute_cell_shifts_basic() -> None: """Test compute_cell_shifts function.""" cell = torch.eye(3).unsqueeze(0) * 2.0 shifts_idx = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) - batch_mapping = torch.tensor([0, 0]) + system_mapping = torch.tensor([0, 0]) - cell_shifts = tst.compute_cell_shifts(cell, shifts_idx, batch_mapping) + cell_shifts = tst.compute_cell_shifts(cell, shifts_idx, system_mapping) expected = torch.tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]]) torch.testing.assert_close(cell_shifts, expected) @@ -1272,16 +1272,16 @@ def test_build_linked_cell_neighborhood_basic() -> None: cutoff = 1.5 n_atoms = torch.tensor([2, 2]) - mapping, batch_mapping, cell_shifts_idx = tst.build_linked_cell_neighborhood( + mapping, system_mapping, cell_shifts_idx = tst.build_linked_cell_neighborhood( positions, cell, pbc, cutoff, n_atoms, self_interaction=False ) # Check that atoms in the same structure are neighbors assert mapping.shape[1] >= 2 # At least 2 neighbor pairs - # Verify batch_mapping has correct length - assert batch_mapping.shape[0] == mapping.shape[1] + # Verify system_mapping has correct length + assert system_mapping.shape[0] == mapping.shape[1] # Verify that there are neighbors from both batches - assert torch.any(batch_mapping == 0) - assert torch.any(batch_mapping == 1) + assert torch.any(system_mapping == 0) + assert torch.any(system_mapping == 1) diff --git a/torch_sim/neighbors.py b/torch_sim/neighbors.py index 9c997eac..40091eee 100644 --- a/torch_sim/neighbors.py +++ b/torch_sim/neighbors.py @@ -642,7 +642,7 @@ def strict_nl( positions: torch.Tensor, cell: torch.Tensor, mapping: torch.Tensor, - batch_mapping: torch.Tensor, + system_mapping: torch.Tensor, shifts_idx: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Apply a strict cutoff to the neighbor list defined in the mapping. @@ -663,7 +663,7 @@ def strict_nl( mapping (torch.Tensor): A tensor of shape (2, n_pairs) that specifies pairs of indices in `positions` for which to compute distances. - batch_mapping (torch.Tensor): + system_mapping (torch.Tensor): A tensor that maps the shifts to the corresponding cells, used in conjunction with `shifts_idx` to compute the correct periodic shifts. shifts_idx (torch.Tensor): @@ -675,8 +675,8 @@ def strict_nl( A tuple containing: - mapping (torch.Tensor): A filtered tensor of shape (2, n_filtered_pairs) with pairs of indices that are within the cutoff distance. - - mapping_batch (torch.Tensor): A tensor of shape (n_filtered_pairs,) - that maps the filtered pairs to their corresponding batches. + - mapping_system (torch.Tensor): A tensor of shape (n_filtered_pairs,) + that maps the filtered pairs to their corresponding systems. - shifts_idx (torch.Tensor): A tensor of shape (n_filtered_pairs, 3) containing the periodic shift indices for the filtered pairs. @@ -689,7 +689,7 @@ def strict_nl( References: - https://github.com/felixmusil/torch_nl """ - cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, batch_mapping) + cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, system_mapping) if cell_shifts is None: d2 = (positions[mapping[0]] - positions[mapping[1]]).square().sum(dim=1) else: @@ -701,9 +701,9 @@ def strict_nl( mask = d2 < cutoff * cutoff mapping = mapping[:, mask] - mapping_batch = batch_mapping[mask] + mapping_system = system_mapping[mask] shifts_idx = shifts_idx[mask] - return mapping, mapping_batch, shifts_idx + return mapping, mapping_system, shifts_idx @torch.jit.script @@ -712,7 +712,7 @@ def torch_nl_n2( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, - batch: torch.Tensor, + system_idx: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute the neighbor list for a set of atomic structures using a @@ -729,7 +729,7 @@ def torch_nl_n2( pbc (torch.Tensor [n_structure, 3] bool): A tensor indicating the periodic boundary conditions to apply. Partial PBC are not supported yet. - batch (torch.Tensor [n_atom,] torch.long): + system_idx (torch.Tensor [n_atom,] torch.long): A tensor containing the index of the structure to which each atom belongs. self_interaction (bool, optional): A flag to indicate whether to keep the center atoms as their own neighbors. @@ -741,7 +741,7 @@ def torch_nl_n2( A tensor containing the indices of the neighbor list for the given positions array. `mapping[0]` corresponds to the central atom indices, and `mapping[1]` corresponds to the neighbor atom indices. - batch_mapping (torch.Tensor [n_neighbors]): + system_mapping (torch.Tensor [n_neighbors]): A tensor mapping the neighbor atoms to their respective structures. shifts_idx (torch.Tensor [n_neighbors, 3]): A tensor containing the cell shift indices used to reconstruct the @@ -750,14 +750,14 @@ def torch_nl_n2( References: - https://github.com/felixmusil/torch_nl """ - n_atoms = torch.bincount(batch) - mapping, batch_mapping, shifts_idx = transforms.build_naive_neighborhood( + n_atoms = torch.bincount(system_idx) + mapping, system_mapping, shifts_idx = transforms.build_naive_neighborhood( positions, cell, pbc, cutoff, n_atoms, self_interaction ) - mapping, mapping_batch, shifts_idx = strict_nl( - cutoff, positions, cell, mapping, batch_mapping, shifts_idx + mapping, mapping_system, shifts_idx = strict_nl( + cutoff, positions, cell, mapping, system_mapping, shifts_idx ) - return mapping, mapping_batch, shifts_idx + return mapping, mapping_system, shifts_idx @torch.jit.script @@ -797,7 +797,7 @@ def torch_nl_linked_cell( A tensor containing the indices of the neighbor list for the given positions array. `mapping[0]` corresponds to the central atom indices, and `mapping[1]` corresponds to the neighbor atom indices. - - batch_mapping (torch.Tensor [n_neighbors]): + - system_mapping (torch.Tensor [n_neighbors]): A tensor mapping the neighbor atoms to their respective structures. - shifts_idx (torch.Tensor [n_neighbors, 3]): A tensor containing the cell shift indices used to reconstruct the @@ -807,11 +807,11 @@ def torch_nl_linked_cell( - https://github.com/felixmusil/torch_nl """ n_atoms = torch.bincount(system_idx) - mapping, batch_mapping, shifts_idx = transforms.build_linked_cell_neighborhood( + mapping, system_mapping, shifts_idx = transforms.build_linked_cell_neighborhood( positions, cell, pbc, cutoff, n_atoms, self_interaction ) - mapping, mapping_batch, shifts_idx = strict_nl( - cutoff, positions, cell, mapping, batch_mapping, shifts_idx + mapping, mapping_system, shifts_idx = strict_nl( + cutoff, positions, cell, mapping, system_mapping, shifts_idx ) - return mapping, mapping_batch, shifts_idx + return mapping, mapping_system, shifts_idx From 87f29401f47fb4c055cd463d5cc6f8fedcc97e34 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 17:20:40 -0700 Subject: [PATCH 4/5] more renaming --- tests/test_integrators.py | 48 +++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/test_integrators.py b/tests/test_integrators.py index b6923aa5..ac7bf4b8 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -20,66 +20,66 @@ def test_calculate_momenta_basic(device: torch.device): seed = 42 dtype = torch.float64 - # Create test inputs for 3 batches with 2 atoms each + # Create test inputs for 3 systems with 2 atoms each n_atoms = 8 positions = torch.randn(n_atoms, 3, dtype=dtype, device=device) masses = torch.rand(n_atoms, dtype=dtype, device=device) + 0.5 - batch = torch.tensor( + system_idx = torch.tensor( [0, 0, 1, 1, 2, 2, 3, 3], device=device - ) # 3 batches with 2 atoms each + ) # 3 systems with 2 atoms each kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device) # Run the function - momenta = calculate_momenta(positions, masses, batch, kT, seed=seed) + momenta = calculate_momenta(positions, masses, system_idx, kT, seed=seed) # Basic checks assert momenta.shape == positions.shape assert momenta.dtype == dtype assert momenta.device == device - # Check that each batch has zero center of mass momentum + # Check that each system has zero center of mass momentum for b in range(4): - batch_mask = batch == b - batch_momenta = momenta[batch_mask] - com_momentum = torch.mean(batch_momenta, dim=0) + system_mask = system_idx == b + system_momenta = momenta[system_mask] + com_momentum = torch.mean(system_momenta, dim=0) assert torch.allclose( com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10 ) def test_calculate_momenta_single_atoms(device: torch.device): - """Test that calculate_momenta preserves momentum for batches with single atoms.""" + """Test that calculate_momenta preserves momentum for systems with single atoms.""" seed = 42 dtype = torch.float64 - # Create test inputs with some batches having single atoms + # Create test inputs with some systems having single atoms positions = torch.randn(5, 3, dtype=dtype, device=device) masses = torch.rand(5, dtype=dtype, device=device) + 0.5 - batch = torch.tensor( + system_idx = torch.tensor( [0, 1, 1, 2, 3], device=device - ) # Batches 0, 2, and 3 have single atoms + ) # systems 0, 2, and 3 have single atoms kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device) # Generate momenta and save the raw values before COM correction generator = torch.Generator(device=device).manual_seed(seed) raw_momenta = torch.randn( positions.shape, device=device, dtype=dtype, generator=generator - ) * torch.sqrt(masses * kT[batch]).unsqueeze(-1) + ) * torch.sqrt(masses * kT[system_idx]).unsqueeze(-1) # Run the function - momenta = calculate_momenta(positions, masses, batch, kT, seed=seed) + momenta = calculate_momenta(positions, masses, system_idx, kT, seed=seed) - # Check that single-atom batches have unchanged momenta - for b in [0, 2, 3]: # Single atom batches - batch_mask = batch == b + # Check that single-atom systems have unchanged momenta + for b in [0, 2, 3]: # Single atom systems + system_mask = system_idx == b # The momentum should be exactly the same as the raw value for single atoms - assert torch.allclose(momenta[batch_mask], raw_momenta[batch_mask]) + assert torch.allclose(momenta[system_mask], raw_momenta[system_mask]) - # Check that multi-atom batches have zero COM - for b in [1]: # Multi-atom batches - batch_mask = batch == b - batch_momenta = momenta[batch_mask] - com_momentum = torch.mean(batch_momenta, dim=0) + # Check that multi-atom systems have zero COM + for b in [1]: # Multi-atom systems + system_mask = system_idx == b + system_momenta = momenta[system_mask] + com_momentum = torch.mean(system_momenta, dim=0) assert torch.allclose( com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10 ) @@ -378,7 +378,7 @@ def test_compute_cell_force_atoms_per_system(): Covers fix in https://github.com/Radical-AI/torch-sim/pull/153.""" from torch_sim.integrators.npt import _compute_cell_force - # Setup minimal state with two batches having 8:1 atom ratio + # Setup minimal state with two systems having 8:1 atom ratio s1, s2 = torch.zeros(8, dtype=torch.long), torch.ones(64, dtype=torch.long) state = NPTLangevinState( From 81ef5541d2b7c474edf5c4e6f48a5d8026c9d67e Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 2 Aug 2025 17:25:02 -0700 Subject: [PATCH 5/5] more renaming --- torch_sim/models/graphpes.py | 6 +++--- torch_sim/models/mace.py | 20 ++++++++++---------- torch_sim/models/orb.py | 12 ++++++------ torch_sim/models/sevennet.py | 6 +++--- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index fe51fe01..6ce52753 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -68,9 +68,9 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra graphs = [] for i in range(state.n_systems): - batch_mask = state.system_idx == i - R = state.positions[batch_mask] - Z = state.atomic_numbers[batch_mask] + system_mask = state.system_idx == i + R = state.positions[system_mask] + Z = state.atomic_numbers[system_mask] cell = state.row_vector_cell[i] nl, shifts = vesin_nl_ts( R, diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 4ec3ec3f..cfd34142 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -184,17 +184,17 @@ def __init__( # Store flag to track if atomic numbers were provided at init self.atomic_numbers_in_init = atomic_numbers is not None - # Set up batch information if atomic numbers are provided + # Set up system_idx information if atomic numbers are provided if atomic_numbers is not None: if system_idx is None: - # If batch is not provided, assume all atoms belong to same system + # If system_idx is not provided, assume all atoms belong to same system system_idx = torch.zeros( len(atomic_numbers), dtype=torch.long, device=self.device ) - self.setup_from_batch(atomic_numbers, system_idx) + self.setup_from_system_idx(atomic_numbers, system_idx) - def setup_from_batch( + def setup_from_system_idx( self, atomic_numbers: torch.Tensor, system_idx: torch.Tensor ) -> None: """Set up internal state from atomic numbers and system indices. @@ -286,7 +286,7 @@ def forward( # noqa: C901 ) state.system_idx = self.system_idx - # Update batch information if new atomic numbers are provided + # Update system_idx information if new atomic numbers are provided if ( state.atomic_numbers is not None and not self.atomic_numbers_in_init @@ -295,7 +295,7 @@ def forward( # noqa: C901 getattr(self, "atomic_numbers", torch.zeros(0, device=self.device)), ) ): - self.setup_from_batch(state.atomic_numbers, state.system_idx) + self.setup_from_system_idx(state.atomic_numbers, state.system_idx) # Process each system's neighbor list separately edge_indices = [] @@ -305,16 +305,16 @@ def forward( # noqa: C901 # TODO (AG): Currently doesn't work for batched neighbor lists for b in range(self.n_systems): - batch_mask = state.system_idx == b + system_mask = state.system_idx == b # Calculate neighbor list for this system edge_idx, shifts_idx = self.neighbor_list_fn( - positions=state.positions[batch_mask], + positions=state.positions[system_mask], cell=state.row_vector_cell[b], pbc=state.pbc, cutoff=self.r_max, ) - # Adjust indices for the batch + # Adjust indices for the system edge_idx = edge_idx + offset shifts = torch.mm(shifts_idx, state.row_vector_cell[b]) @@ -322,7 +322,7 @@ def forward( # noqa: C901 unit_shifts_list.append(shifts_idx) shifts_list.append(shifts) - offset += len(state.positions[batch_mask]) + offset += len(state.positions[system_mask]) # Combine all neighbor lists edge_index = torch.cat(edge_indices, dim=1) diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index cf015fb2..7b4bffd7 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -157,10 +157,10 @@ def state_to_atom_graphs( # noqa: PLR0915 # Process each system in a single loop offset = 0 for i in range(n_systems): - batch_mask = state.system_idx == i - positions_per_system = positions[batch_mask] - atomic_numbers_per_system = atomic_numbers[batch_mask] - atomic_numbers_embedding_per_system = atomic_numbers_embedding[batch_mask] + system_mask = state.system_idx == i + positions_per_system = positions[system_mask] + atomic_numbers_per_system = atomic_numbers[system_mask] + atomic_numbers_embedding_per_system = atomic_numbers_embedding[system_mask] cell_per_system = row_vector_cell[i] pbc_per_system = pbc @@ -223,7 +223,7 @@ def state_to_atom_graphs( # noqa: PLR0915 # Concatenate all the edge data edge_index = torch.cat(all_edges, dim=1) unit_shifts = torch.cat(all_unit_shifts, dim=0) - batch_num_edges = torch.tensor(num_edges, dtype=torch.int64, device=device) + system_num_edges = torch.tensor(num_edges, dtype=torch.int64, device=device) senders, receivers = edge_index[0], edge_index[1] @@ -232,7 +232,7 @@ def state_to_atom_graphs( # noqa: PLR0915 senders=senders, receivers=receivers, n_node=n_node, - n_edge=batch_num_edges, + n_edge=system_num_edges, node_features=_map_concat(node_feats_list), edge_features=_map_concat(edge_feats_list), system_features=_map_concat(graph_feats_list), diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index c4e3d96b..6156fc17 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -182,13 +182,13 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: data_list = [] for b in range(state.system_idx.max().item() + 1): - batch_mask = state.system_idx == b + system_mask = state.system_idx == b - pos = state.positions[batch_mask] + pos = state.positions[system_mask] # SevenNet uses row vector cell convention for neighbor list row_vector_cell = state.row_vector_cell[b] pbc = state.pbc - atomic_numbers = state.atomic_numbers[batch_mask] + atomic_numbers = state.atomic_numbers[system_mask] edge_idx, shifts_idx = self.neighbor_list_fn( positions=pos,