diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index d21449d4..82353ec4 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -77,12 +77,13 @@ def __init__( neighbor_list_fn (Callable): Neighbor list function to use. Default is vesin_nl_ts. device (torch.device | str | None): Device to run the model on - dtype (torch.dtype | None): Data type for computation + dtype (torch.dtype): Data type for computation Raises: ValueError: the model doesn't have a cutoff ValueError: the model has a modal_map but modal is not given ValueError: the modal given is not in the modal_map + ValueError: the model doesn't have a type_map """ super().__init__() @@ -92,7 +93,7 @@ def __init__( if isinstance(self._device, str): self._device = torch.device(self._device) - if torch.dtype is not torch.float32: + if dtype is not torch.float32: warnings.warn( "SevenNetModel currently only supports" "float32, but received different dtype", @@ -100,6 +101,10 @@ def __init__( stacklevel=2, ) + if not model.type_map: + raise ValueError("type_map is missing") + model.eval_type_map = torch.tensor(data=True) + self._dtype = dtype self._memory_scales_with = "n_atoms_x_density" self._compute_stress = True @@ -108,9 +113,6 @@ def __init__( if model.cutoff == 0.0: raise ValueError("Model cutoff seems not initialized") - model.eval_type_map = torch.tensor( - data=True, - ) # TODO: from sevenn not sure if needed model.set_is_batch_data(True) model_loaded = model self.cutoff = torch.tensor(model.cutoff)