In [17]:
import torch
import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(-2 * np.pi, 2 * np.pi, 1000)
y = np.cos(x)

x_train = torch.tensor(x, dtype=torch.float32).unsqueeze(1)
y_train = torch.tensor(y, dtype=torch.float32).unsqueeze(1)

class SimpleMLP:
    def __init__(self):
        self.w1 = torch.randn(1, 100, dtype=torch.float32) * 0.1
        self.w1.requires_grad_()
        self.b1 = torch.zeros(100, dtype=torch.float32)
        self.b1.requires_grad_()
        self.w2 = torch.randn(100, 100, dtype=torch.float32) * 0.1
        self.w2.requires_grad_()
        self.b2 = torch.zeros(100, dtype=torch.float32)
        self.b2.requires_grad_()
        self.w3 = torch.randn(100, 1, dtype=torch.float32) * 0.1
        self.w3.requires_grad_()
        self.b3 = torch.zeros(1, dtype=torch.float32)
        self.b3.requires_grad_()

    def forward(self, x):
        self.z1 = torch.matmul(x, self.w1) + self.b1
        self.a1 = torch.tanh(self.z1)
        self.z2 = torch.matmul(self.a1, self.w2) + self.b2
        self.a2 = torch.tanh(self.z2)
        self.z3 = torch.matmul(self.a2, self.w3) + self.b3
        return self.z3

model = SimpleMLP()

def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

y_pred = model.forward(x_train)
loss = ((y_pred - y_train) ** 2).mean()
loss.backward()

grad_z3 = 2 * (y_pred - y_train) / y_train.size(0)
grad_w3 = torch.matmul(model.a2.t(), grad_z3)
cmp('grad_w3', grad_w3, model.w3)
grad_b3 = grad_z3.sum(0)
cmp('grad_b3', grad_b3, model.b3)

grad_a2 = torch.matmul(grad_z3, model.w3.t())
grad_z2 = grad_a2 * (1 - model.a2 ** 2)
grad_w2 = torch.matmul(model.a1.t(), grad_z2)
cmp('grad_w2', grad_w2, model.w2)

grad_b2 = grad_z2.sum(0)
cmp('grad_b2', grad_b2, model.b2)


grad_a1 = torch.matmul(grad_z2, model.w2.t())
grad_z1 = grad_a1 * (1 - model.a1 ** 2)
grad_w1 = torch.matmul(x_train.t(), grad_z1)
cmp('grad_w1', grad_w1, model.w1)

grad_b1 = grad_z1.sum(0)
cmp('grad_w1', grad_w1, model.w1)


grad_w3         | exact: False | approximate: True  | maxdiff: 4.470348358154297e-08
grad_b3         | exact: False | approximate: True  | maxdiff: 9.546056389808655e-09
grad_w2         | exact: False | approximate: True  | maxdiff: 2.2351741790771484e-08
grad_b2         | exact: False | approximate: True  | maxdiff: 1.1175870895385742e-08
grad_w1         | exact: False | approximate: True  | maxdiff: 8.940696716308594e-08
grad_w1         | exact: False | approximate: True  | maxdiff: 8.940696716308594e-08
