Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions torch_sim/models/sevennet.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def __init__(
neighbor_list_fn (Callable): Neighbor list function to use.
Default is vesin_nl_ts.
device (torch.device | str | None): Device to run the model on
dtype (torch.dtype | None): Data type for computation
dtype (torch.dtype): Data type for computation

Raises:
ValueError: the model doesn't have a cutoff
ValueError: the model has a modal_map but modal is not given
ValueError: the modal given is not in the modal_map
ValueError: the model doesn't have a type_map
"""
super().__init__()

Expand All @@ -92,14 +93,18 @@ def __init__(
if isinstance(self._device, str):
self._device = torch.device(self._device)

if torch.dtype is not torch.float32:
if dtype is not torch.float32:
warnings.warn(
"SevenNetModel currently only supports"
"float32, but received different dtype",
UserWarning,
stacklevel=2,
)

if not model.type_map:
raise ValueError("type_map is missing")
model.eval_type_map = torch.tensor(data=True)

self._dtype = dtype
self._memory_scales_with = "n_atoms_x_density"
self._compute_stress = True
Expand All @@ -108,9 +113,6 @@ def __init__(
if model.cutoff == 0.0:
raise ValueError("Model cutoff seems not initialized")

model.eval_type_map = torch.tensor(
data=True,
) # TODO: from sevenn not sure if needed
model.set_is_batch_data(True)
model_loaded = model
self.cutoff = torch.tensor(model.cutoff)
Expand Down