Skip to content

Commit

Permalink
Fix mkldnn reshape (apache#16455)
Browse files Browse the repository at this point in the history
Change-Id: I7e3ed84d02eaac4dbc0637167519e53e7eb8e168
  • Loading branch information
ZhennanQin authored and aaronmarkham committed Oct 16, 2019
1 parent 8c9a8b0 commit 9ec2316
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 261 deletions.
1 change: 0 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Expand Up @@ -194,7 +194,6 @@ bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input)
bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output);
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
bool SupportMKLDNNReshape(const ReshapeParam &param, const NDArray &data);
} // namespace op

static int GetTypeSize(int dtype) {
Expand Down
6 changes: 5 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_base.cc
Expand Up @@ -428,6 +428,7 @@ void FallBackCompute(Compute fn, const AttrState &attrs_states,
const std::vector<NDArray> &outputs) {
std::vector<TBlob> in_blobs(inputs.size());
std::vector<NDArray> in_bufs;
std::vector<OpReqType> new_req = req;
for (size_t i = 0; i < in_blobs.size(); i++) {
// If the input data isn't stored in the default format, we shouldn't
// call data() directly, which will change the layout of the NDArray.
Expand All @@ -452,6 +453,9 @@ void FallBackCompute(Compute fn, const AttrState &attrs_states,
// for inplace, we already converted & copied input above.
if ((req[i] == kWriteTo) || (req[i] == kWriteInplace)) {
const_cast<NDArray &>(output).InvalidateMKLDNNData();
if (req[i] == kWriteInplace) {
new_req[i] = kWriteTo;
}
} else if (req[i] == kAddTo && output.IsMKLDNNData()) {
NDArray temp = outputs[i].Reorder2Default();
temp_src.emplace_back(temp);
Expand All @@ -462,7 +466,7 @@ void FallBackCompute(Compute fn, const AttrState &attrs_states,
out_blobs[i] = output.data();
}

fn(attrs_states, ctx, in_blobs, req, out_blobs);
fn(attrs_states, ctx, in_blobs, new_req, out_blobs);
for (size_t i = 0; i < out_blobs.size(); i++) {
if (req[i] == kAddTo && outputs[i].IsMKLDNNData())
mxnet::common::CastNonDefaultStorage(temp_src, temp_dst, ctx, false);
Expand Down
48 changes: 0 additions & 48 deletions src/operator/nn/mkldnn/mkldnn_flatten-inl.h

This file was deleted.

79 changes: 0 additions & 79 deletions src/operator/nn/mkldnn/mkldnn_flatten.cc

This file was deleted.

5 changes: 0 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Expand Up @@ -131,11 +131,6 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
const OpReqType &req,
const NDArray &output);

void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &input,
const OpReqType &req,
const NDArray &output);
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
Expand Down
27 changes: 9 additions & 18 deletions src/operator/nn/mkldnn/mkldnn_reshape-inl.h
Expand Up @@ -35,30 +35,21 @@ namespace mxnet {
namespace op {

class MKLDNNReshapeFwd {
protected:
public:
MKLDNNReshapeFwd(const OpReqType &req, const NDArray &input, const NDArray &output);
int GetWorkspaceSize();
void SetNewMem(const NDArray &input, const NDArray &output, void *workspace = nullptr);
void Execute(const NDArray &input, const NDArray &output, void *workspace = nullptr);

private:
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;
std::shared_ptr<mkldnn::memory> temp_;
std::vector<mkldnn::primitive> prims_;
bool needInvalidateInput = false;

public:
MKLDNNReshapeFwd(const OpReqType &req,
const NDArray &input,
const NDArray &output);
int GetWorkspaceSize();
void SetNewMem(const NDArray &input,
const NDArray &output,
void* workspace = nullptr);
void Execute(const NDArray &input,
const NDArray &output,
void* workspace = nullptr);
};

typedef ParamOpSign<ReshapeParam> MKLDNNReshapeSignature;
MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param,
const OpReqType &req,
const NDArray &input,
typedef OpSignature MKLDNNReshapeSignature;
MKLDNNReshapeFwd &GetReshapeForward(const OpReqType &req, const NDArray &input,
const NDArray &output);

} // namespace op
Expand Down

0 comments on commit 9ec2316

Please sign in to comment.