diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 3563f2a9..3b763aeb 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -104,10 +104,7 @@ def state_to_atom_graphs( # noqa: PLR0915 system_config = SystemConfig(radius=6.0, max_num_neighbors=20) # Handle batch information if present - if state.batch is not None: - n_node = torch.bincount(state.batch) - else: - n_node = torch.tensor([len(state.positions)]) + n_node = torch.bincount(state.batch) # Set default dtype if not provided output_dtype = torch.get_default_dtype() if output_dtype is None else output_dtype @@ -148,45 +145,46 @@ def state_to_atom_graphs( # noqa: PLR0915 if wrap and (torch.any(row_vector_cell != 0) and torch.any(pbc)): positions = feat_util.batch_map_to_pbc_cell(positions, row_vector_cell, n_node) - # Compute edges of the graph - edge_index, edge_vectors, unit_shifts, batch_num_edges = ( - feat_util.batch_compute_pbc_radius_graph( - positions=positions, - cells=row_vector_cell, - pbc=pbc.unsqueeze(0).repeat(len(n_node), 1), - radius=system_config.radius, - n_node=n_node, - max_number_neighbors=torch.tensor([max_num_neighbors] * len(n_node)), - edge_method=edge_method, - half_supercell=half_supercell, - device=device, - ) - ) - senders, receivers = edge_index[0], edge_index[1] - n_systems = state.batch.max().item() + 1 + + # Prepare lists to collect data from each system + all_edges = [] + all_vectors = [] + all_unit_shifts = [] + num_edges = [] node_feats_list = [] edge_feats_list = [] graph_feats_list = [] - system_edges = torch.repeat_interleave( - torch.arange(n_systems, device=state.device), batch_num_edges - ) + + # Process each system in a single loop + offset = 0 for i in range(n_systems): batch_mask = state.batch == i - system_edge_mask = system_edges == i - try: - positions_per_system = positions[batch_mask] - atomic_numbers_per_system = atomic_numbers[batch_mask] - atomic_numbers_embedding_per_system = atomic_numbers_embedding[batch_mask] - edge_vectors_per_system = edge_vectors[system_edge_mask] - unit_shifts_per_system = unit_shifts[system_edge_mask] - except Exception: # noqa: BLE001 - import pdb # noqa: T100 - - pdb.set_trace() # noqa: T100 - + positions_per_system = positions[batch_mask] + atomic_numbers_per_system = atomic_numbers[batch_mask] + atomic_numbers_embedding_per_system = atomic_numbers_embedding[batch_mask] cell_per_system = row_vector_cell[i] pbc_per_system = pbc + + # Compute edges directly for this system + edges, vectors, unit_shifts = feat_util.compute_pbc_radius_graph( + positions=positions_per_system, + cell=cell_per_system, + pbc=pbc_per_system, + radius=system_config.radius, + max_number_neighbors=max_num_neighbors, + edge_method=edge_method, + half_supercell=half_supercell, + device=device, + ) + + # Adjust indices for the global batch + all_edges.append(edges + offset) + all_vectors.append(vectors) + all_unit_shifts.append(unit_shifts) + num_edges.append(len(edges[0])) + + # Calculate lattice parameters lattice_per_system = torch.from_numpy( cell_to_cellpar(cell_per_system.squeeze(0).cpu().numpy()) ) @@ -202,8 +200,8 @@ def state_to_atom_graphs( # noqa: PLR0915 } edge_feats = { - "vectors": edge_vectors_per_system, - "unit_shifts": unit_shifts_per_system, + "vectors": vectors, + "unit_shifts": unit_shifts, } graph_feats = { @@ -221,6 +219,16 @@ def state_to_atom_graphs( # noqa: PLR0915 edge_feats_list.append(edge_feats) graph_feats_list.append(graph_feats) + # Update offset for next system + offset += len(positions_per_system) + + # Concatenate all the edge data + edge_index = torch.cat(all_edges, dim=1) + unit_shifts = torch.cat(all_unit_shifts, dim=0) + batch_num_edges = torch.tensor(num_edges, dtype=torch.int64, device=device) + + senders, receivers = edge_index[0], edge_index[1] + # Create and return AtomGraphs object return AtomGraphs( senders=senders,