In [4]:
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view

class MaxPool2d:
    def __init__(self, pool_height, pool_width):
        self.pool_height = pool_height
        self.pool_width = pool_width
        
    def forward(self, X):
        batch_size, in_channels, in_height, in_width = X.shape
        
        out_height = in_height // self.pool_height
        out_width = in_width // self.pool_width
        
        X = sliding_window_view(X, (self.pool_height, self.pool_width), axis=(2, 3)).reshape(
            batch_size,
            in_channels,
            out_height,
            self.pool_height,
            out_width,
            self.pool_width,
        )
        
        return np.max(X, axis=(3, 5))
    
    def backward(self, X, dL_dY, lr):
        batch_size, in_channels, in_height, in_width = X.shape
        
        out_height = in_height // self.pool_height
        out_width = in_width // self.pool_width
        
        X = sliding_window_view(X, (self.pool_height, self.pool_width), axis=(2, 3)).reshape(
            batch_size,
            in_channels,
            out_height,
            self.pool_height,
            out_width,
            self.pool_width,
        )
        
        dL_dY = dL_dY[:, :, :, np.newaxis, :, np.newaxis]
        mask = X == np.max(X, axis=(3, 5), keepdims=True)
        mask = mask.astype(np.float32)
        
        return mask * dL_dY

In [5]:
maxpool = MaxPool2d(2, 2)
X = np.random.randn(1, 1, 4, 4)
print(X)
print(maxpool.forward(X))

[[[[ 0.50882453  0.5201186   0.39930726 -0.52588223]
   [-0.99224041 -0.92532551  0.67849861 -0.55676403]
   [-0.04568105 -1.07491862  0.91999961  1.19673476]
   [-0.98182725  1.06632788  0.0884675   2.48598056]]]]


ValueError: cannot reshape array of size 36 into shape (1,1,2,2,2,2)