forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
conv_op_shared_gpu.cc
27 lines (23 loc) · 910 Bytes
/
conv_op_shared_gpu.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/conv_op_shared.h"
namespace caffe2 {
template <>
void createSharedBuffer<CUDAContext>(Workspace* ws) {
auto* mutexPtr = ws->CreateBlob("__CAFFE2_SHARED_CONV_BUFFER_CUDA_MUTEX__")
->GetMutable<std::unique_ptr<std::mutex>>();
mutexPtr->reset(new std::mutex());
ws->CreateBlob("__CAFFE2_SHARED_CONV_BUFFER_CUDA__");
}
template <>
void runWithSharedBuffer<CUDAContext>(
Workspace* ws,
std::function<void(Tensor* buffer)> f) {
auto* mutexBlob = ws->GetBlob("__CAFFE2_SHARED_CONV_BUFFER_CUDA_MUTEX__");
CAFFE_ENFORCE(mutexBlob, "Must call createSharedBuffer() first");
auto* mutexPtr = mutexBlob->GetMutable<std::unique_ptr<std::mutex>>();
std::lock_guard<std::mutex> g(**mutexPtr);
auto* buffer = BlobGetMutableTensor(
ws->GetBlob("__CAFFE2_SHARED_CONV_BUFFER_CUDA__"), CUDA);
f(buffer);
}
}