In [31]:
value_error = TypeError("Value type can only hold float scalars")

def parse_value(x):
        if isinstance(x, (int, float)):
            x = Value(float(x))
        elif not isinstance(x, Value):
            raise value_error
        return x
    
class Value:
    def __init__(self, data, _children=()):
        if isinstance(data, (int, float)):
            data = float(data)
        else:
            raise value_error
        self.data = data
        self.grad = 0.0
        self._backward = lambda: None
        self._prev = set(_children)
    
    def __add__(self, other):
        other = parse_value(other)
        res = Value(self.data + other.data, (self, other))
        def _backward():
            self.grad += res.grad
            other.grad += res.grad
        res._backward = _backward
        return res
    def __radd__(self, other):
        return self + other

    def __mul__(self, other):
        other = parse_value(other)
        res = Value(self.data * other.data, (self, other))
        def _backward():
            self.grad += other.data * res.grad
            other.grad += self.data * res.grad
        res._backward = _backward
        return res
    def __rmul__(self, other):
        return self * other
    def __neg__(self):
        return self * -1
        
    def __sub__(self, other):
        return self + (-other)
    def __rsub__(self, other):
        return -self + other

    def __truediv__(self, other):
        other = parse_value(other)
        if other.data == 0:
            raise ZeroDivisionError()
        return self * other ** -1
    def __rtruediv__(self, other):
        other = parse_value(other)
        return other / self

    def tanh(self):
        pass

    def exp(self):
        pass

    def relu(self):
        pass

    def __pow__(self, other):
        if not isinstance(other, (int, float)):
            raise NotImplementedError("Value type can only be raised to int/float powers")
        res = Value(self.data ** other, (self,))
        def _backward():
            self.grad += (other * self.data ** (other - 1)) * res.grad
        res._backward = _backward
        return res
    def __rpow__(self, other):
        raise NotImplementedError("Value type can only be raised to int/float powers")

    def zero_grad(self):
        self.grad = 0.0
        
    def backward(self):
        _ordered = list()
        _visited = set()
        def add_list(curr):
            if curr not in _visited:
                _visited.add(curr)
                for p in curr._prev:
                    add_list(p)
                _ordered.append(curr)
        add_list(self)
        self.grad = 1.0
        for value in reversed(_ordered):
            value._backward()
        
    def __repr__(self):
        return f"Value({self.data}, grad={self.grad})"

In [88]:
# minimize (x-6)^4
x = Value(20)
lr = 0.001
epochs = 100000
for i in range(epochs + 1):
    f = (x - 6) ** 4
    f.backward()
    x.data -= lr * x.grad
    if i % (epochs / 10) == 0:        
        print(f"[{i}]\tx: {x.data}: f(x): {f.data}")
    x.zero_grad()

[0]    x: 9.024: f(x): 38416.0
[10000]    x: 6.111699338768181: f(x): 0.00015570018082493382
[20000]    x: 6.079019100305596: f(x): 3.899165892830442e-05
[30000]    x: 6.064528792767858: f(x): 1.7339760565493914e-05
[40000]    x: 6.055887952988646: f(x): 9.756510508045974e-06
[50000]    x: 6.0499900795218355: f(x): 6.2452909507266015e-06
[60000]    x: 6.045635947353427: f(x): 4.33753261876546e-06
[70000]    x: 6.042251647303016: f(x): 3.187036141397209e-06
[80000]    x: 6.0395234814104: f(x): 2.4402348424715804e-06
[90000]    x: 6.037263599927445: f(x): 1.9281858123114093e-06
[100000]    x: 6.035351739272228: f(x): 1.5618949703313464e-06
