In [50]:
import numpy as np

In [51]:
def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1):
    # First figure out what the size of the output should be
    N, C, H, W = x_shape
    assert (H + 2 * padding - field_height) % stride == 0
    assert (W + 2 * padding - field_height) % stride == 0
    out_height = (H + 2 * padding - field_height) // stride + 1
    out_width = (W + 2 * padding - field_width) // stride + 1

    i0 = np.repeat(np.arange(field_height), field_width)
    i0 = np.tile(i0, C)
    i1 = stride * np.repeat(np.arange(out_height), out_width)
    j0 = np.tile(np.arange(field_width), field_height * C)
    j1 = stride * np.tile(np.arange(out_width), out_height)
    i = i0.reshape(-1, 1) + i1.reshape(1, -1)
    j = j0.reshape(-1, 1) + j1.reshape(1, -1)

    k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)

    return (k, i, j)

def im2col_indices(x, field_height, field_width, padding=1, stride=1):
    """ An implementation of im2col based on some fancy indexing """
    # Zero-pad the input
    p = padding
    x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')

    k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding,
                               stride)

    cols = x_padded[:, k, i.astype(int), j.astype(int)]
    C = x.shape[1]
    cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
    return cols


def col2im_indices(self,cols, x_shape, field_height=3, field_width=3, padding=1,
                   stride=1):
    """ An implementation of col2im based on fancy indexing and np.add.at """
    N, C, H, W = x_shape
    H_padded, W_padded = H + 2 * padding, W + 2 * padding
    x_padded = np.zeros((N, C, H_padded, W_padded), dtype=cols.dtype)
    k, i, j = self.get_im2col_indices(x_shape, field_height, field_width, padding,
                               stride)
    cols_reshaped = cols.reshape(C * field_height * field_width, -1, N)
    cols_reshaped = cols_reshaped.transpose(2, 0, 1)
    np.add.at(x_padded, (slice(None), k, i, j), cols_reshaped)
    if padding == 0:
        return x_padded
    return x_padded[:, :, padding:-padding, padding:-padding]



In [68]:
N = 3
C = 3
W = 5
H = 5
data = np.arange(N*C*W*H).reshape([N, C, W, H])


# print(data.shape)
# print(im2col_indices(data,3,3,0,1).shape)

print(im2col_indices(data,3,3,0,1)[:,3])
print(data)


[ 1  2  3  6  7  8 11 12 13 26 27 28 31 32 33 36 37 38 51 52 53 56 57 58
 61 62 63]
[[[[  0   1   2   3   4]
   [  5   6   7   8   9]
   [ 10  11  12  13  14]
   [ 15  16  17  18  19]
   [ 20  21  22  23  24]]

  [[ 25  26  27  28  29]
   [ 30  31  32  33  34]
   [ 35  36  37  38  39]
   [ 40  41  42  43  44]
   [ 45  46  47  48  49]]

  [[ 50  51  52  53  54]
   [ 55  56  57  58  59]
   [ 60  61  62  63  64]
   [ 65  66  67  68  69]
   [ 70  71  72  73  74]]]


 [[[ 75  76  77  78  79]
   [ 80  81  82  83  84]
   [ 85  86  87  88  89]
   [ 90  91  92  93  94]
   [ 95  96  97  98  99]]

  [[100 101 102 103 104]
   [105 106 107 108 109]
   [110 111 112 113 114]
   [115 116 117 118 119]
   [120 121 122 123 124]]

  [[125 126 127 128 129]
   [130 131 132 133 134]
   [135 136 137 138 139]
   [140 141 142 143 144]
   [145 146 147 148 149]]]


 [[[150 151 152 153 154]
   [155 156 157 158 159]
   [160 161 162 163 164]
   [165 166 167 168 169]
   [170 171 172 173 174]]

  [[175 176 177 178 179

In [36]:
print(im2col_indices(data,3,3,0,1))

[[  0  75 150 225   1  76 151 226   2  77 152 227   5  80 155 230   6  81
  156 231   7  82 157 232  10  85 160 235  11  86 161 236  12  87 162 237]
 [  1  76 151 226   2  77 152 227   3  78 153 228   6  81 156 231   7  82
  157 232   8  83 158 233  11  86 161 236  12  87 162 237  13  88 163 238]
 [  2  77 152 227   3  78 153 228   4  79 154 229   7  82 157 232   8  83
  158 233   9  84 159 234  12  87 162 237  13  88 163 238  14  89 164 239]
 [  5  80 155 230   6  81 156 231   7  82 157 232  10  85 160 235  11  86
  161 236  12  87 162 237  15  90 165 240  16  91 166 241  17  92 167 242]
 [  6  81 156 231   7  82 157 232   8  83 158 233  11  86 161 236  12  87
  162 237  13  88 163 238  16  91 166 241  17  92 167 242  18  93 168 243]
 [  7  82 157 232   8  83 158 233   9  84 159 234  12  87 162 237  13  88
  163 238  14  89 164 239  17  92 167 242  18  93 168 243  19  94 169 244]
 [ 10  85 160 235  11  86 161 236  12  87 162 237  15  90 165 240  16  91
  166 241  17  92 167 242  20  9

In [73]:
N_w = 2
N_c = C
W_w = 3
W_h = 3
w = np.arange(N_w * N_c * W_w *W_h).reshape([N_w,N_c,W_w,W_h])
print(w)

[[[[ 0  1  2]
   [ 3  4  5]
   [ 6  7  8]]

  [[ 9 10 11]
   [12 13 14]
   [15 16 17]]

  [[18 19 20]
   [21 22 23]
   [24 25 26]]]


 [[[27 28 29]
   [30 31 32]
   [33 34 35]]

  [[36 37 38]
   [39 40 41]
   [42 43 44]]

  [[45 46 47]
   [48 49 50]
   [51 52 53]]]]


In [70]:
print(w.reshape((w.shape[0], -1)).shape)
print(w.reshape((w.shape[0], -1)))

(2, 27)
[[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
  24 25 26]
 [27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
  51 52 53]]


In [71]:
dot_prod = w.reshape((w.shape[0],-1)).dot(im2col_indices(data,W_w,W_h,0,1))

In [72]:
print(dot_prod)

[[ 15219  41544  67869  15570  41895  68220  15921  42246  68571  16974
   43299  69624  17325  43650  69975  17676  44001  70326  18729  45054
   71379  19080  45405  71730  19431  45756  72081]
 [ 37818 118818 199818  38898 119898 200898  39978 120978 201978  43218
  124218 205218  44298 125298 206298  45378 126378 207378  48618 129618
  210618  49698 130698 211698  50778 131778 212778]]
