Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[1.7] MXNet Extension PRs (#17623, #17569, #17762) (#18063)
Browse files Browse the repository at this point in the history
* Dynamic subgraph compile support (#17623)

This PR adds support for passing the NDArrays from the existing optimize_for API down to the reviewSubgraph function in an external library. It also adds a new API for HybridBlock called optimize_for that can partition the model without running a forward pass.

Feature changes

    Adds new API to HybridBlock optimize_for that partitions the model but does not call the cachedOp
    Modifies the subgraph library example to optionally require args to be provided
    Adds annotation on subgraph inputs for the name of the original param so that inputs can be mapped and passes annotations to input nodes of subgraphs
    Adds support for tensors in MKLDNN format, calls Reorder2Default

New tests

    Adds a new test to partition operators that directly consume params
    add a new model to test where ops to be partitioned have args/params

Bug Fixes

    fixes bug in passing ids vector by value instead of by reference
    fixes bug in passing copies of attributes instead of by reference
    fixes bug where _cached_graph was not updated after partitioning
    fixes memory leak where user-specified attributes on subgraph ops were not freed if subgraph was rejected
    fixes problem incorrectly indexing into shape/dtype maps when annotating the graph

Docs

    Updates the README doc with the latest changes described above

* Adding sparse support to MXTensor for custom operators (#17569)

* Added enum for sparse storage

* Add structure for Dense and Sparse

* redesign the data structure for MXSparse

* pull out aux data from sparse NDArray

* Added more sparse arguments to API interface

* Passed sparse from c_api to lib_api.h and set in MXTensor

* Fix indent

* fix segfault

* Fix NDArray to MXTensor errors

* Add a sample of sparse(CSR) transpose

* Make CSR transpose temporarily work by hardcoding

* Fixed sparse output size(Refined)

* Add tests for symbolic and stateful ops

* Added a sample for row sparse transpose

* Added real row sparse transpose

* Fix output size issue by adding lambda for CheckAndAlloc()

* Fix mixed storage formats error

* Added infer storage type function

* resolve comments

* Set inferSType as optional function

* Resolve comments

* Add error messages

* Resolve comments

* verify transpose ops results

* fix sanity check

* update MX_LIBRARY_VERSION to 5

* Custom Operator Random Number Generator Support (#17762)

Add random number generator support for custom operator libraries.

Design: We pass from MXNet the initialized and seeded states, located on CPU and GPU, to custom library. So user could use those seeds to generate deterministic values from a given seed passed to MXNet. Basically this workflow:

mx.random.seed(128)
r1 = mx.nd.some_custom_random_op(data)
mx.random.seed(128)
r2 = mx.nd.some_custom_random_op(data)
assert (r1 == r2)

This PR does not let custom library generate exactly the same sequence of random numbers comparing to MXNet

This is a continuation of the custom operator project #15921 and #17270

Co-authored-by: guanxinq <58794120+guanxinq@users.noreply.github.com>
Co-authored-by: Ziyi Mu <ziyi.mu@columbia.edu>
  • Loading branch information
3 people committed Apr 16, 2020
1 parent 13f5ad9 commit bf99f27
Show file tree
Hide file tree
Showing 26 changed files with 1,779 additions and 220 deletions.
15 changes: 5 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ if(USE_CUDA)
message("-- CUDA: Using the following NVCC architecture flags ${CUDA_ARCH_FLAGS}")
set(arch_code_list)
foreach(arch_str ${CUDA_ARCH_FLAGS})
if((arch_str MATCHES ".*sm_[0-9]+"))
if((arch_str MATCHES ".*sm_[0-9]+"))
string( REGEX REPLACE ".*sm_([0-9]+)" "\\1" arch_code ${arch_str} )
list(APPEND arch_code_list ${arch_code})
endif()
Expand Down Expand Up @@ -730,26 +730,21 @@ elseif(MSVC)

endif()

# extension libraries (custom operators, custom subgraphs) are built by default
add_library(customop_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc)
add_library(subgraph_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_subgraph/subgraph_lib.cc)
target_include_directories(customop_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
target_include_directories(subgraph_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
if (USE_CUDA)
if(USE_CUDA)
add_library(customop_gpu_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/relu_lib.cu)
target_include_directories(customop_gpu_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet)
endif()
if(UNIX)
target_compile_options(customop_lib PUBLIC -shared)
target_compile_options(subgraph_lib PUBLIC -shared)
if (USE_CUDA)
target_compile_options(customop_gpu_lib PUBLIC -shared)
endif()
elseif(MSVC)
if(MSVC)
target_compile_options(customop_lib PUBLIC /LD)
target_compile_options(subgraph_lib PUBLIC /LD)
set_target_properties(customop_lib PROPERTIES PREFIX "lib")
set_target_properties(subgraph_lib PROPERTIES PREFIX "lib")
if (USE_CUDA)
if(USE_CUDA)
target_compile_options(customop_gpu_lib PUBLIC "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-fPIC>")
set_target_properties(customop_gpu_lib PROPERTIES PREFIX "lib")
endif()
Expand Down
10 changes: 8 additions & 2 deletions example/extensions/lib_custom_op/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@
# specific language governing permissions and limitations
# under the License.

all: gemm_lib relu_lib
all: gemm_lib relu_lib transposecsr_lib transposerowsp_lib

gemm_lib:
g++ -shared -fPIC -std=c++11 gemm_lib.cc -o libgemm_lib.so -I ../../../include/mxnet

relu_lib:
nvcc -shared -std=c++11 -Xcompiler -fPIC relu_lib.cu -o librelu_lib.so -I ../../../include/mxnet

transposecsr_lib:
g++ -shared -fPIC -std=c++11 transposecsr_lib.cc -o libtransposecsr_lib.so -I ../../../include/mxnet

transposerowsp_lib:
g++ -shared -fPIC -std=c++11 transposerowsp_lib.cc -o libtransposerowsp_lib.so -I ../../../include/mxnet

clean:
rm -rf libgemm_lib.so librelu_lib.so
rm -rf libgemm_lib.so librelu_lib.so libtransposecsr_lib.so libtransposerowsp_lib.so
90 changes: 83 additions & 7 deletions example/extensions/lib_custom_op/relu_lib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
/*!
* Copyright (c) 2020 by Contributors
* \file relu_lib.cu
* \brief simple custom relu operator implemented using CUDA function
* \brief simple custom relu and noisy relu operator implemented using CUDA function
*/

#include <iostream>
#include "lib_api.h"

#define NumThreadPerBlock 256 // mxnet recommended cuda thread number per block

__global__ void relu_gpu_forward(float *out, float *in, int64_t N) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N)
Expand Down Expand Up @@ -72,9 +74,9 @@ MXReturnValue forwardGPU(std::map<std::string, std::string> attrs,

mx_stream_t cuda_stream = res.get_cuda_stream();
int64_t N = inputs[0].size();
int block = 256;
int grid = (N + (block - 1)) / block;
relu_gpu_forward<<<grid,block,0,cuda_stream>>>(out_data, in_data, N);
int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock;

relu_gpu_forward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(out_data, in_data, N);

return MX_SUCCESS;
}
Expand All @@ -89,9 +91,9 @@ MXReturnValue backwardGPU(std::map<std::string, std::string> attrs,

mx_stream_t cuda_stream = res.get_cuda_stream();
int64_t N = inputs[0].size();
int block = 256;
int grid = (N + (block - 1)) / block;
relu_gpu_backward<<<grid,block,0,cuda_stream>>>(in_grad, out_grad, in_data, N);
int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock;

relu_gpu_backward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(in_grad, out_grad, in_data, N);

return MX_SUCCESS;
}
Expand Down Expand Up @@ -180,6 +182,80 @@ REGISTER_OP(my_state_relu)
.setCreateOpState(createOpStateCPU, "cpu")
.setCreateOpState(createOpStateGPU, "gpu");

/*
* Below is noisy ReLU operator example
* noisy ReLU is made from ReLU extended to include Gaussian noise
* forward - add Gaussian noise generated from normal distribution to each unit
* backward - gradient doesn't need to change since noise is constant
*/

#define NumRandomPerThread 64 // mxnet recommended random numbers generated per thread

__global__ void noisy_relu_gpu_forward(float *out, float *in, int64_t N, mx_gpu_rand_t* states, int step) {
// the launcher logic ensures tid less than NumGPURandomStates
int tid = blockIdx.x * blockDim.x + threadIdx.x;
// each thread generates unique sequence of random numbers
mx_gpu_rand_t thread_state = states[tid];
// each thread works on <step> number of calculation
int start = tid * step;
int end = start + step;
for (int i=start; i<end && i<N; ++i) {
float noise = curand_normal(&thread_state);
out[i] = in[i] + noise > 0 ? in[i] + noise : 0;
}
}

MXReturnValue noisyForwardCPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();

mx_cpu_rand_t* states = res.get_cpu_rand_states();
std::normal_distribution<float> dist_normal;

for (int i=0; i<inputs[0].size(); ++i) {
float noise = dist_normal(*states);
out_data[i] = in_data[i] + noise > 0 ? in_data[i] + noise : 0;
}
return MX_SUCCESS;
}

MXReturnValue noisyForwardGPU(std::map<std::string, std::string> attrs,
std::vector<MXTensor> inputs,
std::vector<MXTensor> outputs,
OpResource res) {
float* in_data = inputs[0].data<float>();
float* out_data = outputs[0].data<float>();

mx_stream_t cuda_stream = res.get_cuda_stream();
int64_t N = inputs[0].size();

// below is mxnet recommended workflow to parallel random number generating
int nthread = (N + NumRandomPerThread - 1) / NumRandomPerThread;
// we should not launch more threads than mxnet supported random number GPU states
int num_thread_need = nthread < MX_NUM_GPU_RANDOM_STATES ? nthread : MX_NUM_GPU_RANDOM_STATES;
// each cuda thread processes [step * tid, step * id + step) snippet of input tensor
int step = (N + num_thread_need - 1) / num_thread_need;
// this can ensure number of parallel threads less than mxnet supported random number states
int num_block = (num_thread_need + NumThreadPerBlock - 1) / NumThreadPerBlock;

noisy_relu_gpu_forward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(
out_data, in_data, N, res.get_gpu_rand_states(), step);

return MX_SUCCESS;
}

REGISTER_OP(my_noisy_relu)
.setParseAttrs(parseAttrs)
.setInferType(inferType)
.setInferShape(inferShape)
.setForward(noisyForwardCPU, "cpu")
.setForward(noisyForwardGPU, "gpu")
.setBackward(backwardCPU, "cpu")
.setBackward(backwardGPU, "gpu");

MXReturnValue initialize(int version) {
if (version >= 10400) {
std::cout << "MXNet version " << version << " supported" << std::endl;
Expand Down
43 changes: 27 additions & 16 deletions example/extensions/lib_custom_op/test_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
a = mx.nd.array([[-2,-1],[1,2]], ctx=mx.cpu())
b = mx.nd.array([[-2,-1],[1,2]], ctx=mx.gpu())

print("--------start ndarray compute---------")
print("--------ndarray compute---------")
print(mx.nd.my_relu(a))
print(mx.nd.my_relu(b))
print(mx.nd.my_state_relu(a))
print(mx.nd.my_state_relu(b))

print("--------start symbolic compute--------")
print("--------symbolic compute--------")
c = mx.sym.Variable('c')
d = mx.sym.Variable('d')
e = mx.sym.my_relu(c)
Expand All @@ -55,30 +55,41 @@
print(out)
print(out_base)

print("--------start backward compute--------")
print("--------backward compute--------")
out_grad = mx.nd.ones((2,2), ctx=mx.gpu())
exe.backward([out_grad])
exe_base.backward([out_grad])
print(in_grad)
print(in_grad_base)

print("--------start testing larger ndarray---------")
a = mx.nd.uniform(shape=(100,100,100), ctx=mx.cpu())
print("--------test ndarray with size of 1 million---------")
b = mx.nd.uniform(shape=(100,100,100), ctx=mx.gpu())
mx.nd.waitall()
t1 = time.time()
r1 = mx.nd.my_relu(a)
r1 = mx.nd.my_relu(b)
mx.nd.waitall()
t2 = time.time()
r2 = mx.nd.my_relu(b)
r2 = mx.nd.relu(b)
mx.nd.waitall()
t3 = time.time()
r3 = mx.nd.relu(b)
mx.nd.waitall()
t4 = time.time()
print("CPU running time:")
print(t2 - t1)
print("GPU running time:")
print(t3 - t2)
print("Baseline GPU running time:")
print(t4 - t3)
print("Custom ReLU running time in ms:")
print((t2 - t1) * 1000)
print("Native ReLU running time in ms:")
print((t3 - t2) * 1000)

print("--------test noisy relu identical sequence---------")

a = mx.nd.ones(shape=(13,5), ctx=mx.cpu())
b = mx.nd.ones(shape=(13,5), ctx=mx.gpu())

mx.random.seed(128, ctx=mx.cpu())
print(mx.nd.my_noisy_relu(a))

mx.random.seed(128, ctx=mx.cpu())
print(mx.nd.my_noisy_relu(a))

mx.random.seed(128, ctx=mx.gpu())
print(mx.nd.my_noisy_relu(b))

mx.random.seed(128, ctx=mx.gpu())
print(mx.nd.my_noisy_relu(b))
78 changes: 78 additions & 0 deletions example/extensions/lib_custom_op/test_transposecsr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python3

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=arguments-differ

# This test checks dynamic loading of custom library into MXNet
# and checks end to end compute of a simple 2D gemm custom op

import mxnet as mx
import os

#load library
if (os.name=='posix'):
path = os.path.abspath('libtransposecsr_lib.so')
mx.library.load(path)
elif (os.name=='nt'):
path = os.path.abspath('libtransposecsr_lib.dll')
mx.library.load(path)

a = mx.nd.array([[1,3,0,2,1],[0,1,0,0,0],[0,2,4,5,3]])
a = a.tostype('csr')
print("--------Input CSR Array---------")
print("data:", a.data.asnumpy())
print("indices:", a.indices.asnumpy())
print("indptr:", a.indptr.asnumpy())

print("--------Start NDArray Compute---------")
b = mx.nd.my_transposecsr(a)
print("Compute Results:")
print("data:", b.data.asnumpy())
print("indices:", b.indices.asnumpy())
print("indptr:", b.indptr.asnumpy())

print("Stateful Compute Result:")
c = mx.nd.my_state_transposecsr(a, test_kw=100)
print("data:", c.data.asnumpy())
print("indices:", c.indices.asnumpy())
print("indptr:", c.indptr.asnumpy())

print("--------start symbolic compute--------")
d = mx.sym.Variable('d')
e = mx.sym.my_transposecsr(d)
f = mx.sym.my_state_transposecsr(d, test_kw=200)

exe = e.bind(ctx=mx.cpu(),args={'d':a})
exe2 = f.bind(ctx=mx.cpu(),args={'d':a})
out = exe.forward()
print("Compute Results:")
print("data:", out[0].data.asnumpy())
print("indices:", out[0].indices.asnumpy())
print("indptr:", out[0].indptr.asnumpy())

out2 = exe2.forward()
out2 = exe2.forward()
print("Stateful Compute Result:")
print("data:", out2[0].data.asnumpy())
print("indices:", out2[0].indices.asnumpy())
print("indptr:", out2[0].indptr.asnumpy())

print("--------Baseline(dense)--------")
print(mx.nd.transpose(a.tostype('default')))

0 comments on commit bf99f27

Please sign in to comment.