In [1]:
import numpy as np

In [260]:
def separable_conv2d(C, H, W, R_H, R_W,F, input_t, dweights, pweights):
    out_t = np.empty([F, H, W])
    
    # Padding
    pad_along_height = max((H - 1) + R_H - H, 0)
    pad_along_width = max((W - 1) + R_W - W, 0)
    pad_top = pad_along_height // 2
    pad_bottom = pad_along_height - pad_top
    pad_left = pad_along_width // 2
    pad_right = pad_along_width - pad_left
    
    input_t_pad = np.insert(input_t, 0, values=np.zeros([pad_top, C, W]), axis=1)
    input_t_pad = np.insert(input_t_pad, H + pad_top, values=np.zeros([pad_bottom, C, W]), axis=1)

    input_t_pad = np.insert(input_t_pad, 0, values=np.zeros([pad_left, C, H + pad_along_height]), axis=2)
    input_t_pad = np.insert(input_t_pad, W + pad_left, values=np.zeros([pad_right, C, H + pad_along_height]), axis=2)
    
    # The depth conv
    out_put_1 = conv(C, H, W, R_H, R_W, input_t_pad, dweights)
    
    # The point conv
    for f in range(F):
        pweight_matrix = pweights[f].reshape(C, 1).repeat(H * W, axis = 1).reshape([C, H, W])
        out_t[f] = np.sum(out_put_1 * pweight_matrix, axis = 0)
    
    return out_t

In [258]:
def conv(C, H, W, R_H, R_W, input_t_pad, dweights):
    res = np.empty([C, H, W])
    
    # Unroll the input and dweights into new matrix, which is convenient for matrix mutiply.
    # See details here: https://hal.inria.fr/file/index/docid/112631/filename/p1038112283956.pdf
    input_t_pad_matrix = np.empty([H * W, R_H * R_W])
    dweights_matrix = np.empty([R_H * R_W, 1])
    for l in range(C):
        for i in range(H):
            for j in range(W):
                input_t_pad_matrix[i * H + j] = np.ndarray.flatten(input_t_pad[l, i:(i + R_H), j:(j + R_W)])
        res[l] = input_t_pad_matrix.dot(np.ndarray.flatten(dweights[l,:,:]).reshape(R_H * R_W, 1)).reshape([H, W])
    # output for depth conv layer.
    return res

In [261]:
separable_conv2d(3, 5, 5, 2, 2,8, np.ones([3,5,5]), np.ones([3, 2, 2]), np.ones([8, 3]))

[[[ 4.  4.  4.  4.  2.]
  [ 4.  4.  4.  4.  2.]
  [ 4.  4.  4.  4.  2.]
  [ 4.  4.  4.  4.  2.]
  [ 2.  2.  2.  2.  1.]]

 [[ 4.  4.  4.  4.  2.]
  [ 4.  4.  4.  4.  2.]
  [ 4.  4.  4.  4.  2.]
  [ 4.  4.  4.  4.  2.]
  [ 2.  2.  2.  2.  1.]]

 [[ 4.  4.  4.  4.  2.]
  [ 4.  4.  4.  4.  2.]
  [ 4.  4.  4.  4.  2.]
  [ 4.  4.  4.  4.  2.]
  [ 2.  2.  2.  2.  1.]]]


array([[[ 12.,  12.,  12.,  12.,   6.],
        [ 12.,  12.,  12.,  12.,   6.],
        [ 12.,  12.,  12.,  12.,   6.],
        [ 12.,  12.,  12.,  12.,   6.],
        [  6.,   6.,   6.,   6.,   3.]],

       [[ 12.,  12.,  12.,  12.,   6.],
        [ 12.,  12.,  12.,  12.,   6.],
        [ 12.,  12.,  12.,  12.,   6.],
        [ 12.,  12.,  12.,  12.,   6.],
        [  6.,   6.,   6.,   6.,   3.]],

       [[ 12.,  12.,  12.,  12.,   6.],
        [ 12.,  12.,  12.,  12.,   6.],
        [ 12.,  12.,  12.,  12.,   6.],
        [ 12.,  12.,  12.,  12.,   6.],
        [  6.,   6.,   6.,   6.,   3.]],

       [[ 12.,  12.,  12.,  12.,   6.],
        [ 12.,  12.,  12.,  12.,   6.],
        [ 12.,  12.,  12.,  12.,   6.],
        [ 12.,  12.,  12.,  12.,   6.],
        [  6.,   6.,   6.,   6.,   3.]],

       [[ 12.,  12.,  12.,  12.,   6.],
        [ 12.,  12.,  12.,  12.,   6.],
        [ 12.,  12.,  12.,  12.,   6.],
        [ 12.,  12.,  12.,  12.,   6.],
        [  6.,   6.,   6.,   6.,

In [245]:
input_t = np.array([[[1,2,0],
         [ 1,1,3],
         [ 0,2,2]],
        [[ 0,2,1],
         [0,3,2],
         [1,1,0]],
 
        [[ 1,2,1],
         [ 0,1,3],
         [ 3,3,2]]])

In [246]:
dweights  = np.array([[[ 1,  1],
        [ 2,  2]],

       [[ 1,  1],
        [ 1,  1]],

       [[ 0,  1],
        [ 1,  0]]])

In [256]:
conv(3, 2, 2, 2, 2, input_t_pad, dweights)

0
1
2


array([[[  7.,  10.],
        [  6.,  12.]],

       [[  5.,   8.],
        [  5.,   6.]],

       [[  2.,   2.],
        [  4.,   6.]]])