# Dense Layer from Scratch
Forward + backward pass as basics

Notes:
- TODO: conv2D from scratch, and then attention layer from scratch

In [None]:
"""
20 mins to implement

Questions:
- What to do if training diverges?
1. the learning rate might be too high, test lower learning rates
"""

import torch
import torch.nn as nn
torch.manual_seed(0)

class FeedForwardNetwork(nn.Module):

    def __init__(self, D, expansion=4):
        super().__init__()

        H = D * expansion
        self.W1_DH = torch.randn(D, H)
        self.b1_H = torch.randn(H)

        self.act1 = nn.ReLU()

        self.W2_HD = torch.randn(H, D)
        self.b2_D = torch.randn(D)

        self.cache = None
        self.grads = {}

    def forward(self, X_BND):
        """
        Time and Memory Complexity Analysis:
            Assumptions:
            1. expansion=1 so h = d

            Total FLOPs: O(bndh) = O(bnd^2)
            Memory: O(bnd + d^2)

            ratio of memory to FLOPs = O(1/d + 1/bn)
            So if we want to decrease the ratio (because usually we have more FLOPs per memory bytes), we should increase d or bn
        """

        Z1_BNH = torch.einsum("bnd,dh->bnh", X_BND, self.W1_DH) + self.b1_H
        H1_BNH = self.act1(Z1_BNH)
        Y_BND = torch.einsum("bnh,hd->bnd", H1_BNH, self.W2_HD) + self.b2_D

        self.cache = (X_BND, Z1_BNH, H1_BNH)
        return Y_BND

    def backward(self, dY_BND):
        """
        Assume we have a dY_BND = \delta_{loss} / \delta_{Y_BND} which is computed from the derivative of the differentiable loss function
        """
        X_BND, Z1_BNH, H1_BNH = self.cache

        dW2_HD = torch.einsum("bnd,bnh->hd", dY_BND, H1_BNH)
        db2_D = torch.einsum("bnd->d", dY_BND)
        dH1_BNH = torch.einsum("bnd,hd->bnh", dY_BND, self.W2_HD)

        # derivative of ReLU
        dZ1_BNH = dH1_BNH * (Z1_BNH > 0).to(dH1_BNH.dtype)

        dW1_DH = torch.einsum("bnh,bnd->dh", dZ1_BNH, X_BND)
        db1_H = torch.einsum("bnh->h", dZ1_BNH)
        dX_BND = torch.einsum("bnh,dh->bnd", dZ1_BNH, self.W1_DH)

        self.grads = {
            "dW1_DH": dW1_DH,
            "db1_H": db1_H,
            "dW2_HD": dW2_HD,
            "db2_D": db2_D,
        }
        
        return dX_BND

    def step(self, lr=1e-3):
        # update parameters to reduce the loss
        with torch.no_grad():
            self.W1_DH -= self.grads["dW1_DH"] * lr
            self.b1_H -= self.grads["db1_H"] * lr
            self.W2_HD -= self.grads["dW2_HD"] * lr
            self.b2_D -= self.grads["db2_D"] * lr

    def _test_forward_shapes(self):
        B, N, D = 2, 3, 4
        X_BND = torch.randn(B, N, D)
        Y_BND = self.forward(X_BND)
        assert Y_BND.shape == (B, N, D)

    def _test_forward_values(self):
        B, N, D = 2, 3, 4
        X_BND = torch.zeros(B, N, D)

        # set biases to zero
        self.b1_H.zero_()
        self.b2_D.zero_()

        Y_BND = self.forward(X_BND)

        assert torch.allclose(Y_BND, torch.zeros_like(Y_BND))

    def _test_backward_shapes(self):
        B, N, D = 2, 3, 4
        X_BND = torch.randn(B, N, D)
        Y_BND = self.forward(X_BND)
        assert Y_BND.shape == (B, N, D)

        # assume a basic MSE
        target = torch.randn(B, N, D)
        mse_loss = ((Y_BND - target) ** 2).mean()
        dloss = 2 * (Y_BND - target) / (B * N * D)
        dY_BND = dloss

        dX_BND = self.backward(dY_BND)
        assert dX_BND.shape == (B, N, D)

    def _test_backward_values(self):
        B, N, D = 2, 3, 4
        X_BND = torch.zeros(B, N, D)

        # set biases to zero
        self.b1_H.zero_()
        self.b2_D.zero_()

        Y_BND = self.forward(X_BND)

        assert Y_BND.shape == (B, N, D)
        assert torch.allclose(Y_BND, torch.zeros_like(Y_BND))

        # assume a basic MSE
        target = torch.zeros(B, N, D)
        mse_loss = ((Y_BND - target) ** 2).mean()
        dloss = 2 * (Y_BND - target) / (B * N * D)
        dY_BND = dloss

        dX_BND = self.backward(dY_BND)
        assert dX_BND.shape == (B, N, D)
        assert torch.allclose(dX_BND, torch.zeros_like(dX_BND))

    def _test_step(self):
        B, N, D = 2, 3, 4
        X_BND = torch.zeros(B, N, D)

        # set biases to zero
        self.b1_H.zero_()
        self.b2_D.zero_()

        Y_BND = self.forward(X_BND)

        assert Y_BND.shape == (B, N, D)
        assert torch.allclose(Y_BND, torch.zeros_like(Y_BND))

        # assume a basic MSE
        target = torch.zeros(B, N, D)
        mse_loss = ((Y_BND - target) ** 2).mean()
        dloss = 2 * (Y_BND - target) / (B * N * D)
        dY_BND = dloss

        dX_BND = self.backward(dY_BND)
        assert dX_BND.shape == (B, N, D)
        assert torch.allclose(dX_BND, torch.zeros_like(dX_BND))

        # optimizer step
        prev_W1_DH = self.W1_DH.clone()
        prev_b1_H = self.b1_H.clone()
        prev_W2_HD = self.W2_HD.clone()
        prev_b2_D = self.b2_D.clone()
        
        self.step()
        assert torch.allclose(prev_W1_DH, self.W1_DH)
        assert torch.allclose(prev_b1_H, self.b1_H)
        assert torch.allclose(prev_W2_HD, self.W2_HD)
        assert torch.allclose(prev_b2_D, self.b2_D)

    
ffn = FeedForwardNetwork(D = 4)
ffn._test_forward_shapes()
ffn._test_forward_values()
ffn._test_backward_shapes()
ffn._test_backward_values()
ffn._test_step()