In [1]:
import torch

We want to implement the pytorch `.backward()` function such that we get the same gradients for x1, x2, w1, and w2

In [58]:
x1 = torch.Tensor([2.0]).double()                ; x1.requires_grad = True
x2 = torch.Tensor([0.0]).double()                ; x2.requires_grad = True
w1 = torch.Tensor([-3.0]).double()               ; w1.requires_grad = True
w2 = torch.Tensor([1.0]).double()                ; w2.requires_grad = True
b = torch.Tensor([6.8813735870195432]).double()  ; b.requires_grad = True
n = x1*w1 + x2*w2 + b
o = torch.tanh(n)

print(o.data.item())
o.backward()

print('---')
print('x2', x2.grad.item())
print('w2', w2.grad.item())
print('x1', x1.grad.item())
print('w1', w1.grad.item())
print('b', b.grad.item())

0.7071066904050358
---
x2 0.5000001283844369
w2 0.0
x1 -1.5000003851533106
w1 1.0000002567688737
b 0.5000001283844369


In [63]:
import math
from typing import Union
class Value():
    def __init__(self, data, children=(), op=""):
        self.data = data
        self.grad = 0
        self.children = set(children)
        self.op = op # the operation that resulted in this Value
        self._backward = lambda: None
    
    def __add__(self, other: Union[int,"Value"]):
        other = Value(other) if isinstance(other, int) else other
        out = Value(self.data + other.data, (self, other), "+")

        def _backward():
            self.grad += 1.0 * out.grad # must be += because the same variable can be used multiple times thus this is cumulative
            other.grad += 1.0 * out.grad
        self._backward = _backward

        return out
    
    def __mul__(self, other: Union[int,"Value"]):
        other = Value(other) if isinstance(other, int) else other
        out = Value(self.data * other.data, (self, other), "*")
        
        def _backward():
            self.grad += other.data * out.grad # must be += because the same variable can be used multiple times thus this is cumulative
            other.grad += self.data * out.grad
        self._backward = _backward

        return out
    
    def tanh(self):
        t = (math.exp(2*self.data) - 1)/(math.exp(2*self.data) + 1)
        out = Value(t, (self, ), "tanh")

        def _backward():
            self.grad += out.grad * (1 - t**2)
        self._backward = _backward

        return out
    
    def backward(self):
        # topological sort
        visited = set()
        topo_sorted = []
        to_visit = [self]
        
        while to_visit:
            node = to_visit.pop()
            visited.add(node)
            topo_sorted.append(node)
            for child in node.children:
                if child not in visited:
                    to_visit.append(child)

        for node in topo_sorted:
            node._backward()
        


In [64]:
x1 = Value(2)
x2 = Value(0)
w1 = Value(-3)
w2 = Value(1)
b = Value(6.8813735870195432)
n = x1*w1 + x2*w2 + b
o = n.tanh()

o.grad = 1.0
o.backward()

In [65]:
print('x2', x2.grad)
print('w2', w2.grad)
print('x1', x1.grad)
print('w1', w1.grad)
print('b', b.grad)

x2 0
w2 0
x1 -1.4999999999999996
w1 0.9999999999999998
b 0.4999999999999999
