Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Support extra inputs for subgraph ops #18779

Merged
merged 27 commits into from
Aug 14, 2020
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion example/extensions/lib_api/init_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
#include <iostream>
#include "lib_api.h"

using namespace mxnet::ext;

MXReturnValue initialize(int version) {
if (version >= 10700) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}
16 changes: 8 additions & 8 deletions example/extensions/lib_custom_op/gemm_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include <utility>
#include "lib_api.h"

using namespace mxnet::ext;

// main matrix multiplication routine
void gemm(const float* A, const float* B, float* C,
const unsigned n, const unsigned k, const unsigned m) {
Expand Down Expand Up @@ -128,12 +130,12 @@ MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attr
std::vector<int> *outtypes) {
// validate inputs
if (intypes->size() != 2) {
std::cout << "Expected 2 inputs to inferType" << std::endl;
MX_ERROR_MSG << "Expected 2 inputs to inferType";
return MX_FAIL;
}
for (unsigned i = 0; i < intypes->size(); i++) {
if (intypes->at(i) != kFloat32) {
std::cout << "Expected input " << i << " to have float32 type" << std::endl;
MX_ERROR_MSG << "Expected input " << i << " to have float32 type";
return MX_FAIL;
}
}
Expand All @@ -147,11 +149,11 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
std::vector<std::vector<unsigned int>>* outshapes) {
// validate inputs
if (inshapes->size() != 2) {
std::cout << "Expected 2 inputs to inferShape" << std::endl;
MX_ERROR_MSG << "Expected 2 inputs to inferShape";
return MX_FAIL;
}
if (inshapes->at(0).size() != 2 || inshapes->at(1).size() != 2) {
std::cout << "Expected 2D matrices for both inputs to inferShape" << std::endl;
MX_ERROR_MSG << "Expected 2D matrices for both inputs to inferShape";
return MX_FAIL;
}

Expand All @@ -160,7 +162,7 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
unsigned kk = inshapes->at(1)[0];
unsigned m = inshapes->at(1)[1];
if (k != kk) {
std::cout << "Exected first input axis 1 equals to second input axis 0" << std::endl;
MX_ERROR_MSG << "Exected first input axis 1 equals to second input axis 0";
return MX_FAIL;
}

Expand Down Expand Up @@ -196,8 +198,6 @@ class MyStatefulGemm : public CustomStatefulOp {
return backward(attrs_, inputs, outputs, op_res);
}

~MyStatefulGemm() = default;

private:
int count;
const std::unordered_map<std::string, std::string> attrs_;
Expand Down Expand Up @@ -231,7 +231,7 @@ MXReturnValue initialize(int version) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}
4 changes: 3 additions & 1 deletion example/extensions/lib_custom_op/relu_lib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <iostream>
#include "lib_api.h"

using namespace mxnet::ext;

#define NumThreadPerBlock 256 // mxnet recommended cuda thread number per block

__global__ void relu_gpu_forward(float *out, float *in, int64_t N) {
Expand Down Expand Up @@ -263,7 +265,7 @@ MXReturnValue initialize(int version) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}
22 changes: 12 additions & 10 deletions example/extensions/lib_custom_op/transposecsr_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include <utility>
#include "lib_api.h"

using namespace mxnet::ext;

void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) {
MXSparse* A = src.data<MXSparse>();
MXSparse* B = dst.data<MXSparse>();
Expand Down Expand Up @@ -71,11 +73,11 @@ MXReturnValue forward(const std::unordered_map<std::string, std::string>& attrs,
// The data types and storage types of inputs and outputs should be the same.
if(inputs->at(0).dtype != outputs->at(0).dtype ||
inputs->at(0).stype != outputs->at(0).stype) {
std::cout << "Error! Expected all inputs and outputs to be the same type."
<< "Found input storage type:" << inputs->at(0).stype
<< " Found output storage type:" << outputs->at(0).stype
<< " Found input data type:" << inputs->at(0).dtype
<< " Found output data type:" << outputs->at(0).dtype << std::endl;
MX_ERROR_MSG << "Error! Expected all inputs and outputs to be the same type."
<< "Found input storage type:" << inputs->at(0).stype
<< " Found output storage type:" << outputs->at(0).stype
<< " Found input data type:" << inputs->at(0).dtype
<< " Found output data type:" << outputs->at(0).dtype;
return MX_FAIL;
}

Expand All @@ -102,11 +104,11 @@ MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attr
std::vector<int>* outtypes) {
// validate inputs
if (intypes->size() != 1) {
std::cout << "Expected 1 inputs to inferType" << std::endl;
MX_ERROR_MSG << "Expected 1 inputs to inferType";
return MX_FAIL;
}
if (intypes->at(0) != kFloat32) {
std::cout << "Expected input to have float32 type" << std::endl;
MX_ERROR_MSG << "Expected input to have float32 type";
return MX_FAIL;
}

Expand All @@ -118,7 +120,7 @@ MXReturnValue inferSType(const std::unordered_map<std::string, std::string>& att
std::vector<int>* instypes,
std::vector<int>* outstypes) {
if (instypes->at(0) != kCSRStorage) {
std::cout << "Expected storage type is kCSRStorage" << std::endl;
MX_ERROR_MSG << "Expected storage type is kCSRStorage";
return MX_FAIL;
}
outstypes->at(0) = instypes->at(0);
Expand All @@ -130,7 +132,7 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
std::vector<std::vector<unsigned int>>* outshapes) {
// validate inputs
if (inshapes->size() != 1) {
std::cout << "Expected 1 inputs to inferShape" << std::endl;
MX_ERROR_MSG << "Expected 1 inputs to inferShape";
return MX_FAIL;
}

Expand Down Expand Up @@ -195,7 +197,7 @@ MXReturnValue initialize(int version) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}
22 changes: 12 additions & 10 deletions example/extensions/lib_custom_op/transposerowsp_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include <utility>
#include "lib_api.h"

using namespace mxnet::ext;

void transpose(MXTensor& src, MXTensor& dst, const OpResource& res) {
MXSparse* A = src.data<MXSparse>();
MXSparse* B = dst.data<MXSparse>();
Expand Down Expand Up @@ -74,11 +76,11 @@ MXReturnValue forward(const std::unordered_map<std::string, std::string>& attrs,
// The data types and storage types of inputs and outputs should be the same.
if(inputs->at(0).dtype != outputs->at(0).dtype ||
inputs->at(0).stype != outputs->at(0).stype) {
std::cout << "Error! Expected all inputs and outputs to be the same type."
<< "Found input storage type:" << inputs->at(0).stype
<< " Found output storage type:" << outputs->at(0).stype
<< " Found input data type:" << inputs->at(0).dtype
<< " Found output data type:" << outputs->at(0).dtype << std::endl;
MX_ERROR_MSG << "Error! Expected all inputs and outputs to be the same type."
<< "Found input storage type:" << inputs->at(0).stype
<< " Found output storage type:" << outputs->at(0).stype
<< " Found input data type:" << inputs->at(0).dtype
<< " Found output data type:" << outputs->at(0).dtype;
return MX_FAIL;
}
transpose(inputs->at(0), outputs->at(0), res);
Expand All @@ -104,11 +106,11 @@ MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attr
std::vector<int>* outtypes) {
// validate inputs
if (intypes->size() != 1) {
std::cout << "Expected 1 inputs to inferType" << std::endl;
MX_ERROR_MSG << "Expected 1 inputs to inferType";
return MX_FAIL;
}
if (intypes->at(0) != kFloat32) {
std::cout << "Expected input to have float32 type" << std::endl;
MX_ERROR_MSG << "Expected input to have float32 type";
return MX_FAIL;
}

Expand All @@ -120,7 +122,7 @@ MXReturnValue inferSType(const std::unordered_map<std::string, std::string>& att
std::vector<int>* instypes,
std::vector<int>* outstypes) {
if (instypes->at(0) != kRowSparseStorage) {
std::cout << "Expected storage type is kRowSparseStorage" << std::endl;
MX_ERROR_MSG << "Expected storage type is kRowSparseStorage";
return MX_FAIL;
}
outstypes->at(0) = instypes->at(0);
Expand All @@ -132,7 +134,7 @@ MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& att
std::vector<std::vector<unsigned int>>* outshapes) {
// validate inputs
if (inshapes->size() != 1) {
std::cout << "Expected 1 inputs to inferShape" << std::endl;
MX_ERROR_MSG << "Expected 1 inputs to inferShape";
return MX_FAIL;
}

Expand Down Expand Up @@ -197,7 +199,7 @@ MXReturnValue initialize(int version) {
std::cout << "MXNet version " << version << " supported" << std::endl;
return MX_SUCCESS;
} else {
std::cout << "MXNet version " << version << " not supported" << std::endl;
MX_ERROR_MSG << "MXNet version " << version << " not supported";
return MX_FAIL;
}
}