Skip to content

Commit

Permalink
fix topi.nn.global_pool layout="NHWC" (apache#4656)
Browse files Browse the repository at this point in the history
* Update topi.cc

fix topi.nn.global_pool layout="NHWC"

* add topi.nn.global_pool layout=NHWC test
  • Loading branch information
qihaitao authored and alexwong committed Feb 26, 2020
1 parent 6379157 commit f519f14
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
2 changes: 1 addition & 1 deletion topi/src/topi.cc
Expand Up @@ -527,7 +527,7 @@ TVM_REGISTER_GLOBAL("topi.nn.pool_grad")
TVM_REGISTER_GLOBAL("topi.nn.global_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::global_pool(args[0],
static_cast<nn::PoolType>(static_cast<int>(args[1])));
static_cast<nn::PoolType>(static_cast<int>(args[1])), args[2]);
});

TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool")
Expand Down
16 changes: 12 additions & 4 deletions topi/tests/python/test_topi_pooling.py
Expand Up @@ -178,16 +178,20 @@ def test_pool_grad():
verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False, add_relu=True)


def verify_global_pool(n, c, h, w, pool_type):
def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'):

assert layout in ["NCHW", "NHWC"]
A = tvm.placeholder((n, c, h, w), name='A')
B = topi.nn.global_pool(A, pool_type=pool_type)
B = topi.nn.global_pool(A, pool_type=pool_type, layout=layout)
B = topi.nn.relu(B)

a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)

axis = (layout.find('H'), layout.find('W'))
if pool_type == 'avg':
b_np = np.mean(a_np, axis=(2,3), keepdims=True)
b_np = np.mean(a_np, axis=axis, keepdims=True)
elif pool_type =='max':
b_np = np.max(a_np, axis=(2,3), keepdims=True)
b_np = np.max(a_np, axis=axis, keepdims=True)
b_np = np.maximum(b_np, 0.0)

def check_device(device):
Expand All @@ -212,6 +216,10 @@ def test_global_pool():
verify_global_pool(4, 1024, 7, 7, 'avg')
verify_global_pool(1, 1024, 7, 7, 'max')
verify_global_pool(4, 1024, 7, 7, 'max')
verify_global_pool(1, 1024, 7, 7, 'avg', 'NHWC')
verify_global_pool(4, 1024, 7, 7, 'avg', 'NHWC')
verify_global_pool(1, 1024, 7, 7, 'max', 'NHWC')
verify_global_pool(4, 1024, 7, 7, 'max', 'NHWC')

def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="float32"):
def start_index(index, odim, idim):
Expand Down

0 comments on commit f519f14

Please sign in to comment.