Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

gpu mem pool strategy #11041

Merged
merged 3 commits into from Jun 14, 2018
Merged

gpu mem pool strategy #11041

merged 3 commits into from Jun 14, 2018

Conversation

szha
Copy link
Member

@szha szha commented May 24, 2018

Description

adjust GPU memory pool strategy

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • add knob for minimum memory pool chunk size
  • add option (MXNET_GPU_MEM_POOL_TYPE="Round") for using nearest power of 2 size for better memory reuse

Comments

@szha szha force-pushed the mem_strategy branch 10 times, most recently from fd64b96 to b8b942e Compare May 25, 2018 03:10
@szha szha changed the title [WIP] gpu mem pool strategy gpu mem pool strategy May 25, 2018
LOG(INFO) << "Using GPUPooledRoundedStorageManager.";
} else {
if (strategy != "Naive") {
LOG(INFO) << "Unknown memory pool strategy specified: " << strategy << ".";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log(fatal)?

@szha szha force-pushed the mem_strategy branch 2 times, most recently from bcba6e2 to de2a823 Compare May 25, 2018 21:26
@zhreshold
Copy link
Member

Still no clue what's going wrong with this PR. Nothing specific to windows, weirdly python2-GPU-win is good.
I will try it on a local windows pc.

@@ -71,7 +78,7 @@ class GPUPooledStorageManager final : public StorageManager {
private:
void DirectFreeNoLock(Storage::Handle handle) {
cudaError_t err = cudaFree(handle.dptr);
size_t size = handle.size + NDEV;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure + NDEV is not needed any more? what if NDEV=32 and min_chunk=33 and handle.size=30? Original code would allocate 62. New code would allocate 33

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc'd @ptrendx. My understanding on this was that there needs to be enough bytes to make sure that for 32 devices at least each device has 1 byte, for nccl scattering. Could you confirm, @ptrendx?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is correct.

@@ -52,6 +54,11 @@ class GPUPooledStorageManager final : public StorageManager {
*/
GPUPooledStorageManager() {
reserve_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_RESERVE", 5);
min_chunk_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_MIN_CHUNK", 4096);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

page size instead of min chunk?

@@ -82,19 +89,19 @@ class GPUPooledStorageManager final : public StorageManager {
private:
void ReleaseAll();
// used memory
size_t used_memory_ = 0;
size_t used_memory_ = 0, min_chunk_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new line

private:
#if __SIZEOF_SIZE_T__ == __SIZEOF_LONG__

#if defined(__clang__) || defined(__GNUC__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be so complicated? You just need to take the highest bit and shift left by 1 if it's smaller than size.

This is called the finding the MSB. See https://www.google.com/search?ei=__UNW-DMG6iF0wLqyr4g&q=how+to+find+most+significant+bit+in+c&oq=take+highest+bit&gs_l=psy-ab.1.0.0i71k1l8.0.0.0.4417.0.0.0.0.0.0.0.0..0.0....0...1c..64.psy-ab..0.0.0....0.LUbIFjlZyeU

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these builtins would utilize hardware instructions when available.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really faster? It looks too complicated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also the default implementation with pow and log is really slow

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change the default implementation to use bit shifting and then do a comparison

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I compared my current solution, the bit shifting, and static_cast<int>(std::ceil(std::log2(s))), with -O3 is turned on on my mac (clang), the speed looks like the following:

Running 10000000 iters.
Addr width 64
It took me 0.00981569 seconds. result: 223222785
It took me 0.128623 seconds. result: 223222785
It took me 0.0801588 seconds. result: 223222785

@szha szha force-pushed the mem_strategy branch 10 times, most recently from 0319b42 to 63aac3f Compare June 4, 2018 20:39
@szha
Copy link
Member Author

szha commented Jun 6, 2018

I've simplified the implementation to exclude optimization using intrinsics and bit scans. They are backed up in https://github.com/szha/mxnet/tree/mem_strategy_backup

@@ -23,7 +23,7 @@
import platform

blacklist = [
'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh',
'Windows.h', 'intrin.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

@szha szha force-pushed the mem_strategy branch 2 times, most recently from e57bae9 to 9b39b72 Compare June 8, 2018 22:11

TEST(GPUStorage, Round_GPU) {
if (mxnet::test::unitTestsWithCuda) {
putenv("MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF=20");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long does this variable persist? It could have side effects on other tests

#include <gtest/gtest.h>
#include <dmlc/logging.h>
#include <mxnet/storage.h>
#include <cstdio>
#include "test_util.h"
#include "storage/pooled_storage_manager.h"
Copy link
Contributor

@marcoabreu marcoabreu Jun 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate import? I think it's already part of the storage namespace at mxnet/storage.h

@marcoabreu marcoabreu dismissed their stale review June 9, 2018 11:03

Didn't want to block

@szha szha force-pushed the mem_strategy branch 2 times, most recently from d0d8bf7 to 00086f1 Compare June 11, 2018 02:17
@@ -16,7 +16,7 @@
# under the License.

from mxnet.test_utils import *
from common import setup_module, with_seed
from common import setup_module, with_seed, teardown
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really necessary to import this in every single test? Looks a bit ugly tbh

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

applying this change would allow all tests within a module to finish before moving onto the next test, thus eliminating the case where side effect of tests in another module spills over to the next. In terms of testing practice, including a setup/teardown is common.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but we're not actually using it in most files, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now we are

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah in common.py :) But isn't it sufficient to import it there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately no. it is the same case as setup_module

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

argh :/

@szha szha force-pushed the mem_strategy branch 6 times, most recently from 37ecc98 to 72b386f Compare June 12, 2018 03:08
size_t free, total;
cudaMemGetInfo(&free, &total);
if (free <= total * reserve_ / 100 || size > free - total * reserve_ / 100)
ReleaseAll();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen to the storage handles currently pointing to some of the memory?

std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
int bucket = get_bucket(handle->size);
size_t size = get_size(bucket);
auto&& reuse_pool = memory_pool_[bucket];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if it's no error (the rvalue reference will de deduced to normal lvalue reference) it's better to use it explicitly as auto&

@szha szha merged commit bf26886 into apache:master Jun 14, 2018
@leezu leezu mentioned this pull request Jun 22, 2018
9 tasks
@ThomasDelteil
Copy link
Contributor

@szha should we document this new env variable or is it still experimental?

@szha szha deleted the mem_strategy branch June 25, 2018 00:21
@szha
Copy link
Member Author

szha commented Jun 25, 2018

@ThomasDelteil I intended to have people experiment with this first.

zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* use nearest power of 2 for gpu memory pool sizes

* add linear

* add test
XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
* use nearest power of 2 for gpu memory pool sizes

* add linear

* add test
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug of CuDNN RNN with variable sequence length
8 participants