In [1]:
import numpy as np

In [2]:
def get_im2col_indices(x_shape, field_height=3, field_width=3, padding=1, stride=1):
    # First figure out what the size of the output should be
    N, C, H, W = x_shape
    assert (H + 2 * padding - field_height) % stride == 0
    assert (W + 2 * padding - field_height) % stride == 0
    out_height = (H + 2 * padding - field_height) / stride + 1
    out_width = (W + 2 * padding - field_width) / stride + 1
    print(out_height)
    print(out_width)
    i0 = np.repeat(np.arange(field_height,dtype='int32'), field_width)
    i0 = np.tile(i0, C)
    i1 = stride * np.repeat(np.arange(out_height,dtype='int32'), out_width)
    j0 = np.tile(np.arange(field_width), field_height * C)
    j1 = stride * np.tile(np.arange(out_width,dtype='int32'), int(out_height))
    i = i0.reshape(-1, 1) + i1.reshape(1, -1)
    j = j0.reshape(-1, 1) + j1.reshape(1, -1)

    k = np.repeat(np.arange(C,dtype='int32'), field_height * field_width).reshape(-1, 1)

    return (k, i, j)

In [3]:
def im2col_indices(x, field_height=3, field_width=3, padding=1, stride=1):
    """ An implementation of im2col based on some fancy indexing """
    # Zero-pad the input
    p = padding
    x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')

    k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding,
                               stride)

    cols = x_padded[:, k, i, j]
    C = x.shape[1]
    cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
    return cols

In [4]:
def col2im_indices(cols, x_shape, field_height=3, field_width=3, padding=1,
                   stride=1):
    """ An implementation of col2im based on fancy indexing and np.add.at """
    N, C, H, W = x_shape
    H_padded, W_padded = H + 2 * padding, W + 2 * padding

    x_padded = np.zeros((N, C, H_padded, W_padded), dtype=cols.dtype)
    k, i, j = get_im2col_indices(x_shape, field_height, field_width, padding,
                               stride)
    cols_reshaped = cols.reshape(C * field_height * field_width, -1, N)
    cols_reshaped = cols_reshaped.transpose(2, 0, 1)
    np.add.at(x_padded, (slice(None), k, i, j), cols_reshaped)
    if padding == 0:
        return x_padded
    return x_padded[:, :, padding:-padding, padding:-padding]

In [5]:
x = np.array([1,2,3,4,5,6,7,8,9,9,8,7,6,5,4,3,2,1,1,2,3,4,5,6,7,8,9,9,8,7,6,5,4,3,2,1])
x = x.reshape(2,2,3,3)
cols = im2col_indices(x)
print(cols)

im = col2im_indices(cols, x.shape)
print(im)

3.0
3.0
[[0 0 0 0 0 0 0 0 1 1 2 2 0 0 4 4 5 5]
 [0 0 0 0 0 0 1 1 2 2 3 3 4 4 5 5 6 6]
 [0 0 0 0 0 0 2 2 3 3 0 0 5 5 6 6 0 0]
 [0 0 1 1 2 2 0 0 4 4 5 5 0 0 7 7 8 8]
 [1 1 2 2 3 3 4 4 5 5 6 6 7 7 8 8 9 9]
 [2 2 3 3 0 0 5 5 6 6 0 0 8 8 9 9 0 0]
 [0 0 4 4 5 5 0 0 7 7 8 8 0 0 0 0 0 0]
 [4 4 5 5 6 6 7 7 8 8 9 9 0 0 0 0 0 0]
 [5 5 6 6 0 0 8 8 9 9 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 9 9 8 8 0 0 6 6 5 5]
 [0 0 0 0 0 0 9 9 8 8 7 7 6 6 5 5 4 4]
 [0 0 0 0 0 0 8 8 7 7 0 0 5 5 4 4 0 0]
 [0 0 9 9 8 8 0 0 6 6 5 5 0 0 3 3 2 2]
 [9 9 8 8 7 7 6 6 5 5 4 4 3 3 2 2 1 1]
 [8 8 7 7 0 0 5 5 4 4 0 0 2 2 1 1 0 0]
 [0 0 6 6 5 5 0 0 3 3 2 2 0 0 0 0 0 0]
 [6 6 5 5 4 4 3 3 2 2 1 1 0 0 0 0 0 0]
 [5 5 4 4 0 0 2 2 1 1 0 0 0 0 0 0 0 0]]
3.0
3.0
[[[[ 4 12 12]
   [24 45 36]
   [28 48 36]]

  [[36 48 28]
   [36 45 24]
   [12 12  4]]]


 [[[ 4 12 12]
   [24 45 36]
   [28 48 36]]

  [[36 48 28]
   [36 45 24]
   [12 12  4]]]]


In [6]:
def conv_forward(X, W, b, stride=1, padding=1):
    cache = W, b, stride, padding
    n_filters, d_filter, h_filter, w_filter = W.shape
    n_x, d_x, h_x, w_x = X.shape
    h_out = (h_x - h_filter + 2 * padding) / stride + 1
    w_out = (w_x - w_filter + 2 * padding) / stride + 1

    if not h_out.is_integer() or not w_out.is_integer():
        raise Exception('Invalid output dimension!')

    h_out, w_out = int(h_out), int(w_out)

    X_col = im2col_indices(X, h_filter, w_filter, padding=padding, stride=stride)
    W_col = W.reshape(n_filters, -1)

    out = W_col @ X_col + b
    out = out.reshape(n_filters, h_out, w_out, n_x)
    out = out.transpose(3, 0, 1, 2)

    cache = (X, W, b, stride, padding, X_col)

    return out, cache


def conv_backward(dout, cache):
    X, W, b, stride, padding, X_col = cache
    n_filter, d_filter, h_filter, w_filter = W.shape

    db = np.sum(dout, axis=(0, 2, 3))
    db = db.reshape(n_filter, -1)

    dout_reshaped = dout.transpose(1, 2, 3, 0).reshape(n_filter, -1)
    dW = dout_reshaped @ X_col.T
    dW = dW.reshape(W.shape)

    W_reshape = W.reshape(n_filter, -1)
    dX_col = W_reshape.T @ dout_reshaped
    dX = col2im_indices(dX_col, X.shape, h_filter, w_filter, padding=padding, stride=stride)

    return dX, dW, db

In [None]:
(out, cache)= conv_forward(x, )