Skip to content

Commit

Permalink
Added Torch::Backends::MPS.available? method
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Apr 13, 2023
1 parent d8230d7 commit b0174ec
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
11 changes: 9 additions & 2 deletions examples/mnist/main.rb
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,15 @@ def test(model, device, test_loader)

Torch.manual_seed(seed)

use_cuda = Torch::CUDA.available?
device = Torch.device(use_cuda ? "cuda" : "cpu")
device_type =
if Torch::CUDA.available?
"cuda"
elsif Torch::Backends::MPS.available?
"mps"
else
"cpu"
end
device = Torch.device(device_type)
puts "Device type: #{device.type}"

root = File.join(__dir__, "data")
Expand Down
4 changes: 4 additions & 0 deletions ext/torch/backends.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@ void init_backends(Rice::Module& m) {
Rice::define_module_under(rb_mBackends, "MKL")
.add_handler<torch::Error>(handle_error)
.define_singleton_function("available?", &torch::hasMKL);

Rice::define_module_under(rb_mBackends, "MPS")
.add_handler<torch::Error>(handle_error)
.define_singleton_function("available?", &torch::hasMPS);
}
4 changes: 4 additions & 0 deletions test/backends_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,8 @@ def test_openmp
def test_mkl
Torch::Backends::MKL.available?
end

def test_mps
Torch::Backends::MPS.available?
end
end

0 comments on commit b0174ec

Please sign in to comment.