In [2]:
import numpy as np

In [3]:
# Sources:
# + https://leonardoaraujosantos.gitbook.io/artificial-inteligence/machine_learning/deep_learning/convolution_layer/making_faster
# + https://jott.live/markdown/as_strided

In [4]:
import torch
import torch.nn as nn
def test_conv2d(func):
  r = np.random.random((3, 3, 10, 10)).astype(np.float32)

  strides = 3

  conv = nn.Conv2d(3, 3, 3, strides, padding=0, dilation=1, bias=False)
  kernel = conv.weight.detach().numpy()

  tout = conv(torch.tensor(r)).detach().numpy()
  out = func(r, kernel, strides)

  print(np.allclose(tout, out))

In [5]:
# Classic 2d convolution (actually it's cross-correlation 🤓)
# https://ezyang.github.io/convolution-visualizer/index.html
def naive_conv2d(x, kernel, strides):
  if isinstance(strides, int):
    strides = (strides, strides)
  
  bs, c, h, w = x.shape
  out_features, in_features, kh, kw = kernel.shape
  sh, sw = strides

  if c != in_features:
    raise ValueError("Input channels must match kernel channels")

  out_h = (h - kh) // sh + 1
  out_w = (w - kw) // sw + 1

  out = np.zeros((bs, out_features, out_h, out_w))

  for b in range(bs):
    for of in range(out_features):
      for i in range(out_h):
        for j in range(out_w):
          for f in range(in_features):
            for k in range(kh):
              for l in range(kw):
                out[b, of, i, j] += x[b, f, i*sh + k, j*sw + l] * kernel[of, f, k, l]

  return out

In [6]:
test_conv2d(naive_conv2d)

True


In [7]:
# To speed up the convolution, we can use the im2col trick and matmul
def naive_im2col(x, sh, sw, kh, kw):
  bs, c, h, w = x.shape

  conv_h = (h - kh) // sh + 1
  conv_w = (w - kw) // sw + 1

  # im2col matrix shape
  out_h = (kh*kw*c)
  out_w = conv_h * conv_w

  out = np.zeros((bs, out_h, out_w))

  for b in range(bs):
    for ch in range(c):
      for i in range(conv_h):
        for j in range(conv_w):
          for k in range(kh):
            for l in range(kw):
              out[b, ch*kh*kw + k*kw + l, i*conv_w + j] = x[b, ch, i*sh + k, j*sw + l]

  return out

def naive_im2col_conv2d(x, kernel, strides):
  if isinstance(strides, int):
    strides = (strides, strides)
  
  bs, c, h, w = x.shape
  out_features, in_features, kh, kw = kernel.shape
  sh, sw = strides

  if c != in_features:
    raise ValueError("Input channels must match kernel channels")

  out_h = (h - kh) // sh + 1
  out_w = (w - kw) // sw + 1

  x_col = naive_im2col(x, sh, sw, kh, kw)
  kernel_col = kernel.reshape(out_features, -1)

  return np.matmul(kernel_col, x_col).reshape(bs, out_features, out_h, out_w)


In [8]:
# Example from https://leonardoaraujosantos.gitbook.io/artificial-inteligence/machine_learning/deep_learning/convolution_layer/making_faster
r = np.arange(3*4*4).reshape(1,3, 4, 4)
print(r)
naive_im2col(r+1, 1, 1, 2, 2)

[[[[ 0  1  2  3]
   [ 4  5  6  7]
   [ 8  9 10 11]
   [12 13 14 15]]

  [[16 17 18 19]
   [20 21 22 23]
   [24 25 26 27]
   [28 29 30 31]]

  [[32 33 34 35]
   [36 37 38 39]
   [40 41 42 43]
   [44 45 46 47]]]]


array([[[ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
        [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
        [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
        [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
        [17., 18., 19., 21., 22., 23., 25., 26., 27.],
        [18., 19., 20., 22., 23., 24., 26., 27., 28.],
        [21., 22., 23., 25., 26., 27., 29., 30., 31.],
        [22., 23., 24., 26., 27., 28., 30., 31., 32.],
        [33., 34., 35., 37., 38., 39., 41., 42., 43.],
        [34., 35., 36., 38., 39., 40., 42., 43., 44.],
        [37., 38., 39., 41., 42., 43., 45., 46., 47.],
        [38., 39., 40., 42., 43., 44., 46., 47., 48.]]])

In [9]:
test_conv2d(naive_im2col_conv2d)

True


In [10]:
# Benchmark the two implementations
import timeit
iterations = 10
r = np.random.random((3, 3, 100, 100)).astype(np.float32)
kernel = np.random.random((3, 3, 3, 3)).astype(np.float32)

torch_conv = nn.Conv2d(3, 3, 3, 1, padding=0, dilation=1, bias=False)
torch_conv.weight.data = torch.tensor(kernel)

print("Raw python: ", timeit.timeit(lambda: naive_conv2d(r, kernel, 1), number=iterations))
print("im2col: ", timeit.timeit(lambda: naive_im2col_conv2d(r, kernel, 1), number=iterations))
print("pytorch's: ", timeit.timeit(lambda: torch_conv(torch.tensor(r)), number=iterations))

Raw python:  21.136627309000687
im2col:  5.1284166109999205
pytorch's:  0.10551550399941334


*im2col* convolution is clearly faster. But our naive implementation lags behind pytorch's. We can improve it by using as_strided tricks.

In [11]:
def as_strided(x, shape, strides):
  nb = x.dtype.itemsize
  strides = tuple(stride*nb for stride in strides)
  return np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)

def im2col(x, sh, sw, kh, kw):
  bs, c, h, w = x.shape

  conv_h = (h - kh) // sh + 1
  conv_w = (w - kw) // sw + 1

  # Former loop rewritten for easier as_strided usage
  # out_h = (kh*kw*c)
  # out_w = conv_h * conv_w
  # out = np.zeros((bs * out_h * out_w))
  # out_bs_strides = out_h * out_w
  # out_h_strides = out_w
  # x = x.ravel()
  # x_bs_strides = c * h * w
  # x_c_strides = h * w
  # x_h_strides = w
  # for b in range(bs):
  #   for ch in range(c):
  #     for i in range(conv_h):
  #       for j in range(conv_w):
  #         for k in range(kh):
  #           for l in range(kw):
  #             out[b*out_bs_strides + 
  #                 (ch*kh*kw + k*kw + l)*out_h_strides +
  #                 (i*conv_w + j)] = \
  #             x[b*x_bs_strides + ch * x_c_strides +
  #               (i*sh + k)*x_h_strides +
  #               j*sw + l]
  # out = out.reshape(bs, out_h, out_w)

  # Now it's easier to use as_strided

  # x_idx = b*x_bs_strides + ch * x_c_strides +
  #         (i*sh + k)*x_h_strides +
  #         j*sw + l
  # d (x_idx) / db = x_bs_strides
  # d (x_idx) / dch = x_c_strides
  # d (x_idx) / di = sh*x_h_strides
  # d (x_idx) / dj = sw
  # d (x_idx) / dk = x_h_strides
  # d (x_idx) / dl = 1

  x_bs_strides = c * h * w
  x_c_strides = h * w
  x_h_strides = w
  out = as_strided(x, (bs, c, conv_h, conv_w, kh, kw), (x_bs_strides, x_c_strides, sh*x_h_strides, sw, x_h_strides, 1))
  # for b in range(bs):
  #   for ch in range(c):
  #     for i in range(conv_h):
  #       for j in range(conv_w):
  #         for k in range(kh):
  #           for l in range(kw):
  #             out[b*out_bs_strides + (ch*kh*kw + k*kw + l)*out_h_strides + (i*conv_w + j)] = ...
  # Notice (i*conv_w + j) is at the end so:
  out = out.transpose(0, 1, 4, 5, 2, 3)
  # out.shape = (bs, c, kh, kw, conv_h, conv_w)
  # Reshape to (bs, out_h, out_w) = (bs, c * kh * kw, conv_h * conv_w)
  out = out.reshape(bs, c * kh * kw, conv_h * conv_w)
  
  return out

def as_strided_im2col_conv2d(x, kernel, strides):
  if isinstance(strides, int):
    strides = (strides, strides)
  
  bs, c, h, w = x.shape
  out_features, in_features, kh, kw = kernel.shape
  sh, sw = strides

  if c != in_features:
    raise ValueError("Input channels must match kernel channels")

  out_h = (h - kh) // sh + 1
  out_w = (w - kw) // sw + 1

  x_col = np.ascontiguousarray(im2col(x, sh, sw, kh, kw))
  kernel_col = kernel.reshape(out_features, -1)

  return np.matmul(kernel_col, x_col).reshape(bs, out_features, out_h, out_w)

In [12]:
# Example from https://leonardoaraujosantos.gitbook.io/artificial-inteligence/machine_learning/deep_learning/convolution_layer/making_faster
r = np.arange(3*4*4).reshape(1,3, 4, 4)
print(r)
im2col(r+1, 1, 1, 2, 2)

[[[[ 0  1  2  3]
   [ 4  5  6  7]
   [ 8  9 10 11]
   [12 13 14 15]]

  [[16 17 18 19]
   [20 21 22 23]
   [24 25 26 27]
   [28 29 30 31]]

  [[32 33 34 35]
   [36 37 38 39]
   [40 41 42 43]
   [44 45 46 47]]]]


array([[[ 1,  2,  3,  5,  6,  7,  9, 10, 11],
        [ 2,  3,  4,  6,  7,  8, 10, 11, 12],
        [ 5,  6,  7,  9, 10, 11, 13, 14, 15],
        [ 6,  7,  8, 10, 11, 12, 14, 15, 16],
        [17, 18, 19, 21, 22, 23, 25, 26, 27],
        [18, 19, 20, 22, 23, 24, 26, 27, 28],
        [21, 22, 23, 25, 26, 27, 29, 30, 31],
        [22, 23, 24, 26, 27, 28, 30, 31, 32],
        [33, 34, 35, 37, 38, 39, 41, 42, 43],
        [34, 35, 36, 38, 39, 40, 42, 43, 44],
        [37, 38, 39, 41, 42, 43, 45, 46, 47],
        [38, 39, 40, 42, 43, 44, 46, 47, 48]]])

In [13]:
iterations = 10
r = np.random.random((3, 3, 100, 100)).astype(np.float32)
kernel = np.random.random((3, 3, 3, 3)).astype(np.float32)

torch_conv = nn.Conv2d(3, 3, 3, 1, padding=0, dilation=1, bias=False)
torch_conv.weight.data = torch.tensor(kernel)

print("Raw python: ", timeit.timeit(lambda: naive_conv2d(r, kernel, 1), number=iterations))
print("im2col: ", timeit.timeit(lambda: naive_im2col_conv2d(r, kernel, 1), number=iterations))
print("im2col as_strided: ", timeit.timeit(lambda: as_strided_im2col_conv2d(r, kernel, 1), number=iterations))
print("pytorch's: ", timeit.timeit(lambda: torch_conv(torch.tensor(r)), number=iterations))

Raw python:  20.648738496998703
im2col:  5.13208071400004
im2col as_strided:  0.005101021999507793
pytorch's:  0.002748929000517819


I must be tripping! Our new implementation is almost as fast as pytorch's!. Let's do some more tests to confirm.

In [14]:
iterations = 200
r = np.random.random((3, 3, 1000, 1000)).astype(np.float32)
kernel = np.random.random((3, 3, 3, 3)).astype(np.float32)

torch_conv = nn.Conv2d(3, 3, 3, 1, padding=0, dilation=1, bias=False)
torch_conv.weight.data = torch.tensor(kernel)

print("im2col as_strided: ", timeit.timeit(lambda: as_strided_im2col_conv2d(r, kernel, 1), number=iterations))
print("pytorch's: ", timeit.timeit(lambda: torch_conv(torch.tensor(r)), number=iterations))

im2col as_strided:  28.469897312001194
pytorch's:  10.562133766001352


Siiiiiiiiike. Our implementation is actually slower. Who would have guessed? But it's still faster than the naive implementation.

If we do the as_strided trick directly on the naive convolution instead of the im2col, would it be faster?

In [15]:
def as_strided_conv2d(x, kernel, strides):
  if isinstance(strides, int):
    strides = (strides, strides)
  
  bs, c, h, w = x.shape
  out_features, in_features, kh, kw = kernel.shape
  sh, sw = strides

  if c != in_features:
    raise ValueError("Input channels must match kernel channels")

  out_h = (h - kh) // sh + 1
  out_w = (w - kw) // sw + 1

  # out = np.zeros((bs, out_features, out_h, out_w))
  # for b in range(bs):
  #   for of in range(out_features):
  #     for i in range(out_h):
  #       for j in range(out_w):
  #         for f in range(in_features):
  #           for k in range(kh):
  #             for l in range(kw):
  #               out[b, of, i, j] += x[b, f, i*sh + k, j*sw + l] * kernel[of, f, k, l]


  # Rewrite the loop to use as_strided easily
  # out = np.zeros((bs * out_features * out_h * out_w))
  # out_bs_strides = out_features * out_h * out_w
  # out_f_strides = out_h * out_w
  # out_h_strides = out_w

  # x = x.ravel()
  x_bs_strides = c * h * w
  x_c_strides = h * w
  x_h_strides = w

  # kernel = kernel.ravel()
  kernel_of_strides = in_features * kh * kw
  kernel_if_strides = kh * kw
  kernel_k_strides = kw
  # for b in range(bs):
  #   for of in range(out_features):
  #     for i in range(out_h):
  #       for j in range(out_w):
  #         for f in range(in_features):
  #           for k in range(kh):
  #             for l in range(kw):
  #               out_idx = b * out_bs_strides + of * out_f_strides + i * out_h_strides + j
  #               x_idx = b * x_bs_strides + f * x_c_strides + (i*sh + k) * x_h_strides + (j*sw + l)
  #               kernel_idx = of * kernel_of_strides + f * kernel_if_strides + k * kernel_k_strides + l
  #               out[out_idx] += x[x_idx] * kernel[kernel_idx]
  # out = out.reshape(bs, out_features, out_h, out_w)

  loop_shape = (bs, out_features, out_h, out_w, in_features, kh, kw)
  # d (x_idx) / db = x_bs_strides
  # d (x_idx) / dof = 0
  # d (x_idx) / di = sh*x_h_strides
  # d (x_idx) / dj = sw
  # d (x_idx) / df = x_c_strides
  # d (x_idx) / dk = x_h_strides
  # d (x_idx) / dl = 1
  x = as_strided(x, loop_shape, (x_bs_strides, 0, sh*x_h_strides, sw, x_c_strides, x_h_strides, 1))

  # d (kernel_idx) / db = 0
  # d (kernel_idx) / dof = kernel_of_strides
  # d (kernel_idx) / di = 0
  # d (kernel_idx) / dj = 0
  # d (kernel_idx) / df = kernel_if_strides
  # d (kernel_idx) / dk = kernel_k_strides
  # d (kernel_idx) / dl = 1
  kernel = as_strided(kernel, loop_shape, (0, kernel_of_strides, 0, 0, kernel_if_strides, kernel_k_strides, 1))

  x = np.ascontiguousarray(x)
  kernel = np.ascontiguousarray(kernel)
  
  out = np.sum(x * kernel, axis=(4, 5, 6))

  return out

In [16]:
r = np.random.random((3, 3, 10, 10)).astype(np.float32)
kernel = np.random.random((3, 3, 3, 3)).astype(np.float32)

o1 = naive_conv2d(r, kernel, 1)
o2 = as_strided_conv2d(r, kernel, 1)

print(np.allclose(o1, o2))

True


In [17]:
iterations = 10
r = np.random.random((3, 3, 1000, 1000)).astype(np.float32)
kernel = np.random.random((3, 3, 3, 3)).astype(np.float32)

torch_conv = nn.Conv2d(3, 3, 3, 1, padding=0, dilation=1, bias=False)
torch_conv.weight.data = torch.tensor(kernel)

# print("Raw python: ", timeit.timeit(lambda: naive_conv2d(r, kernel, 1), number=iterations))
# print("im2col: ", timeit.timeit(lambda: naive_im2col_conv2d(r, kernel, 1), number=iterations))
print("im2col as_strided: ", timeit.timeit(lambda: as_strided_im2col_conv2d(r, kernel, 1), number=iterations))
print("conv as_strided: ", timeit.timeit(lambda: as_strided_conv2d(r, kernel, 1), number=iterations))
print("pytorch's: ", timeit.timeit(lambda: torch_conv(torch.tensor(r)), number=iterations))

im2col as_strided:  1.4121128540009522
conv as_strided:  15.061238737000167
pytorch's:  0.539920388999235


Its slower wtf

might revisit this later