diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 962f7fa6..0130e819 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -35,7 +35,7 @@ update_config, ) from fairchem.core.models.model_registry import model_name_to_local_file - from torch_geometric.data import Batch + from torch_geometric.data import Batch, Data except ImportError: @@ -350,19 +350,22 @@ def forward(self, state: SimState | StateDict) -> dict: ) natoms = torch.bincount(state.batch) - pbc = torch.tensor( - [state.pbc, state.pbc, state.pbc] * len(natoms), dtype=torch.bool - ).view(-1, 3) fixed = torch.zeros((state.batch.size(0), natoms.sum()), dtype=torch.int) - self.data_object = Batch( - pos=state.positions, - cell=state.row_vector_cell, - atomic_numbers=state.atomic_numbers, - natoms=natoms, - batch=state.batch, - fixed=fixed, - pbc=pbc, - ) + data_list = [] + for i, (n, c) in enumerate( + zip(natoms, torch.cumsum(natoms, dim=0), strict=False) + ): + data_list.append( + Data( + pos=state.positions[c - n : c].clone(), + cell=state.row_vector_cell[i, None].clone(), + atomic_numbers=state.atomic_numbers[c - n : c].clone(), + fixed=fixed[c - n : c].clone(), + natoms=n, + pbc=torch.tensor([state.pbc, state.pbc, state.pbc], dtype=torch.bool), + ) + ) + self.data_object = Batch.from_data_list(data_list) if self.dtype is not None: self.data_object.pos = self.data_object.pos.to(self.dtype)