Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Custom Operator Random Number Generator Support (#17762)
Browse files Browse the repository at this point in the history
Add random number generator support for custom operator libraries.

Design: We pass from MXNet the initialized and seeded states, located on CPU and GPU, to custom library. So user could use those seeds to generate deterministic values from a given seed passed to MXNet. Basically this workflow:

mx.random.seed(128)
r1 = mx.nd.some_custom_random_op(data)
mx.random.seed(128)
r2 = mx.nd.some_custom_random_op(data)
assert (r1 == r2)

This PR does not let custom library generate exactly the same sequence of random numbers comparing to MXNet

This is a continuation of the custom operator project #15921 and #17270
  • Loading branch information
rondogency committed Apr 8, 2020
1 parent f906a02 commit 16ddc6d
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 59 deletions.
17 changes: 6 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ if(USE_CUDA)
message("-- CUDA: Using the following NVCC architecture flags ${CUDA_ARCH_FLAGS}")
set(arch_code_list)
foreach(arch_str ${CUDA_ARCH_FLAGS})
if((arch_str MATCHES ".*sm_[0-9]+"))
if((arch_str MATCHES ".*sm_[0-9]+"))
string( REGEX REPLACE ".*sm_([0-9]+)" "\\1" arch_code ${arch_str} )
list(APPEND arch_code_list ${arch_code})
endif()
Expand Down Expand Up @@ -719,7 +719,7 @@ elseif(MSVC)
"$<$<COMPILE_LANGUAGE:CUDA>:--gpu-code=sm_${arch},compute_${arch}>"
)
target_compile_options(
mxnet_${arch}
mxnet_${arch}
PRIVATE "$<$<AND:$<CONFIG:DEBUG>,$<COMPILE_LANGUAGE:CUDA>>:-Xcompiler=-MTd -Gy /bigobj>")
target_compile_options(
mxnet_${arch}
Expand Down Expand Up @@ -748,26 +748,21 @@ elseif(MSVC)
endif()
endif()

# extension libraries (custom operators, custom subgraphs) are built by default
add_library(customop_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc)
add_library(subgraph_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_subgraph/subgraph_lib.cc)
target_include_directories(customop_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
target_include_directories(subgraph_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
if (USE_CUDA)
if(USE_CUDA)
add_library(customop_gpu_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/relu_lib.cu)
target_include_directories(customop_gpu_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
endif()
if(UNIX)
target_compile_options(customop_lib PUBLIC -shared)
target_compile_options(subgraph_lib PUBLIC -shared)
if (USE_CUDA)
target_compile_options(customop_gpu_lib PUBLIC -shared)
endif()
elseif(MSVC)
if(MSVC)
target_compile_options(customop_lib PUBLIC /LD)
target_compile_options(subgraph_lib PUBLIC /LD)
set_target_properties(customop_lib PROPERTIES PREFIX "lib")
set_target_properties(subgraph_lib PROPERTIES PREFIX "lib")
if (USE_CUDA)
if(USE_CUDA)
target_compile_options(customop_gpu_lib PUBLIC "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-fPIC>")
set_target_properties(customop_gpu_lib PROPERTIES PREFIX "lib")
endif()
Expand Down
90 changes: 83 additions & 7 deletions example/extensions/lib_custom_op/relu_lib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
/*!
* Copyright (c) 2020 by Contributors
* \file relu_lib.cu
* \brief simple custom relu operator implemented using CUDA function
* \brief simple custom relu and noisy relu operator implemented using CUDA function
*/

#include <iostream>
#include "lib_api.h"

#define NumThreadPerBlock 256 // mxnet recommended cuda thread number per block

__global__ void relu_gpu_forward(float *out, float *in, int64_t N) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N)
Expand Down Expand Up @@ -72,9 +74,9 @@ MXReturnValue forwardGPU(std::map<std::string, std::string> attrs,

mx_stream_t cuda_stream = res.get_cuda_stream();
int64_t N = inputs[0].size();
int block = 256;
int grid = (N + (block - 1)) / block;
relu_gpu_forward<<<grid,block,0,cuda_stream>>>(out_data, in_data, N);
int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock;

relu_gpu_forward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(out_data, in_data, N);

return MX_SUCCESS;
}
Expand All @@ -89,9 +91,9 @@ MXReturnValue backwardGPU(std::map<std::string, std::string> attrs,

mx_stream_t cuda_stream = res.get_cuda_stream();
int64_t N = inputs[0].size();
int block = 256;
int grid = (N + (block - 1)) / block;
relu_gpu_backward<<<grid,block,0,cuda_stream>>>(in_grad, out_grad, in_data, N);
int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock;

relu_gpu_backward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(in_grad, out_grad, in_data, N);

return MX_SUCCESS;
}
Expand Down Expand Up @@ -180,6 +182,80 @@ REGISTER_OP(my_state_relu)
.setCreateOpState(createOpStateCPU, "cpu")
.setCreateOpState(createOpStateGPU, "gpu");

/*
* Below is noisy ReLU operator example
* noisy ReLU is made from ReLU extended to include Gaussian noise
* forward - add Gaussian noise generated from normal distribution to each unit
* backward - gradient doesn't need to change since noise is constant
*/

#define NumRandomPerThread 64 // mxnet recommended random numbers generated per thread

__global__ void noisy_relu_gpu_forward(float *out, float *in, int64_t N, mx_gpu_rand_t* states, int step) {
// the launcher logic ensures tid less than NumGPURandomStates
int tid = blockIdx.x * blockDim.x + threadIdx.x;
// each thread generates unique sequence of random numbers
mx_gpu_rand_t thread_state = states[tid];
// each thread works on <step> number of calculation
int start = tid * step;
int end = start + step;
for (int i=start; i<end && i<N; ++i) {
float noise = curand_normal(&thread_state);
out[i] = in[i] + noise > 0 ? in[i] + noise : 0;
}
}

MXReturnValue noisyForwardCPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();

mx_cpu_rand_t* states = res.get_cpu_rand_states();
std::normal_distribution<float> dist_normal;

for (int i=0; i<inputs[0].size(); ++i) {
float noise = dist_normal(*states);
out_data[i] = in_data[i] + noise > 0 ? in_data[i] + noise : 0;
}
return MX_SUCCESS;
}

MXReturnValue noisyForwardGPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();

mx_stream_t cuda_stream = res.get_cuda_stream();
int64_t N = inputs[0].size();

// below is mxnet recommended workflow to parallel random number generating
int nthread = (N + NumRandomPerThread - 1) / NumRandomPerThread;
// we should not launch more threads than mxnet supported random number GPU states
int num_thread_need = nthread < MX_NUM_GPU_RANDOM_STATES ? nthread : MX_NUM_GPU_RANDOM_STATES;
// each cuda thread processes [step * tid, step * id + step) snippet of input tensor
int step = (N + num_thread_need - 1) / num_thread_need;
// this can ensure number of parallel threads less than mxnet supported random number states
int num_block = (num_thread_need + NumThreadPerBlock - 1) / NumThreadPerBlock;

noisy_relu_gpu_forward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(
out_data, in_data, N, res.get_gpu_rand_states(), step);

return MX_SUCCESS;
}

REGISTER_OP(my_noisy_relu)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape)
.setForward(noisyForwardCPU, "cpu")
.setForward(noisyForwardGPU, "gpu")
.setBackward(backwardCPU, "cpu")
.setBackward(backwardGPU, "gpu");

MXReturnValue initialize(int version) {
if (version >= 10400) {
std::cout << "MXNet version " << version << " supported" << std::endl;
Expand Down
43 changes: 27 additions & 16 deletions example/extensions/lib_custom_op/test_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
a = mx.nd.array([[-2,-1],[1,2]], ctx=mx.cpu())
b = mx.nd.array([[-2,-1],[1,2]], ctx=mx.gpu())

print("--------start ndarray compute---------")
print("--------ndarray compute---------")
print(mx.nd.my_relu(a))
print(mx.nd.my_relu(b))
print(mx.nd.my_state_relu(a))
print(mx.nd.my_state_relu(b))

print("--------start symbolic compute--------")
print("--------symbolic compute--------")
c = mx.sym.Variable('c')
d = mx.sym.Variable('d')
e = mx.sym.my_relu(c)
Expand All @@ -55,30 +55,41 @@
print(out)
print(out_base)

print("--------start backward compute--------")
print("--------backward compute--------")
out_grad = mx.nd.ones((2,2), ctx=mx.gpu())
exe.backward([out_grad])
exe_base.backward([out_grad])
print(in_grad)
print(in_grad_base)

print("--------start testing larger ndarray---------")
a = mx.nd.uniform(shape=(100,100,100), ctx=mx.cpu())
print("--------test ndarray with size of 1 million---------")
b = mx.nd.uniform(shape=(100,100,100), ctx=mx.gpu())
mx.nd.waitall()
t1 = time.time()
r1 = mx.nd.my_relu(a)
r1 = mx.nd.my_relu(b)
mx.nd.waitall()
t2 = time.time()
r2 = mx.nd.my_relu(b)
r2 = mx.nd.relu(b)
mx.nd.waitall()
t3 = time.time()
r3 = mx.nd.relu(b)
mx.nd.waitall()
t4 = time.time()
print("CPU running time:")
print(t2 - t1)
print("GPU running time:")
print(t3 - t2)
print("Baseline GPU running time:")
print(t4 - t3)
print("Custom ReLU running time in ms:")
print((t2 - t1) * 1000)
print("Native ReLU running time in ms:")
print((t3 - t2) * 1000)

print("--------test noisy relu identical sequence---------")

a = mx.nd.ones(shape=(13,5), ctx=mx.cpu())
b = mx.nd.ones(shape=(13,5), ctx=mx.gpu())

mx.random.seed(128, ctx=mx.cpu())
print(mx.nd.my_noisy_relu(a))

mx.random.seed(128, ctx=mx.cpu())
print(mx.nd.my_noisy_relu(a))

mx.random.seed(128, ctx=mx.gpu())
print(mx.nd.my_noisy_relu(b))

mx.random.seed(128, ctx=mx.gpu())
print(mx.nd.my_noisy_relu(b))
57 changes: 46 additions & 11 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,14 @@
#include <iostream>
#include <utility>
#include <stdexcept>
#include <random>

#define MX_LIBRARY_VERSION 5
#if defined(__NVCC__)
#include <curand_kernel.h>
#endif

/* Make sure to update the version number everytime you make changes */
#define MX_LIBRARY_VERSION 6

/*!
* \brief For loading multiple custom op libraries in Linux, exporting same symbol multiple
Expand Down Expand Up @@ -395,8 +401,8 @@ struct MXTensor {
stype == oth.stype;
}

// For dense, data_ptr points to data.
// For sparse, data_ptr points to MXSparse.
// For dense, data_ptr points to 1D flattened tensor data
// For sparse, data_ptr points to MXSparse
void *data_ptr;

// shape is in [2,3,4] format to represent high-dim tensor
Expand Down Expand Up @@ -426,9 +432,17 @@ typedef void (*sparse_malloc_t)(void*, int, int, int, void**, int64_t**, int64_t

#if defined(__NVCC__)
typedef cudaStream_t mx_stream_t;
typedef curandStatePhilox4_32_10_t mx_gpu_rand_t;
#else
typedef void* mx_stream_t;
typedef void* mx_gpu_rand_t;
#endif
typedef std::mt19937 mx_cpu_rand_t;

/*! \brief MXNet initialized random states for each device, used for parallelism */
/* Each thread should generate random number unique sequence out of different states */
#define MX_NUM_CPU_RANDOM_STATES 1024
#define MX_NUM_GPU_RANDOM_STATES 32768

/*!
* \brief provide resource APIs memory allocation mechanism to Forward/Backward functions
Expand All @@ -437,10 +451,12 @@ class OpResource {
public:
OpResource(xpu_malloc_t cpu_malloc_fp, void* cpu_alloc_fp,
xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream,
sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp)
sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp,
void* rng_cpu_states, void* rng_gpu_states)
: cpu_malloc(cpu_malloc_fp), gpu_malloc(gpu_malloc_fp),
cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream),
sparse_malloc(sparse_malloc_fp), sparse_alloc(sparse_alloc_fp) {}
sparse_malloc(sparse_malloc_fp), sparse_alloc(sparse_alloc_fp),
rand_cpu_states(rng_cpu_states), rand_gpu_states(rng_gpu_states) {}

/*! \brief allocate cpu memory controlled by MXNet */
void* alloc_cpu(int size) {
Expand All @@ -463,6 +479,19 @@ class OpResource {
&(sparse->data), &(sparse->indices), &(sparse->indptr));
}

/*! \brief get pointer to initialized and seeded random number states located on CPU */
/* Access each state by states[id], but this id should be <= MX_NUM_CPU_RANDOM_STATES */
mx_cpu_rand_t* get_cpu_rand_states() {
return static_cast<mx_cpu_rand_t*>(rand_cpu_states);
}

/*! \brief get pointer to initialized and seeded random number states located on GPU */
/* Access each state by states[id], but this id should be <= MX_NUM_GPU_RANDOM_STATES */
/* Note that if you are using cpu build, it will return a nullptr */
mx_gpu_rand_t* get_gpu_rand_states() {
return static_cast<mx_gpu_rand_t*>(rand_gpu_states);
}

private:
/*! \brief allocation lambda function */
xpu_malloc_t cpu_malloc, gpu_malloc;
Expand All @@ -474,6 +503,8 @@ class OpResource {
sparse_malloc_t sparse_malloc;
/*! \brief lambda function to return allocated sparse memory handle */
void *sparse_alloc;
/*! \brief cpu and gpu rng fully inited and seeded states */
void *rand_cpu_states, *rand_gpu_states;
};

/*!
Expand Down Expand Up @@ -997,7 +1028,8 @@ typedef int (*opCallFComp_t)(fcomp_t fcomp, const char* const* keys,
void** in_indices, void** out_indices,
void** in_indptr, void** out_indptr,
int64_t* in_indices_shapes, int64_t* out_indices_shapes,
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes);
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
void* rng_cpu_states, void* rng_gpu_states);

#define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs"
typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const* keys,
Expand Down Expand Up @@ -1026,7 +1058,8 @@ typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op,
void** in_indices, void** out_indices,
void** in_indptr, void** out_indptr,
int64_t* in_indices_shapes, int64_t* out_indices_shapes,
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes);
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
void* rng_cpu_states, void* rng_gpu_states);

#define MXLIB_PARTREGSIZE_STR "_partRegSize"
typedef int (*partRegSize_t)(void);
Expand Down Expand Up @@ -1284,7 +1317,8 @@ extern "C" {
int* instypes, int* outstypes, void** in_indices, void** out_indices,
void** in_indptr, void** out_indptr,
int64_t* in_indices_shapes, int64_t* out_indices_shapes,
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes) {
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
void* rng_cpu_states, void* rng_gpu_states) {
// create map of attributes from list
std::map<std::string, std::string> attrs;
for (int i = 0; i < num; i++) {
Expand Down Expand Up @@ -1345,7 +1379,7 @@ extern "C" {
}

OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
cuda_stream, sparse_malloc, sparse_alloc);
cuda_stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states);
return fcomp(attrs, inputs, outputs, res);
}

Expand Down Expand Up @@ -1419,7 +1453,8 @@ extern "C" {
int* instypes, int* outstypes, void** in_indices, void** out_indices,
void** in_indptr, void** out_indptr,
int64_t* in_indices_shapes, int64_t* out_indices_shapes,
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes) {
int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
void* rng_cpu_states, void* rng_gpu_states) {
// create a vector of tensors for inputs
std::vector<MXTensor> inputs(num_in);
// create a vector for sparse inputs
Expand Down Expand Up @@ -1476,7 +1511,7 @@ extern "C" {
}

OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
stream, sparse_malloc, sparse_alloc);
stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states);

CustomStatefulOp* op_ptr = reinterpret_cast<CustomStatefulOp*>(state_op);
if (is_forward) {
Expand Down

0 comments on commit 16ddc6d

Please sign in to comment.