In [670]:
import numpy as np

from nn import Module, Pad
from abc import ABC, abstractmethod

In [690]:
class Pool(Module):
    
    def __init__(self, channels, kernel_shape, *, stride=1, padding=0, dilation=1):
        # parameters
        self.x = None
        self.pool = None # I dislike having to use abstract property to similate abstract attribute
        
        # modules 
        self.input_to_windows = PoolWindows(channels=channels, kernel_shape=kernel_shape, stride=stride, 
                                           padding=padding, dilation=dilation)
        
    def forward(self, x):
        self.x = x 
        x = self.input_to_windows(x)
        return self.pool(x)
    
    def backward(self, dy):
        dy = self.pool.backward(dy)
        return self.input_to_windows.backward(dy)

In [691]:
class MaxPool(Pool):
    
    def __init__(self, channels, kernel_shape, *, stride=1, padding=0, dilation=1):
        super().__init__(channels=channels, kernel_shape=kernel_shape, stride=stride, padding=padding,
                         dilation=dilation)
        self.pool = Max(axis=tuple(range(-len(kernel_shape), 0)))

In [692]:
class AvgPool(Pool):
    
    def __init__(self, channels, kernel_shape, *, stride=1, padding=0, dilation=1):
        super().__init__(channels=channels, kernel_shape=kernel_shape, stride=stride, padding=padding,
                         dilation=dilation)
        self.pool = Mean(axis=tuple(range(-len(kernel_shape), 0)))

In [693]:
from numpy.lib.stride_tricks import as_strided 
from nn import Div 


def _fill_reduced_axis(axis, orig=1, fill=0, ndim=None, dtype=int):
    """
    Perhaps could be done cleaner, but for now implement in Sum and Max. Soon implement in Mean as well. 
    """
    axis = np.asarray(axis)
    if ndim is None:
        ndim = len(orig) + axis.size
    axis %= ndim
    axis_complement = np.setdiff1d(np.arange(ndim), axis)
    ret = np.zeros(ndim, dtype=dtype)
    ret[axis] = fill
    ret[axis_complement] = orig
    return ret


class Max(Module):
    
    def __init__(self, axis=None):
        self.axis = axis
        self.x = None
        
    def forward(self, x):
        self.x = x 
        return np.max(x, axis=self.axis)
    
    def backward(self, dy):
        dy = as_strided(dy, self.x.shape, _fill_reduced_axis(self.axis, dy.strides, 0))
        # TODO: save max computation early and just reshape it here 
        mask = np.max(self.x, axis=self.axis, keepdims=True) == self.x
        dx = np.where(mask, dy, 0)
        return dx
    

class Sum(Module):

    def __init__(self, axis=None):
        self.axis = axis
        self.x = None

    def forward(self, x):
        self.x = x
        return np.sum(x, axis=self.axis)

    def backward(self, dy):
        return as_strided(
            dy, 
            self.x.shape, 
            _fill_reduced_axis(self.axis, dy.strides, 0), 
            writeable=False
        )
    
class Mean(Module):
    
    def __init__(self, axis=None):
        self.axis = axis 
        self.sum = Sum(axis=axis)
        self.div = Div()
        
    def _size_of_axis(self, x_shape):
        foo = sum(map(x_shape.__getitem__, self.axis))
        return foo
        
    def forward(self, x):
        return self.div(self.sum(x), self._size_of_axis(x.shape))
    
    def backward(self, dy):
        dy, _ = self.div.backward(dy)
        dx = self.sum.backward(dy)
        return dx

In [847]:
from math import prod

img_shape = (5, 5)

kernel_shape=(3, 3)
mp = MaxPool(1, kernel_shape)
ap = AvgPool(1, kernel_shape)
img = np.arange(prod(img_shape)).reshape(*img_shape, 1).astype(float)
mp(img), ap(img)
dy = np.ones((3, 3, 1))
ap.backward(dy).reshape(*img_shape), mp.backward(dy).reshape(*img_shape)

(array([[0.16666667, 0.33333333, 0.5       , 0.33333333, 0.16666667],
        [0.33333333, 0.66666667, 1.        , 0.66666667, 0.33333333],
        [0.5       , 1.        , 1.5       , 1.        , 0.5       ],
        [0.33333333, 0.66666667, 1.        , 0.66666667, 0.33333333],
        [0.16666667, 0.33333333, 0.5       , 0.33333333, 0.16666667]]),
 array([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 0., 1., 1., 1.]]))

In [843]:
class AsStrided(Module):
    """
    For some reason breaks when input was created using np.arange(...).reshape(...), but works in all other
    cases tested.
    """

    def __init__(self, shape, strides, *, writeable: bool = False):
        self.shape = shape
        self.strides = strides
        self.writeable = writeable
        self.x = None

    def forward(self, x):
        self.x = x
        return as_strided(x, self.shape, self.strides, writeable=self.writeable)

    def backward(self, dy):
        """
        The input array to forward i.e. self.x is destroyed when backward is called. This can be 
        remedied by passing in a stride function instead of a value, but it is not implemented this 
        way for now. 
        """
        # dx is an alias for self.x due to technical reasons, ideally it would not be so
        self.x.fill(dy.dtype.type(0))
        dx = self.x
        dx_strided = as_strided(dx, self.shape, self.strides, writeable=True)
        for i in np.ndindex(dy.shape):
            dx_strided[i] += dy[i]
        return dx

class Windows(Module):
    
    def __init__(self, kernel_shape, *, stride=1, padding=0, dilation=1):
        # Parameters
        self.kernel_shape = kernel_shape
        self.stride = stride
        self.padding = padding
        self.dilation = dilation 
        
        # Modules 
        self.pad = Pad(self.padding)
        self.as_strided = None 
        
    def _outer_shape(self, x_shape):
        # adjust for kernel shape and dilation 
        outer_shape = np.subtract(x_shape[:-1], np.subtract(self.kernel_shape, 1) * self.dilation)
        
        # adjust for stride 
        return np.ceil(np.divide(outer_shape, self.stride)).astype(int)
    
    def _outer_strides(self, x_strides):
        return np.multiply(x_strides[:-1], self.stride)
    
    def _kernel_strides(self, x_strides):
        return np.multiply(x_strides[:-1], self.dilation)
    
    @abstractmethod 
    def _shape(self, x_shape):
        """
        Returns the shape of the window view. Utilize _outer_shape and kernel_shape.
        """
        
    @abstractmethod 
    def _strides(self, x_strides):
        """
        Returns the strides of the window view. Utilize _outer_strides and _kernel_strides
        """
    
    def forward(self, x):
        x = self.pad(x)
        self.as_strided = AsStrided(self._shape(x.shape), self._strides(x.strides))
        return self.as_strided(x)
    
    def backward(self, dy):
        dy = self.as_strided.backward(dy)
        return self.pad.backward(dy)
    

class PoolWindows(Windows):
    
    def __init__(self, channels, kernel_shape, *, stride=1, padding=0, dilation=1):
        super().__init__(kernel_shape=kernel_shape, stride=stride, padding=padding, dilation=dilation)
        self.channels = channels 
    
    def _shape(self, x_shape):
        return *self._outer_shape(x_shape=x_shape), self.channels, *self.kernel_shape
    
    def _strides(self, x_strides):
        return *self._outer_strides(x_strides=x_strides), x_strides[-1], \
    *self._kernel_strides(x_strides=x_strides)

    
class ConvWindows(Windows):
    
    def __init__(self, in_channels, out_channels, kernel_shape, *, stride=1, padding=0, dilation=1):
        super().__init__(kernel_shape=kernel_shape, stride=stride, padding=padding, dilation=dilation)
        self.in_channels = in_channels
        self.out_channels = out_channels 
        
    def _shape(self, x_shape):
        return *self._outer_shape(x_shape=x_shape), self.out_channels, *self.kernel_shape, self.in_channels
    
    def _strides(self, x_strides):
        return *self._outer_strides(x_strides=x_strides), 0, *self._kernel_strides(x_strides=x_strides), \
    x_strides[-1]