Skip to content

Commit

Permalink
Fixed error with inspect for MPS tensors - fixes #50
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Feb 22, 2024
1 parent 633b10c commit f5a12ed
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## 0.15.0 (unreleased)

- Updated LibTorch to 2.2.0
- Fixed error with `inspect` for MPS tensors

## 0.14.1 (2023-12-26)

Expand Down
1 change: 1 addition & 0 deletions ext/torch/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions

rb_cTensor
.define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
.define_method("mps?", [](Tensor& self) { return self.is_mps(); })
.define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
.define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
.define_method("dim", [](Tensor& self) { return self.dim(); })
Expand Down
11 changes: 8 additions & 3 deletions lib/torch/inspector.rb
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def initialize(tensor)
return if nonzero_finite_vals.numel == 0

# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
nonzero_finite_abs = nonzero_finite_vals.abs.double
nonzero_finite_min = nonzero_finite_abs.min.double
nonzero_finite_max = nonzero_finite_abs.max.double
nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs)
nonzero_finite_min = tensor_totype(nonzero_finite_abs.min)
nonzero_finite_max = tensor_totype(nonzero_finite_abs.max)

nonzero_finite_vals.each do |value|
if value.item != value.item.ceil
Expand Down Expand Up @@ -107,6 +107,11 @@ def format(value)
# Ruby throws error when negative, Python doesn't
" " * [@max_width - ret.size, 0].max + ret
end

def tensor_totype(t)
dtype = t.mps? ? :float : :double
t.to(dtype: dtype)
end
end

def inspect
Expand Down

0 comments on commit f5a12ed

Please sign in to comment.