From e2e0fbd4188fcbcc6bf69d1ef22b3f6f0a927f84 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Mon, 26 Jun 2017 10:36:49 -0700 Subject: [PATCH 1/3] Add tesnor.h --- paddle/framework/tensor.h | 91 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 paddle/framework/tensor.h diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h new file mode 100644 index 0000000000000..a658537430e27 --- /dev/null +++ b/paddle/framework/tensor.h @@ -0,0 +1,91 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +#pragma once + +namespace paddle { +namespace framework { + +class Tensor { + using paddle::platform::Place; + using paddle::platform::get_place; + + public: + explicit Tensor(DDim dims) : dims_(dims), place_(get_place()) {} + explicit Tensor(DDim dims, Place place) : dims_(dims), place_(place) {} + + template + const T* data() const { + PADDLE_ASSERT(holder_ != nullptr); + PADDLE_ASSERT(holder_->Place() == place_); + PADDLE_ASSERT(holder_->Size() >= dims_.product() * sizeof(T)); + return static_cast(holder->Ptr()); + } + + template ::value>::type> + T* mutable_data() { + if (holder_ == nullptr || holder_->Place() != place_ || + holder_->Size() < dims_.product() * sizeof(T)) { + holder_.reset(new PlaceholderImpl(place_, dims.product() * sizeof(T))); + } + return static_cast(holder_->Ptr()); + } + + template ::value>::type> + T* mutable_data(DDim dims) { + dims_ = dims; + return mutable_data(); + } + + template ::value>::type> + T* mutable_data(DDim dims, Place place) { + dims_ = dims; + place_ = place; + return mutable_data(); + } + + private: + // Placeholder hides type T, so it doesn't appear as a template + // parameter of Variable. + struct Placeholder { + virtual ~Placeholder() {} + virtual void* Ptr() const = 0; + virtual Place Place() const = 0; + virtual size_t Size() const = 0; + }; + + template + struct PlaceholderImpl : public Placeholder { + PlaceholderImpl(Place pl, size_t size) + : ptr_(memory::Alloc(pl, size), paddle::memory::Deleter(pl)), + place_(pl), + size_(size) {} + + virtual void* Ptr() const { return static_cast(ptr_.get()); } + virtual size_t Size() const { return size_; } + virtual Place Place() const { return place_; } + + std::unique_ptr ptr_; + Place place_; // record the place of ptr_. + size_t size_; // size of the memory block. + }; + + std::unique_ptr holder_; // holds the memory block if allocated. + DDim dims_; // could be smallers than the holder_->Size(). + paddle::platform::Place place_; +}; + +} // namespace framework +} // namespace paddle From 2c188a20de53741e6f965738636eb7d6f797a821 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Tue, 27 Jun 2017 12:00:41 -0700 Subject: [PATCH 2/3] Follow QingQing's suggestion --- paddle/framework/tensor.h | 31 ++++++++----------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index a658537430e27..8962b76a12cdb 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -20,23 +20,19 @@ class Tensor { using paddle::platform::get_place; public: - explicit Tensor(DDim dims) : dims_(dims), place_(get_place()) {} - explicit Tensor(DDim dims, Place place) : dims_(dims), place_(place) {} - template const T* data() const { - PADDLE_ASSERT(holder_ != nullptr); - PADDLE_ASSERT(holder_->Place() == place_); - PADDLE_ASSERT(holder_->Size() >= dims_.product() * sizeof(T)); + PADDLE_ASSERT(holder_ != nullptr, + "Tensor::data must be called after Tensor::mutable_data"); return static_cast(holder->Ptr()); } template ::value>::type> - T* mutable_data() { - if (holder_ == nullptr || holder_->Place() != place_ || - holder_->Size() < dims_.product() * sizeof(T)) { - holder_.reset(new PlaceholderImpl(place_, dims.product() * sizeof(T))); + T* mutable_data(DDim dims, Place place) { + if (holder_ == nullptr || holder_->Place() != place || + holder_->Size() < dims.product() * sizeof(T)) { + holder_.reset(new PlaceholderImpl(place, dims.product() * sizeof(T))); } return static_cast(holder_->Ptr()); } @@ -44,16 +40,7 @@ class Tensor { template ::value>::type> T* mutable_data(DDim dims) { - dims_ = dims; - return mutable_data(); - } - - template ::value>::type> - T* mutable_data(DDim dims, Place place) { - dims_ = dims; - place_ = place; - return mutable_data(); + return mutable_data(dims, paddle::platform::get_place()); } private: @@ -69,7 +56,7 @@ class Tensor { template struct PlaceholderImpl : public Placeholder { PlaceholderImpl(Place pl, size_t size) - : ptr_(memory::Alloc(pl, size), paddle::memory::Deleter(pl)), + : ptr_(paddle::memory::Alloc(pl, size), paddle::memory::Deleter(pl)), place_(pl), size_(size) {} @@ -83,8 +70,6 @@ class Tensor { }; std::unique_ptr holder_; // holds the memory block if allocated. - DDim dims_; // could be smallers than the holder_->Size(). - paddle::platform::Place place_; }; } // namespace framework From c263c21f7e0feebca20ab33cd606330de81e9aee Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Tue, 27 Jun 2017 17:35:27 -0700 Subject: [PATCH 3/3] Update copyright informaiton --- paddle/framework/tensor.h | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 8962b76a12cdb..067f2a85264b4 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -1,15 +1,17 @@ -/* - Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once namespace paddle {