# Backpopragation Homework

In [6]:
!jt -t monokai -T -kl

In [56]:
import torch

class MLP:
    def __init__(
        self,
        linear_1_in_features,
        linear_1_out_features,
        f_function,
        linear_2_in_features,
        linear_2_out_features,
        g_function
    ):
        """
        Args:
            linear_1_in_features: the in features of first linear layer
            linear_1_out_features: the out features of first linear layer
            linear_2_in_features: the in features of second linear layer
            linear_2_out_features: the out features of second linear layer
            f_function: string for the f function: relu | sigmoid | identity
            g_function: string for the g function: relu | sigmoid | identity
        """
        self.f_function = f_function
        self.g_function = g_function

        self.parameters = dict(
            W1 = torch.randn(linear_1_out_features, linear_1_in_features),
            b1 = torch.randn(linear_1_out_features),
            W2 = torch.randn(linear_2_out_features, linear_2_in_features),
            b2 = torch.randn(linear_2_out_features),
        )
        self.grads = dict(
            dJdW1 = torch.zeros(linear_1_out_features, linear_1_in_features),
            dJdb1 = torch.zeros(linear_1_out_features),
            dJdW2 = torch.zeros(linear_2_out_features, linear_2_in_features),
            dJdb2 = torch.zeros(linear_2_out_features),
        )

        # put all the cache value you need in self.cache
        self.cache = dict()

    def forward(self, x):
        """
        Args:
            x: tensor shape (batch_size, linear_1_in_features)
        """
        # Done: Implement the forward function
        self.cache["z0"] = x
        self.cache["s1"] = torch.matmul(self.cache["z0"], self.parameters["W1"].T) \
                        + self.parameters["b1"]

        if self.f_function == "relu":
            self.cache["z1"] = torch.relu(self.cache["s1"])
        elif self.f_function == "sigmoid":
            self.cache["z1"] = torch.sigmoid(self.cache["s1"])
        elif self.f_function == "identity":
            self.cache["z1"] = self.cache["s1"]
            
        self.cache["s2"] = torch.matmul(self.cache["z1"], self.parameters["W2"].T) \
                        + self.parameters["b2"]
        
        if self.f_function == "relu":
            self.cache["y_hat"] = torch.relu(self.cache["s2"])
        elif self.f_function == "sigmoid":
            self.cache["y_hat"] = torch.sigmoid(self.cache["s2"])
        elif self.f_function == "identity":
            self.cache["y_hat"] = self.cache["s2"]    
        
        return self.cache["y_hat"]
        
    def sigmoid_backward(self, x):
        return x * (1 - x)
    
    def relu_backward(self, x):
        return torch.where(x<=0, torch.zeros(x.shape), torch.ones(x.shape))
    
    def backward(self, dJdy_hat):
        """
        Args:
            dJdy_hat: The gradient tensor of shape (batch_size, linear_2_out_features)
        """
        # TODO: Implement the backward function
        dJdy_hat = dJdy_hat.clone()
        
        if self.g_function == "sigmoid":
            self.cache["dy_hat_ds2"] = self.sigmoid_backward(self.cache["y_hat"])
        elif self.g_function == "relu":
            self.cache["dy_hat_ds2"] = self.relu_backward(self.cache["s2"])
        elif self.g_function  == "identity":
            self.cache["dy_hat_ds2"] = torch.ones(self.cache["s2"].shape)
        
        self.cache["dJ_ds2"] = dJdy_hat * self.cache["dy_hat_ds2"]/dJdy_hat.shape[0]
        self.grads["dJdW2"] = torch.matmul(self.cache["dJ_ds2"].T, self.cache["z1"])
        self.grads["dJdb2"] = torch.matmul(self.cache["dJ_ds2"].T, 
                                          torch.ones(dJdy_hat.shape[0]))
        
        if self.g_function == "sigmoid":
            self.cache["dz1_ds1"] = self.sigmoid_backward(self.cache["z1"])
        elif self.g_function == "relu":
            self.cache["dz1_ds1"] = self.relu_backward(self.cache["s1"])
        elif self.g_function  == "identity":
            self.cache["dz1_ds1"] = torch.ones(self.cache["s1"].shape)
        
        self.cache["dJ_dz1"] = torch.matmul(self.cache["dJ_ds2"], self.parameters["W2"])
        self.cache["dJ_ds1"] = self.cache["dJ_dz1"] * self.cache["dz1_ds1"]
        
        self.grads["dJdW1"] = torch.matmul(self.cache["dJ_ds1"].T, self.cache["z0"])
        self.grads["dJdb1"] = torch.matmul(self.cache["dJ_ds1"].T, 
                                           torch.ones(self.cache["dJ_ds1"].shape[0]))
        
    def clear_grad_and_cache(self):
        for grad in self.grads:
            self.grads[grad].zero_()
        self.cache = dict()

def mse_loss(y, y_hat):
    """
    Args:
        y: the label tensor (batch_size, linear_2_out_features)
        y_hat: the prediction tensor (batch_size, linear_2_out_features)

    Return:
        J: scalar of loss
        dJdy_hat: The gradient tensor of shape (batch_size, linear_2_out_features)
    """
    # TODO: Implement the mse loss
    dJdy_hat = y - y_hat
    squaredDelta = dJdy_hat ** 2
    J = squaredDelta.mean() 
    
    # return loss, dJdy_hat
    return J, dJdy_hat

def bce_loss(y, y_hat):
    """
    Args:
        y_hat: the prediction tensor
        y: the label tensor
        
    Return:
        loss: scalar of loss
        dJdy_hat: The gradient tensor of shape (batch_size, linear_2_out_features)
    """
    # TODO: Implement the bce loss
    J = (-(y * torch.log(y_hat) + (1-y) * torch.log(1-y_hat))).mean()
    dJdy_hat = -y/y_hat + (1-y)/1-y_hat
    
    # return loss, dJdy_hat
    return J, dJdy_hat

### Test 1

In [57]:
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
# from mlp import MLP, mse_loss, bce_loss

net = MLP(
    linear_1_in_features=2,
    linear_1_out_features=20,
    f_function='relu',
    linear_2_in_features=20,
    linear_2_out_features=5,
    g_function='identity'
)
x = torch.randn(10, 2)
y = torch.randn(10, 5)

net.clear_grad_and_cache()
y_hat = net.forward(x)
J, dJdy_hat = mse_loss(y, y_hat)
net.backward(dJdy_hat)

#------------------------------------------------
# compare the result with autograd
net_autograd = nn.Sequential(
    OrderedDict([
        ('linear1', nn.Linear(2, 20)),
        ('relu', nn.ReLU()),
        ('linear2', nn.Linear(20, 5)),
    ])
)
net_autograd.linear1.weight.data = net.parameters['W1']
net_autograd.linear1.bias.data = net.parameters['b1']
net_autograd.linear2.weight.data = net.parameters['W2']
net_autograd.linear2.bias.data = net.parameters['b2']

y_hat_autograd = net_autograd(x)

J_autograd = F.mse_loss(y_hat_autograd, y)

net_autograd.zero_grad()
J_autograd.backward()

print((net_autograd.linear1.weight.grad.data - net.grads['dJdW1']).norm() < 1e-3)
print((net_autograd.linear1.bias.grad.data - net.grads['dJdb1']).norm() < 1e-3)
print((net_autograd.linear2.weight.grad.data - net.grads['dJdW2']).norm() < 1e-3)
print((net_autograd.linear2.bias.grad.data - net.grads['dJdb2']).norm()< 1e-3)
#------------------------------------------------

tensor(False)
tensor(False)
tensor(False)
tensor(False)
