In [19]:
import numpy as np

def np_maxpool2d(x, pool_size=2, stride=2):
    N, C, H, W = x.shape
    HO = int((H - pool_size) / stride + 1)
    WO = int((W - pool_size) / stride + 1)
    out = np.zeros((N, C, HO, WO))

    for n in range(N):
        for c in range(C):
            for i in range(HO):
                for j in range(WO):
                    start_i = i * stride
                    start_j = j * stride
                    patch = x[n, c, start_i:start_i+pool_size, start_j:start_j+pool_size]
                    out[n, c, i, j] = np.max(patch)

    return out

def make_maxpool2d(shapeX, pool_size, stride, tgt, tgt_host, func_name, dtype="float32"):
    N, C, H, W = shapeX
    
    input_mat = tvm.te.placeholder(shapeX, dtype=dtype, name='input_mat')
    
    di = tvm.te.reduce_axis((0, pool_size), name='di')
    dj = tvm.te.reduce_axis((0, pool_size), name='dj')
    oh = H // stride
    ow = W // stride
    output_mat = tvm.te.compute((N, C, oh, ow),
                                lambda n, c, h, w: tvm.te.max(
                                    input_mat[n, c, h*stride+di, w*stride+dj],
                                    axis=[di, dj]
                                ))
    
    s = tvm.te.create_schedule(output_mat.op)
    f = tvm.build(s, [input_mat, output_mat], tgt, target_host=tgt_host, name=func_name)
    
    return f

def np_maxpool2d_backward(grad_output, x, pool_size=2, stride=2):
    N, C, H, W = x.shape
    _, _, HO, WO = grad_output.shape
    grad_input = np.zeros_like(x)

    for n in range(N):
        for c in range(C):
            for i in range(HO):
                for j in range(WO):
                    start_i = i * stride
                    start_j = j * stride
                    patch = x[n, c, start_i:start_i+pool_size, start_j:start_j+pool_size]
                    max_index = np.unravel_index(np.argmax(patch), patch.shape)
                    grad_input[n, c, start_i+max_index[0], start_j+max_index[1]] += grad_output[n, c, i, j]

    return grad_input

def make_maxpool2d_grad(shapeX, pool_size, stride, tgt, tgt_host, func_name, dtype="float32"):
    N, C, H, W = shapeX
    
    input_mat = tvm.te.placeholder(shapeX, dtype=dtype, name='input_mat')
    output_grad_mat = tvm.te.placeholder((N, C, H//stride, W//stride), dtype=dtype, name='output_grad_mat')
    
    di = tvm.te.reduce_axis((0, pool_size), name='di')
    dj = tvm.te.reduce_axis((0, pool_size), name='dj')
    
    input_mat_grad = tvm.te.compute(shapeX,
                                    lambda n, c, h, w: tvm.te.sum(
                                        tvm.te.if_then_else(
                                            tvm.te.all(h*stride+di < H, w*stride+dj < W,
                                                       tvm.te.max(input_mat[n, c, h*stride+di, w*stride+dj],
                                                                  axis=[di, dj]) == input_mat[n, c, h*stride+di, w*stride+dj]),
                                            output_grad_mat[n, c, h, w],
                                            0
                                        ),
                                        axis=[di, dj]
                                    ), name='input_mat_grad')
    
    s = tvm.te.create_schedule(input_mat_grad.op)
    s[output_grad_mat].compute_at(s[input_mat_grad], s[input_mat_grad].op.axis[2])
    s[input_mat_grad].parallel(input_mat_grad.op.axis[0])
    
    f = tvm.build(s, [input_mat, output_grad_mat, input_mat_grad], tgt, target_host=tgt_host, name=func_name)
    
    return f



In [20]:
x = np.random.randn(2, 3, 4, 4)
dout = np.random.randn(2, 3, 2, 2)
pool_size = 2
stride = 2

In [21]:
out = np_maxpool2d(x, pool_size, stride)
dx = np_maxpool2d_backprop(x, dout, pool_size, stride)

In [22]:
x[0][0]

array([[ 1.11889023,  0.62000086,  0.96244895, -0.63596303],
       [-1.12377302, -0.03996   ,  0.2699869 , -0.05489106],
       [-0.39580806,  0.73798365,  0.16091689,  1.04872109],
       [-0.78010466, -1.05803005, -0.2180376 , -0.22106726]])

In [23]:
out[0][0]

array([[1.11889023, 0.96244895],
       [0.73798365, 1.04872109]])

In [26]:
dout[0][0]

array([[-1.08377042,  1.93676884],
       [ 1.05419877, -1.22055425]])

In [24]:
dx[0][0]

array([[-1.08377042,  0.        ,  1.93676884,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  1.05419877,  0.        , -1.22055425],
       [ 0.        ,  0.        ,  0.        ,  0.        ]])