This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
DType regression #3018
Merged
Merged
DType regression #3018
Changes from 2 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
761f8bd
DTypeRegressionOutput
Godricly 8922f70
Update DType test for pooling and regression
Godricly 3272038
nullptr fix
Godricly bad103e
fix infershape with {} and nullptr
Godricly 63ee12e
nullptr fix
Godricly e9ae240
Merge branch 'master' into DTypeRegression_commit
Godricly File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,7 @@ struct RegressionOutputParam : public dmlc::Parameter<RegressionOutputParam> { | |
|
||
// Special Operator to output regression value in forward | ||
// And get gradient in calculation. | ||
template<typename xpu, typename ForwardOp, typename BackwardOp> | ||
template<typename xpu, typename ForwardOp, typename BackwardOp, typename DType> | ||
class RegressionOutputOp : public Operator { | ||
public: | ||
explicit RegressionOutputOp(RegressionOutputParam param) : param_(param) {} | ||
|
@@ -48,8 +48,8 @@ class RegressionOutputOp : public Operator { | |
CHECK_EQ(in_data.size(), 2) << "RegressionOutputOp Input: [data, label]"; | ||
CHECK_EQ(out_data.size(), 1) << "RegressionOutputOp Output: [output]"; | ||
Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
Tensor<xpu, 2> data = in_data[reg_enum::kData].FlatTo2D<xpu, real_t>(s); | ||
Tensor<xpu, 2> out = out_data[reg_enum::kOut].FlatTo2D<xpu, real_t>(s); | ||
Tensor<xpu, 2, DType> data = in_data[reg_enum::kData].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> out = out_data[reg_enum::kOut].FlatTo2D<xpu, DType>(s); | ||
Assign(out, req[reg_enum::kOut], F<ForwardOp>(data)); | ||
} | ||
|
||
|
@@ -69,11 +69,11 @@ class RegressionOutputOp : public Operator { | |
Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
real_t num_output = | ||
in_data[reg_enum::kLabel].Size()/in_data[reg_enum::kLabel].shape_[0]; | ||
Tensor<xpu, 2> out = out_data[reg_enum::kOut].FlatTo2D<xpu, real_t>(s); | ||
Tensor<xpu, 2> grad = in_grad[reg_enum::kData].FlatTo2D<xpu, real_t>(s); | ||
Tensor<xpu, 2> label = in_data[reg_enum::kLabel] | ||
.get_with_shape<xpu, 2, real_t>(out.shape_, s); | ||
Assign(grad, req[reg_enum::kData], param_.grad_scale/num_output* | ||
Tensor<xpu, 2, DType> out = out_data[reg_enum::kOut].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> grad = in_grad[reg_enum::kData].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> label = in_data[reg_enum::kLabel] | ||
.get_with_shape<xpu, 2, DType>(out.shape_, s); | ||
Assign(grad, req[reg_enum::kData], scalar<DType>(param_.grad_scale/num_output)* | ||
F<BackwardOp>(out, reshape(label, grad.shape_))); | ||
} | ||
|
||
|
@@ -84,7 +84,7 @@ class RegressionOutputOp : public Operator { | |
// Decalre Factory function, used for dispatch specialization | ||
template<typename xpu> | ||
Operator* CreateRegressionOutputOp(reg_enum::RegressionOutputType type, | ||
RegressionOutputParam param); | ||
RegressionOutputParam param, int dtype); | ||
|
||
#if DMLC_USE_CXX11 | ||
template<reg_enum::RegressionOutputType type> | ||
|
@@ -129,6 +129,27 @@ class RegressionOutputProp : public OperatorProperty { | |
return true; | ||
} | ||
|
||
bool InferType(std::vector<int> *in_type, | ||
std::vector<int> *out_type, | ||
std::vector<int> *aux_type) const override { | ||
CHECK_EQ(in_type->size(), 2) << "Input:[data, label]"; | ||
int dtype = (*in_type)[0]; | ||
|
||
auto nin = in_type->size(); | ||
in_type->clear(); | ||
in_type->push_back(dtype); | ||
for (index_t i = 1; i < nin; ++i) in_type->push_back(dtype); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please follow the coding convention, use {} to wrap it, and put it in a separate line. |
||
|
||
if (dtype == -1) { | ||
LOG(FATAL) << "Input type to regression_output is not specified."; | ||
return false; | ||
} | ||
|
||
out_type->clear(); | ||
out_type->push_back(dtype); | ||
return true; | ||
} | ||
|
||
OperatorProperty* Copy() const override { | ||
auto ptr = new RegressionOutputProp<type>(); | ||
ptr->param_ = param_; | ||
|
@@ -165,7 +186,13 @@ class RegressionOutputProp : public OperatorProperty { | |
return {{in_data[reg_enum::kData], out_data[reg_enum::kOut]}}; | ||
} | ||
|
||
Operator* CreateOperator(Context ctx) const override; | ||
Operator* CreateOperator(Context ctx) const override { | ||
LOG(FATAL) << "Not Implemented."; | ||
return NULL; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nullptr |
||
} | ||
|
||
Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape, | ||
std::vector<int> *in_type) const override; | ||
|
||
protected: | ||
RegressionOutputParam param_; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,24 +11,38 @@ namespace op { | |
|
||
template<> | ||
Operator *CreateRegressionOutputOp<cpu>(reg_enum::RegressionOutputType type, | ||
RegressionOutputParam param) { | ||
switch (type) { | ||
case reg_enum::kLinear: | ||
return new RegressionOutputOp<cpu, mshadow::op::identity, mshadow::op::minus>(param); | ||
case reg_enum::kLogistic: | ||
return new RegressionOutputOp<cpu, mshadow_op::sigmoid, mshadow::op::minus>(param); | ||
case reg_enum::kMAE: | ||
return new RegressionOutputOp<cpu, mshadow::op::identity, mshadow_op::minus_sign>(param); | ||
default: | ||
LOG(FATAL) << "unknown activation type " << type; | ||
} | ||
return nullptr; | ||
RegressionOutputParam param, int dtype) { | ||
Operator *op = NULL; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use nullptr? |
||
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { | ||
switch (type) { | ||
case reg_enum::kLinear: | ||
op = new RegressionOutputOp | ||
<cpu, mshadow::op::identity, mshadow::op::minus, DType>(param); | ||
break; | ||
case reg_enum::kLogistic: | ||
op = new RegressionOutputOp | ||
<cpu, mshadow_op::sigmoid, mshadow::op::minus, DType>(param); | ||
break; | ||
case reg_enum::kMAE: | ||
op = new RegressionOutputOp | ||
<cpu, mshadow::op::identity, mshadow_op::minus_sign, DType>(param); | ||
break; | ||
default: | ||
LOG(FATAL) << "unknown RegressionOutput type " << type; | ||
} | ||
}); | ||
return op; | ||
} | ||
|
||
// DO_BIND_DISPATCH comes from operator_common.h | ||
template<reg_enum::RegressionOutputType type> | ||
Operator *RegressionOutputProp<type>::CreateOperator(Context ctx) const { | ||
DO_BIND_DISPATCH(CreateRegressionOutputOp, type, param_); | ||
Operator *RegressionOutputProp<type>::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape, | ||
std::vector<int> *in_type) const { | ||
std::vector<TShape> out_shape, aux_shape; | ||
std::vector<int> out_type, aux_type; | ||
CHECK(InferType(in_type, &out_type, &aux_type)); | ||
CHECK(InferShape(in_shape, &out_shape, &aux_shape)); | ||
DO_BIND_DISPATCH(CreateRegressionOutputOp, type, param_, (*in_type)[0]); | ||
} | ||
|
||
DMLC_REGISTER_PARAMETER(RegressionOutputParam); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct? What if Dtype is int8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your review. Which line are you worry about? I think int8 is currently not supported as we are using
MSHADOW_REAL_TYPE_SWITCH
to create operators.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, merge as is for now. We have to change a lot for the coming unit8...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep... I'll propose a pr to fix nullptr issue in other operators later. Is there any paper list to share about uint8 network?
There are some other issues to be fixed to make DType network really works, like kvstore, param init. For the lstm case, extra states data_type is need. We need to discuss more about this DType support and a more detail todo list.