In [None]:
import numpy as np
from torch import nn
import torch.nn.functional as F

In [None]:
#| export
class Module():
    
    def __call__(self, *args):
        """ 
        Whenever our classes that inherit module are called, they should be taking their arguments
        and carrying out the their respective implementation of the forward pass
        """
        self.args = args
        self.out = self.forward(*args)
        return self.out
    
    def forward(self):
        raise Exception("Forward has not been implemented.")
    
    def backward(self):
        # Call our backward function with the initial output and all arguments
        self.bwd(self.out, *args)
        
    def bwd(self):
        raise Exception("Backward has not been implemented")

In [None]:
#| export
class Linear(Module):
    
    def __init__(self, weight, bias):
        self.weight = weight
        self.bias = bias
        
    def forward(self, inputs):
        inputs @ self.weight + self.bias
    
    def bwd(self, output, inputs):
        # Using .t() here to be able to use matrix multiplication (@)
        # This will allow us to transpose the input matrix
        
        # gradient w.r.t the inputs represents how the output of the entire linear unit affects the loss
        inputs.g = self.output.g @ self.weight.t()
        # gradient of the loss funtion w.r.t output * weights (since weights are the gradient of output w.r.t to weights)
        self.weight.g = self.output.g @ self.inputs.t()
        # gradient of L w.r.t output * 1 since gradient of output w.r.t bias is 1
        self.bias.g = self.output.g.sum()

In [None]:
#| export
class Linear(nn.Module):
    
    def __init__(self, n_in, n_out):
        super().__init__()
        self.w = torch.randn(n_in,n_out).requires_grad_()
        self.b = torch.zeros(n_out).requires_grad_()
        
    def forward(self, inp): 
        return inp@self.w + self.b

In [None]:
#| export
class Model(nn.Module):
    
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [Linear(n_in,nh), nn.ReLU(), Linear(nh,n_out)]
        
    def __call__(self, x, targ):
        for l in self.layers: 
            x = l(x)
        return F.mse_loss(x, targ[:,None])