Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Summary:
Added nhwc support for:
1. cudnn_batch_norm & cudnn_batch_norm_backward
2. cudnn_convolution_forward & cudnn_convolution_backward
3. cudnn_convolution_transpose & cudnn_convolution_transpose_backward

patching suggest_memory_format for convolution

suggest_memory_format has ambiguous meaning for two cases:
1. tensor with NCHW where C = 1.
   we could use stride of C as a hint to tell the intended memory format.
2. tensor with NCHW where H == W == 1.
   there's no way to identify the intended memory format from strides.

Currently we fallback to NCHW whenever we see contiguous tensor. Hence avoiding
ambiguity for some of the special cases.
Pull Request resolved: pytorch#23861

Differential Revision: D18263434

Pulled By: VitalyFedyunin

fbshipit-source-id: dd9f69576ec12fec879cd87a3d446931371360d9
  • Loading branch information
jjsjann123 authored and facebook-github-bot committed Nov 4, 2019
1 parent 70f3f23 commit 8160f39
Show file tree
Hide file tree
Showing 14 changed files with 301 additions and 55 deletions.
8 changes: 6 additions & 2 deletions aten/src/ATen/cudnn/Descriptors.cpp
Expand Up @@ -110,7 +110,7 @@ void FilterDescriptor::set(const at::Tensor &t, int64_t pad) {
throw std::runtime_error("cuDNN supports only up to " STR(CUDNN_DIM_MAX) " dimensions");
#undef _STR
#undef STR
if (!t.is_contiguous()) {
if (!t.is_contiguous(t.suggest_memory_format())) {
// NB: It is possible for this test to be insufficient, because the
// Tensor passed in to set the filter descriptor may not be the actual
// Tensor whose data pointer is passed to cuDNN. Nevertheless,
Expand All @@ -125,7 +125,11 @@ void FilterDescriptor::set(const at::Tensor &t, int64_t pad) {
size[i] = (int) 1;
}
dim = std::max(dim, pad);
set(getDataType(t), (int) dim, size);
cudnnTensorFormat_t filter_format = CUDNN_TENSOR_NCHW;
if (t.suggest_memory_format() == at::MemoryFormat::ChannelsLast) {
filter_format = CUDNN_TENSOR_NHWC;
}
set(getDataType(t), (int) dim, size, filter_format);
}

}}
4 changes: 2 additions & 2 deletions aten/src/ATen/cudnn/Descriptors.h
Expand Up @@ -140,8 +140,8 @@ class FilterDescriptor
void set(const at::Tensor &t, int64_t pad = 0);

private:
void set(cudnnDataType_t dataType, int dim, int* size) {
AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, CUDNN_TENSOR_NCHW, dim, size));
void set(cudnnDataType_t dataType, int dim, int* size, cudnnTensorFormat_t filter_format) {
AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size));
}
};

Expand Down
22 changes: 10 additions & 12 deletions aten/src/ATen/native/Convolution.cpp
Expand Up @@ -539,9 +539,6 @@ at::Tensor _convolution(

const bool input_is_mkldnn = input_r.is_mkldnn();
auto input = input_r;
if (!input_is_mkldnn) {
input = input.contiguous();
}
auto weight = weight_r;
auto bias = bias_r;
auto k = weight.ndimension();
Expand Down Expand Up @@ -583,15 +580,15 @@ at::Tensor _convolution(
auto dilation = params.dilation;
if (params.use_cudnn_depthwise(input, weight)) {
output = at::cudnn_convolution(
input, weight, bias,
input.contiguous(input.suggest_memory_format()), weight, bias,
padding, stride, dilation, params.groups, params.benchmark, params.deterministic);

} else if (params.use_miopen(input)){
output = at::miopen_depthwise_convolution(
input, weight, bias,
input.contiguous(), weight, bias,
padding, stride, dilation, params.groups, params.benchmark, params.deterministic);
} else {
output = at::thnn_conv_depthwise2d(input, weight, kernel_size, bias, stride, padding, dilation);
output = at::thnn_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias, stride, padding, dilation);
}
} else if (params.use_cudnn(input)) {
TORCH_CHECK(input.type() == weight.type(),
Expand All @@ -603,11 +600,11 @@ at::Tensor _convolution(

if (params.transposed) {
output = at::cudnn_convolution_transpose(
input, weight, bias,
input.contiguous(input.suggest_memory_format()), weight, bias,
params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
} else {
output = at::cudnn_convolution(
input, weight, bias,
input.contiguous(input.suggest_memory_format()), weight, bias,
params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
}
} else if (params.use_miopen(input)) {
Expand All @@ -620,11 +617,11 @@ at::Tensor _convolution(

if (params.transposed) {
output = at::miopen_convolution_transpose(
input, weight, bias,
input.contiguous(), weight, bias,
params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
} else {
output = at::miopen_convolution(
input, weight, bias,
input.contiguous(), weight, bias,
params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
}
} else if (params.use_mkldnn(input)) {
Expand All @@ -636,7 +633,7 @@ at::Tensor _convolution(
"Input type (", input.type().toString(), ") and bias type (", bias.type().toString(),
") should be the same");
if (!input_is_mkldnn) {
output = at::mkldnn_convolution(input, weight.contiguous(), bias.defined() ? bias.contiguous() : bias,
output = at::mkldnn_convolution(input.contiguous(), weight.contiguous(), bias.defined() ? bias.contiguous() : bias,
params.padding, params.stride, params.dilation, params.groups);
} else {
// do not call contiguous on mkldnn tensor
Expand All @@ -650,9 +647,10 @@ at::Tensor _convolution(
input.device().type(), input, weight, bias, params.padding, params.stride, params.groups);
} else if (params.groups == 1) {
output = at::_convolution_nogroup(
input, weight, bias, params.stride, params.padding, params.dilation, params.transposed, params.output_padding);
input.contiguous(), weight, bias, params.stride, params.padding, params.dilation, params.transposed, params.output_padding);
} else {
std::vector<Tensor> outputs(params.groups);
input = input.contiguous();
for (int g = 0; g < params.groups; ++g) {
auto input_g = subtensor(input, 1, params.groups, g);
auto weight_g = subtensor(weight, 0, params.groups, g);
Expand Down
14 changes: 10 additions & 4 deletions aten/src/ATen/native/Normalization.cpp
Expand Up @@ -354,7 +354,7 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(const Tensor
// of backends, while enabling it to keep the information about the used backend, so that it can
// use its corresponding backward implementation.
// XXX: The indices of backends need to be kept synchronized between this function and its _backward.
std::tuple<Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */,
const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
bool training, double momentum, double eps, bool cudnn_enabled) {
Expand Down Expand Up @@ -390,14 +390,16 @@ std::tuple<Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
if (use_cudnn && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()) {
return std::tuple_cat(
at::cudnn_batch_norm(
input.contiguous(), weight.contiguous(),
input.contiguous(input.suggest_memory_format()), weight.contiguous(),
bias.contiguous(),
running_mean.defined() ? running_mean.contiguous() : running_mean,
running_var.defined() ? running_var.contiguous() : running_var,
training, momentum, eps),
std::make_tuple(1));
}

Tensor reserve = at::empty({0}, input.options().dtype(kByte));

bool use_miopen = (input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& input.scalar_type() != at::kDouble
Expand All @@ -415,12 +417,14 @@ std::tuple<Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
running_mean.defined() ? running_mean.contiguous() : running_mean,
running_var.defined() ? running_var.contiguous() : running_var,
training, momentum, eps),
std::tuple<Tensor>(reserve),
std::make_tuple(2));
}

return std::tuple_cat(
at::native_batch_norm(
input, weight, bias, running_mean, running_var, training, momentum, eps),
std::tuple<Tensor>(reserve),
std::make_tuple(0));
}

Expand All @@ -429,11 +433,13 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
const Tensor& input, const Tensor& grad_output, const Tensor& weight /* optional */,
const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
const Tensor& save_mean /* optional */, const Tensor& save_var_transform /* optional */,
bool train, double epsilon, std::array<bool, 3> output_mask) {
bool train, double epsilon, std::array<bool, 3> output_mask, const Tensor &reservedSpace) {
if (impl_index == 0) {
return at::native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var_transform, train, epsilon, output_mask);
} else if (impl_index == 1) {
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon);
// TODO: _batch_norm_impl_index_backward is only used in JIT. cudnn NHWC
// format conversion is done inside cudnn_batch_norm_backward instead
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon, reservedSpace);
} else if (impl_index == 2) {
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon);
}
Expand Down

0 comments on commit 8160f39

Please sign in to comment.