## Numerical Gradient

In [5]:
import torch
import numpy as np

In [8]:
def compute_numeric_gradient(f, x, dLdf=None, h=1e-7):
    
    flat_x = x.contiguous().flatten()
    grad = torch.zeros_like(x)
    flat_grad = grad.flatten()

    # Initialize upstream gradient to be ones if not provide
    if dLdf is None:
        y = f(x)
        dLdf = torch.ones_like(y)
    dLdf = dLdf.flatten()

    # iterate over all indexes in x
    for i in range(flat_x.shape[0]):
        oldval = flat_x[i].item()  # Store the original value
        flat_x[i] = oldval + h  # Increment by h
        fxph = f(x).flatten()  # Evaluate f(x + h)
        flat_x[i] = oldval - h  # Decrement by h
        fxmh = f(x).flatten()  # Evaluate f(x - h)
        flat_x[i] = oldval  # Restore original value

        # compute the partial derivative with centered formula
        dfdxi = (fxph - fxmh) / (2 * h)

        # use chain rule to compute dLdx
        flat_grad[i] = dLdf.dot(dfdxi).item()
        #flat_grad[i] = dLdf.float().dot(dfdxi).item()

    # Note that since flat_grad was only a reference to grad,
    # we can just return the object in the shape of x by returning grad
    return grad


def rel_error(x, y, eps=1e-10):

    top = (x - y).abs().max().item()
    bot = (x.abs() + y.abs()).clamp(min=eps).max().item()
    return top / bot


# Forward

In [19]:
class ConvFward(object):
    def forward(x, w, b, conv_param):
        out = None
        num_train, Channel, H_x, W_x = x.shape
        num_f, Channel, H_f, W_f = w.shape
        hi = 0
        wi = 0
        stride = conv_param['stride']
        pad = conv_param['pad']

        H_out = 1 + (H_x + 2 * pad - H_f) // stride #2
        W_out = 1 + (W_x + 2 * pad - W_f) // stride

        p2d = (pad, pad, pad, pad) # pad last dim by (pad, pad) and 2nd to last by (pad, pad)
        x_pad = torch.nn.functional.pad(x, p2d, "constant", 0)
        out = torch.zeros(num_train, num_f, H_out, W_out).to(x.dtype)

        for k in range(num_train):
            for i in range(num_f):
                for hi in range(H_out):
                    step_h = hi * stride
                    for wi in range(W_out):
                        step_w = wi * stride
                        sample = x_pad[k, :, step_h:(step_h+H_f), step_w:(step_w+W_f)]

                        out[k, i, hi, wi] = torch.sum(sample * w[i,]) + b[i]

        cache = (x, w, b, conv_param)
        return out, cache

# Backprop

In [43]:
class ConvBprop(object):    
    def backward(dout, cache):
        
        num_train, num_dout, H_dout, W_dout = dout.shape
        x, w, b, conv_param = cache
        num_train, Channel, H_x, W_x = x.shape         
        dx, dw, db = None, None, None
        dx, dw, db = torch.zeros(x.shape), torch.zeros(w.shape), torch.zeros(b.shape)
        num_f, Channel, H_f, W_f = w.shape
        pad = conv_param['pad']
        stride = conv_param['stride']

        p2d = (pad, pad, pad, pad) # pad last dim by (pad, pad) and 2nd to last by (pad, pad)
        dout_pad = torch.nn.functional.pad(dout, p2d, "constant", 0)

        x_pad = torch.nn.functional.pad(x, p2d, "constant", 0).to(x.dtype).to(x.device)
        dx_pad = torch.zeros(x_pad.shape).to(x.dtype).to(x.device)
        w_rot = w.rot90(2, [2, 3])
        _, _, H_wrot, W_wrot = w_rot.shape

        H_dw = (H_x + 2 * pad - H_dout) // stride + 1
        W_dw = (W_x + 2 * pad - W_dout) // stride + 1
        H_dx = (H_dout + 2 * pad - H_wrot) // stride + 1
        W_dx = (W_dout + 2 * pad - W_wrot) // stride + 1
       
        for k in range(num_train):
            for fi in range(num_f):
                for hi in range(H_dout):
                    step_h = hi * stride
                    for wi in range(W_dout):
                        step_w = wi * stride
                        sample = x_pad[k, :, step_h:(step_h + H_f), step_w:(step_w + W_f)]
                        dw[fi,] += sample * dout[k, fi, hi, wi]
                        dx_pad[k, :, step_h:(step_h+H_f), step_w:(step_w+W_f)] += w[fi,] * dout[k, fi, hi, wi]

        dx = dx_pad[:, :, pad:(pad+H_x), pad:(pad+W_x)]
        
        # another method                
#         for k in range(num_train):
#             for i in range(num_f):
#                 step_h_dw, step_h_dx = 0, 0
#                 for hi_dw in range(H_dw):
#                     step_w_dw = 0
#                     for wi_dw in range(W_dw):
#                         for c in range(Channel):
#                             sample_x = x_pad[k, c, step_h_dw:(step_h_dw + H_dout),step_w_dw:(step_w_dw + W_dout)]
#                             # dW = X conv. dout
#                             dw[i, c, hi_dw, wi_dw] += torch.sum(sample_x * dout[k, i,])
#                         step_w_dw += stride
#                     step_h_dw += stride
#                 #print(H_dx)
#                 for hi_dx in range(H_dx):
#                     #print(hi_dx)
#                     step_w_dx = 0
#                     for wi_dx in range(W_dx):
#                         for c in range(Channel):            
#                             sample_dout = dout_pad[k,i,step_h_dx:(step_h_dx+H_wrot),step_w_dx:(step_w_dx+W_wrot)]

#                             # dX = dout_pad conv. W_rot
#                             dx[k, c, hi_dx, wi_dx] += torch.sum(sample_dout * w_rot[i, c,])
#                         step_w_dx += stride
#                     step_h_dx += stride
                    
        db = dout.sum(dim=3).sum(dim=2).sum(dim=0)

        
        
        
        
        return dx, dw, db

## check

In [38]:
# too large!!!
# x = torch.randn(10, 3, 31, 31, dtype=torch.float64, device='cpu')
# w = torch.randn(25, 3, 3, 3, dtype=torch.float64, device='cpu')
# b = torch.randn(25, dtype=torch.float64, device='cpu')
# dout = torch.randn(10, 25, 16, 16, dtype=torch.float64, device='cpu')
# #x_cuda, w_cuda, b_cuda, dout_cuda = x.to('cuda'), w.to('cuda'), b.to('cuda'), dout.to('cuda')
# conv_param = {'stride': 2, 'pad': 1}

In [45]:
dx_num = compute_numeric_gradient(lambda x: ConvFward.forward(x, w, b, conv_param)[0], x, dout)
dw_num = compute_numeric_gradient(lambda w: ConvFward.forward(x, w, b, conv_param)[0], w, dout)
db_num = compute_numeric_gradient(lambda b: ConvFward.forward(x, w, b, conv_param)[0], b, dout)

out, cache = ConvFward.forward(x, w, b, conv_param)
dx, dw, db = ConvBprop.backward(dout, cache)

print('Testing Conv.backward function')
print('dx error: ', rel_error(dx, dx_num))
print('dw error: ', rel_error(dw, dw_num))
print('db error: ', rel_error(db, db_num))

Testing Conv.backward function
dx error:  5.6149568876129245e-09
dw error:  4.2332766033218214e-08
db error:  1.3408573952983589e-10


In [44]:
x_shape = torch.tensor((2, 3, 4, 4))
w_shape = torch.tensor((5, 3, 4, 4))
x = torch.linspace(-0.1, 0.5, steps=torch.prod(x_shape), dtype=torch.float64, device='cpu').reshape(*x_shape)
w = torch.linspace(-0.2, 0.3, steps=torch.prod(w_shape), dtype=torch.float64, device='cpu').reshape(*w_shape)
b = torch.linspace(-0.1, 0.2, steps=5, dtype=torch.float64, device='cpu')

conv_param = {"stride": 2, "pad": 1}
out, cache = ConvFward.forward(x, w, b, conv_param)
dout = torch.randn(2, 5, 2, 2, dtype=torch.float64, device='cpu')
dx, dw, db = ConvBprop.backward(dout, cache)
print(dx)

tensor([[[[-0.1606, -0.5360, -0.5324, -0.3892],
          [-0.4095, -0.6521, -0.6530, -0.2448],
          [-0.4088, -0.6555, -0.6563, -0.2489],
          [-0.3314, -0.1663, -0.1708,  0.1766]],

         [[-0.0486, -0.4780, -0.4743, -0.4431],
          [-0.4065, -0.6654, -0.6663, -0.2612],
          [-0.4057, -0.6688, -0.6696, -0.2653],
          [-0.4404, -0.2377, -0.2421,  0.2142]],

         [[ 0.0634, -0.4199, -0.4163, -0.4971],
          [-0.4034, -0.6788, -0.6796, -0.2776],
          [-0.4027, -0.6821, -0.6829, -0.2817],
          [-0.5493, -0.3091, -0.3135,  0.2518]]],


        [[[-0.0272, -0.4024, -0.4007, -0.3654],
          [-0.3586, -0.6414, -0.6417, -0.2801],
          [-0.3630, -0.6426, -0.6429, -0.2769],
          [-0.3144, -0.2616, -0.2636,  0.0457]],

         [[-0.0614, -0.3754, -0.3737, -0.3042],
          [-0.3761, -0.6462, -0.6465, -0.2674],
          [-0.3804, -0.6474, -0.6477, -0.2643],
          [-0.2977, -0.2934, -0.2954, -0.0029]],

         [[-0.0957, -0.3484,

In [52]:
[1,2,3,4]

[1, 2, 3, 4]