Skip to content

Commit

Permalink
Merge pull request #492 from ShichengChen/dev
Browse files Browse the repository at this point in the history
Make singa use multiple memory pools
  • Loading branch information
nudles committed Aug 9, 2019
2 parents 0063144 + 6973152 commit d94ba5d
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 26 deletions.
5 changes: 4 additions & 1 deletion include/singa/core/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <string>
#include <functional>
#include <memory>
#include <mutex>

#include "singa/singa_config.h"
#include "singa/core/common.h"
Expand Down Expand Up @@ -295,7 +296,8 @@ class Platform {
/// Create a set of CudaGPU Device using given GPU IDs.
static const std::vector<std::shared_ptr<Device>>
CreateCudaGPUsOn(const std::vector<int> &devices, size_t init_size = 0);


static std::vector<std::shared_ptr<Device> > UsedDevice;
/// This function is implementd by Caffe (http://caffe.berkeleyvision.org/).
/// This function checks the availability of GPU #device_id.
/// It attempts to create a context on the device by calling cudaFree(0).
Expand All @@ -311,6 +313,7 @@ class Platform {
/// the permission. cudaFree(0) is one of those with no side effect,
/// except the context initialization.
static bool CheckDevice(const int device_id);
static std::mutex mtx_;
#endif // USE_CUDA

#ifdef USE_OPENCL
Expand Down
5 changes: 4 additions & 1 deletion include/singa/core/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class DeviceMemPool {
virtual std::pair<size_t, size_t> GetMemUsage() {
return std::make_pair(0u, 0u);
}
virtual std::pair<size_t, size_t> GetMemUsage(int id) {
return std::make_pair(0u, 0u);
}
virtual ~DeviceMemPool(){};

protected:
Expand All @@ -62,6 +65,7 @@ class CnMemPool : public DeviceMemPool {
void Free(void* ptr);

std::pair<size_t, size_t> GetMemUsage() override;
std::pair<size_t, size_t> GetMemUsage(int id) override;

// release all memory and set cnmem manager to unintialized
~CnMemPool();
Expand All @@ -78,7 +82,6 @@ class CnMemPool : public DeviceMemPool {
// lock on the initialized variable
std::mutex mtx_;

static std::atomic<int> pool_count;
};

class CudaMemPool : public DeviceMemPool {
Expand Down
28 changes: 18 additions & 10 deletions src/core/device/platform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
#include "singa/core/device.h"
#include "singa/singa_config.h"
#include "singa/utils/opencl_utils.h"

#include <iostream>
using namespace std;
namespace singa {

#ifdef USE_CUDA


std::vector<std::shared_ptr<Device> > Platform::UsedDevice;
std::mutex Platform::mtx_;
int Platform::GetNumGPUs() {
int count;
CUDA_CHECK(cudaGetDeviceCount(&count));
Expand Down Expand Up @@ -118,23 +121,28 @@ Platform::CreateCudaGPUs(const size_t num_devices, size_t init_size) {
return CreateCudaGPUsOn(use_gpus, init_size);
}

const vector<shared_ptr<Device>>
Platform::CreateCudaGPUsOn(const vector<int> &devices, size_t init_size) {
const vector<shared_ptr<Device> > Platform::CreateCudaGPUsOn(
const vector<int>& devices, size_t init_size) {
MemPoolConf conf;
if (init_size > 0)
conf.set_init_size(init_size);
if (init_size > 0) conf.set_init_size(init_size);
size_t bytes = conf.init_size() << 20;
for (auto device : devices) {
conf.add_device(device);
CHECK_LE(bytes, Platform::GetGPUMemSize(device).first);
}
mtx_.lock();
if (UsedDevice.size() == 0) {
int count = Platform::GetNumGPUs();
for (int i = 0; i < count; i++) UsedDevice.push_back(nullptr);
}
auto pool = std::make_shared<CnMemPool>(conf);

vector<shared_ptr<Device> > ret;
for (auto device : devices) {
auto dev = std::make_shared<CudaGPU>(device, pool);
ret.push_back(dev);
for (size_t i = 0; i < devices.size(); i++) {
if (UsedDevice[devices[i]] == nullptr)
UsedDevice[devices[i]] = std::make_shared<CudaGPU>(devices[i], pool);
ret.push_back(UsedDevice[devices[i]]);
}
mtx_.unlock();
return ret;
}

Expand Down
12 changes: 8 additions & 4 deletions src/core/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,31 @@
#ifdef USE_CUDA

namespace singa {
std::atomic<int> CnMemPool::pool_count(0);
std::pair<size_t, size_t> CnMemPool::GetMemUsage() {
size_t free, total;
auto status = cnmemMemGetInfo(&free, &total, NULL);
CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS)
<< cnmemGetErrorString(status);
return std::make_pair(free, total);
}
std::pair<size_t, size_t> CnMemPool::GetMemUsage(int id) {
CHECK_EQ(cudaSetDevice(id), cudaError_t::cudaSuccess);
size_t free, total;
auto status = cnmemMemGetInfo(&free, &total, NULL);
CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS)
<< cnmemGetErrorString(status);
return std::make_pair(free, total);
}

CnMemPool::CnMemPool(int numDevices, size_t init_size, size_t max_size) {
for (int i = 0; i < numDevices; i++)
conf_.add_device(i);
conf_.set_init_size(init_size);
conf_.set_max_size(max_size);
CHECK_LT(++pool_count, 2) << "CnMemPool must be used as a singleton.";
}

CnMemPool::CnMemPool(const MemPoolConf &conf) {
conf_ = conf;
CHECK_LT(++pool_count, 2) << "CnMemPool must be used as a singleton.";
}

void CnMemPool::Init() {
Expand Down Expand Up @@ -79,7 +84,6 @@ CnMemPool::~CnMemPool() {
CHECK_EQ(status, cnmemStatus_t::CNMEM_STATUS_SUCCESS)
<< " " << cnmemGetErrorString(status);
initialized_ = false;
--pool_count;
}
mtx_.unlock();
}
Expand Down
Empty file added test/python/test_memoryPool.py
Empty file.
25 changes: 15 additions & 10 deletions test/singa/test_platform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,22 @@
#include "gtest/gtest.h"
#include "singa/core/device.h"
#include "singa/core/tensor.h"

#include <iostream>
using namespace std;
#ifdef USE_CUDA
using singa::Platform;

TEST(Platform, CreateMultDevice) {
int n = Platform::GetNumGPUs();
auto devs = Platform::CreateCudaGPUs(n);
for (int i= 0;i < devs.size();i++) {
auto b = devs[i]->NewBlock(512+512*(2-i));
EXPECT_EQ(512+512*(2-i), devs[i]->GetAllocatedMem());
devs[i]->FreeBlock(b);
}
}


TEST(Platform, NumGPUs) {
int n = Platform::GetNumGPUs();
EXPECT_GE(n, 0);
Expand Down Expand Up @@ -68,15 +81,7 @@ TEST(Platform, CreateDevice) {
}
}

TEST(Platform, CreateMultDevice) {
int n = Platform::GetNumGPUs();
auto devs = Platform::CreateCudaGPUs(n);
for (auto dev : devs) {
auto b = dev->NewBlock(32);
EXPECT_LE(32u, dev->GetAllocatedMem());
dev->FreeBlock(b);
}
}


TEST(Platform, CreatTensor) {
auto cuda = Platform::CreateCudaGPUs(1)[0];
Expand Down

0 comments on commit d94ba5d

Please sign in to comment.