diff --git a/aten/src/ATen/cudnn/Descriptors.cpp b/aten/src/ATen/cudnn/Descriptors.cpp index eef5e52581d0b..b8a0491b82974 100644 --- a/aten/src/ATen/cudnn/Descriptors.cpp +++ b/aten/src/ATen/cudnn/Descriptors.cpp @@ -110,7 +110,7 @@ void FilterDescriptor::set(const at::Tensor &t, int64_t pad) { throw std::runtime_error("cuDNN supports only up to " STR(CUDNN_DIM_MAX) " dimensions"); #undef _STR #undef STR - if (!t.is_contiguous()) { + if (!t.is_contiguous(t.suggest_memory_format())) { // NB: It is possible for this test to be insufficient, because the // Tensor passed in to set the filter descriptor may not be the actual // Tensor whose data pointer is passed to cuDNN. Nevertheless, @@ -125,7 +125,11 @@ void FilterDescriptor::set(const at::Tensor &t, int64_t pad) { size[i] = (int) 1; } dim = std::max(dim, pad); - set(getDataType(t), (int) dim, size); + cudnnTensorFormat_t filter_format = CUDNN_TENSOR_NCHW; + if (t.suggest_memory_format() == at::MemoryFormat::ChannelsLast) { + filter_format = CUDNN_TENSOR_NHWC; + } + set(getDataType(t), (int) dim, size, filter_format); } }} diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h index c6f48fe50c6e2..c65bf2351b4f4 100644 --- a/aten/src/ATen/cudnn/Descriptors.h +++ b/aten/src/ATen/cudnn/Descriptors.h @@ -140,8 +140,8 @@ class FilterDescriptor void set(const at::Tensor &t, int64_t pad = 0); private: - void set(cudnnDataType_t dataType, int dim, int* size) { - AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, CUDNN_TENSOR_NCHW, dim, size)); + void set(cudnnDataType_t dataType, int dim, int* size, cudnnTensorFormat_t filter_format) { + AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size)); } }; diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 630301ba6e019..48e3ad1715532 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -539,9 +539,6 @@ at::Tensor _convolution( const bool input_is_mkldnn = input_r.is_mkldnn(); auto input = input_r; - if (!input_is_mkldnn) { - input = input.contiguous(); - } auto weight = weight_r; auto bias = bias_r; auto k = weight.ndimension(); @@ -583,15 +580,15 @@ at::Tensor _convolution( auto dilation = params.dilation; if (params.use_cudnn_depthwise(input, weight)) { output = at::cudnn_convolution( - input, weight, bias, + input.contiguous(input.suggest_memory_format()), weight, bias, padding, stride, dilation, params.groups, params.benchmark, params.deterministic); } else if (params.use_miopen(input)){ output = at::miopen_depthwise_convolution( - input, weight, bias, + input.contiguous(), weight, bias, padding, stride, dilation, params.groups, params.benchmark, params.deterministic); } else { - output = at::thnn_conv_depthwise2d(input, weight, kernel_size, bias, stride, padding, dilation); + output = at::thnn_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias, stride, padding, dilation); } } else if (params.use_cudnn(input)) { TORCH_CHECK(input.type() == weight.type(), @@ -603,11 +600,11 @@ at::Tensor _convolution( if (params.transposed) { output = at::cudnn_convolution_transpose( - input, weight, bias, + input.contiguous(input.suggest_memory_format()), weight, bias, params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic); } else { output = at::cudnn_convolution( - input, weight, bias, + input.contiguous(input.suggest_memory_format()), weight, bias, params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic); } } else if (params.use_miopen(input)) { @@ -620,11 +617,11 @@ at::Tensor _convolution( if (params.transposed) { output = at::miopen_convolution_transpose( - input, weight, bias, + input.contiguous(), weight, bias, params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic); } else { output = at::miopen_convolution( - input, weight, bias, + input.contiguous(), weight, bias, params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic); } } else if (params.use_mkldnn(input)) { @@ -636,7 +633,7 @@ at::Tensor _convolution( "Input type (", input.type().toString(), ") and bias type (", bias.type().toString(), ") should be the same"); if (!input_is_mkldnn) { - output = at::mkldnn_convolution(input, weight.contiguous(), bias.defined() ? bias.contiguous() : bias, + output = at::mkldnn_convolution(input.contiguous(), weight.contiguous(), bias.defined() ? bias.contiguous() : bias, params.padding, params.stride, params.dilation, params.groups); } else { // do not call contiguous on mkldnn tensor @@ -650,9 +647,10 @@ at::Tensor _convolution( input.device().type(), input, weight, bias, params.padding, params.stride, params.groups); } else if (params.groups == 1) { output = at::_convolution_nogroup( - input, weight, bias, params.stride, params.padding, params.dilation, params.transposed, params.output_padding); + input.contiguous(), weight, bias, params.stride, params.padding, params.dilation, params.transposed, params.output_padding); } else { std::vector outputs(params.groups); + input = input.contiguous(); for (int g = 0; g < params.groups; ++g) { auto input_g = subtensor(input, 1, params.groups, g); auto weight_g = subtensor(weight, 0, params.groups, g); diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index b7cf25e50c2b6..f3bd190dcfeef 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -354,7 +354,7 @@ std::tuple batch_norm_backward_cpu_template(const Tensor // of backends, while enabling it to keep the information about the used backend, so that it can // use its corresponding backward implementation. // XXX: The indices of backends need to be kept synchronized between this function and its _backward. -std::tuple _batch_norm_impl_index( +std::tuple _batch_norm_impl_index( const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */, const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */, bool training, double momentum, double eps, bool cudnn_enabled) { @@ -390,7 +390,7 @@ std::tuple _batch_norm_impl_index( if (use_cudnn && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()) { return std::tuple_cat( at::cudnn_batch_norm( - input.contiguous(), weight.contiguous(), + input.contiguous(input.suggest_memory_format()), weight.contiguous(), bias.contiguous(), running_mean.defined() ? running_mean.contiguous() : running_mean, running_var.defined() ? running_var.contiguous() : running_var, @@ -398,6 +398,8 @@ std::tuple _batch_norm_impl_index( std::make_tuple(1)); } + Tensor reserve = at::empty({0}, input.options().dtype(kByte)); + bool use_miopen = (input.is_cuda() && input.dim() <= MIOPEN_DIM_MAX && input.scalar_type() != at::kDouble @@ -415,12 +417,14 @@ std::tuple _batch_norm_impl_index( running_mean.defined() ? running_mean.contiguous() : running_mean, running_var.defined() ? running_var.contiguous() : running_var, training, momentum, eps), + std::tuple(reserve), std::make_tuple(2)); } return std::tuple_cat( at::native_batch_norm( input, weight, bias, running_mean, running_var, training, momentum, eps), + std::tuple(reserve), std::make_tuple(0)); } @@ -429,11 +433,13 @@ std::tuple _batch_norm_impl_index_backward( const Tensor& input, const Tensor& grad_output, const Tensor& weight /* optional */, const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */, const Tensor& save_mean /* optional */, const Tensor& save_var_transform /* optional */, - bool train, double epsilon, std::array output_mask) { + bool train, double epsilon, std::array output_mask, const Tensor &reservedSpace) { if (impl_index == 0) { return at::native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var_transform, train, epsilon, output_mask); } else if (impl_index == 1) { - return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon); + // TODO: _batch_norm_impl_index_backward is only used in JIT. cudnn NHWC + // format conversion is done inside cudnn_batch_norm_backward instead + return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon, reservedSpace); } else if (impl_index == 2) { return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon); } diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp index e0b104f9b32fc..07595534fe33c 100644 --- a/aten/src/ATen/native/cudnn/BatchNorm.cpp +++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp @@ -9,7 +9,7 @@ namespace at { namespace native { // See Note [ATen preprocessor philosophy] -std::tuple cudnn_batch_norm( +std::tuple cudnn_batch_norm( const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double exponential_average_factor, double epsilon) { @@ -20,7 +20,7 @@ std::tuple cudnn_batch_norm_backward( const Tensor& input, const Tensor& grad_output, const Tensor& weight, const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_var, - double epsilon) { + double epsilon, const Tensor& reservedSpace) { AT_ERROR("cudnn_batch_norm_backward: ATen not compiled with cuDNN support"); } @@ -49,7 +49,7 @@ Tensor expandScale(const Tensor& t, int64_t dim) { } // namespace -std::tuple cudnn_batch_norm( +std::tuple cudnn_batch_norm( const Tensor& input_t, const Tensor& weight_t, const Tensor& bias_t, const Tensor& running_mean_t, const Tensor& running_var_t, bool training, double exponential_average_factor, double epsilon) @@ -74,7 +74,10 @@ std::tuple cudnn_batch_norm( } checkAllSameType(c, {weight, bias, running_mean, running_var}); // TODO: is weight required to be contiguous? - checkAllContiguous(c, {input, weight, bias, running_mean, running_var}); + checkAllContiguous(c, {weight, bias, running_mean, running_var}); + // TODO: TensorArg check should start handle memory format + TORCH_CHECK(input->is_contiguous(input->suggest_memory_format())); + checkDimRange(c, input, 2, 6 /* exclusive */); auto num_features = input->size(1); for (auto t : {weight, bias, running_mean, running_var}) { @@ -86,6 +89,12 @@ std::tuple cudnn_batch_norm( cudnnBatchNormMode_t mode; if (input->dim() == 2) { mode = CUDNN_BATCHNORM_PER_ACTIVATION; + } else if (training) { +#if CUDNN_VERSION >= 7400 + mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; +#else + mode = CUDNN_BATCHNORM_SPATIAL; +#endif // CUDNN_VERSION >= 7400 } else { mode = CUDNN_BATCHNORM_SPATIAL; // TODO: The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was @@ -94,7 +103,8 @@ std::tuple cudnn_batch_norm( // video R(2+1)D. We will fall back to the normal CUDNN_BATCHNORM_SPATIAL } - auto output_t = at::empty(input->sizes(), input->options()); + auto output_t = at::empty_like(*input, input->options(), input->suggest_memory_format()); + TensorArg output{ output_t, "output", 0 }; auto handle = getCudnnHandle(); @@ -106,10 +116,68 @@ std::tuple cudnn_batch_norm( Constant zero(dataType, 0); Tensor save_mean, save_var; + Tensor reserve; + if (training) { + int64_t num_features = input_t.size(1); save_mean = at::empty({ num_features }, weight_t.options()); save_var = at::empty({ num_features }, weight_t.options()); + +#if CUDNN_VERSION >= 7400 + auto op = CUDNN_BATCHNORM_OPS_BN; + size_t workspace_size; + AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( + handle, + mode, + op, + idesc.desc(), + idesc.desc(), + idesc.desc(), + wdesc.desc(), + nullptr, + &workspace_size)); + Tensor workspace = at::empty(workspace_size, input->options().dtype(kByte)); + + // get the reserved size and allocate as tensor + size_t reserve_size; + AT_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( + handle, + mode, + op, + nullptr, + idesc.desc(), + &reserve_size)); + reserve = at::empty(reserve_size, input->options().dtype(kByte)); + + AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx( + handle, + mode, + op, + &one, + &zero, + idesc.desc(), + input->data_ptr(), + nullptr, // z descriptor for BN-Add-Relu + nullptr, // z for BN-Add-ReLU + idesc.desc(), + output->data_ptr(), + wdesc.desc(), + weight->data_ptr(), + bias->data_ptr(), + exponential_average_factor, + at::maybe_data_ptr(running_mean), + at::maybe_data_ptr(running_var), + epsilon, + save_mean.data_ptr(), + save_var.data_ptr(), + nullptr, + workspace.data_ptr(), + workspace_size, + reserve.data_ptr(), + reserve_size)); +#else + reserve = at::empty({0}, input->options().dtype(kByte)); AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTraining( handle, mode, &one, &zero, idesc.desc(), input->data_ptr(), @@ -123,7 +191,9 @@ std::tuple cudnn_batch_norm( epsilon, save_mean.data_ptr(), save_var.data_ptr())); +#endif // CUDNN_VERSION >= 7400 } else { + reserve = at::empty({0}, input->options().dtype(kByte)); AT_CUDNN_CHECK(cudnnBatchNormalizationForwardInference( handle, mode, &one, &zero, idesc.desc(), input->data_ptr(), @@ -139,25 +209,29 @@ std::tuple cudnn_batch_norm( // save_mean and save_var can be undefined // If this causes problems, we can initialize them to empty tensors // of the correct type - return std::tuple{output_t, save_mean, save_var}; + return std::tuple{output_t, save_mean, save_var, reserve}; } // NB: CuDNN only implements the backward algorithm for batchnorm // in training mode (evaluation mode batchnorm has a different algorithm), // which is why this doesn't accept a 'training' parameter. std::tuple cudnn_batch_norm_backward( - const Tensor& input_t, const Tensor& grad_output_t, const Tensor& weight_t, + const Tensor& input_t, const Tensor& grad_output_t, + const Tensor& weight_t, // Unused: but we require them to be passed so that double backwards // has access const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean_t, const Tensor& save_var_t, - double epsilon) + double epsilon, const Tensor& reserveSpace) { + // TODO: Is it worth it to have a contiguous call or maybe we should go with + // whatever format is given here. TensorArg input{ input_t, "input", 1 }, - grad_output{ grad_output_t, "grad_output", 2 }, + grad_output{ grad_output_t.contiguous(input_t.suggest_memory_format()), "grad_output", 2 }, weight{ weight_t, "weight", 3 }, save_mean{ save_mean_t, "save_mean", 4 }, - save_var{ save_var_t, "save_var", 5 }; + save_var{ save_var_t, "save_var", 5 }, + reserve{ reserveSpace, "reserve_space", 6 }; CheckedFrom c = "cudnn_batch_norm_backward"; setCuDNNStreamToCurrent(); @@ -171,7 +245,10 @@ std::tuple cudnn_batch_norm_backward( checkAllSameType(c, {input, grad_output}); checkAllSameType(c, {weight, save_mean, save_var}); // TODO: is weight required to be contiguous? - checkAllContiguous(c, {input, grad_output, save_mean, save_var}); + checkAllContiguous(c, {save_mean, save_var}); + // TODO: TensorArg check should start handle memory format + TORCH_CHECK(input->is_contiguous(input->suggest_memory_format())); + TORCH_CHECK(grad_output->is_contiguous(grad_output->suggest_memory_format())); checkDimRange(c, input, 2, 6 /* exclusive */); checkSameSize(c, input, grad_output); auto num_features = input->size(1); @@ -183,30 +260,73 @@ std::tuple cudnn_batch_norm_backward( if (input->dim() == 2) { mode = CUDNN_BATCHNORM_PER_ACTIVATION; } else { +#if CUDNN_VERSION >= 7400 + mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; +#else + mode = CUDNN_BATCHNORM_SPATIAL; +#endif // CUDNN_VERSION >= 7400 // TODO: The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was // introduced in CuDNN 7 for performance optimization, but it results in // accuracy losses in convolution models such as ResNeXt-101 and // video R(2+1)D. We will fall back to the normal CUDNN_BATCHNORM_SPATIAL - mode = CUDNN_BATCHNORM_SPATIAL; } - auto grad_input_t = at::empty(input->sizes(), input->options()); + auto grad_input_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format()); auto grad_weight_t = at::empty(weight->sizes(), weight->options()); auto grad_bias_t = at::empty(weight->sizes(), weight->options()); auto handle = getCudnnHandle(); auto dataType = getCudnnDataType(*input); - TensorDescriptor idesc{ *input, 4 }; // input, output, grad_output descriptor - TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, bias, save_mean, etc. + TensorDescriptor idesc{ *input, 4 }; // input, grad_output descriptor + TensorDescriptor odesc{ *grad_output, 4 }; // input, grad_output descriptor + TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, save_mean, etc. Constant one(dataType, 1); Constant zero(dataType, 0); +#if CUDNN_VERSION >= 7400 + auto op = CUDNN_BATCHNORM_OPS_BN; + + size_t workspace_size; + AT_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize( + handle, + mode, + op, + idesc.desc(), + idesc.desc(), + idesc.desc(), + nullptr, + odesc.desc(), + wdesc.desc(), + nullptr, + &workspace_size)); + Tensor workspace = at::empty(workspace_size, input->options().dtype(kByte)); + + AT_CUDNN_CHECK(cudnnBatchNormalizationBackwardEx( + handle, mode, op, &one, &zero, &one, &zero, + idesc.desc(), input->data_ptr(), + nullptr, nullptr, + odesc.desc(), grad_output->data_ptr(), + nullptr, nullptr, + idesc.desc(), grad_input_t.data_ptr(), + wdesc.desc(), weight->data_ptr(), + nullptr, + grad_weight_t.data_ptr(), + grad_bias_t.data_ptr(), + epsilon, + save_mean->data_ptr(), + save_var->data_ptr(), + nullptr, + workspace.data_ptr(), + workspace_size, + reserve->data_ptr(), + reserve->numel())); +#else AT_CUDNN_CHECK(cudnnBatchNormalizationBackward( handle, mode, &one, &zero, &one, &zero, idesc.desc(), input->data_ptr(), - idesc.desc(), grad_output->data_ptr(), + odesc.desc(), grad_output->data_ptr(), idesc.desc(), grad_input_t.data_ptr(), wdesc.desc(), weight->data_ptr(), grad_weight_t.data_ptr(), @@ -214,6 +334,7 @@ std::tuple cudnn_batch_norm_backward( epsilon, save_mean->data_ptr(), save_var->data_ptr())); +#endif // CUDNN_VERSION >= 7400 return std::tuple{grad_input_t, grad_weight_t, grad_bias_t}; } diff --git a/aten/src/ATen/native/cudnn/Conv.cpp b/aten/src/ATen/native/cudnn/Conv.cpp index 18997f2a6e8f3..0ad0ddf4570e2 100644 --- a/aten/src/ATen/native/cudnn/Conv.cpp +++ b/aten/src/ATen/native/cudnn/Conv.cpp @@ -906,7 +906,8 @@ Tensor cudnn_convolution_forward( auto output_t = at::empty( conv_output_size(input->sizes(), weight->sizes(), padding, stride, dilation, groups), - input->options()); + input->options(), + input->suggest_memory_format()); if (output_t.numel() == 0) { return output_t; @@ -917,7 +918,7 @@ Tensor cudnn_convolution_forward( convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups); // See #4500 - Tensor weight_contig = weight->contiguous(); + Tensor weight_contig = weight->contiguous(input->suggest_memory_format()); raw_cudnn_convolution_forward_out( *output, *input, weight_contig, @@ -964,7 +965,7 @@ std::tuple cudnn_convolution_transpose_backwar IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, std::array output_mask) { - Tensor grad_output = grad_output_t.contiguous(); + Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); Tensor grad_input, grad_weight, grad_bias; if (output_mask[0]) { @@ -1043,14 +1044,14 @@ Tensor cudnn_convolution_backward_input( checkAllSameType(c, {grad_output, weight}); checkAllSameGPU(c, {grad_output, weight}); - auto grad_input_t = at::empty(input_size, grad_output->options()); + auto grad_input_t = at::empty(input_size, grad_output->options(), grad_output->suggest_memory_format()); // Avoid "grad_input" when this is being used as transposed convolution TensorArg grad_input{ grad_input_t, "result", 0 }; convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups); // See #4500 - Tensor weight_contig = weight->contiguous(); + Tensor weight_contig = weight->contiguous(grad_output->suggest_memory_format()); raw_cudnn_convolution_backward_input_out( *grad_input, *grad_output, weight_contig, @@ -1090,7 +1091,7 @@ std::tuple cudnn_convolution_backward( IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, std::array output_mask) { - Tensor grad_output = grad_output_t.contiguous(); + Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); Tensor grad_input, grad_weight, grad_bias; if (input.numel() == 0) { @@ -1185,7 +1186,7 @@ Tensor cudnn_convolution_backward_weight( checkAllSameType(c, {grad_output, input}); checkAllSameGPU(c, {grad_output, input}); - auto grad_weight_t = at::empty(weight_size, grad_output->options()); + auto grad_weight_t = at::empty(weight_size, grad_output->options(), grad_output->suggest_memory_format()); // For uniformity with everything else, although it seems grad_weight // would be unambiguous too. diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 581f775460093..aef1193ea2a0f 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -470,9 +470,9 @@ - func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor -- func: _batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, int) +- func: _batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int) -- func: _batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) +- func: _batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor) # Sample bernoulli with values in `self` as probability. - func: bernoulli(Tensor self, *, Generator? generator=None) -> Tensor @@ -771,12 +771,12 @@ dispatch: CUDA: cudnn_affine_grid_generator_backward -- func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) +- func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) dispatch: CUDA: cudnn_batch_norm # NB: You can only use this if you used cudnn_batch_norm training=True -- func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) +- func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) dispatch: CUDA: cudnn_batch_norm_backward diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 02f2b16370a76..3e6362a8bb15a 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -196,7 +196,7 @@ class CAFFE2_API Tensor { } at::MemoryFormat suggest_memory_format() const { - if (impl_->is_strides_like_channels_last()) { + if (!is_mkldnn() && !is_sparse() && !impl_->is_contiguous() && impl_->is_strides_like_channels_last()) { return at::MemoryFormat::ChannelsLast; } return at::MemoryFormat::Contiguous; diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index a3742c316895b..aa878caff08ed 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -27,6 +27,10 @@ ('ones_like', datetime.date(2019, 11, 11)), ('full_like', datetime.date(2019, 11, 11)), ('AutogradAnyNonZero', datetime.date(2019, 11, 11)), + ('_batch_norm_impl_index', datetime.date(2019, 11, 15)), + ('_batch_norm_impl_index_backward', datetime.date(2019, 11, 15)), + ('cudnn_batch_norm', datetime.date(2019, 11, 15)), + ('cudnn_batch_norm_backward', datetime.date(2019, 11, 15)), ] diff --git a/test/test_nn.py b/test/test_nn.py index 52787f1fd5c9b..e0c152f9d29b3 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -5299,6 +5299,36 @@ def func(root): gradcheck(func, [v]) gradgradcheck(func, [v]) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @unittest.skipIf(not TEST_CUDNN, "needs cudnn") + @skipIfRocm + def test_batchnorm_cudnn_nhwc(self): + input = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda", requires_grad=True) + input = input.contiguous(memory_format=torch.channels_last) + input.retain_grad() + grad = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda") + grad = grad.contiguous(memory_format=torch.channels_last) + bn = nn.BatchNorm2d(8).cuda().float() + bn.weight.data.uniform_() + bn.bias.data.uniform_() + + ref_input = input.detach().clone().contiguous().requires_grad_(True) + ref_grad = grad.detach().clone().contiguous() + ref_bn = nn.BatchNorm2d(8).cuda().float() + ref_bn.load_state_dict(bn.state_dict()) + + out = bn(input) + out.backward(grad) + ref_out = ref_bn(ref_input) + ref_out.backward(ref_grad) + + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_out.is_contiguous()) + self.assertEqual(out, ref_out) + self.assertEqual(bn.weight.grad, ref_bn.weight.grad) + self.assertEqual(bn.bias.grad, ref_bn.bias.grad) + self.assertEqual(input.grad, ref_input.grad) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_batchnorm_cudnn_half(self): # THNN @@ -6612,6 +6642,63 @@ def func(*inputs): return gradgradcheck(func, inputs, (grad_y,)) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @unittest.skipIf(not TEST_CUDNN, "needs cudnn") + @skipIfRocm + def test_conv_cudnn_nhwc(self): + input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True) + input = input.contiguous(memory_format=torch.channels_last) + input.retain_grad() + grad = torch.rand(2, 4, 2, 2, dtype=torch.float32, device="cuda") + grad = grad.contiguous(memory_format=torch.channels_last) + conv = nn.Conv2d(8, 4, 3).cuda().float() + conv.weight.data = conv.weight.contiguous(memory_format=torch.channels_last) + + ref_input = input.detach().clone().contiguous().requires_grad_(True) + ref_grad = grad.detach().clone().contiguous() + ref_conv = nn.Conv2d(8, 4, 3).cuda().float() + # load_state_dict will restore the stride & memory_layout on ref_conv.weight. + ref_conv.load_state_dict(conv.state_dict()) + + out = conv(input) + out.backward(grad) + ref_out = ref_conv(ref_input) + ref_out.backward(ref_grad) + + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_out.is_contiguous()) + self.assertEqual(out, ref_out) + self.assertEqual(conv.weight.grad, ref_conv.weight.grad) + self.assertEqual(conv.bias.grad, ref_conv.bias.grad) + self.assertEqual(input.grad, ref_input.grad) + + @unittest.expectedFailure + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @unittest.skipIf(not TEST_CUDNN, "needs cudnn") + @skipIfRocm + def test_conv_cudnn_memory_layout_dominance(self): + # desired behavior here is to have the memory_layout of conv.weight to + # dominante the layout of output. + # which is not the same as current behavior, we'll fix this in + # following up PRs and remove the `expectedFailure` tag + input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True) + conv = nn.Conv2d(8, 4, 3).cuda().float() + + out = conv(input) + self.assertTrue(out.is_contiguous()) + + input = input.contiguous(memory_format=torch.channels_last) + out = conv(input) + self.assertTrue(out.is_contiguous()) + + conv.weight.data = conv.weight.contiguous(memory_format=torch.channels_last) + out = conv(input) + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + + input = input.contiguous() + out = conv(input) + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + def test_conv_double_backward(self): batch_size = 2 for kern, inp_size, dilations in [(3, 6, [1, 2]), (3, 7, [1]), (4, 9, [1])]: diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 35be26f0c747b..8f9291c101ae2 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1445,15 +1445,16 @@ # because it should be merged into the previous convolution (left for future # work.) # NB2: The quotes around the gradient are needed to appease YAML parsing rules. -- name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) - input, weight, bias: "training ? cudnn_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)" +- name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) + input, weight, bias: "training ? cudnn_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon, retain_variables ? result3.clone() : result3) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)" # HACK: save_mean and save_var are going to be passed in as # requires_grad variables (even though we'll never backprop through # them) so we need to prevent the unpacking from triggering an error. -- name: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) +- name: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) save_mean: not_implemented("cudnn_batch_norm_backward save_mean") save_var: not_implemented("cudnn_batch_norm_backward save_var") + reserveSpace: not_implemented("cudnn_batch_norm_backward reserveSpace") input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask) # nnpack diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 2fb93baf1d1ee..5c677d018eee8 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -178,6 +178,8 @@ 'std::tuple', 'std::tuple', 'std::tuple', + 'std::tuple', + 'std::tuple', 'std::vector', 'Scalar', 'bool', 'int64_t', 'void*', 'void', 'QScheme', 'double', diff --git a/torch/csrc/autograd/utils/wrap_outputs.h b/torch/csrc/autograd/utils/wrap_outputs.h index c0edf425a1248..b70646c776805 100644 --- a/torch/csrc/autograd/utils/wrap_outputs.h +++ b/torch/csrc/autograd/utils/wrap_outputs.h @@ -125,6 +125,28 @@ inline PyObject* wrap(std::tuple tensors return r.release(); } +inline PyObject* wrap(std::tuple tensors) { + auto r = THPObjectPtr{PyTuple_New(5)}; + if (!r) throw python_error(); + PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors)))); + PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors)))); + PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors)))); + PyTuple_SET_ITEM(r.get(), 3, wrap(std::move(std::get<3>(tensors)))); + PyTuple_SET_ITEM(r.get(), 4, wrap(std::get<4>(tensors))); + return r.release(); +} + +inline PyObject* wrap(std::tuple tensors) { + auto r = THPObjectPtr{PyTuple_New(5)}; + if (!r) throw python_error(); + PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors)))); + PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors)))); + PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors)))); + PyTuple_SET_ITEM(r.get(), 3, wrap(std::move(std::get<3>(tensors)))); + PyTuple_SET_ITEM(r.get(), 4, wrap(std::move(std::get<4>(tensors)))); + return r.release(); +} + inline PyObject* wrap(std::tuple tensors) { auto r = THPObjectPtr{PyTuple_New(4)}; if (!r) throw python_error(); diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index 9e05f2462a658..0f9db73a54831 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -1078,7 +1078,7 @@ const std::vector functions = { eps : float, cudnn_enabled : bool): - output, save1, save2, impl_idx = torch._batch_norm_impl_index( + output, save1, save2, reserve, impl_idx = torch._batch_norm_impl_index( input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled) has_weight = weight is not None @@ -1087,7 +1087,7 @@ const std::vector functions = { def backward(grad_output): dinput, dweight, dbias = torch._batch_norm_impl_index_backward( impl_idx, input, grad_output, weight, running_mean, running_var, - save1, save2, training, eps, [True, has_weight, has_bias]) + save1, save2, training, eps, [True, has_weight, has_bias], reserve) return dinput, dweight, dbias, None, None, None, None, None, None return output, backward @@ -1108,7 +1108,7 @@ const std::vector functions = { input_reshape = input.contiguous().view(1, n, -1) - bn_out, save1, save2, impl_idx = torch._batch_norm_impl_index( + bn_out, save1, save2, reserve, impl_idx = torch._batch_norm_impl_index( input_reshape, None, None, None, None, True, 0.0, eps, cudnn_enable) @@ -1145,7 +1145,7 @@ const std::vector functions = { grad_input, _, _ = torch._batch_norm_impl_index_backward( impl_idx, input_reshape, grad_bn_out, None, None, None, - save1, save2, True, eps, [True, False, False]) + save1, save2, True, eps, [True, False, False], reserve) grad_input = grad_input.view(input.size()) return grad_input, None, grad_weight, grad_bias, None, None