In [2]:
import numpy as np 
import h5py
import matplotlib.pyplot as plt
from joblib import Memory

np.random.seed(1)
mem = Memory(location='./tmp', verbose=0)


# Utilities 


In [106]:
def plot_image(x):
    plotdata = x / 255
    plt.gray()
    plt.imshow(plotdata)
    plt.show()

def zero_pad(x, pad):
    return np.pad(
            x, 
            ((0,0), (pad, pad), (pad, pad), (0,0)),
            'constant', 
            constant_values = (0)
        )

def create_mask(A):
    """
    A->m x f x f x c
    max_arr->m x 1
    return -> m x f x f x c
    """
    
    max_arr = np.max(A, axis=(1,2,3))
    return A == max_arr[:,None,None,None] #adding dimensions to make broadcasted comparision possible

In [66]:
a = np.random.randint(2, size=(2,2,2,2))
print(create_mask(a))


[[[[False  True]
   [False  True]]

  [[False False]
   [ True  True]]]


 [[[ True  True]
   [False  True]]

  [[ True False]
   [False  True]]]]


# Convolve Forward

1. nh_prev includes the padding
2. w_shifted -> for broadcasting to all training examples(m), the first dimension has to be 1
3. <strike>assumption -> prefer looping over number of filters rather than training examples</strike>
4. reshape the two matrices for a two way broadcast 

In [47]:
@mem.cache
def convolve_forward(A_prev, w, b, hparams):
    
    stride = hparams["stride"]
    pad = hparams["pad"]
    
    A_prev_pad = zero_pad(A_prev, pad)
    m, nh_prev, nw_prev, nc_prev = A_prev_pad.shape
    f, f, nc_prev, nc = w.shape
    
    nh = int((nh_prev - f)/stride) + 1
    nw = int((nw_prev - f)/stride) + 1
    Z = np.zeros((m, nh, nw, nc))
        
    for i in range(0,nh_prev-f+1,stride):
        for j in range(0,nw_prev-f+1,stride):
            A_slice = A_prev_pad[:, i:i+f, j:j+f, :]
            A_slice = A_slice.reshape(*A_slice.shape,1)
            w_shifted = np.array([w[:,:,:,:]])

            x = np.sum( A_slice * w_shifted, axis=(1,2,3))
            Z[:, int(i/stride), int(j/stride), :] = x + b[:,:,:,:]
    
    return Z

In [48]:
a = np.random.randn(10, 5, 5, 3)
w = np.random.randn(3,3,3,10)
b = np.random.randn(1,1,1,10)
hparams={"stride":1, "pad":1}
z = convolve_forward(a, w, b, hparams)

m,nh,nw,nc = a.shape
f,f,nc,nc_next = w.shape

print(z.shape)


(10, 5, 5, 10)


# Pool forward

In [49]:
@mem.cache
def pool_forward(A_prev, hparams):
    stride = hparams["stride"]
    f = hparams["f"]
    t = hparams["type"]
    m, nh_prev, nw_prev, nc_prev = A_prev.shape
    
    nh = int((nh_prev - f)/stride) + 1
    nw = int((nw_prev - f)/stride) + 1
    nc = nc_prev
    
    Z = np.zeros((m, nh, nw, nc))
    
    for i in range(0,nh_prev-f+1,stride):
        for j in range(0,nw_prev-f+1,stride):
            A_slice = A_prev[:, i:i+f, j:j+f, :]

            if t == "max":
                Z[:, int(i/stride), int(j/stride), :] = np.max(A_slice, axis = (1,2))
            else:
                Z[:, int(i/stride), int(j/stride), :] = np.avg(A_slice, axis = (1,2))

    
    return Z

In [50]:
a = np.random.randint(99,size=(1,4,4,1))
hparams={"stride":2, "f":2, "type":"avg"}

z = max_pool_forward(a, hparams)
print("a",a[0])
print("b",z)

a [[[12]
  [ 6]
  [54]
  [52]]

 [[15]
  [91]
  [28]
  [13]]

 [[33]
  [39]
  [71]
  [93]]

 [[79]
  [11]
  [97]
  [79]]]
b [[[[91.]
   [54.]]

  [[79.]
   [97.]]]]


# Convolve Backward

In [71]:
@mem.cache
def convolve_backward(dA, w, b, hparams, A_prev):
    m, nh, nw, nc = dA.shape
    m, nh_prev, nw_prev, nc_prev = A_prev.shape
    f, f, nc_prev, nc = w.shape
    stride = hparams["stride"]
    pad = hparams["pad"]
    
    dA_prev = np.zeros(A_prev.shape) 
    dw = np.zeros(w.shape)
    db = np.zeros(b.shape)
    
    for i in range(0, nh_prev-f+1, stride):
        for j in range(0, nw_prev-f+1, stride):
            for c in range(0,nc):
                dA_slice = dA[:,int(i/stride), int(j/stride), c]
                dA_slice = dA_slice.reshape(*dA_slice.shape, 1,1, 1)
                w_shifted = np.array([w[:,:,:,c]])
                
                dA_prev[:,i:i+f, j:j+f, :] += dA_slice*w_shifted
                dw[:,:,:,c] += np.sum(A_prev[:,i:i+f, j:j+f, :] * dA_slice, axis=0)
                db[:,:,:,c] += np.sum(dA_slice, axis=0)
                
    return dA_prev, dw, db

In [None]:
da = np.random.randint(3, size=(10, 5, 5, 12))
w = np.random.randint(3, size=(3,3,3,12))
b = np.random.randint(3, size=(1,1,1,12))

hparams={"stride":1, "pad":1}
a_prev = np.random.randint(3, size=(10, 7,7,3))
print(convolve_backward(da,w,b,hparams,a_prev)[2])
print(convolve_backward(da,w,b,hparams,a_prev)[2].shape)

# Pool Backward

In [107]:
@mem.cache
def pool_backward(dA, hparams, A_prev):
    m, nh, nw, nc = dA.shape
    m, nh_prev, nw_prev, nc_prev = A_prev.shape
    stride = hparams["stride"]
    f = hparams["f"]
    dA_prev = np.zeros(A_prev.shape) 
     
    for i in range(0, nh_prev-f+1, stride):
        for j in range(0, nw_prev-f+1, stride):
            for c in range(0,nc):
                dA_slice = dA[:,int(i/stride), int(j/stride), c]
                dA_slice = dA_slice.reshape(*dA_slice.shape, 1,1, 1)
                mask = create_mask(A_prev[:,i:i+f, j:j+f, :])
                
                dA_prev[:,i:i+f, j:j+f, :] += dA_slice*mask
                
    return dA_prev

In [108]:
da = np.random.randint(3, size=(2, 3, 3, 3))
hparams={"stride":2, "f":2}
a_prev = np.random.randint(3, size=(2, 6,6,3))

print(pool_backward(da,hparams,a_prev))
print(pool_backward(da,hparams,a_prev).shape)

(2, 2, 2, 3)
(2,)
[[[[ True  True False]
   [ True False  True]]

  [[False False False]
   [False False False]]]


 [[[ True False False]
   [False False False]]

  [[False False False]
   [ True  True False]]]]
(2, 2, 2, 3)
(2,)
[[[[ True  True False]
   [ True False  True]]

  [[False False False]
   [False False False]]]


 [[[ True False False]
   [False False False]]

  [[False False False]
   [ True  True False]]]]
(2, 2, 2, 3)
(2,)
[[[[ True  True False]
   [ True False  True]]

  [[False False False]
   [False False False]]]


 [[[ True False False]
   [False False False]]

  [[False False False]
   [ True  True False]]]]
(2, 2, 2, 3)
(2,)
[[[[False False  True]
   [ True False  True]]

  [[ True False  True]
   [ True  True False]]]


 [[[ True False  True]
   [False  True False]]

  [[ True  True  True]
   [False  True False]]]]
(2, 2, 2, 3)
(2,)
[[[[False False  True]
   [ True False  True]]

  [[ True False  True]
   [ True  True False]]]


 [[[ True False  True]
   [False