In [1]:
import numpy as np
from tqdm import tqdm

In [2]:
def np_conv2d_grad(dout, cache):
    """
    A naive implementation of the backward pass for a convolutional layer.

    Inputs:
    - dout: Upstream derivatives.
    - cache: A tuple of (x, w, conv_param)

    Returns a tuple of:
    - dx: Gradient with respect to x
    - dw: Gradient with respect to w
    """
    
    dx, dw = None, None
    
    x, w, conv_param = cache
    pad = conv_param['pad'] # always 0
    assert pad == 0, "Current implementation only supports pad = 0"
    stride = conv_param['stride'] # always 1
    assert stride == 1, "Current implementation only supports stride = 1"
    
    dx = np.zeros_like(x)
    dw = np.zeros_like(w)
    
    N, C, H, W = x.shape
    M, _, R, S = w.shape
    _, _, HO, WO = dout.shape
    
    for n in range(N):      
        for m in range(M):  
            for i in range(HO):
                for j in range(WO):
                    for r in range(R):
                        for s in range(S):
                            for c in range(C): 
                                dw[m,c,r,s] += x[n,c,stride*i+r,stride*j+s] * dout[n,m,i,j]
                  
    # both works
    for n in range(N):      
        for m in range(M):  
            for i in range(HO):
                for j in range(WO):
                    for r in range(R):
                        for s in range(S):
                            for c in range(C): 
                                dx[n,c,stride*i+r,stride*j+s] += w[m,c,r,s] * dout[n,m,i,j]
                                
    # for n in range(N):
    #     for m in range(M):
    #         for i in range(HO):
    #             for j in range(WO):
    #                 h1 = i * stride
    #                 h2 = i * stride + R
    #                 w1 = j * stride
    #                 w2 = j * stride + S
    #                 dx[n, :, h1:h2, w1:w2] += w[m,:,:,:] * dout[n,m,i,j]
    
    return dx, dw


In [3]:
x = np.random.randn(2, 3, 5, 5)
w = np.random.randn(3, 3, 3, 3)
dout = np.random.randn(2, 3, 3, 3)
conv_param = {'stride': 1, 'pad': 0}

dx, dw = np_conv2d_grad(dout, (x, w, conv_param))

import torch
# convert numpy array to torch tensor
x = torch.from_numpy(x)
w = torch.from_numpy(w)
dout = torch.from_numpy(dout)

grad_input = torch.nn.grad.conv2d_input(x.shape, w, dout, stride=1, padding=0)
grad_weight = torch.nn.grad.conv2d_weight(x, w.shape, dout, stride=1, padding=0)

# assert whether grad_input == dx, etc.
print(np.allclose(grad_input, dx))
print(np.allclose(grad_weight, dw))

True
True
