Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix deconvolution / PR 13421 (#13433)
Browse files Browse the repository at this point in the history
* add test case

* revert refactor

* use with seed decorator

* retrigger

* remove seed

* remove iteration

* remove old test

* update deconvolution test to have filter length that triggers mkldnn reorder
  • Loading branch information
azai91 authored and anirudh2290 committed Dec 1, 2018
1 parent ff4c178 commit 79532d9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
6 changes: 4 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,14 @@ void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param,
// For inference, we want to reorder the weight array so we don't need to
// reorder data every time.
if (weight.IsDefaultData()) {
weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group);
// We also need to modify the layout on the original weight array. The
// data conversion happens after the weight array is used.
const_cast<NDArray&>(weight).MKLDNNDataReorderAsync(fwd_pd.weights_primitive_desc());
} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd_pd.weights_primitive_desc());
}
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd_pd.weights_primitive_desc());
}
auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut],
fwd_pd.diff_src_primitive_desc(), req[deconv::kOut]);
Expand Down
10 changes: 5 additions & 5 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,11 +683,11 @@ def test_deconvolution_options():
check_consistency_NxM([sym, sym_no_cudnn], ctx_list)

# 2D deconvolution
ctx_list = [{'ctx': mx.gpu(0), 'deconv_data': (2, 2, 10, 10), 'type_dict': {'deconv_data': np.float64}},
{'ctx': mx.gpu(0), 'deconv_data': (2, 2, 10, 10), 'type_dict': {'deconv_data': np.float32}},
{'ctx': mx.gpu(0), 'deconv_data': (2, 2, 10, 10), 'type_dict': {'deconv_data': np.float16}},
{'ctx': mx.cpu(0), 'deconv_data': (2, 2, 10, 10), 'type_dict': {'deconv_data': np.float64}},
{'ctx': mx.cpu(0), 'deconv_data': (2, 2, 10, 10), 'type_dict': {'deconv_data': np.float32}}]
ctx_list = [{'ctx': mx.gpu(0), 'deconv_data': (2, 8, 10, 10), 'type_dict': {'deconv_data': np.float64}},
{'ctx': mx.gpu(0), 'deconv_data': (2, 8, 10, 10), 'type_dict': {'deconv_data': np.float32}},
{'ctx': mx.gpu(0), 'deconv_data': (2, 8, 10, 10), 'type_dict': {'deconv_data': np.float16}},
{'ctx': mx.cpu(0), 'deconv_data': (2, 8, 10, 10), 'type_dict': {'deconv_data': np.float64}},
{'ctx': mx.cpu(0), 'deconv_data': (2, 8, 10, 10), 'type_dict': {'deconv_data': np.float32}}]
# Pad > 0
sym = mx.sym.Deconvolution(num_filter=2, kernel=(3,3), pad=(1,1), name='deconv')
sym_no_cudnn = mx.sym.Deconvolution(num_filter=2, kernel=(3,3), pad=(1,1), cudnn_off=True, name='deconv')
Expand Down

0 comments on commit 79532d9

Please sign in to comment.