Skip to content

Commit

Permalink
CHGNet now supports Apple MPS:tada:, pytorch-MPS has been tested with…
Browse files Browse the repository at this point in the history
… torch-2.0.1
  • Loading branch information
BowenD-UCB committed Sep 25, 2023
1 parent 2ebc57f commit 516c422
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 15 deletions.
15 changes: 6 additions & 9 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,13 @@ def __init__(
"""
super().__init__(**kwargs)

# mps is disabled before stable version of pytorch on apple mps is released
if use_device == "mps":
raise NotImplementedError("'mps' backend is not supported yet")
# elif torch.backends.mps.is_available():
# self.device = 'mps'

# Determine the device to use
self.device = use_device or ("cuda" if torch.cuda.is_available() else "cpu")
if self.device == "cuda":
self.device = f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}"
if use_device == "mps" and torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = use_device or ("cuda" if torch.cuda.is_available() else "cpu")
if self.device == "cuda":
self.device = f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}"

# Move the model to the specified device
self.model = (model or CHGNet.load()).to(self.device)
Expand Down
8 changes: 2 additions & 6 deletions tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,5 @@ def test_relaxation(algorithm: Literal["legacy", "fast"]):
"use_device", ["cpu", param("cuda", marks=no_cuda), param("mps", marks=no_mps)]
)
def test_structure_optimizer_passes_kwargs_to_model(use_device) -> None:
try:
relaxer = StructOptimizer(use_device=use_device)
assert re.match(rf"{use_device}(:\d+)?", relaxer.calculator.device)
except NotImplementedError as exc:
# TODO: remove try/except once mps is supported
assert str(exc) == "'mps' backend is not supported yet" # noqa: PT017
relaxer = StructOptimizer(use_device=use_device)
assert re.match(rf"{use_device}(:\d+)?", relaxer.calculator.device)

0 comments on commit 516c422

Please sign in to comment.