### Imports

In [61]:
from numpy.lib.stride_tricks import as_strided
import numpy as np
import torch
import torch.nn as nn

### Input storage
Z : [BATCHES] x [HEIGHT] x [WIDTH] x [INPUT_CHANNELS] NHWC format </br>
W : [KERNEL_SIZE] X [KERNEL_SIZE] x [INPUT_CHANNELS] x [OUTPUT_CHANNELS]

</br>
Pytorch expects Z to be in NCHW form and W to be in [OUT_CHANNELS][IN_CHANNELS][K][K]

In [62]:
Z = np.random.rand(10,32,32,8)
W = np.random.rand(3,3,8,16)

### Reference Implementation

In [63]:
def conv_reference(Z, weights):
  #1. NHWC -> NCHW
  Z_tensor = torch.tensor(Z).permute(0,3,1,2)
  W_tensor = torch.tensor(weights).permute(3,2,0,1)

  out = nn.functional.conv2d(Z_tensor, W_tensor)

  return out.permute(0,2,3,1).numpy()

In [64]:
%%time
out_reference = conv_reference(Z,W)
out_reference.shape

CPU times: user 3.96 ms, sys: 0 ns, total: 3.96 ms
Wall time: 4.79 ms


(10, 30, 30, 16)

### Naive Implementation involving 7 for loops

In [65]:
def conv_naive(Z, weights):
  batches, height, width, C_in = Z.shape
  K,_,_,C_out = weights.shape

  out = np.zeros(shape=(batches, height-K+1, width-K+1, C_out))

  for batch in range(batches):
    for c_in in range(C_in):
      for c_out in range(C_out):
        for h in range(height-K+1):
          for w in range(width-K+1):
            for i in range(K):
              for j in range(K):
                out[batch, h, w, c_out] += Z[batch, h+i, w+j, c_in] * weights[i,j,c_in,c_out]

  return out

In [66]:
%%time
out_naive = conv_naive(Z,W)
out_naive.shape

CPU times: user 11.9 s, sys: 3.02 ms, total: 11.9 s
Wall time: 12.7 s


(10, 30, 30, 16)

In [67]:
np.linalg.norm(out_naive - out_reference)

np.float64(2.0491806431769934e-12)

### Matrix Multiplication based convolution
First think about the 1x1 convolution case, where we do matmul of matrix Z with W[0][0] vector. This can be extended to the kxk convolution, where we do matmul of all indices of Z where filter's [i,j]th element would act upon with W[i][j]

In [68]:
def conv_matmul(Z, weights):
  batches, height, width, C_in = Z.shape
  K,_,_,C_out = weights.shape

  out = np.zeros(shape=(batches, height-K+1, width-K+1, C_out))

  for i in range(K):
    for j in range(K):
      out += Z[:, i:i+height-K+1, j:j+width-K+1, :] @ weights[i, j, :, :]

  return out

In [69]:
%%time
out_matmul = conv_matmul(Z, W)
out_matmul.shape

CPU times: user 16 ms, sys: 0 ns, total: 16 ms
Wall time: 8.65 ms


(10, 30, 30, 16)

In [70]:
np.linalg.norm(out_matmul - out_reference)

np.float64(1.7882188248097304e-12)

### im2col based convolution
We treat convolution conv(X,W) as matrix multiplication of X_cap with W. In backprop, to calculate derivative of conv(X,W) and need to calculate v * d[conv(X,W)]/dW where v is intermediate gradient signal propagated so far, we can do conv(v, flip(X)) = X_hat.T @ v

In [71]:
def conv_im2col(Z, weights):
  batches, height, width, C_in = Z.shape
  K,_,_,C_out = weights.shape
  Ns, Hs, Ws, Cs = Z.strides

  out = np.zeros(shape=(batches, height-K+1, width-K+1, C_out))

  #create X_cap
  temp = as_strided(Z, shape=(batches, height-K+1, width-K+1, K, K, C_in), strides=(Ns, Hs, Ws, Hs, Ws, Cs))
  X_cap = temp.reshape(-1, K*K*C_in)

  out = X_cap @ weights.reshape(K*K*C_in, C_out)

  return out.reshape(batches, height-K+1, width-K+1, C_out)

In [72]:
%%time
out_im2col = conv_im2col(Z, W)
out_im2col.shape

CPU times: user 8.31 ms, sys: 0 ns, total: 8.31 ms
Wall time: 6.03 ms


(10, 30, 30, 16)

In [73]:
np.linalg.norm(out_im2col - out_reference)

np.float64(0.0)