Skip to content

Commit

Permalink
add zhangshen op-test (#5600)
Browse files Browse the repository at this point in the history
* add some op-test

* fix dims_error in my branch

* Fix the bad backward kernel function by using 'cuda::atomic::Add' (#5614)

* Test `nn.AdaptiveAvgPoolXd` (#5615)

* Fix the bad backward kernel function by using 'cuda::atomic::Add'

* Support the 'NoneType' annotation

* Support objects of 'collections.abc.Iterable' as 'output_size'

* Test with all cases of 'output_size'

* Update adaptive_pool_gpu_kernel.cu

* Skip testing `nn.AdaptiveAvgPool3d` for the current PyTorch

* remove some useless test

* Format TODO

* Add the assertion messages for 'output_size'

* Reformat codes

* Remove raw tests for `flow.negative`

* Remove unnecessary codes and add the assertion messages

* Merge updates for 'generators.py' from master

* Remove unnecessary 'random()'

* Delete the separate test for `AvgPool2d`

* Fix import paths

* Fix import problems

* Remove the PyTorch import

* Denote the annotations for `tile` and `repeat` ops

* Add the test for `nn.AvgPool1d`

* Choose better generators for `nn.MaxPoolXd`

* Randomly choose `dilation` and default values

* auto format by CI

* Test more kwargs for `nn.AvgPoolXd`

* Add tests for `return_indices`

* auto format by CI

Co-authored-by: Tianyu Zhao <guikarist@gmail.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
4 people committed Aug 19, 2021
1 parent 4a4b382 commit a99929d
Show file tree
Hide file tree
Showing 12 changed files with 173 additions and 1,629 deletions.
12 changes: 4 additions & 8 deletions oneflow/user/kernels/adaptive_pool_cpu_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void AvgForwardCompute(user_op::KernelComputeContext* ctx, const int32_t& dim) {
const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape();
const Shape& y_shape = ctx->TensorDesc4ArgNameAndIndex("y", 0)->shape();

// TODO: Support 'channels_last'
// TODO (Tianyu): Support 'channels_last'
std::string data_format = "channels_first";
const Shape& in = GetShape5D(x_shape, data_format, dim);
const Shape& out = GetShape5D(y_shape, data_format, dim);
Expand Down Expand Up @@ -100,7 +100,7 @@ void AvgBackwardCompute(user_op::KernelComputeContext* ctx, const int32_t& dim)
const Shape& dx_shape = ctx->TensorDesc4ArgNameAndIndex("dx", 0)->shape();
const Shape& dy_shape = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->shape();

// TODO: Support 'channels_last'
// TODO (Tianyu): Support 'channels_last'
std::string data_format = "channels_first";
const Shape& in = GetShape5D(dx_shape, data_format, dim);
const Shape& out = GetShape5D(dy_shape, data_format, dim);
Expand Down Expand Up @@ -234,9 +234,7 @@ class AdaptivePool3DCpuGradKernel final : public user_op::OpKernel {
#define REGISTER_ADAPTIVE_POOL_KERNEL_WITH_DEVICE(device) \
REGISTER_ADAPTIVE_POOL_KERNEL(device, float) \
REGISTER_ADAPTIVE_POOL_KERNEL(device, double) \
REGISTER_ADAPTIVE_POOL_KERNEL(device, int8_t) \
REGISTER_ADAPTIVE_POOL_KERNEL(device, int32_t) \
REGISTER_ADAPTIVE_POOL_KERNEL(device, int64_t)
REGISTER_ADAPTIVE_POOL_KERNEL(device, int)

REGISTER_ADAPTIVE_POOL_KERNEL_WITH_DEVICE(DeviceType::kCPU)

Expand All @@ -257,9 +255,7 @@ REGISTER_ADAPTIVE_POOL_KERNEL_WITH_DEVICE(DeviceType::kCPU)
#define REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL_WITH_DEVICE(device) \
REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, float) \
REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, double) \
REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, int8_t) \
REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, int32_t) \
REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, int64_t)
REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL(device, int)

REGISTER_ADAPTIVE_POOL_BACKWARD_KERNEL_WITH_DEVICE(DeviceType::kCPU)
} // namespace oneflow
20 changes: 11 additions & 9 deletions oneflow/user/kernels/adaptive_pool_gpu_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include "oneflow/core/kernel/kernel_util.cuh"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/kernel/util/cuda_half_util.h"
#include "oneflow/core/cuda/atomic.cuh"
#include "oneflow/core/operator/operator_util.h"
#include "oneflow/user/utils/pool_util.h"

Expand Down Expand Up @@ -59,6 +60,7 @@ __global__ void AdaptiveAvgPoolCudaKernel(const T* input, T* output, int num_ele
const int in_panel_size = in_d * in_h * in_w;

CUDA_1D_KERNEL_LOOP(idx, num_elems) {
// TODO (Tianyu): Replace following codes with 'NdIndexOffsetHelper'
int bc_idx = idx / out_panel_size;
int out_d_idx = (idx % out_panel_size) / out_w / out_h;
int out_h_idx = (idx % out_panel_size) % (out_h * out_w) / out_w;
Expand Down Expand Up @@ -100,6 +102,7 @@ __global__ void AdaptiveAvgPoolGradCudaKernel(T* input, const T* output, int num
const int in_panel_size = in_d * in_h * in_w;

CUDA_1D_KERNEL_LOOP(idx, num_elems) {
// TODO (Tianyu): Replace following codes with 'NdIndexOffsetHelper'
int bc_idx = idx / out_panel_size;
int out_d_idx = (idx % out_panel_size) / out_w / out_h;
int out_h_idx = (idx % out_panel_size) % (out_h * out_w) / out_w;
Expand All @@ -122,7 +125,10 @@ __global__ void AdaptiveAvgPoolGradCudaKernel(T* input, const T* output, int num
input + bc_idx * in_panel_size + in_start_d * in_h * in_w + in_start_h * in_w + in_start_w;
for (int id = 0; id < k_d; ++id) {
for (int ih = 0; ih < k_h; ++ih) {
for (int iw = 0; iw < k_w; ++iw) { *(input_ptr + ih * in_w + iw) += grad_delta; }
for (int iw = 0; iw < k_w; ++iw) {
// TODO (Tianyu): Use 'atmoic::Add' when necessary
cuda::atomic::Add(input_ptr + ih * in_w + iw, grad_delta);
}
}
input_ptr += in_h * in_w; // next input depth
}
Expand All @@ -139,7 +145,7 @@ void AvgForwardCompute(KernelComputeContext* ctx, const int32_t& dim) {
const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape();
const Shape& y_shape = ctx->TensorDesc4ArgNameAndIndex("y", 0)->shape();

// TODO: Support 'channels_last'
// TODO (Tianyu): Support 'channels_last'
std::string data_format = "channels_first";
const Shape& in = GetShape5D(x_shape, data_format, dim);
const Shape& out = GetShape5D(y_shape, data_format, dim);
Expand All @@ -160,7 +166,7 @@ void AvgBackwardCompute(KernelComputeContext* ctx, const int32_t& dim) {
const Shape& dx_shape = ctx->TensorDesc4ArgNameAndIndex("dx", 0)->shape();
const Shape& dy_shape = ctx->TensorDesc4ArgNameAndIndex("dy", 0)->shape();

// TODO: Support 'channels_last'
// TODO (Tianyu): Support 'channels_last'
std::string data_format = "channels_first";
const Shape& in = GetShape5D(dx_shape, data_format, dim);
const Shape& out = GetShape5D(dy_shape, data_format, dim);
Expand Down Expand Up @@ -258,9 +264,7 @@ class GpuAdaptiveAvgPool3dGradKernel final : public OpKernel {

REGISTER_GPU_ADAPTIVE_AVGPOOL_KERNEL(DeviceType::kGPU, float);
REGISTER_GPU_ADAPTIVE_AVGPOOL_KERNEL(DeviceType::kGPU, double);
REGISTER_GPU_ADAPTIVE_AVGPOOL_KERNEL(DeviceType::kGPU, int8_t);
REGISTER_GPU_ADAPTIVE_AVGPOOL_KERNEL(DeviceType::kGPU, int32_t);
REGISTER_GPU_ADAPTIVE_AVGPOOL_KERNEL(DeviceType::kGPU, int64_t);
REGISTER_GPU_ADAPTIVE_AVGPOOL_KERNEL(DeviceType::kGPU, int);

#define REGISTER_GPU_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(device, dtype) \
REGISTER_USER_KERNEL("adaptive_avg_pool1d_grad") \
Expand All @@ -278,9 +282,7 @@ REGISTER_GPU_ADAPTIVE_AVGPOOL_KERNEL(DeviceType::kGPU, int64_t);

REGISTER_GPU_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(DeviceType::kGPU, float);
REGISTER_GPU_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(DeviceType::kGPU, double);
REGISTER_GPU_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(DeviceType::kGPU, int8_t);
REGISTER_GPU_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(DeviceType::kGPU, int32_t);
REGISTER_GPU_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(DeviceType::kGPU, int64_t);
REGISTER_GPU_ADAPTIVE_AVGPOOL_BACKWARD_KERNEL(DeviceType::kGPU, int);

} // namespace user_op

Expand Down
70 changes: 36 additions & 34 deletions python/oneflow/nn/modules/adaptive_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,23 @@
"""
import oneflow as flow
from oneflow.nn.module import Module
from oneflow.nn.common_types import _size_1_t
from oneflow.nn.modules.utils import _single, _pair, _triple


def _generate_output_size(input_size, output_size):
new_output_size = []
if isinstance(output_size, int):
for _ in range(len(input_size) - 2):
new_output_size.append(output_size)
elif isinstance(output_size, tuple):
assert len(input_size) - 2 == len(
output_size
), f"The length of 'output_size' does not match the input size, {len(input_size) - 2} expected"
for i in range(len(output_size)):
if output_size[i] is None:
new_output_size.append(input_size[i + 2])
else:
assert isinstance(
output_size[i], int
), "numbers in 'output_size' should be integer"
new_output_size.append(output_size[i])
else:
raise ValueError("invalid 'output_size', 'int' or 'tuple' expected")
assert len(input_size) - 2 == len(
output_size
), f"the length of 'output_size' does not match the input size, {len(input_size) - 2} expected"
for i in range(len(output_size)):
if output_size[i] is None:
new_output_size.append(input_size[i + 2])
else:
assert isinstance(
output_size[i], int
), "numbers in 'output_size' should be integer"
new_output_size.append(output_size[i])
return tuple(new_output_size)


Expand All @@ -55,7 +51,7 @@ class AdaptiveAvgPool1d(Module):
>>> import numpy as np
>>> import oneflow as flow
>>> import oneflow.nn as nn
>>> m = nn.AdaptiveAvgPool1d(5)
>>> input = flow.Tensor(np.random.randn(1, 64, 8))
>>> output = m(input)
Expand All @@ -64,19 +60,19 @@ class AdaptiveAvgPool1d(Module):
"""

def __init__(self, output_size) -> None:
def __init__(self, output_size: _size_1_t) -> None:
super().__init__()
self.output_size = output_size
assert output_size is not None, "'output_size' cannot be NoneType"
self.output_size = _single(output_size)

def forward(self, x):
assert len(x.shape) == 3
if isinstance(self.output_size, tuple):
new_output_size = self.output_size[0]
elif isinstance(self.output_size, int):
new_output_size = self.output_size
else:
raise ValueError("'output_size' should be integer or tuple")
return flow.F.adaptive_avg_pool1d(x, output_size=(new_output_size,))
assert (
len(x.shape) == 3 and len(self.output_size) == 1
), "the length of 'output_size' does not match the input size, 1 expected"
assert isinstance(
self.output_size[0], int
), "numbers in 'output_size' should be integer"
return flow.F.adaptive_avg_pool1d(x, output_size=self.output_size)


def adaptive_avg_pool1d(input, output_size):
Expand Down Expand Up @@ -110,7 +106,7 @@ class AdaptiveAvgPool2d(Module):
>>> import numpy as np
>>> import oneflow as flow
>>> import oneflow.nn as nn
>>> m = nn.AdaptiveAvgPool2d((5,7))
>>> input = flow.Tensor(np.random.randn(1, 64, 8, 9))
>>> output = m(input)
Expand All @@ -133,10 +129,13 @@ class AdaptiveAvgPool2d(Module):

def __init__(self, output_size) -> None:
super().__init__()
self.output_size = output_size
assert output_size is not None, "'output_size' cannot be NoneType"
self.output_size = _pair(output_size)

def forward(self, x):
assert len(x.shape) == 4
assert (
len(x.shape) == 4
), f"expected 4-dimensional tensor, but got {len(x.shape)}-dimensional tensor"
new_output_size = _generate_output_size(x.shape, self.output_size)
return flow.F.adaptive_avg_pool2d(x, output_size=new_output_size)

Expand Down Expand Up @@ -172,7 +171,7 @@ class AdaptiveAvgPool3d(Module):
>>> import numpy as np
>>> import oneflow as flow
>>> import oneflow.nn as nn
>>> m = nn.AdaptiveAvgPool3d((5,7,9))
>>> input = flow.Tensor(np.random.randn(1, 64, 8, 9, 10))
>>> output = m(input)
Expand All @@ -195,10 +194,13 @@ class AdaptiveAvgPool3d(Module):

def __init__(self, output_size) -> None:
super().__init__()
self.output_size = output_size
assert output_size is not None, "'output_size' cannot be NoneType"
self.output_size = _triple(output_size)

def forward(self, x):
assert len(x.shape) == 5
assert (
len(x.shape) == 5
), f"expected 5-dimensional tensor, but got {len(x.shape)}-dimensional tensor"
new_output_size = _generate_output_size(x.shape, self.output_size)
return flow.F.adaptive_avg_pool3d(x, output_size=new_output_size)

Expand Down
Loading

0 comments on commit a99929d

Please sign in to comment.