From cb76265f3ed13a49008a8f2fe60d2344c5da658c Mon Sep 17 00:00:00 2001 From: YutackPark Date: Wed, 23 Apr 2025 10:43:15 +0900 Subject: [PATCH 1/2] fix wrong dtype compare --- torch_sim/models/sevennet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index d21449d4..456313f2 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -77,7 +77,7 @@ def __init__( neighbor_list_fn (Callable): Neighbor list function to use. Default is vesin_nl_ts. device (torch.device | str | None): Device to run the model on - dtype (torch.dtype | None): Data type for computation + dtype (torch.dtype): Data type for computation Raises: ValueError: the model doesn't have a cutoff @@ -92,7 +92,7 @@ def __init__( if isinstance(self._device, str): self._device = torch.device(self._device) - if torch.dtype is not torch.float32: + if dtype is not torch.float32: warnings.warn( "SevenNetModel currently only supports" "float32, but received different dtype", From 13780498651e7a19eb05b7c220bd21e0903c7e34 Mon Sep 17 00:00:00 2001 From: YutackPark Date: Wed, 23 Apr 2025 10:51:06 +0900 Subject: [PATCH 2/2] make sure sevennet type_map exists --- torch_sim/models/sevennet.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 456313f2..82353ec4 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -83,6 +83,7 @@ def __init__( ValueError: the model doesn't have a cutoff ValueError: the model has a modal_map but modal is not given ValueError: the modal given is not in the modal_map + ValueError: the model doesn't have a type_map """ super().__init__() @@ -100,6 +101,10 @@ def __init__( stacklevel=2, ) + if not model.type_map: + raise ValueError("type_map is missing") + model.eval_type_map = torch.tensor(data=True) + self._dtype = dtype self._memory_scales_with = "n_atoms_x_density" self._compute_stress = True @@ -108,9 +113,6 @@ def __init__( if model.cutoff == 0.0: raise ValueError("Model cutoff seems not initialized") - model.eval_type_map = torch.tensor( - data=True, - ) # TODO: from sevenn not sure if needed model.set_is_batch_data(True) model_loaded = model self.cutoff = torch.tensor(model.cutoff)