In [1]:
import numpy as np

from numpy.lib.stride_tricks import as_strided 
from nn import Module, Mul, Add, AsStrided, Sum, Windows2d, Flatten, Pad
from init import normal
from utils import unpad, _as_pairs

In [11]:
class Windows(Module):
    
    def __init__(self, in_channels, out_channels, kernel_shape, *, stride=1, padding=0, dilation=1):
        # Parameters
        self.in_channels = in_channels
        self.out_channels = out_channels 
        self.kernel_shape = kernel_shape
        self.stride = stride
        self.padding = padding
        self.dilation = dilation 
        
        # Modules 
        self.pad = Pad(self.padding)
        self.as_strided = None 
        
    def _shape(self, x_shape):
        outer_shape = np.subtract(x_shape[:-1], np.subtract(self.kernel_shape, 1) * self.dilation)
        outer_shape = np.ceil(np.divide(outer_shape, self.stride)).astype(int)
        return *outer_shape, self.out_channels, *self.kernel_shape, self.in_channels    
    
    def _strides(self, x_strides):
        outer_strides = np.multiply(x_strides[:-1], self.stride)
        kernel_strides = np.multiply(x_strides[:-1], self.dilation)
        return *outer_strides, 0, *kernel_strides, x_strides[-1]
    
    def forward(self, x):
        x = self.pad(x)
        self.as_strided = AsStrided(self._shape(x.shape), self._strides(x.strides))
        return self.as_strided(x)
    
    def backward(self, dy):
        dy = self.as_strided.backward(dy)
        return self.pad.backward(dy)
    

In [12]:
class Conv2d(Module):

    def __init__(self, in_channels, out_channels, kernel_shape, *, stride=1,
                 padding=0, dilation=1, bias=True):
        # Parameters
        self.weights = normal((out_channels, *kernel_shape, in_channels))
        self.bias = normal((out_channels, 1)) if bias is True else None

        # Modules
        self.input_to_windows = Windows(out_channels=out_channels, kernel_shape=kernel_shape,
                                          in_channels=in_channels, stride=stride, padding=padding,
                                          dilation=dilation)
        self.weights_to_windows = None
        self.bias_to_windows = None
        self.mul = Mul()
        self.sum = Sum(axis=(-3, -2, -1))
        self.add = Add() if bias is True else None

    def forward(self, x):
        x = self.input_to_windows(x)

        self.weights_to_windows = AsStrided(x.shape, (0, 0, *self.weights.strides), writeable=False)
        k = self.weights_to_windows(self.weights)

        y = self.mul(x, k)
        y = self.sum(y)

        if self.bias is not None:
            self.bias_to_windows = AsStrided(y.shape, (0, 0, self.bias.strides[-1]), writeable=False)
            b = self.bias_to_windows(self.bias)
            y = self.add(y, b)
        return y

    def backward(self, dy, lr=0.1):
        if self.bias is not None:
            dy, dy_b = self.add.backward(dy)

            db = self.bias_to_windows.backward(dy_b)
            self.bias -= db * lr

        dy = self.sum.backward(dy)
        dy_x, dy_w = self.mul.backward(dy)

        dw = self.weights_to_windows.backward(dy_w)
        self.weights -= dw * lr

        dx = self.input_to_windows.backward(dy_x)
        return dx

In [None]:
from nn import Div, Exp, Sum


class Softmax2(Module):
    
    def __init__(self, axis=None):
        self.exp = Exp()
        self.sum = Sum(axis=axis)
        self.div = Div()
    
    def forward(self, x):
        exp_x = self.exp(x - x.max(axis=self.axis))
        return self.div(exp_x, self.sum(exp_x))
    
    def backward(self, dy):
        """
        THIS IS DEFINITELY WRONG, NEED TO FIX 
        Also, we should probably just do it without variable axis for now. 
        
        """
        dy_exp_x, dy_sum_exp_x = self.div.backward(dy)
        dx = self.exp.backward(dy_exp_x)
        return dx

In [None]:
class Softmax(Module):
    
    def __init__(self, axis=None):
        self.exp = Exp()
        self.sum = Sum(axis=axis)
        self.div = Div()
        
    def forward(self, x):
        exp_x = self.exp(x)
        return self.div(exp_x, self.sum(exp_x))
    
    def backward(self, dy):
        dy_exp_x, dy_sum_exp_x = self.div.backward(dy)
        dx = self.exp.backward(dy_exp_x)
        return dx

In [13]:
c1 = Conv2d(channels1, channels2, kernel_shape, bias=True)
c2 = Conv2d(channels2, channels3, kernel_shape, bias=False)
f1 = Flatten()

y = f1(c2(c1(img)))
loss = np.random.normal(size=y.shape)
dy = c1.backward(c2.backward(f1.backward(loss)))

dy.shape, y.shape

In [None]:
dy.shape, y.shape

In [None]:
model = Conv2d(channels1, channels2, kernel_shape, bias=False)

In [None]:
y = model(img)

In [None]:
dy = np.random.normal(size=y.shape)

In [None]:
model.backward(dy).shape, y.shape

In [5]:
channels1 = 2
channels2 = 5
channels3 = 1

img = np.ones((8, 10, channels1))
kernel_shape = (3, 3)