Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-100] Support float16 in Correlation operator #10125

Merged
merged 1 commit into from
Mar 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions src/operator/correlation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ struct CorrelationParam : public dmlc::Parameter<CorrelationParam> {
.describe("operation type is either multiplication or subduction");
}
};
template<typename xpu>
template<typename xpu, typename DType>
class CorrelationOp : public Operator {
public:
explicit CorrelationOp(CorrelationParam param) {
Expand All @@ -79,14 +79,14 @@ class CorrelationOp : public Operator {
CHECK_EQ(in_data.size(), 2U);
CHECK_EQ(out_data.size(), 3U);
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4> data1 = in_data[Correlation::kData1].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> data2 = in_data[Correlation::kData2].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> out = out_data[Correlation::kOut].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> tmp1 = out_data[Correlation::kTemp1].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> tmp2 = out_data[Correlation::kTemp2].get<xpu, 4, real_t>(s);
tmp1 = 0.0f;
tmp2 = 0.0f;
out = 0.0f;
Tensor<xpu, 4, DType> data1 = in_data[Correlation::kData1].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> data2 = in_data[Correlation::kData2].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> out = out_data[Correlation::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> tmp1 = out_data[Correlation::kTemp1].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> tmp2 = out_data[Correlation::kTemp2].get<xpu, 4, DType>(s);
tmp1 = DType(0.0f);
tmp2 = DType(0.0f);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preferably, Use static_cast.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will address this issue shortly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like static_cast is causing some compilation errors, taking a look at it now.

out = DType(0.0f);
CHECK_EQ(data1.CheckContiguous(), true);
CHECK_EQ(data2.CheckContiguous(), true);
CHECK_EQ(out.CheckContiguous(), true);
Expand Down Expand Up @@ -124,13 +124,13 @@ class CorrelationOp : public Operator {
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4> grad_data1 = in_grad[Correlation::kData1].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> grad_data2 = in_grad[Correlation::kData2].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> out_g = out_grad[Correlation::kOut].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> tmp1 = out_data[Correlation::kTemp1].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> tmp2 = out_data[Correlation::kTemp2].get<xpu, 4, real_t>(s);
if (req[0] != kAddTo) grad_data1 = 0.0f;
if (req[1] != kAddTo) grad_data2 = 0.0f;
Tensor<xpu, 4, DType> grad_data1 = in_grad[Correlation::kData1].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> grad_data2 = in_grad[Correlation::kData2].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> out_g = out_grad[Correlation::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> tmp1 = out_data[Correlation::kTemp1].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> tmp2 = out_data[Correlation::kTemp2].get<xpu, 4, DType>(s);
if (req[0] != kAddTo) grad_data1 = DType(0.0f);
if (req[1] != kAddTo) grad_data2 = DType(0.0f);
CHECK_EQ(grad_data1.CheckContiguous(), true);
CHECK_EQ(grad_data2.CheckContiguous(), true);
CHECK_EQ(out_g.CheckContiguous(), true);
Expand Down Expand Up @@ -163,7 +163,7 @@ class CorrelationOp : public Operator {
}; // class CorrelationOp
// Decalre Factory function
template<typename xpu>
Operator* CreateOp(CorrelationParam param);
Operator* CreateOp(CorrelationParam param, int dtype);
#if DMLC_USE_CXX11
class CorrelationProp : public OperatorProperty {
public:
Expand Down Expand Up @@ -228,6 +228,22 @@ void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) overr
out_shape->push_back(Shape4(dshape1[0], paddedbottomheight, paddedbottomwidth, dshape1[1]));
return true;
}
bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
int dtype = (*in_type)[0];
type_assign(&(*in_type)[1], dtype);
type_assign(&(*out_type)[0], dtype);
type_assign(&(*out_type)[1], dtype);
type_assign(&(*out_type)[2], dtype);

TYPE_ASSIGN_CHECK(*in_type, 0, dtype);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need both type_assign and TYPE_ASSIGN_CHECK ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was advice from Jun, he suggested that we can do a mutual inference by reducing all datatypes to one and then assign the reduced type back to everything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the reduced datatype is not -1 then all inputs and outputs will then have the same datatype at the end of this function.

TYPE_ASSIGN_CHECK(*in_type, 1, dtype);
TYPE_ASSIGN_CHECK(*out_type, 0, dtype);
TYPE_ASSIGN_CHECK(*out_type, 1, dtype);
TYPE_ASSIGN_CHECK(*out_type, 2, dtype);
return dtype != -1;
}
OperatorProperty* Copy() const override {
CorrelationProp* Correlation_sym = new CorrelationProp();
Correlation_sym->param_ = this->param_;
Expand All @@ -244,7 +260,13 @@ void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) overr
return {out_grad[Correlation::kOut],
out_data[Correlation::kTemp1], out_data[Correlation::kTemp2]};
}
Operator* CreateOperator(Context ctx) const override;
Operator* CreateOperator(Context ctx) const override {
LOG(FATAL) << "Not Implemented.";
return NULL;
}

Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const override;

private:
CorrelationParam param_;
Expand Down
13 changes: 9 additions & 4 deletions src/operator/correlation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,16 @@ inline void CorrelationBackward(const Tensor<cpu, 4, Dtype> &out_grad,
namespace mxnet {
namespace op {
template<>
Operator *CreateOp<cpu>(CorrelationParam param) {
return new CorrelationOp<cpu>(param);
Operator *CreateOp<cpu>(CorrelationParam param, int dtype) {
Operator* op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new CorrelationOp<cpu, DType>(param);
});
return op;
}
Operator* CorrelationProp::CreateOperator(Context ctx) const {
DO_BIND_DISPATCH(CreateOp, param_);
Operator* CorrelationProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
}
DMLC_REGISTER_PARAMETER(CorrelationParam);
MXNET_REGISTER_OP_PROPERTY(Correlation, CorrelationProp)
Expand Down
8 changes: 6 additions & 2 deletions src/operator/correlation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,12 @@ inline void CorrelationBackward(const Tensor<gpu, 4, Dtype> &out_grad,
namespace mxnet {
namespace op {
template<>
Operator* CreateOp<gpu>(CorrelationParam param) {
return new CorrelationOp<gpu>(param);
Operator* CreateOp<gpu>(CorrelationParam param, int dtype) {
Operator* op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new CorrelationOp<gpu, DType>(param);
});
return op;
}
} // namespace op
} // namespace mxnet
25 changes: 13 additions & 12 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2161,12 +2161,12 @@ def correlation_backward(out_grad,tmp1,tmp2,data1,data2,pad_size,kernel_size,str
return tmp1_grad[:,:,pad_size:pad_size+data1.shape[2],pad_size:pad_size+data1.shape[3]],tmp2_grad[:,:,pad_size:pad_size+data1.shape[2],pad_size:pad_size+data1.shape[3]],


def unittest_correlation(data_shape,kernel_size,max_displacement,stride1,stride2,pad_size,is_multiply):
def unittest_correlation(data_shape,kernel_size,max_displacement,stride1,stride2,pad_size,is_multiply,dtype):

img1 = np.random.random(data_shape)
img1 = img1.astype(np.float32)
img1 = img1.astype(dtype)
img2 = np.random.random(data_shape)
img2 = img2.astype(np.float32)
img2 = img2.astype(dtype)

net1 = get_correlation(img1,img2,kernel_size,max_displacement,stride1,stride2,pad_size,is_multiply)
net2 = get_correlation(img1,img2,kernel_size,max_displacement,stride1,stride2,pad_size,is_multiply )
Expand Down Expand Up @@ -2198,15 +2198,16 @@ def unittest_correlation(data_shape,kernel_size,max_displacement,stride1,stride2

@with_seed()
def test_correlation():
unittest_correlation((1,3,10,10), kernel_size = 1,max_displacement = 4,stride1 = 1,stride2 = 1,pad_size = 4,is_multiply = False)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 5,stride1 = 1,stride2 = 1,pad_size = 5,is_multiply = False)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 5,stride1 = 1,stride2 = 1,pad_size = 5,is_multiply = True)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 10,stride1 = 1,stride2 = 2,pad_size = 10,is_multiply = True)
unittest_correlation((5,1,4,4), kernel_size = 3,max_displacement = 1,stride1 = 1,stride2 = 1,pad_size = 2,is_multiply = True)
unittest_correlation((5,1,4,4), kernel_size = 3,max_displacement = 1,stride1 = 2,stride2 = 1,pad_size = 2,is_multiply = True)
unittest_correlation((5,1,4,4), kernel_size = 3,max_displacement = 1,stride1 = 2,stride2 = 1,pad_size = 2,is_multiply = False)
unittest_correlation((5,1,6,4), kernel_size = 3,max_displacement = 1,stride1 = 2,stride2 = 1,pad_size = 2,is_multiply = False)
unittest_correlation((5,1,11,11), kernel_size = 5,max_displacement = 1,stride1 = 1,stride2 = 1,pad_size = 2,is_multiply = False)
for dtype in ['float16', 'float32', 'float64']:
unittest_correlation((1,3,10,10), kernel_size = 1,max_displacement = 4,stride1 = 1,stride2 = 1,pad_size = 4,is_multiply = False, dtype = dtype)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 5,stride1 = 1,stride2 = 1,pad_size = 5,is_multiply = False, dtype = dtype)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 5,stride1 = 1,stride2 = 1,pad_size = 5,is_multiply = True, dtype = dtype)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 10,stride1 = 1,stride2 = 2,pad_size = 10,is_multiply = True, dtype = dtype)
unittest_correlation((5,1,4,4), kernel_size = 3,max_displacement = 1,stride1 = 1,stride2 = 1,pad_size = 2,is_multiply = True, dtype = dtype)
unittest_correlation((5,1,4,4), kernel_size = 3,max_displacement = 1,stride1 = 2,stride2 = 1,pad_size = 2,is_multiply = True, dtype = dtype)
unittest_correlation((5,1,4,4), kernel_size = 3,max_displacement = 1,stride1 = 2,stride2 = 1,pad_size = 2,is_multiply = False, dtype = dtype)
unittest_correlation((5,1,6,4), kernel_size = 3,max_displacement = 1,stride1 = 2,stride2 = 1,pad_size = 2,is_multiply = False, dtype = dtype)
unittest_correlation((5,1,11,11), kernel_size = 5,max_displacement = 1,stride1 = 1,stride2 = 1,pad_size = 2,is_multiply = False, dtype = dtype)


@with_seed()
Expand Down