Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/lod #7714

Merged
merged 27 commits into from
Jan 31, 2018
Merged

Fix/lod #7714

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)

cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor init)

cc_test(variable_test SRCS variable_test.cc)

Expand Down
2 changes: 0 additions & 2 deletions paddle/framework/lod_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ limitations under the License. */
#include <algorithm>
#include <iterator>

#include <glog/logging.h>

namespace paddle {
namespace framework {

Expand Down
26 changes: 14 additions & 12 deletions paddle/framework/lod_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/system/cuda/experimental/pinned_allocator.h>
#endif

#include <glog/logging.h>
#include "paddle/framework/ddim.h"
#include "paddle/framework/mixed_vector.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor_util.h"
#include "paddle/platform/enforce.h"
Expand All @@ -31,15 +31,6 @@ limitations under the License. */
namespace paddle {
namespace framework {

#ifndef PADDLE_WITH_CUDA
template <typename T>
using Vector = std::vector<T>;
#else
template <typename T>
using Vector = thrust::host_vector<
T, thrust::system::cuda::experimental::pinned_allocator<T>>;
#endif

/*
* LoD is short for Level of Details.
*
Expand All @@ -55,7 +46,15 @@ using Vector = thrust::host_vector<
* 0 2 4 7
* 0 2 5 7 10 12 15 20
*/
using LoD = std::vector<Vector<size_t>>;
struct LoD : public std::vector<Vector<size_t>> {
using std::vector<Vector<size_t>>::vector;

void CopyFromCUDA() {
for (auto it = this->begin(); it != this->end(); ++it) {
it->CopyFromCUDA();
}
}
};

std::ostream& operator<<(std::ostream& os, const LoD& lod);
std::ostream& operator<<(std::ostream& os, const LoDTensor& t);
Expand Down Expand Up @@ -109,7 +108,10 @@ bool CheckAbsLoD(const LoD& in, int tensor_height = -1);
*/
class LoDTensor : public Tensor {
public:
LoDTensor() {}
LoDTensor() : Tensor() {}

/* Constructor with place should only be used in pybind */
explicit LoDTensor(const platform::Place& place) : Tensor(place) {}

explicit LoDTensor(const LoD& lod) : lod_(lod) {}

Expand Down
11 changes: 11 additions & 0 deletions paddle/framework/lod_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@
namespace paddle {
namespace framework {

TEST(LoD, data) {
LoD lod{{0, 1, 2}};
lod.push_back({0, 2, 4, 5});
lod.push_back(std::vector<size_t>({0, 1, 6, 8, 10, 11}));

auto& v = lod[0];
for (size_t i = 0; i < v.size(); ++i) {
EXPECT_EQ(v[i], i);
}
}

TEST(LodExpand, test) {
LoD lod{{0, 2}};
LoDTensor tensor;
Expand Down
46 changes: 45 additions & 1 deletion paddle/framework/lod_tensor_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include "paddle/framework/init.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/platform/assert.h"

Expand All @@ -26,7 +28,48 @@ __global__ void test(size_t* a, int size) {
}
}

TEST(Vector, Normal) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::memory;

paddle::framework::InitDevices();

paddle::framework::Vector<size_t> vec({1, 2, 3});
size_t* ptr = vec.data();
for (size_t i = 0; i < vec.size(); ++i) {
EXPECT_EQ(vec[i], *(ptr + i));
}

vec.clear();
vec.CopyFromCUDA();

std::vector<size_t> v = {1, 2, 3};
for (size_t i = 0; i < v.size(); ++i) {
EXPECT_EQ(v[i], vec[i]);
}
}

TEST(LoD, data) {
paddle::framework::InitDevices();

paddle::framework::LoD lod{{0, 1, 2}};
lod.push_back({0, 2, 4, 5});
lod.push_back(std::vector<size_t>({0, 1, 6, 8, 10, 11}));

auto& v = lod[0];
test<<<1, 1>>>(v.cuda_data(), v.size());
cudaDeviceSynchronize();

v.CopyFromCUDA();
for (size_t i = 0; i < v.size(); ++i) {
EXPECT_EQ(v[i], i * 2);
}
}

TEST(LoDTensor, LoDInGPU) {
paddle::framework::InitDevices();

paddle::framework::LoDTensor lod_tensor;
paddle::platform::CUDAPlace place(0);

Expand All @@ -42,8 +85,9 @@ TEST(LoDTensor, LoDInGPU) {

auto lod = lod_tensor.lod();

test<<<1, 8>>>(lod[0].data(), lod[0].size());
test<<<1, 8>>>(lod[0].cuda_data(), lod[0].size());
cudaDeviceSynchronize();
lod.CopyFromCUDA();

for (size_t i = 0; i < src_lod[0].size(); ++i) {
EXPECT_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
Expand Down
154 changes: 154 additions & 0 deletions paddle/framework/mixed_vector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <initializer_list>
#include <vector>

#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"

namespace paddle {
namespace framework {

/**
* @brief Vector support both cpu and gpu.
* host vector lifetime is same with Vector
* device vector is lazily malloc and modified.
*/

template <typename T>
class Vector : public std::vector<T> {
public:
/* NOTE(dzhwinter):
* Data always store and modified on Host.
* If the data is modified when use cuda_data interface,
* You need to call the CopyFromCUDA explicitly to synchronize data.
*
*/
enum class kDataPosition {
kDataOnHost = 0,
kDataOnDevice = 1,
};

public:
using std::vector<T>::vector;

Vector() {}
Vector(const std::vector<T> &v) : std::vector<T>(v) {} // NOLINT

virtual ~Vector() {
#ifdef PADDLE_WITH_CUDA
if (cuda_ptr_ != nullptr) {
memory::Free<platform::CUDAPlace>(place_, static_cast<void *>(cuda_ptr_));
}
#endif
}

T *cuda_data() {
CopyToCUDA();
PADDLE_ENFORCE_NOT_NULL(
cuda_ptr_, "No data or Insufficient CUDA memory to allocation");
return static_cast<T *>(cuda_ptr_);
}

T *data() { return std::vector<T>::data(); }

const T *data() const { return std::vector<T>::data(); }

void CopyToCUDA();

void CopyFromCUDA();

void CopyToPeer(platform::Place);

private:
void *cuda_ptr_ = nullptr;
size_t cuda_size_ = 0;
/*The DataPosition is unused now,
if we want support random access from cpu and cuda,
we need to overload all the vector method */

kDataPosition position_ = kDataPosition::kDataOnHost;
platform::CUDAPlace place_;
};

template <typename T>
void Vector<T>::CopyToCUDA() {
#ifdef PADDLE_WITH_CUDA
if (cuda_ptr_ == nullptr) {
cuda_ptr_ =
memory::Alloc<platform::CUDAPlace>(place_, this->size() * sizeof(T));
}
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *cuda_ctx = pool.GetByPlace(place_);

memory::Copy(place_, static_cast<void *>(cuda_ptr_), platform::CPUPlace(),
static_cast<const void *>(this->data()),
this->size() * sizeof(T), cuda_ctx->stream());
cuda_ctx->Wait();

cuda_size_ = this->size();
#endif
}

template <typename T>
void Vector<T>::CopyFromCUDA() {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *cuda_ctx = pool.GetByPlace(place_);
if (cuda_ptr_ == nullptr) {
LOG(WARNING) << "No uncommited cuda data.";
return;
}
this->resize(cuda_size_);
memory::Copy(platform::CPUPlace(), static_cast<void *>(this->data()), place_,
static_cast<const void *>(cuda_ptr_), this->size() * sizeof(T),
cuda_ctx->stream());
cuda_ctx->Wait();

#endif
}

template <typename T>
void Vector<T>::CopyToPeer(platform::Place peer_place) {
if (platform::is_cpu_place(peer_place)) {
return;
}
#ifdef PADDLE_WITH_CUDA
auto *cuda_ctx = platform::DeviceContextPool::Instance().GetByPlace(place_);
void *peer_cuda_ptr_ = memory::Alloc<platform::CUDAPlace>(
boost::get<platform::CUDAPlace>(peer_place), this->size() * sizeof(T));
memory::Copy(boost::get<platform::CUDAPlace>(peer_place),
static_cast<void *>(peer_cuda_ptr_), place_,
static_cast<const void *>(cuda_ptr_), this->size() * sizeof(T),
cuda_ctx->stream());
cuda_ctx->Wait();
memory::Free<platform::CUDAPlace>(place_, static_cast<void *>(cuda_ptr_));
place_ = boost::get<platform::CUDAPlace>(peer_place);
cuda_ptr_ = peer_cuda_ptr_;
#endif
}

template class Vector<int>;
template class Vector<unsigned>;
template class Vector<size_t>;
template class Vector<int64_t>;

} // namespace framework
} // namespace paddle
7 changes: 7 additions & 0 deletions paddle/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ class Tensor {
public:
Tensor() : offset_(0) {}

/*! Constructor with place should only be used in pybind. */
explicit Tensor(const platform::Place& place) : offset_(0) {
holder_->set_place(place);
}

/*! Return a pointer to mutable memory block. */
template <typename T>
inline T* data();
Expand Down Expand Up @@ -137,6 +142,7 @@ class Tensor {
virtual std::type_index type() const = 0;
virtual platform::Place place() const = 0;
virtual void set_type(std::type_index type) = 0;
virtual void set_place(platform::Place place) = 0;
};

template <typename Place>
Expand All @@ -156,6 +162,7 @@ class Tensor {
virtual void* ptr() const { return static_cast<void*>(ptr_.get()); }
virtual std::type_index type() const { return type_; }
virtual void set_type(std::type_index type) { type_ = type; }
virtual void set_place(platform::Place place) { place_ = place; }

/*! the pointer of memory block. */
std::unique_ptr<uint8_t, memory::PODDeleter<uint8_t, Place>> ptr_;
Expand Down
2 changes: 1 addition & 1 deletion paddle/inference/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
set(FLUID_CORE_MODULES proto_desc paddle_memory executor prune init)
set(FLUID_CORE_MODULES proto_desc paddle_memory lod_tensor executor prune init)

cc_library(paddle_fluid_api
SRCS io.cc
Expand Down
6 changes: 3 additions & 3 deletions paddle/operators/adagrad_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
math::scatter::MergeAdd<platform::CUDADeviceContext, T> merge_func;
auto grad_merge = merge_func(context, grad);
auto* grad_merge_data = grad_merge.mutable_value()->template data<T>();
auto& merge_rows = grad_merge.rows();
framework::Vector<int64_t> merge_rows(grad_merge.rows());
// 2. m += g_m * g_m
math::scatter::Mul<platform::CUDADeviceContext, T> sqare_func;
auto grad_square = sqare_func(context, grad_merge, grad_merge);
Expand All @@ -101,8 +101,8 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
SparseAdagradFunctorKernel<
T, 256><<<grid2, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(grad_merge_data, grad_merge.rows().data(),
lr, param_data, moment_data, grad_width,
.stream()>>>(grad_merge_data, merge_rows.cuda_data(), lr,
param_data, moment_data, grad_width,
epsilon);
}
};
Expand Down
7 changes: 6 additions & 1 deletion paddle/operators/adam_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,12 @@ class AdamOpKernel : public framework::OpKernel<T> {
merge_func(ctx.template device_context<DeviceContext>(), grad);
auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>();
auto* rows = grad_merge.rows().data();
int64_t* rows = nullptr;
if (platform::is_gpu_place(ctx.GetPlace())) {
rows = grad_merge.mutable_rows()->cuda_data();
} else {
rows = grad_merge.mutable_rows()->data();
}
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();

SparseAdamFunctor<T> functor(
Expand Down
Loading