Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Speed fused_op compilation by caching ptx and jit-compiled functions (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
DickJC123 authored and apeforest committed Nov 19, 2019
1 parent e007dcd commit 169ed69
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 142 deletions.
4 changes: 1 addition & 3 deletions src/operator/fusion/fused_op-inl.h
Expand Up @@ -982,11 +982,9 @@ const char kernel_begin[] = R"code(
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < N; i+= gridDim.x * blockDim.x) {
int offset = i*nvec;
)code";

const char kernel_end[] = R"code(
}
const char kernel_end[] = R"code(}
}
)code";

Expand Down
50 changes: 24 additions & 26 deletions src/operator/fusion/fused_op.cc
Expand Up @@ -49,31 +49,30 @@ void FusedOpParamParser(nnvm::NodeAttrs* attrs) {
attrs->parsed = FusedOpPtr(new FusedOp(attrs, param));
}

FusedOp::FusedOp(const nnvm::NodeAttrs* attrs, const FusedOpConfig& config) {
this->inputs_ = std::vector<FusedOpEntry>(config.num_inputs);
this->outputs_ = std::vector<FusedOpEntry>(config.num_outputs);
this->subgraph_ = nnvm::Graph();
this->subgraph_.outputs = attrs->subgraphs[0]->outputs;
this->initialized_ = false;
this->cc_major_ = -1;
this->cc_minor_ = -1;
FusedOp::FusedOp(const nnvm::NodeAttrs* attrs, const FusedOpConfig& config) :
initialized_(false),
kernel_function_dev_id_(-1) {
inputs_ = std::vector<FusedOpEntry>(config.num_inputs);
outputs_ = std::vector<FusedOpEntry>(config.num_outputs);
subgraph_ = nnvm::Graph();
subgraph_.outputs = attrs->subgraphs[0]->outputs;
}

bool FusedOp::InferShape(const nnvm::NodeAttrs &attrs,
std::vector<mxnet::TShape> *in_attrs,
std::vector<mxnet::TShape> *out_attrs) {
this->subgraph_.attrs.erase("shape");
this->subgraph_.attrs.erase("shape_inputs");
subgraph_.attrs.erase("shape");
subgraph_.attrs.erase("shape_inputs");
std::vector<mxnet::TShape> input_shapes(*in_attrs);
this->subgraph_ = mxnet::exec::InferShape(std::move(this->subgraph_),
std::move(input_shapes),
"__shape__");
subgraph_ = mxnet::exec::InferShape(std::move(subgraph_),
std::move(input_shapes),
"__shape__");

const auto& g = this->subgraph_.indexed_graph();
const auto& g = subgraph_.indexed_graph();
const auto& input_nids = g.input_nodes();

std::vector<mxnet::TShape> out_shapes;
const std::vector<mxnet::TShape> shapes = this->subgraph_.GetAttr<mxnet::ShapeVector>("shape");
const std::vector<mxnet::TShape> shapes = subgraph_.GetAttr<mxnet::ShapeVector>("shape");
for (auto& e : g.outputs()) {
out_shapes.push_back(shapes[g.entry_id(e)]);
}
Expand Down Expand Up @@ -105,18 +104,18 @@ bool FusedOp::InferShape(const nnvm::NodeAttrs &attrs,
bool FusedOp::InferType(const nnvm::NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
this->subgraph_.attrs.erase("dtype");
this->subgraph_.attrs.erase("dtype_inputs");
subgraph_.attrs.erase("dtype");
subgraph_.attrs.erase("dtype_inputs");
std::vector<int> input_types(*in_attrs);
this->subgraph_ = mxnet::exec::InferType(std::move(this->subgraph_),
std::move(input_types),
"__dtype__");
subgraph_ = mxnet::exec::InferType(std::move(subgraph_),
std::move(input_types),
"__dtype__");

const auto& g = this->subgraph_.indexed_graph();
const auto& g = subgraph_.indexed_graph();
const auto& input_nids = g.input_nodes();

std::vector<int> out_types;
const std::vector<int> types = this->subgraph_.GetAttr<nnvm::DTypeVector>("dtype");
const std::vector<int> types = subgraph_.GetAttr<nnvm::DTypeVector>("dtype");
for (auto& e : g.outputs()) {
out_types.push_back(types[g.entry_id(e)]);
}
Expand Down Expand Up @@ -149,10 +148,9 @@ template <typename Attr>
std::tuple<const nnvm::NodePtr,
std::vector<Attr>,
std::vector<Attr>>
FusedOp::GetAttrs(const std::string& attr_name,
const uint32_t node_id) {
const auto& g = this->subgraph_.indexed_graph();
const std::vector<Attr> attrs = this->subgraph_.GetAttr<std::vector<Attr>>(attr_name);
FusedOp::GetAttrs(const std::string& attr_name, const uint32_t node_id) {
const auto& g = subgraph_.indexed_graph();
const std::vector<Attr> attrs = subgraph_.GetAttr<std::vector<Attr>>(attr_name);
const auto& node = g[node_id];
std::vector<Attr> inputs, outputs;
for (const auto& e : node.inputs) {
Expand Down
225 changes: 134 additions & 91 deletions src/operator/fusion/fused_op.cu
Expand Up @@ -163,9 +163,33 @@ void AddPointerAndShape(const TBlob& data,
});
}

// Obtain compilation log from the program.
std::string GetCompileLog(nvrtcProgram program) {
size_t log_size_including_null;
NVRTC_CALL(nvrtcGetProgramLogSize(program, &log_size_including_null));
// For most std::string implementations, this is probably 1 char bigger than needed. OK though.
std::string log(log_size_including_null, '\0');
NVRTC_CALL(nvrtcGetProgramLog(program, &log[0]));
// Make sure the string reflects the true size (so minus the null terminator).
log.resize(log_size_including_null - 1);
return log;
}

// Obtain compilation result (ptx assembly) from the program.
std::string GetPtx(nvrtcProgram program) {
size_t ptx_size_including_null;
NVRTC_CALL(nvrtcGetPTXSize(program, &ptx_size_including_null));
// For most std::string implementations, this is probably 1 char bigger than needed. OK though.
std::string ptx(ptx_size_including_null, '\0');
NVRTC_CALL(nvrtcGetPTX(program, &ptx[0]));
// Make sure the string reflects the true size (so minus the null terminator).
ptx.resize(ptx_size_including_null - 1);
return ptx;
}

} // namespace

void FusedOp::GenerateCode(int kernel_index, const std::vector<OpReqType> &req,
std::string FusedOp::GenerateCode(const std::vector<OpReqType> &req,
const std::vector<int> &in_dtypes,
const std::vector<int> &out_dtypes,
const std::vector<int> &in_ndims,
Expand All @@ -175,7 +199,7 @@ void FusedOp::GenerateCode(int kernel_index, const std::vector<OpReqType> &req,
const int nvec,
const std::string &kernel_name,
std::vector<uint32_t>* check_shapes) {
const auto& g = this->subgraph_.indexed_graph();
const auto& g = subgraph_.indexed_graph();
std::string code = "";
int temp_name_counter = 0;
using NodeEntry = nnvm::IndexedGraph::NodeEntry;
Expand Down Expand Up @@ -459,16 +483,11 @@ void FusedOp::GenerateCode(int kernel_index, const std::vector<OpReqType> &req,
++counter;
}

this->code_[kernel_index] = code;

// Add boilerplate and type information
if (dmlc::GetEnv("MXNET_FUSION_VERBOSE", false)) {
LOG(INFO) << code_[kernel_index];
}
std::string kernel_params = "";
std::string tensor_params = "";
nnvm::Symbol sym;
sym.outputs = this->subgraph_.outputs;
sym.outputs = subgraph_.outputs;
const std::vector<std::string> input_names = sym.ListInputNames(nnvm::Symbol::kAll);
size_t num_params = in_dtypes.size() + out_dtypes.size();
size_t i = 0;
Expand Down Expand Up @@ -513,85 +532,102 @@ void FusedOp::GenerateCode(int kernel_index, const std::vector<OpReqType> &req,
}
kernel_params += tensor_params;

code_[kernel_index] = std::string(fusion::fp16_support_string) + "\n" +
fusion::type_support_string + "\n" +
fusion::function_definitions + "\n" +
fusion::backward_function_definitions + "\n" +
aux_code + "\n" +
"__launch_bounds__(" + std::to_string(FusedOp::NTHREADS) + ")\n" +
"__global__ void FusedKernel_" + kernel_name +
"(size_t N, " + kernel_params + ") {\n" +
fusion::kernel_begin + "\n" +
code_[kernel_index] + "\n" +
fusion::kernel_end;
// Create kernel source (minus the common header)
return aux_code + "\n" +
"__launch_bounds__(" + std::to_string(FusedOp::NTHREADS) + ")\n" +
"__global__ void FusedKernel_" + kernel_name +
"(size_t N, " + kernel_params + ") {\n" +
fusion::kernel_begin + "\n" +
code + "\n" +
fusion::kernel_end;
}

void FusedOp::CompileCode(int kernel_index, const std::string &kernel_name) {
CUfunction FusedOp::CompileCode(const std::string &code,
const std::string &kernel_name,
int dev_id) {
// Guard NVRTC calls
std::lock_guard<std::mutex> lock_nvrtc(mutex_);
nvrtcProgram program;
NVRTC_CALL(
nvrtcCreateProgram(&program, // prog
&code_[kernel_index][0], // buffer
(kernel_name + "_kernel.cu").c_str(), // name
0, // num headers
NULL, // headers
NULL)); // include names
std::string gpu_arch = "--gpu-architecture=compute_" +
std::to_string(this->cc_major_) +
std::to_string(this->cc_minor_);

const char *opts[] = {gpu_arch.c_str(),
"--std=c++11",
"-default-device"};
const std::string kernel_name_demangled = "FusedKernel_" + kernel_name;
NVRTC_CALL(nvrtcAddNameExpression(program, (kernel_name_demangled).c_str()));

nvrtcResult compileResult = nvrtcCompileProgram(program, // prog
3, // num options
opts); // options
// Obtain compilation log from the program.
size_t log_size;
NVRTC_CALL(nvrtcGetProgramLogSize(program, &log_size));
std::string log(log_size, '\0');
NVRTC_CALL(nvrtcGetProgramLog(program, &log[0]));
CHECK_EQ(compileResult, NVRTC_SUCCESS)
<< "NVRTC Compilation failed. Please set environment variable MXNET_USE_FUSION to 0.\n" << log;
// Obtain PTX from the program.
size_t ptx_size;
NVRTC_CALL(nvrtcGetPTXSize(program, &ptx_size));
ptx_[kernel_index].reserve(ptx_size);
NVRTC_CALL(nvrtcGetPTX(program, &ptx_[kernel_index][0]));
const char *name;
NVRTC_CALL(nvrtcGetLoweredName(program,
kernel_name_demangled.c_str(),
&name));
kernel_name_[kernel_index] = name;
// Destroy the program.
NVRTC_CALL(nvrtcDestroyProgram(&program));
int device;
CUdevice cu_device;
CUcontext context;
CUmodule module;
CUDA_CALL(cudaGetDevice(&device));
CUDA_DRIVER_CALL(cuDeviceGet(&cu_device, device));
CUDA_DRIVER_CALL(cuDevicePrimaryCtxRetain(&context, cu_device));
CUDA_DRIVER_CALL(cuModuleLoadData(&module, &ptx_[kernel_index][0]));
CUDA_DRIVER_CALL(cuModuleGetFunction(&kernel_[kernel_index],
module,
kernel_name_[kernel_index].c_str()));
// Local class for value type of compile cache
struct KernelInfo {
std::string mangled_name;
std::string ptx;
std::vector<CUfunction> functions;
};
// Maps from the cuda source code (minus header) to the ptx and jit-compiled CUfunctions.
using KernelCache = std::map<std::string, KernelInfo>;
// Per-gpu-architecture compiled kernel cache with jit-compiled function for each device context
static std::map<int32_t, KernelCache> compiled_kernels;
int sm_arch = SMArch(dev_id);
KernelCache& compiled_kernels_this_arch = compiled_kernels[sm_arch]; // make null map as needed
KernelInfo& kinfo = compiled_kernels_this_arch[code]; // make KernelInfo as needed
if (kinfo.ptx.size() == 0) {
// It's the first time we've seen this kernel, so we need to generate the ptx and mangled_name.
static std::string common_header =
std::string(fusion::fp16_support_string) + "\n" +
fusion::type_support_string + "\n" +
fusion::function_definitions + "\n" +
fusion::backward_function_definitions + "\n";
std::string code_with_header = common_header + code;
// If verbose mode, output kernel source, though not including the common header
if (dmlc::GetEnv("MXNET_FUSION_VERBOSE", false)) {
LOG(INFO) << "\n" << std::string(80, '-') << "\n" << code;
}
if (compiled_kernels_this_arch.size() == CACHESIZE_WARN_THRESHOLD + 1 &&
dmlc::GetEnv("MXNET_FUSION_SIZE_WARNING", true)) {
LOG(WARNING) << "The number of different fused ops exceeds " << CACHESIZE_WARN_THRESHOLD
<< ". Set MXNET_FUSION_SIZE_WARNING=0 to quiet this warning.";
}
nvrtcProgram program;
NVRTC_CALL(nvrtcCreateProgram(&program, // prog
&code_with_header[0], // buffer
(kernel_name + "_kernel.cu").c_str(), // name
0, // num headers
NULL, // headers
NULL)); // include names

std::string gpu_arch_arg = "--gpu-architecture=compute_" + std::to_string(sm_arch);
const char *opts[] = {gpu_arch_arg.c_str(),
"--std=c++11",
"-default-device"};
const std::string kernel_name_demangled = "FusedKernel_" + kernel_name;
NVRTC_CALL(nvrtcAddNameExpression(program, (kernel_name_demangled).c_str()));

nvrtcResult compileResult = nvrtcCompileProgram(program, // prog
3, // num options
opts); // options
CHECK_EQ(compileResult, NVRTC_SUCCESS)
<< "NVRTC Compilation failed. Please set environment variable MXNET_USE_FUSION to 0.\n"
<< GetCompileLog(program);

kinfo.ptx = GetPtx(program);
const char *mangled_name;
NVRTC_CALL(nvrtcGetLoweredName(program,
kernel_name_demangled.c_str(),
&mangled_name));
kinfo.mangled_name = mangled_name;
// Destroy the program.
NVRTC_CALL(nvrtcDestroyProgram(&program));
}
// Ensure function array is deep enough to index by dev_id
while (kinfo.functions.size() <= static_cast<size_t>(dev_id))
kinfo.functions.push_back(static_cast<CUfunction>(nullptr));
// Jit-compile ptx for the device as needed
if (kinfo.functions[dev_id] == static_cast<CUfunction>(nullptr)) {
// Make sure driver context is set to the proper device
CUdevice cu_device;
CUcontext context;
CUDA_DRIVER_CALL(cuDeviceGet(&cu_device, dev_id));
CUDA_DRIVER_CALL(cuDevicePrimaryCtxRetain(&context, cu_device));
// Jit-compile ptx for the driver's current context
CUmodule module;
CUDA_DRIVER_CALL(cuModuleLoadData(&module, kinfo.ptx.c_str()));
CUDA_DRIVER_CALL(cuModuleGetFunction(&kinfo.functions[dev_id],
module,
kinfo.mangled_name.c_str()));
}
return kinfo.functions[dev_id];
}

bool FusedOp::CheckComputeCapability(const OpContext &ctx) {
const int dev_id = ctx.run_ctx.ctx.dev_id;
const int cc_major = ComputeCapabilityMajor(dev_id);
const int cc_minor = ComputeCapabilityMinor(dev_id);

const bool ret = cc_major == this->cc_major_ && cc_minor == this->cc_minor_;
this->cc_major_ = cc_major;
this->cc_minor_ = cc_minor;
return ret;
}

void FusedOp::CheckShapesAndTypes(const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
Expand Down Expand Up @@ -665,23 +701,30 @@ void FusedOp::Forward<gpu>(const nnvm::NodeAttrs& attrs,
const auto& node_shapes = intermediate_shapes_[0].internal_attr;
const auto& node_dtypes = intermediate_dtypes_[0].internal_attr;

// Check and save compute capability of the current GPU
if (!CheckComputeCapability(ctx)) initialized_ = false;
int dev_id = ctx.run_ctx.ctx.dev_id;

// A change between training and inference modes may require different kernel functions
initialized_ = initialized_ && (req == saved_reqs_);
saved_reqs_ = req;

if (!initialized_) {
this->GenerateCode(0, req, in_dtypes, out_dtypes, in_ndims, out_ndims,
const auto& code = GenerateCode(req, in_dtypes, out_dtypes, in_ndims, out_ndims,
node_shapes, node_dtypes, nvec, attrs.name, &check_shape_args_);
this->CompileCode(0, attrs.name);
kernel_functions_[fusion::kGeneral] = CompileCode(code, attrs.name, dev_id);
if (check_shape_args_.size() > 0) {
this->GenerateCode(1, req, in_dtypes, out_dtypes, in_ndims, out_ndims,
const auto& code = GenerateCode(req, in_dtypes, out_dtypes, in_ndims, out_ndims,
node_shapes, node_dtypes, nvec, attrs.name, NULL);
this->CompileCode(1, attrs.name);
kernel_functions_[fusion::kShapeOptimized] = CompileCode(code, attrs.name, dev_id);
}
initialized_ = true;
kernel_function_dev_id_ = dev_id;
}

// A change in device would force recompiling, but this is unexpected so signal as an error
if (dev_id != kernel_function_dev_id_)
LOG(FATAL) << "Fused op compiled for device " << kernel_function_dev_id_
<< ", not expecting switch to device " << dev_id;

Stream<gpu>* s = ctx.get_stream<gpu>();
auto stream = Stream<gpu>::GetStream(s);
std::vector<void*> args;
Expand Down Expand Up @@ -713,18 +756,18 @@ void FusedOp::Forward<gpu>(const nnvm::NodeAttrs& attrs,
for (auto &ptr : ptrs) {
args.push_back(reinterpret_cast<void *>(&ptr));
}
int kernel_index = 0;
int kernel_variant = fusion::kGeneral;
if (check_shape_args_.size() > 0) {
kernel_index = 1;
kernel_variant = fusion::kShapeOptimized;
for (const auto &shape_id : check_shape_args_) {
const auto& shape = node_shapes[shape_id];
if (shape[shape.ndim()-1] % nvec != 0) {
kernel_index = 0;
kernel_variant = fusion::kGeneral;
}
}
}
CUDA_DRIVER_CALL(
cuLaunchKernel(kernel_[kernel_index],
cuLaunchKernel(kernel_functions_[kernel_variant],
num_blocks, 1, 1, // grid dim
FusedOp::NTHREADS, 1, 1, // block dim
0, stream, // shared mem and stream
Expand Down

0 comments on commit 169ed69

Please sign in to comment.