Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory CPU allocator #2596

Merged
merged 23 commits into from
Jun 28, 2017
Merged

Conversation

wangkuiyi
Copy link
Collaborator

@wangkuiyi wangkuiyi commented Jun 26, 2017

This PR is a successor of #2552. Please review #2552 before this.

@@ -78,6 +78,10 @@
#
# cc_test(example_test SRCS example_test.cc DEPS example glog gflags)

if(WITH_GPU)
add_definitions(-DPADDLE_WITH_GPU)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need a new flag PADDLE_WITH_GPU? I think it's duplicate

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean PADDLE_ONLY_CPU? OK. switch to use it.

@@ -0,0 +1,5 @@
if(${WITH_GPU})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unnecessary to wrap nv_test in if(${WITH_GPU}), because the internal implementation of nv_test already equipped this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nv_test was designed to handle *.cu files. But here it is a .cc file, because it doesn't contain CUDA code. However, here the source code calls cudaMallocHost defined in CUDA libraries, but we don't have a external/cuda.cmake.

I think a complete solution here should be a single line:

cc_library(cpu_allocator_test SRCS cpu_allocator_test.cc DEPS cuda cudart)

Actually, I tried and succeeded to add cmake/external/cuda.cmake that defines two CMake targets -- cuda and cudart, but the building of cpu_allocator_test target complains that it cannot find libpthread.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try this.

public:
void* Alloc(size_t size) {
void* p;
if (cudaMallocHost(&p, size) != cudaSuccess) {
Copy link
Contributor

@gangliao gangliao Jun 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use cudaMallocHost(&p, size) in here? It should be mlock. cudaMallocHost should be used in class GPUAllocator<true>. Because cudaMallocHost's p can be accessed directly by GPU device.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think cudaMallocHost allocates CPU memory, other than GPU memory, so I used it here. Am I wrong?

According to the documentation, cudaMallocHost is more than malloc + mlock, it actually makes the CUDA driver tracks OS paging to make sure that cudaMemcpy works in an efficient way with the allocate memory block.

// the CUDA memory space and accessed by the device rapidly. Don't
// allocate too much staging memory; otherwise system performance will
// degrade because the OS cannot find enough swap memory space.
void* AllocStaging(CPUPlace, size_t);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we have to expose pinned memory interfaces? I don't think the developer could explicitly invoke them, that will be easily out of control.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just not sure how the pinned interface would be called, so I just exposed it.

What do you think -- how could we know if GPUAllocator::Alloc should return a malloc-ed block or a pinned block?

@@ -97,6 +97,7 @@ class BuddyAllocator {
struct Block {
size_t size;
Block* left, right;
size_t index; // allocator id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

different allocation has different malloc and free methods.

// between host and device. Allocates too much would reduce the
// amount of memory available to the system for paging. So, by
// default, we should use CPUAllocator<staging=false>.
template <bool staging>
class CPUAllocator {
public:
void* Alloc(size_t size);
void Free(void* p);
void Free(void* p, size_t size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need parameter size for unmlock

return nullptr;
}

void Free(void* p, size_t size) {
Copy link
Contributor

@gangliao gangliao Jun 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the reason why we still keep size in here is that we need to count the allocated / released size of pinned memory for performance protection.

WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This source file is not a final, just first draft. We need to add more features like remaining size, total size and so on.

@wangkuiyi
Copy link
Collaborator Author

wangkuiyi commented Jun 26, 2017

@gangliao I just noticed that as long as we want to call paddle::platform::get_place, we'd have to have typedef boost::variant<CPUPlace, GPUPlace> Place; otherwise, we cannot give get_place a return type. This seems that we cannot remove the dependency to boost.

@wangkuiyi
Copy link
Collaborator Author

@gangliao @reyoung Another point I noticed when I was working on #2611 is that we need a deleter functor class for unique_ptr. Please find the detail here.

@wangkuiyi
Copy link
Collaborator Author

#include <stdlib.h>
#include <assert.h>
#include <memory>
#include <iostream>

template <bool locking>
void* CPUAlloc(size_t size) {
  std::cout << "malloc" << std::endl;
  return std::malloc(size);
}

template <bool locking>
void CPUFree(void* p, size_t size, boo locking) {
  std::cout << "free" << std::endl;
  if (locking) {
    munlock(p, size);
  }
  std::free(p);
}


struct CPUDeleter {
  CPUDeleter(void* p, size_t size, bool locking) :
      p_(p), size_(size), locking_(locking) {}

  template <typename T>
  void operator()(T* p) {
    assert(static_cast<T*>(p_) == p);
    std::cout << "Deleter" << std::endl;
    CPUFree(p_, size_, locking_);
  }

  void* p_;
  size_t size_;
  bool locking_;
};

int main() {
  void* p = CPUAlloc<false>(1024); // GPUAllocator<false>::Alloc
  int* i = static_cast<int*>(p);
  std::unique_ptr<int, CPUDeleter> ptr(i, CPUDeleter(p, 1024, false));
  std::cout << ptr.get() << std::endl;
}

p = Allocator::Alloc(1024);

int* i = static_cast<int*>(p);
std::shared_ptr<int> ptr(i, [](int* p) { Allocator::Free(p, 1024); });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangkuiyi we can use this method to replace deleter

Copy link
Collaborator Author

@wangkuiyi wangkuiyi Jun 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a good idea! Just would a lambda be too lengthy for the callers of Alloc?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, maybe.

But you can name it as follows:

auto deleter =  [](int* p) { Allocator::Free(p, 1024); }

int* i = static_cast<int*>(p);
std::shared_ptr<int> ptr(i, deleter);

@gangliao
Copy link
Contributor

@wangkuiyi Shall we merge this pull request? Starting a new one. Because its title is Memory CPU allocator, we already did this.

@wangkuiyi
Copy link
Collaborator Author

Sure. It's just that I am the ower of this PR so I cannot approve and merge it by myself. @gangliao

Copy link
Contributor

@helinwang helinwang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@wangkuiyi wangkuiyi merged commit 2d840ea into PaddlePaddle:develop Jun 28, 2017
@wangkuiyi wangkuiyi deleted the memory_cpu_allocator branch June 28, 2017 21:43
@gangliao gangliao moved this from Doing to Done in PaddlePaddle Refactoring: Phase 1 Aug 2, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
No open projects
Development

Successfully merging this pull request may close these issues.

None yet

3 participants