In [4]:
import numpy as np

In [6]:
class WindowsNd(Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, *, stride=1, padding=0, dilation=1):
        # Parameters
        self.in_channels = in_channels
        self.out_channels = out_channels 
        self.kernel_size = kernel_size 
        self.stride = stride
        self.padding = padding
        self.dilation = dilation 
        
        # Modules 
        self.pad = Pad(self.padding)
        self.as_strided = None 
        
    def _shape(self, x_shape):
        outer_shape = np.subtract(x_shape[:-1], np.subtract(self.kernel_shape, 1) * self.dilation)
        outer_shape = np.ceil(np.divide(outer_shape, self.stride)).astype(int)
        return *outer_shape, self.out_channels, *self.kernel_shape, self.in_channels    
    
    def _strides(self, x_strides):
        outer_strides = np.multiply(x_strides[:-1], self.stride)
        kernel_strides = np.multiply(x_strides[:-1], self.dilation)
        return *outer_strides, 0, *kernel_strides, x_strides[-1]
    
    def forward(self, x):
        x = self.pad(x)
        self.as_strided = AsStrided(self._shape(x.shape), self._strides(x.strides))
        return self.as_strided(x)
    
    def backward(self, dy):
        dy = self.as_strided.backward(dy)
        return self.pad.backward(dy)
    

In [None]:
w = WindowsNd()

In [2]:
from nn import Module

In [3]:
from utils import unpad

class Pad(Module):
    
    def __init__(self, padding):
        self.padding = padding 
        
    def forward(self, x):
        return np.pad(x, self.padding)
    
    def backward(self, dy):
        return unpad(dy, self.padding)

In [None]:
class Windows2d(Module):

    def __init__(self, in_channels, out_channels, kernel_shape, *, stride=1,
                 padding=((0, 0), (0, 0), (0, 0)), dilation=1):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_shape = kernel_shape
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.as_strided = None

    def compute_shape(self, x_shape):
        outer_shape = np.subtract(x_shape[:-1], np.subtract(self.kernel_shape, 1) * self.dilation)
        outer_shape = np.ceil(np.divide(outer_shape, self.stride)).astype(int)
        return *outer_shape, self.out_channels, *self.kernel_shape, self.in_channels

    def compute_strides(self, x_strides):
        outer_strides = np.multiply(x_strides[:-1], self.stride)
        kernel_strides = np.multiply(x_strides[:-1], self.dilation)
        return *outer_strides, 0, *kernel_strides, x_strides[-1]

    def forward(self, x):
        x = np.pad(x, self.padding)
        strides = self.compute_strides(x.strides)
        shape = self.compute_shape(x.shape)
        self.as_strided = AsStrided(shape, strides, writeable=False)
        return self.as_strided(x)

    def backward(self, dy):
        dy = self.as_strided.backward(dy)
        dx = unpad(dy, self.padding)
        return dx
