Skip to content

SevenNetModel does not work with torch.float64 #92

@orionarcher

Description

@orionarcher

The SevenNetModel currently fails when float64 is set as the type.

---------------------------------------------------------------------------
OperationFailure                          Traceback (most recent call last)
File /workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:11
      [8](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:8) db = client["ray_md_testing"]
      [9](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:9) collection = db["torch_sim_speedtest"]
---> [11](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:11) docs = list(collection.find())
     [13](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:13) tags = [
     [14](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:14)     "production_v0.1",  # 1000 steps, 8000:10000 max atoms
     [15](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:15)     "production_v0.2",  # 100 steps, 8000 max atoms
   (...)
     [26](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:26)     "production_v0.13",  # 300 steps, timesteps fixed, mace only, small systems
     [27](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:27) ]
     [28](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/workspaces/propfoliotorchsim/propfolio/propfolio/scripts/analysis/analyze_torchsim_vs_ase_benchmark.py:28) active_tag = tags[10]

File /usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1281, in Cursor.__next__(self)
   [1280](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1280) def __next__(self) -> _DocumentType:
-> [1281](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1281)     return self.next()

File /usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1257, in Cursor.next(self)
   [1255](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1255) if self._empty:
   [1256](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1256)     raise StopIteration
-> [1257](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1257) if len(self._data) or self._refresh():
   [1258](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1258)     return self._data.popleft()
   [1259](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/synchronous/cursor.py:1259) else:
...
    [244](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/helpers_shared.py:244) elif code == 43:
    [245](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/helpers_shared.py:245)     raise CursorNotFound(errmsg, code, response, max_wire_version)
--> [247](https://vscode-remote+ssh-002dremote-002bpropfoliotorchsim-002edevpod.vscode-resource.vscode-cdn.net/usr/local/lib/python3.12/dist-packages/pymongo/helpers_shared.py:247) raise OperationFailure(errmsg, code, response, max_wire_version)

OperationFailure: bad auth : Authentication failed., full error: {'ok': 0, 'errmsg': 'bad auth : Authentication failed.', 'code': 8000, 'codeName': 'AtlasError'}
Output is truncated. View as a [scrollable element](command:cellOutput.enableScrolling?01c24e86-747e-4fed-8893-4c4044a48a66) or open in a [text editor](command:workbench.action.openLargeOutput?01c24e86-747e-4fed-8893-4c4044a48a66). Adjust cell output [settings](command:workbench.action.openSettings?%5B%22%40tag%3AnotebookOutputLayout%22%5D)...
read-write
read-write
hi
Restarted .venv (Python 3.12.9)

Restarted .venv (Python 3.12.9)

Restarted .venv (Python 3.12.9)

Restarted .venv (Python 3.12.9)

Restarted .venv (Python 3.12.9)

Connected to .venv (Python 3.12.9)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[8], line 20
     12 pretrained_sevenn_model = model_loaded.to(device, dtype=torch.float64)
     14 model = SevenNetModel(
     15     model=pretrained_sevenn_model,
     16     modal="omat24",
     17     device=device,
     18     dtype=torch.float64,
     19 )
---> 20 model(state)
     22 # max_scaler = ts.autobatching.estimate_max_memory_scaler(
     23 #     model=model,
     24 #     state_list=[ts.initialize_state(atoms, dtype=torch.float64, device=device)],
     25 #     metric_values=[1],
     26 #     max_atoms=1000000,
     27 # )

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
...
               ~~~~~~~~~~~~~ <--- HERE
    else:
        return _VF.tensordot(a, b, dims_a, dims_b, out=out)  # type: ignore[attr-defined]
RuntimeError: both inputs should have same dtype
Output is truncated. View as a [scrollable element](command:cellOutput.enableScrolling?9ece058a-4398-4f43-b4f9-ebb62f9d7810) or open in a [text editor](command:workbench.action.openLargeOutput?9ece058a-4398-4f43-b4f9-ebb62f9d7810). Adjust cell output [settings](command:workbench.action.openSettings?%5B%22%40tag%3AnotebookOutputLayout%22%5D)...
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[7], line 19
     12 pretrained_sevenn_model = model_loaded.to(device, dtype=torch.float64)
     14 model = SevenNetModel(
     15     model=pretrained_sevenn_model,
     16     modal="omat24",
     17     device=device,
     18 )
---> 19 model(state)
     21 # max_scaler = ts.autobatching.estimate_max_memory_scaler(
     22 #     model=model,
     23 #     state_list=[ts.initialize_state(atoms, dtype=torch.float64, device=device)],
     24 #     metric_values=[1],
     25 #     max_atoms=1000000,
     26 # )

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
...
---> 42     x = x @ w
     43     x = self.act(x)
     44     x = x * self.var_out**0.5

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: double != float
Output is truncated. View as a [scrollable element](command:cellOutput.enableScrolling?4b7d2cbe-f7ef-4f2e-9a51-03e8da37e3ec) or open in a [text editor](command:workbench.action.openLargeOutput?4b7d2cbe-f7ef-4f2e-9a51-03e8da37e3ec). Adjust cell output [settings](command:workbench.action.openSettings?%5B%22%40tag%3AnotebookOutputLayout%22%5D)...
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 19
     12 pretrained_sevenn_model = model_loaded.to(device)
     14 model = SevenNetModel(
     15     model=pretrained_sevenn_model,
     16     modal="omat24",
     17     device=device,
     18 )
---> 19 model(state)
     21 # max_scaler = ts.autobatching.estimate_max_memory_scaler(
     22 #     model=model,
     23 #     state_list=[ts.initialize_state(atoms, dtype=torch.float64, device=device)],
     24 #     metric_values=[1],
     25 #     max_atoms=1000000,
     26 # )

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File /workspaces/propfoliotorchsim/torch-sim/build/__editable__.torch_sim-0.0.0rc0-py3-none-any/torch_sim/models/sevennet.py:226, in SevenNetModel.forward(self, state)
    223     batched_data = batched_data.to_dict()
    224     del batched_data["data_info"]
--> 226 output = self.model(batched_data)
    228 results = {}
    229 # Process energy

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/sevenn/nn/sequential.py:182, in AtomGraphSequential.forward(self, input)
    180 data = self._preprocess(input)
    181 for module in self:
--> 182     data = module(data)
    183 return data

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/sevenn/nn/convolution.py:125, in IrrepsConvolution.forward(self, data)
    123 assert self.convolution is not None, 'Convolution is not instantiated'
    124 assert self.weight_nn is not None, 'Weight_nn is not instantiated'
--> 125 weight = self.weight_nn(data[self.key_weight_input])
    126 x = data[self.key_x]
    127 if self.is_parallel:

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/container.py:250, in Sequential.forward(self, input)
    248 def forward(self, input):
    249     for module in self:
--> 250         input = module(input)
    251     return input

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File /workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/e3nn/nn/_fc.py:42, in _Layer.forward(self, x)
     40 if self.act is not None:
     41     w = self.weight / (self.h_in * self.var_in) ** 0.5
---> 42     x = x @ w
     43     x = self.act(x)
     44     x = x * self.var_out**0.5

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: double != float

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions