Skip to content

Commit

Permalink
from elementSizeInBytes to element_size, following upstream commit ht… (
Browse files Browse the repository at this point in the history
horovod#919)

* from elementSizeInBytes to element_size, following upstream commit pytorch/pytorch#17785

Signed-off-by: labor00 <abrvb@outlook.com>

* uses TORCH_VERSION macro to ensure backward compatibility

Signed-off-by: labor00 <abrvb@outlook.com>

* if pytorch version string contains dev return a very big number

Signed-off-by: labor00 <abrvb@outlook.com>
Signed-off-by: Yana Shchyokotova <yana.shchyokotova@intel.com>
  • Loading branch information
labor00 authored and shirosankaku committed May 30, 2019
1 parent bdc7bf0 commit 5a05d6a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
4 changes: 4 additions & 0 deletions horovod/torch/adapter_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ const TensorShape TorchTensor::shape() const {
const void* TorchTensor::data() const { return tensor_.data_ptr(); }

int64_t TorchTensor::size() const {
# if TORCH_VERSION >= 1001000000
return tensor_.element_size() * tensor_.numel();
#else
return tensor_.type().elementSizeInBytes() * tensor_.numel();
#endif
}

TorchOpContext::TorchOpContext(int device, ::torch::Tensor output)
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,8 @@ def build_tf_extension(build_ext, options):


def parse_version(version_str):
if "dev" in version_str:
return 9999999999
m = re.match('^(\d+)(?:\.(\d+))?(?:\.(\d+))?(?:\.(\d+))?', version_str)
if m is None:
return None
Expand Down

0 comments on commit 5a05d6a

Please sign in to comment.