In [26]:
from utils import im2col, col2im
import numpy as np

class Conv():
    
    def __init__(self, nb_filters, filter_size, nb_channels, stride=1, padding=0):
        self.n_F = nb_filters
        self.f = filter_size
        self.n_C = nb_channels
        self.s = stride
        self.p = padding

        # Xavier-Glorot initialization - used for sigmoid, tanh.
        self.W = {'val': np.random.randn(self.n_F, self.n_C, self.f, self.f) * np.sqrt(1. / (self.f)),
                  'grad': np.zeros((self.n_F, self.n_C, self.f, self.f))}  
        self.b = {'val': np.random.randn(self.n_F) * np.sqrt(1. / self.n_F), 'grad': np.zeros((self.n_F))}

        self.cache = None

    def forward(self, X):
        """
            Performs a forward convolution.
           
            Parameters:
            - X : Last conv layer of shape (m, n_C_prev, n_H_prev, n_W_prev).
            Returns:
            - out: previous layer convolved.
        """
        m, n_C_prev, n_H_prev, n_W_prev = X.shape

        n_C = self.n_F
        n_H = int((n_H_prev + 2 * self.p - self.f)/ self.s) + 1
        n_W = int((n_W_prev + 2 * self.p - self.f)/ self.s) + 1
        
        X_col = im2col(X, self.f, self.f, self.s, self.p)
        w_col = self.W['val'].reshape((self.n_F, -1))
        b_col = self.b['val'].reshape(-1, 1)
        # Perform matrix multiplication.
        out = w_col @ X_col + b_col
        # Reshape back matrix to image.
        out = np.array(np.hsplit(out, m)).reshape((m, n_C, n_H, n_W))
        self.cache = X, X_col, w_col
        return out

    def backward(self, dout):
        """
            Distributes error from previous layer to convolutional layer and
            compute error for the current convolutional layer.

            Parameters:
            - dout: error from previous layer.
            
            Returns:
            - dX: error of the current convolutional layer.
            - self.W['grad']: weights gradient.
            - self.b['grad']: bias gradient.
        """
        X, X_col, w_col = self.cache
        m, _, _, _ = X.shape
        # Compute bias gradient.
        self.b['grad'] = np.sum(dout, axis=(0,2,3))
        # Reshape dout properly.
        print(dout.shape)
        dout = dout.reshape(dout.shape[0] * dout.shape[1], dout.shape[2] * dout.shape[3])
        print(dout.shape)
        dout = np.array(np.vsplit(dout, m))
        print(dout.shape)
        dout = np.concatenate(dout, axis=-1)
        print(dout.shape)
        # Perform matrix multiplication between reshaped dout and w_col to get dX_col.
        dX_col = w_col.T @ dout
        # Perform matrix multiplication between reshaped dout and X_col to get dW_col.
        dw_col = dout @ X_col.T
        # Reshape back to image (col2im).
        dX = col2im(dX_col, X.shape, self.f, self.f, self.s, self.p)
        # Reshape dw_col into dw.
        self.W['grad'] = dw_col.reshape((dw_col.shape[0], self.n_C, self.f, self.f))
                
        return dX, self.W['grad'], self.b['grad']


In [27]:
conv = Conv(3, 3, 3)

In [28]:
X = np.random.randn(2, 3, 13, 13)
dout = np.random.randn(2, 3, 11, 11)
conv.forward(X)
conv.backward(dout)

(2, 3, 11, 11)
(6, 121)
(2, 3, 121)
(3, 242)


(array([[[[-0.55562889, -0.86237926, -3.02367533, ..., -2.50776927,
            1.05470823,  0.69871079],
          [ 1.10809196,  4.62023288, -3.4724309 , ..., -0.77335149,
            1.83344728,  1.05738722],
          [-1.68788565, -0.6923119 , -0.07026534, ..., -0.24754246,
           -0.3759313 ,  1.3095935 ],
          ...,
          [-3.86933986,  2.1604432 ,  3.80712143, ..., -3.31037181,
           -1.85733637, -0.34719003],
          [ 2.93907322, -1.05342663, -0.66916303, ...,  2.9535061 ,
            2.20370249,  1.19101881],
          [ 1.05426932,  0.16157848,  1.2425575 , ..., -1.31893627,
           -1.05997226, -1.38368711]],
 
         [[ 0.360244  , -0.3387647 ,  1.7040068 , ...,  1.02572506,
           -1.01612056, -0.39746905],
          [-0.76451918, -3.58928411,  4.52713683, ...,  2.2402647 ,
           -1.88721682, -1.42354057],
          [ 0.24271413, -0.70826555,  4.54180053, ..., -0.40472293,
           -1.16113409, -1.58488388],
          ...,
          [ 0

In [50]:
cache = dict()
cache['x'] = np.random.randn(2, 6, 13, 13)
cache['w'] = np.random.randn(16, 6, 3, 3)
cache['b'] = np.random.randn(16)
cache = (cache['x'], cache['w'], cache['b'])
dout = np.random.randn(2, 16, 11, 11)
dx, dw, db = conv_backward_naive(dout, cache)

(121, 16) (16, 54)
11
11
3
3
(6, 13, 13)
0 0
0 1
0 2
0 3
0 4
0 5
0 6
0 7
0 8
0 9
0 10
1 0
1 1
1 2
1 3
1 4
1 5
1 6
1 7
1 8
1 9
1 10
2 0
2 1
2 2
2 3
2 4
2 5
2 6
2 7
2 8
2 9
2 10
3 0
3 1
3 2
3 3
3 4
3 5
3 6
3 7
3 8
3 9
3 10
4 0
4 1
4 2
4 3
4 4
4 5
4 6
4 7
4 8
4 9
4 10
5 0
5 1
5 2
5 3
5 4
5 5
5 6
5 7
5 8
5 9
5 10
6 0
6 1
6 2
6 3
6 4
6 5
6 6
6 7
6 8
6 9
6 10
7 0
7 1
7 2
7 3
7 4
7 5
7 6
7 7
7 8
7 9
7 10
8 0
8 1
8 2
8 3
8 4
8 5
8 6
8 7
8 8
8 9
8 10
9 0
9 1
9 2
9 3
9 4
9 5
9 6
9 7
9 8
9 9
9 10
10 0
10 1
10 2
10 3
10 4
10 5
10 6
10 7
10 8
10 9
10 10
(6, 13, 13)
(121, 16) (16, 54)
11
11
3
3
(6, 13, 13)
0 0
0 1
0 2
0 3
0 4
0 5
0 6
0 7
0 8
0 9
0 10
1 0
1 1
1 2
1 3
1 4
1 5
1 6
1 7
1 8
1 9
1 10
2 0
2 1
2 2
2 3
2 4
2 5
2 6
2 7
2 8
2 9
2 10
3 0
3 1
3 2
3 3
3 4
3 5
3 6
3 7
3 8
3 9
3 10
4 0
4 1
4 2
4 3
4 4
4 5
4 6
4 7
4 8
4 9
4 10
5 0
5 1
5 2
5 3
5 4
5 5
5 6
5 7
5 8
5 9
5 10
6 0
6 1
6 2
6 3
6 4
6 5
6 6
6 7
6 8
6 9
6 10
7 0
7 1
7 2
7 3
7 4
7 5
7 6
7 7
7 8
7 9
7 10
8 0
8 1
8 2
8 3
8 4
8 5
8 6
8 7
8 8
8 9


In [17]:
dx

array([[[[ -3.36827113,  -0.65160937,   7.12135642, ...,   5.84980384,
           -1.74683695,  -0.67293779],
         [  2.1693469 ,  -0.48887211,  10.3373096 , ...,  -3.53481662,
           16.31029144,  -3.29171516],
         [ -0.06495796,  -7.75884848,  -3.57374069, ...,  12.6859035 ,
           26.63869245,  10.35722145],
         ...,
         [  2.98768258,  -4.54868289, -11.7673802 , ...,   2.13664031,
            1.23112468,  12.48650446],
         [ -1.38737368,   7.56916801,   5.70164278, ...,   5.41080627,
           11.49004462,   1.35352731],
         [  7.25921345,   4.64831587,   2.67893909, ...,  -1.75974035,
            5.21661409,   0.03039314]],

        [[ -4.24182066,   8.83162925,  -9.65781209, ...,   4.76023625,
           -7.03690438,  -2.17598465],
         [  0.75796496,  -6.24083929, -17.57993266, ...,   3.89014721,
            6.42156898,  -2.22227202],
         [ -7.97511166, -14.58480416,  -5.58985523, ...,   9.9890899 ,
          -14.03607268,  -2.30362