In [153]:
value_error = TypeError("Value type can only hold float scalars")

def parse_value(x):
        if isinstance(x, float) or isinstance(x, int):
            x = Value(float(x))
        elif not isinstance(x, Value):
            raise value_error
        return x
    
class Value:
    def __init__(self, data, _children=()):
        if isinstance(data, float) or isinstance(data, int):
            data = float(data)
        else:
            raise value_error
        self.data = data
        self.grad = 0.0
        self._backward = lambda: None
        self._prev = set(_children)
    
    def __add__(self, other):
        other = parse_value(other)
        return Value(self.data + other.data, (self, other))
    def __radd__(self, other):
        return self + other

    def __mul__(self, other):
        other = parse_value(other)
        return Value(self.data * other.data, (self, other))
    def __rmul__(self, other):
        return self * other
    def __neg__(self):
        return Value(-self.data)
        
    def __sub__(self, other):
        return self + (-1 * other)
    def __rsub__(self, other):
        return -1 * self + other

    def __truediv__(self, other):
        other = parse_value(other)
        if other.data == 0:
            raise ZeroDivisionError()
        return Value(self.data / other.data)
    def __rtruediv__(self, other):
        if self.data == 0:
            raise ZeroDivisionError()
        return Value(other / self.data)

    def __pow__(self, other):
        other = parse_value(other)
        return Value(self.data ** other.data)
    def __rpow__(self, other):
        other = parse_value(other)
        return Value(other.data ** self.data)

    def __eq__(self, other):
        other = parse_value(other)
        return self.data == other.data
    def __req__(self, other):
        other = parse_value(other)
        return self == other
        
    def __repr__(self):
        return f"Value({self.data}, grad={self.grad})"

In [154]:
Value(2) == Value(2)

True

In [157]:
3.2 ** 13 == 3.2 ** Value(13)

True