Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't modify torch default dtype #328

Open
janosh opened this issue Feb 24, 2024 · 1 comment
Open

Don't modify torch default dtype #328

janosh opened this issue Feb 24, 2024 · 1 comment

Comments

@janosh
Copy link
Contributor

janosh commented Feb 24, 2024

this line prevents running other models after MACE for relaxation in the same Python session since MACE recommends float64 for geometry optimization while e.g. chgnet and m3gnet use float32.

torch_tools.set_default_dtype(default_dtype)

error messages are not helpful so will likely take users time to troubleshoot this issue when encountered. only current workaround is to manually reset default dtype to float32 with

torch.set_default_dtype(torch.float32)

after every time MACE is called.

Suggested fix

only convert model inputs to model's dtype without modifying all float tensors everywhere

minimal example

import torch
from ase.build import bulk
from mace.calculators import mace_mp

orig_dtype = torch.get_default_dtype()
print(f"{orig_dtype=}")
>>> orig_dtype=torch.float32

atoms = bulk("Cu") * (2, 2, 2)
atoms.calc = mace_mp(default_dtype="float64")
atoms.get_potential_energy()

new_dtype = torch.get_default_dtype()
print(f"{new_dtype=}")
>>> orig_dtype=torch.float64
@hatemhelal
Copy link
Contributor

FWIW you can use a context manager to run different models with different torch default dtype. There is an implementation in PR #310, commit: 80211fd

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants