In [33]:
from enum import Enum, auto

class Operator(Enum):
    ADD = auto()
    MUL = auto()
    POW = auto()
    NEG = auto()

    def __str__(self):
        match self:
            case Operator.ADD:
                return '+'
            case Operator.MUL:
                return '*'
            case Operator.POW:
                return '^'
            case Operator.NEG:
                return '-'

In [40]:
from graphviz import Digraph
from IPython.display import display

def visualize(tensor, graph='compute graph'):
    dot = Digraph(comment=graph)
    added_edges = set()

    def add_node(x):
        fillcolor = 'white' if x.requires_grad else 'lightcoral'
        if graph == 'compute graph':
            dot.node(str(id(x)), f"{x.data:.3f}", style='filled', fillcolor=fillcolor)
        else:
            dot.node(str(id(x)), f"{x.grad:.3f}", style='filled', fillcolor=fillcolor)
    # add_node = lambda x: 
    add_op_node = lambda x: dot.node(str(id(x)), str(x.type))

    def add_edge(start,end):
        x = str(id(start))
        y = str(id(end))
        edge = (x,y)
        if edge not in added_edges:
            dot.edges([edge])
            added_edges.add(edge)


    def make_graph(t: Tensor):
        add_node(t)
        if t.op:
            add_op_node(t.op)
            if graph == 'compute graph':
                add_edge(t.op, t)
            else: 
                add_edge(t, t.op)

            if t.op.left:
                make_graph(t.op.left)
                if graph == 'compute graph':
                    add_edge(t.op.left, t.op)
                else: 
                    add_edge(t.op, t.op.left)

            if t.op.right:
                make_graph(t.op.right)
                if graph =='compute graph':
                    add_edge(t.op.right, t.op)
                else:
                    add_edge(t.op, t.op.right)
    make_graph(tensor)
    display(dot)

In [41]:
from dataclasses import dataclass
class Op:
    def __init__(self, type: str, left, right=None):
        self.type = type
        self.left = left
        self.right = right

In [257]:
import math
@dataclass
class Tensor:
    data: float
    op: Op | None = None
    grad: float = 0.0
    requires_grad: bool = True
    # no_grad: bool = False

    @staticmethod
    def ensure_tensor(value, requires_grad=False):
        return value if isinstance(value, Tensor) else Tensor(value, requires_grad=requires_grad)

    def validate_arg(fn):
        def wrapper(self, other):
            other = Tensor.ensure_tensor(other, requires_grad=False)
            return fn(self, other)
        return wrapper

    @validate_arg
    def __add__(self, other):
        n = self.data + other.data
        op = Op(Operator.ADD, self, other)
        return Tensor(n, op)
    def __radd__(self, other):
        return self + other
    @validate_arg
    def __mul__(self, other):
        n = self.data * other.data
        op = Op(Operator.MUL, self, other)
        return Tensor(n, op)
    
    def __rmul__(self, other):
        return self * other
    
    @validate_arg
    def __pow__(self, power):
        try:
            n = self.data ** power.data
        except ZeroDivisionError:
            n = float('nan')
        op = Op(Operator.POW, self, power)
        return Tensor(n, op)
    
    @validate_arg
    def __rpow__(self, base):
        return base ** self
    
    @validate_arg
    def __sub__(self, other):
        other.data = -other.data
        return self + other
    
    @validate_arg
    def __truediv__(self, other):
        return self * other ** -1 
    
    def __rtruediv__(self, other):
        return other * self ** -1

    def __neg__(self):
        data = -self.data
        op = Op(Operator.NEG, self)
        return Tensor(data, op)

    def backward(self):
        self.grad += 1.0
        self._backward()

    def _backward(self):
        op = self.op
        if not op:
            return
        match op.type:
            case Operator.ADD:
                op.left.grad += self.grad
                op.right.grad += self.grad

            case Operator.MUL:
                op.left.grad += self.grad * op.right.data 
                op.right.grad += self.grad * op.left.data

            case Operator.POW:
                n = op.right.data
                op.left.grad += self.grad * n * op.left.data ** (n-1)

                base = op.left.data
                op.right.grad += self.grad * self.data * math.log(base)
            case Operator.NEG:
                assert op.right is None, "Unary Operation can't have operands"
                op.left.grad += -self.grad


        if op.left:
            op.left._backward()
        if op.right:
            op.right._backward()
    

@dataclass
class Optimizer:
    params: list[Tensor]
    lr: float= 0.01

    def step(self):
        for param in self.params:
            if param.requires_grad:
                param.data = param.data - self.lr * param.grad

    def zero_grad(self):
        for t in self.params:
            t.grad = 0.0

In [352]:
x = [(0,0), (0,1), (1,0), (1,1)]
y = [0, 1, 1, 0]

In [353]:
sigmoid = lambda x: 1/(1+math.e**-x)

In [354]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i+n]

In [355]:
from random import random
class Linear:
    def __init__(self, in_features:int , out_features: int, bias=True):
        self.input = in_features
        self.output = out_features
        self.bias = bias

        self.weights = [Tensor(random()) for _ in range(self.input * self.output)]
        if bias:
            self.bias = [Tensor(random()) for _ in range(self.output)]
        else: 
            self.bias = [0]*len(self.output) # on the assumption that raw int value converted to Tensor are always with requires grad= False

    def __call__(self, x: list):
        return self.forward(x)

    def forward(self, x: list):
        assert len(x[0]) == self.input, f"Got {len(x)} features, expected {self.input}"

        output = []
        dot_prod = lambda x, w, b: sum([xi * wi for xi, wi in zip(x,w)], b)

        # the loop un-batches the input -- gets single input point at a time
        for inp in x: 
            weight = chunks(self.weights, self.input) 
            out = [dot_prod(inp, w, b) for w, b in zip(weight, self.bias)]
            output.append(out)
        return output
    def params(self):
        return self.weights, self.bias