In [2]:
import numpy as np
from collections import defaultdict
from typing import Callable, Any, Tuple, Optional, Dict, Iterator, Union

grad_tracking_enabled = True

def wrap_forward_function(func):
    def new_function(*args, **kwargs):
        arguments = tuple([a for a in args])
        result = func(*args, **kwargs)
        requires_grad = grad_tracking_enabled and any([isinstance(a, Parameter) and a.requires_grad for a in args])
        if requires_grad:
            result.parents = arguments
            result.func = new_function  
            result.kwargs = kwargs
            result.requires_grad = True
        return result
    new_function.__name__ = func.__name__
    return new_function


class Parameter:
    def __init__(
        self,
        array: Optional[np.ndarray] = None,
        requires_grad: bool = False,
        parents: Optional[Tuple['Parameter', ...]] = None,
        func: Optional[Callable] = None,
        kwargs: Optional[dict] = None,
    ):
        self.array = array if array is not None else np.array(0.0)
        self.grad: Optional[np.ndarray] = None
        self.requires_grad = requires_grad
        self.parents = parents if parents is not None else ()
        self.func = func
        self.kwargs = kwargs if kwargs is not None else {}

    # Adjusted methods to handle scalars
    def __add__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(np.array(other), requires_grad=False)
        return add(self, other)

    def __radd__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(np.array(other), requires_grad=False)
        return add(other, self)

    def __sub__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(np.array(other), requires_grad=False)
        return subtract(self, other)

    def __rsub__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(np.array(other), requires_grad=False)
        return subtract(other, self)

    def __mul__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(np.array(other), requires_grad=False)
        return multiply(self, other)

    def __rmul__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(np.array(other), requires_grad=False)
        return multiply(other, self)

    def __truediv__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(np.array(other), requires_grad=False)
        return divide_op(self, other)

    def __rtruediv__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(np.array(other), requires_grad=False)
        return divide_op(other, self)


    def __pow__(self, power: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(power, Parameter):
            power = Parameter(np.array(power), requires_grad=False)
        return power_func(self, power)

    def __neg__(self) -> 'Parameter':
        return negate(self)

    def __matmul__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            raise ValueError("Cannot perform matmul with non-Parameter type")
        return matmul(self, other)

    def __rmatmul__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            raise ValueError("Cannot perform matmul with non-Parameter type")
        return matmul(other, self)

    def exp(self) -> 'Parameter':
        return exp(self)

    def log(self) -> 'Parameter':
        return log(self)

    def sin(self) -> 'Parameter':
        return sin(self)

    def cos(self) -> 'Parameter':
        return cos(self)

    def tanh(self) -> 'Parameter':
        return tanh(self)

    def sigmoid(self) -> 'Parameter':
        return sigmoid(self)

    def relu(self) -> 'Parameter':
        return relu(self)

    def sqrt(self) -> 'Parameter':
        return sqrt(self)

    def abs(self) -> 'Parameter':
        return abs_func(self)

    def reshape(self, *shape) -> 'Parameter':
        return reshape(self, shape)

    def sum(self, axis=None, keepdims=False) -> 'Parameter':
        return sum_func(self, axis=axis, keepdims=keepdims)

    @property
    def T(self) -> 'Parameter':
        return transpose(self)

    def backward(self, grad: Optional[np.ndarray] = None):
        if not self.requires_grad:
            return
        if grad is None:
            grad = np.ones_like(self.array)
        if self.grad is None:
            self.grad = grad
        else:
            self.grad += grad
        topo_order = self._topological_sort()
        for tensor in reversed(topo_order):
            if tensor.func is None:
                continue
            for idx, parent in enumerate(tensor.parents):
                backward_func = lookup.get_backward_function(tensor.func, idx)
                parent_grad = backward_func(*tensor.parents, tensor.grad, **tensor.kwargs)
                if parent.requires_grad and parent_grad is not None:
                    if parent.grad is None:
                        parent.grad = parent_grad
                    else:
                        parent.grad += parent_grad

    def _topological_sort(self) -> list:
        visited = set()
        topo_order = []
        def dfs(tensor: 'Parameter'):
            if tensor in visited:
                return
            visited.add(tensor)
            for parent in tensor.parents:
                dfs(parent)
            topo_order.append(tensor)
        dfs(self)
        return topo_order

    def zero_grad(self):
        self.grad = None

    def __repr__(self) -> str:
        return f"Parameter(shape={self.array.shape}, requires_grad={self.requires_grad})"

class BackwardLookupTable:
    def __init__(self):
        self.table: Dict[Callable, Dict[int, Callable]] = defaultdict(dict)

    def add_element(self, forward_function: Callable, position: int, backward_func: Callable):
        self.table[forward_function][position] = backward_func

    def get_backward_function(self, forward_function: Callable, position: int) -> Callable:
        if forward_function not in self.table or position not in self.table[forward_function]:
            raise KeyError(f"No backward function found for {forward_function} at position {position}")
        return self.table[forward_function][position]

lookup = BackwardLookupTable()

def _broadcast_backward(grad_out, target_shape):
    grad = grad_out
    while len(grad.shape) > len(target_shape):
        grad = grad.sum(axis=0)
    for axis, size in enumerate(target_shape):
        if size == 1:
            grad = grad.sum(axis=axis, keepdims=True)
    return grad


@wrap_forward_function
def add(x, y):
    return Parameter(array=x.array + y.array)

def add_back0(x, y, grad_out):
    grad = _broadcast_backward(grad_out, x.array.shape)
    return grad

def add_back1(x, y, grad_out):
    grad = _broadcast_backward(grad_out, y.array.shape)
    return grad

lookup.add_element(add, 0, add_back0)
lookup.add_element(add, 1, add_back1)


@wrap_forward_function
def subtract(x, y):
    return Parameter(array=x.array - y.array)

def subtract_back0(x, y, grad_out):
    grad = _broadcast_backward(grad_out, x.array.shape)
    return grad

def subtract_back1(x, y, grad_out):
    grad = -_broadcast_backward(grad_out, y.array.shape)
    return grad

lookup.add_element(subtract, 0, subtract_back0)
lookup.add_element(subtract, 1, subtract_back1)


@wrap_forward_function
def multiply(x, y):
    return Parameter(array=x.array * y.array)

def multiply_back0(x, y, grad_out):
    grad = grad_out * y.array
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

def multiply_back1(x, y, grad_out):
    grad = grad_out * x.array
    grad = _broadcast_backward(grad, y.array.shape)
    return grad

lookup.add_element(multiply, 0, multiply_back0)
lookup.add_element(multiply, 1, multiply_back1)


@wrap_forward_function
def divide_op(x, y):
    return Parameter(array=x.array / y.array)

def divide_back0(x, y, grad_out):
    grad = grad_out / y.array
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

def divide_back1(x, y, grad_out):
    if y.requires_grad:
        grad = -grad_out * x.array / (y.array ** 2 + 1e-12)
        grad = _broadcast_backward(grad, y.array.shape)
        return grad
    else:
        return None  

lookup.add_element(divide_op, 0, divide_back0)
lookup.add_element(divide_op, 1, divide_back1)




@wrap_forward_function
def power_func(x, y):
    return Parameter(array=x.array ** y.array)

def power_back0(x, y, grad_out):
    grad = grad_out * y.array * (x.array ** (y.array - 1))
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

def power_back1(x, y, grad_out):
    grad = grad_out * (x.array ** y.array) * np.log(x.array + 1e-12)
    grad = _broadcast_backward(grad, y.array.shape)
    return grad

lookup.add_element(power_func, 0, power_back0)
lookup.add_element(power_func, 1, power_back1)


@wrap_forward_function
def negate(x):
    return Parameter(array=-x.array)

def negate_back(x, grad_out):
    grad = -grad_out
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

lookup.add_element(negate, 0, negate_back)


@wrap_forward_function
def matmul(x, y):
    return Parameter(array=x.array @ y.array)

def matmul_back0(x, y, grad_out):
    grad = grad_out @ y.array.T
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

def matmul_back1(x, y, grad_out):
    grad = x.array.T @ grad_out
    grad = _broadcast_backward(grad, y.array.shape)
    return grad

lookup.add_element(matmul, 0, matmul_back0)
lookup.add_element(matmul, 1, matmul_back1)


@wrap_forward_function
def exp(x):
    return Parameter(array=np.exp(x.array))

def exp_back(x, grad_out):
    grad = grad_out * np.exp(x.array)
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

lookup.add_element(exp, 0, exp_back)


@wrap_forward_function
def log(x):
    return Parameter(array=np.log(x.array + 1e-12))

def log_back(x, grad_out):
    grad = grad_out / (x.array + 1e-12)
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

lookup.add_element(log, 0, log_back)

@wrap_forward_function
def sum_func(x, axis=None, keepdims=False):
    return Parameter(array=np.sum(x.array, axis=axis, keepdims=keepdims))

def sum_back(x, grad_out, axis=None, keepdims=False):
    grad = np.broadcast_to(grad_out, x.array.shape)
    return grad

lookup.add_element(sum_func, 0, sum_back)


@wrap_forward_function
def max_func(x, axis=None, keepdims=False):
    return Parameter(array=np.max(x.array, axis=axis, keepdims=keepdims))

def max_back(x, grad_out, axis=None, keepdims=False):
    grad = np.zeros_like(x.array)
    max_vals = np.max(x.array, axis=axis, keepdims=True)
    mask = (x.array == max_vals)
    grad += mask * grad_out
    grad = _broadcast_backward(grad, x.array.shape)
    return grad


lookup.add_element(max_func, 0, max_back)


@wrap_forward_function
def transpose(x):
    return Parameter(array=x.array.T)

def transpose_back(x, grad_out):
    return grad_out.T

lookup.add_element(transpose, 0, transpose_back)


@wrap_forward_function
def relu(x):
    return Parameter(array=np.maximum(0, x.array))

def relu_back(x, grad_out):
    grad = grad_out.copy()
    grad[x.array <= 0] = 0
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

lookup.add_element(relu, 0, relu_back)

def softmax(x):
    x_max = max_func(x, axis=1, keepdims=True)
    exp_x = exp(x - x_max)
    sum_exp_x = sum_func(exp_x, axis=1, keepdims=True)
    return exp_x / sum_exp_x


def one_hot(targets, num_classes):
    targets_array = targets.array.astype(int)
    one_hot_array = np.zeros((len(targets_array), num_classes))
    one_hot_array[np.arange(len(targets_array)), targets_array] = 1
    return Parameter(one_hot_array, requires_grad=False)

def cross_entropy_loss(predictions, targets):
    probs = softmax(predictions)
    N = predictions.array.shape[0]
    target_probs = one_hot(targets, num_classes=predictions.array.shape[1])
    epsilon = 1e-12
    log_probs = log(probs + epsilon)
    loss = -sum_func(target_probs * log_probs) / N
    return loss


class Module:
    def __init__(self):
        self._modules = {}
        self._parameters = {}

    def modules(self):
        return self.__dict__["_modules"].values()

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        for param in self.__dict__["_parameters"].values():
            yield param
        if recurse:
            for module in self.modules():
                yield from module.parameters(recurse=True)

    def __setattr__(self, key: str, val: Any) -> None:
        if isinstance(val, Parameter):
            self.__dict__.setdefault("_parameters", {})[key] = val
        elif isinstance(val, Module):
            self.__dict__.setdefault("_modules", {})[key] = val
        else:
            super().__setattr__(key, val)

    def __getattr__(self, key: str) -> Union[Parameter, "Module"]:
        if "_parameters" in self.__dict__ and key in self.__dict__["_parameters"]:
            return self.__dict__["_parameters"][key]
        if "_modules" in self.__dict__ and key in self.__dict__["_modules"]:
            return self.__dict__["_modules"][key]
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")

class Linear(Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.weight = Parameter(np.random.randn(out_features, in_features) * np.sqrt(2. / in_features), requires_grad=True)
        if bias:
            self.bias = Parameter(np.zeros(out_features), requires_grad=True)
        else:
            self.bias = None

    def forward(self, x: Parameter) -> Parameter:
        output = x @ self.weight.T
        if self.bias is not None:
            output = output + self.bias
        return output

class ReLU(Module):
    def forward(self, x: Parameter) -> Parameter:
        return relu(x)


class NoGrad:
    def __enter__(self):
        global grad_tracking_enabled
        self.prev_state = grad_tracking_enabled
        grad_tracking_enabled = False

    def __exit__(self, exc_type, exc_value, traceback):
        global grad_tracking_enabled
        grad_tracking_enabled = self.prev_state


np.random.seed(0)
num_samples = 1000
input_size = 784  
num_classes = 10
X_train = Parameter(np.random.randn(num_samples, input_size), requires_grad=False)
y_train = Parameter(np.random.randint(0, num_classes, size=(num_samples,)), requires_grad=False)
class SimpleNN(Module):
    def __init__(self):
        super().__init__()
        self.fc1 = Linear(784, 128)
        self.relu = ReLU()
        self.fc2 = Linear(128, 10)

    def forward(self, x):
        x = self.fc1.forward(x)
        x = self.relu.forward(x)
        x = self.fc2.forward(x)
        return x


model = SimpleNN()


learning_rate = 0.01
num_epochs = 50
losses = []
for epoch in range(num_epochs):
   
    logits = model.forward(X_train)
    loss = cross_entropy_loss(logits, y_train)
    losses.append(loss.array)
    
    loss.backward()

    
    for param in model.parameters():
        if param.requires_grad:
            param.array -= learning_rate * param.grad
            param.zero_grad()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.array}")


Epoch 1/50, Loss: 3.101780981301624
Epoch 2/50, Loss: 3.070237892668103
Epoch 3/50, Loss: 3.0410785745979196
Epoch 4/50, Loss: 3.0139881115115594
Epoch 5/50, Loss: 2.9887099435893743
Epoch 6/50, Loss: 2.965022850172187
Epoch 7/50, Loss: 2.942739397965661
Epoch 8/50, Loss: 2.921701695983401
Epoch 9/50, Loss: 2.9017753837332325
Epoch 10/50, Loss: 2.882835483290364
Epoch 11/50, Loss: 2.8647814339150677
Epoch 12/50, Loss: 2.84752813978086
Epoch 13/50, Loss: 2.8309983348151766
Epoch 14/50, Loss: 2.815122943483556
Epoch 15/50, Loss: 2.79984457761786
Epoch 16/50, Loss: 2.7851056965208616
Epoch 17/50, Loss: 2.770865343749279
Epoch 18/50, Loss: 2.7570832613457616
Epoch 19/50, Loss: 2.743720992507786
Epoch 20/50, Loss: 2.7307420688488
Epoch 21/50, Loss: 2.718117621467126
Epoch 22/50, Loss: 2.705823372609964
Epoch 23/50, Loss: 2.69383065371097
Epoch 24/50, Loss: 2.682119267444408
Epoch 25/50, Loss: 2.6706679663958455
Epoch 26/50, Loss: 2.659460033543948
Epoch 27/50, Loss: 2.648479581156597
Epoch 

In [5]:
import numpy as np
from collections import defaultdict
from typing import Callable, Any, Tuple, Optional, Dict, Iterator, Union

#########################
# Array Abstraction
#########################

Arr = np.ndarray  # You can change this if you have a custom array type

def as_arr(data, dtype=None) -> Arr:
    return np.array(data, dtype=dtype)

def ones_like(x: Arr) -> Arr:
    return np.ones_like(x)

def zeros_like(x: Arr) -> Arr:
    return np.zeros_like(x)

def sum_arr(x: Arr, axis=None, keepdims=False) -> Arr:
    return x.sum(axis=axis, keepdims=keepdims)

def max_arr(x: Arr, axis=None, keepdims=False) -> Arr:
    return x.max(axis=axis, keepdims=keepdims)

def exp_arr(x: Arr) -> Arr:
    return np.exp(x)

def log_arr(x: Arr) -> Arr:
    return np.log(x)

def transpose_arr(x: Arr) -> Arr:
    return x.T

def reshape_arr(x: Arr, shape) -> Arr:
    return x.reshape(shape)

def matmul_arr(x: Arr, y: Arr) -> Arr:
    return x @ y

def abs_arr(x: Arr) -> Arr:
    return np.abs(x)

def broadcast_to(x: Arr, shape) -> Arr:
    return np.broadcast_to(x, shape)

def randn(*shape) -> Arr:
    return np.random.randn(*shape)

def randint(low, high, size) -> Arr:
    return np.random.randint(low, high, size=size)

grad_tracking_enabled = True

#########################
# Wrap forward functions
#########################

def wrap_forward_function(func):
    def new_function(*args, **kwargs):
        arguments = tuple([a for a in args])
        result = func(*args, **kwargs)
        requires_grad = grad_tracking_enabled and any([isinstance(a, Parameter) and a.requires_grad for a in args])
        if requires_grad:
            result.parents = arguments
            result.func = new_function  
            result.kwargs = kwargs
            result.requires_grad = True
        return result
    new_function.__name__ = func.__name__
    return new_function

#########################
# Parameter Class
#########################

class Parameter:
    def __init__(
        self,
        array: Optional[Arr] = None,
        requires_grad: bool = False,
        parents: Optional[Tuple['Parameter', ...]] = None,
        func: Optional[Callable] = None,
        kwargs: Optional[dict] = None,
    ):
        self.array = array if array is not None else as_arr(0.0)
        self.grad: Optional[Arr] = None
        self.requires_grad = requires_grad
        self.parents = parents if parents is not None else ()
        self.func = func
        self.kwargs = kwargs if kwargs is not None else {}

    def __add__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(as_arr(other), requires_grad=False)
        return add(self, other)

    def __radd__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(as_arr(other), requires_grad=False)
        return add(other, self)

    def __sub__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(as_arr(other), requires_grad=False)
        return subtract(self, other)

    def __rsub__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(as_arr(other), requires_grad=False)
        return subtract(other, self)

    def __mul__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(as_arr(other), requires_grad=False)
        return multiply(self, other)

    def __rmul__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(as_arr(other), requires_grad=False)
        return multiply(other, self)

    def __truediv__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(as_arr(other), requires_grad=False)
        return divide_op(self, other)

    def __rtruediv__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            other = Parameter(as_arr(other), requires_grad=False)
        return divide_op(other, self)

    def __pow__(self, power: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(power, Parameter):
            power = Parameter(as_arr(power), requires_grad=False)
        return power_func(self, power)

    def __neg__(self) -> 'Parameter':
        return negate(self)

    def __matmul__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            raise ValueError("Cannot perform matmul with non-Parameter type")
        return matmul(self, other)

    def __rmatmul__(self, other: Union['Parameter', float, int]) -> 'Parameter':
        if not isinstance(other, Parameter):
            raise ValueError("Cannot perform matmul with non-Parameter type")
        return matmul(other, self)

    def exp(self) -> 'Parameter':
        return exp(self)

    def log(self) -> 'Parameter':
        return log(self)

    def sin(self) -> 'Parameter':
        return sin(self)

    def cos(self) -> 'Parameter':
        return cos(self)

    def tanh(self) -> 'Parameter':
        return tanh(self)

    def sigmoid(self) -> 'Parameter':
        return sigmoid(self)

    def relu(self) -> 'Parameter':
        return relu(self)

    def sqrt(self) -> 'Parameter':
        return sqrt(self)

    def abs(self) -> 'Parameter':
        return abs_func(self)

    def reshape(self, *shape) -> 'Parameter':
        return reshape(self, shape)

    def sum(self, axis=None, keepdims=False) -> 'Parameter':
        return sum_func(self, axis=axis, keepdims=keepdims)

    @property
    def T(self) -> 'Parameter':
        return transpose(self)

    def backward(self, grad: Optional[Arr] = None):
        if not self.requires_grad:
            return
        if grad is None:
            grad = ones_like(self.array)
        if self.grad is None:
            self.grad = grad
        else:
            self.grad += grad
        topo_order = self._topological_sort()
        for tensor in reversed(topo_order):
            if tensor.func is None:
                continue
            for idx, parent in enumerate(tensor.parents):
                backward_func = lookup.get_backward_function(tensor.func, idx)
                parent_grad = backward_func(*tensor.parents, tensor.grad, **tensor.kwargs)
                if parent.requires_grad and parent_grad is not None:
                    if parent.grad is None:
                        parent.grad = parent_grad
                    else:
                        parent.grad += parent_grad

    def _topological_sort(self) -> list:
        visited = set()
        topo_order = []
        def dfs(tensor: 'Parameter'):
            if tensor in visited:
                return
            visited.add(tensor)
            for parent in tensor.parents:
                dfs(parent)
            topo_order.append(tensor)
        dfs(self)
        return topo_order

    def zero_grad(self):
        self.grad = None

    def __repr__(self) -> str:
        return f"Parameter(shape={self.array.shape}, requires_grad={self.requires_grad})"



class BackwardLookupTable:
    '''
    A backpropagation helper lookup table
    Is a dictionary from the fucntion to the dictionary of positon - funciton pairs
    '''
    def __init__(self):
        self.table: Dict[Callable, Dict[int, Callable]] = defaultdict(dict)

    def add_element(self, forward_function: Callable, position: int, backward_func: Callable):
        self.table[forward_function][position] = backward_func

    def get_backward_function(self, forward_function: Callable, position: int) -> Callable:
        if forward_function not in self.table or position not in self.table[forward_function]:
            raise KeyError(f"No backward function found for {forward_function} at position {position}")
        return self.table[forward_function][position]

lookup = BackwardLookupTable()



def _broadcast_backward(grad_out, target_shape):
    grad = grad_out
    while len(grad.shape) > len(target_shape):
        grad = sum_arr(grad, axis=0)
    for axis, size in enumerate(target_shape):
        if size == 1:
            grad = sum_arr(grad, axis=axis, keepdims=True)
    return grad



@wrap_forward_function
def add(x, y):
    return Parameter(array=x.array + y.array)

def add_back0(x, y, grad_out):
    grad = _broadcast_backward(grad_out, x.array.shape)
    return grad

def add_back1(x, y, grad_out):
    grad = _broadcast_backward(grad_out, y.array.shape)
    return grad

lookup.add_element(add, 0, add_back0)
lookup.add_element(add, 1, add_back1)


@wrap_forward_function
def subtract(x, y):
    return Parameter(array=x.array - y.array)

def subtract_back0(x, y, grad_out):
    grad = _broadcast_backward(grad_out, x.array.shape)
    return grad

def subtract_back1(x, y, grad_out):
    grad = -_broadcast_backward(grad_out, y.array.shape)
    return grad

lookup.add_element(subtract, 0, subtract_back0)
lookup.add_element(subtract, 1, subtract_back1)


@wrap_forward_function
def multiply(x, y):
    return Parameter(array=x.array * y.array)

def multiply_back0(x, y, grad_out):
    grad = grad_out * y.array
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

def multiply_back1(x, y, grad_out):
    grad = grad_out * x.array
    grad = _broadcast_backward(grad, y.array.shape)
    return grad

lookup.add_element(multiply, 0, multiply_back0)
lookup.add_element(multiply, 1, multiply_back1)


@wrap_forward_function
def divide_op(x, y):
    return Parameter(array=x.array / y.array)

def divide_back0(x, y, grad_out):
    grad = grad_out / y.array
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

def divide_back1(x, y, grad_out):
    if y.requires_grad:
        grad = -grad_out * x.array / (y.array ** 2 + 1e-12)
        grad = _broadcast_backward(grad, y.array.shape)
        return grad
    else:
        return None

lookup.add_element(divide_op, 0, divide_back0)
lookup.add_element(divide_op, 1, divide_back1)


@wrap_forward_function
def power_func(x, y):
    return Parameter(array=x.array ** y.array)

def power_back0(x, y, grad_out):
    grad = grad_out * y.array * (x.array ** (y.array - 1))
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

def power_back1(x, y, grad_out):
    grad = grad_out * (x.array ** y.array) * log_arr(x.array + 1e-12)
    grad = _broadcast_backward(grad, y.array.shape)
    return grad

lookup.add_element(power_func, 0, power_back0)
lookup.add_element(power_func, 1, power_back1)


@wrap_forward_function
def negate(x):
    return Parameter(array=-x.array)

def negate_back(x, grad_out):
    grad = -grad_out
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

lookup.add_element(negate, 0, negate_back)


@wrap_forward_function
def matmul(x, y):
    return Parameter(array=matmul_arr(x.array, y.array))

def matmul_back0(x, y, grad_out):
    grad = matmul_arr(grad_out, transpose_arr(y.array))
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

def matmul_back1(x, y, grad_out):
    grad = matmul_arr(transpose_arr(x.array), grad_out)
    grad = _broadcast_backward(grad, y.array.shape)
    return grad

lookup.add_element(matmul, 0, matmul_back0)
lookup.add_element(matmul, 1, matmul_back1)


@wrap_forward_function
def exp(x):
    return Parameter(array=exp_arr(x.array))

def exp_back(x, grad_out):
    grad = grad_out * exp_arr(x.array)
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

lookup.add_element(exp, 0, exp_back)


@wrap_forward_function
def log(x):
    return Parameter(array=log_arr(x.array + 1e-12))

def log_back(x, grad_out):
    grad = grad_out / (x.array + 1e-12)
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

lookup.add_element(log, 0, log_back)


@wrap_forward_function
def sum_func(x, axis=None, keepdims=False):
    return Parameter(array=sum_arr(x.array, axis=axis, keepdims=keepdims))

def sum_back(x, grad_out, axis=None, keepdims=False):
    grad = broadcast_to(grad_out, x.array.shape)
    return grad

lookup.add_element(sum_func, 0, sum_back)


@wrap_forward_function
def max_func(x, axis=None, keepdims=False):
    return Parameter(array=max_arr(x.array, axis=axis, keepdims=keepdims))

def max_back(x, grad_out, axis=None, keepdims=False):
    grad = zeros_like(x.array)
    max_vals = max_arr(x.array, axis=axis, keepdims=True)
    mask = (x.array == max_vals)
    grad += mask * grad_out
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

lookup.add_element(max_func, 0, max_back)


@wrap_forward_function
def transpose(x):
    return Parameter(array=transpose_arr(x.array))

def transpose_back(x, grad_out):
    return grad_out.T

lookup.add_element(transpose, 0, transpose_back)


@wrap_forward_function
def relu(x):
    return Parameter(array=np.maximum(0, x.array))

def relu_back(x, grad_out):
    grad = grad_out.copy()
    grad[x.array <= 0] = 0
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

lookup.add_element(relu, 0, relu_back)


@wrap_forward_function
def abs_func(x):
    return Parameter(array=abs_arr(x.array))

def abs_back(x, grad_out):
    grad = grad_out * np.sign(x.array)
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

lookup.add_element(abs_func, 0, abs_back)


@wrap_forward_function
def sqrt(x):
    return Parameter(array=np.sqrt(x.array + 1e-12))

def sqrt_back(x, grad_out):
    grad = grad_out * (1.0 / (2.0 * np.sqrt(x.array + 1e-12)))
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

lookup.add_element(sqrt, 0, sqrt_back)


@wrap_forward_function
def sin(x):
    return Parameter(array=np.sin(x.array))

def sin_back(x, grad_out):
    grad = grad_out * np.cos(x.array)
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

lookup.add_element(sin, 0, sin_back)


@wrap_forward_function
def cos(x):
    return Parameter(array=np.cos(x.array))

def cos_back(x, grad_out):
    grad = -grad_out * np.sin(x.array)
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

lookup.add_element(cos, 0, cos_back)


@wrap_forward_function
def tanh(x):
    val = np.tanh(x.array)
    return Parameter(array=val)

def tanh_back(x, grad_out):
    val = np.tanh(x.array)
    grad = grad_out * (1 - val**2)
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

lookup.add_element(tanh, 0, tanh_back)


@wrap_forward_function
def sigmoid(x):
    val = 1 / (1 + np.exp(-x.array))
    return Parameter(array=val)

def sigmoid_back(x, grad_out):
    val = 1 / (1 + np.exp(-x.array))
    grad = grad_out * val * (1 - val)
    grad = _broadcast_backward(grad, x.array.shape)
    return grad

lookup.add_element(sigmoid, 0, sigmoid_back)


@wrap_forward_function
def reshape(x, shape):
    return Parameter(array=reshape_arr(x.array, shape))

def reshape_back(x, grad_out, shape):
    return reshape_arr(grad_out, x.array.shape)

lookup.add_element(reshape, 0, reshape_back)



def softmax(x):
    x_max = max_func(x, axis=1, keepdims=True)
    exp_x = exp(x - x_max)
    sum_exp_x = sum_func(exp_x, axis=1, keepdims=True)
    return exp_x / sum_exp_x

def one_hot(targets, num_classes):
    targets_array = targets.array.astype(int)
    one_hot_array = np.zeros((len(targets_array), num_classes))
    one_hot_array[np.arange(len(targets_array)), targets_array] = 1
    return Parameter(one_hot_array, requires_grad=False)

def cross_entropy_loss(predictions, targets):
    probs = softmax(predictions)
    N = predictions.array.shape[0]
    target_probs = one_hot(targets, num_classes=predictions.array.shape[1])
    epsilon = 1e-12
    log_probs = log(probs + epsilon)
    loss = -sum_func(target_probs * log_probs) / N
    return loss



class Module:
    def __init__(self):
        self._modules = {}
        self._parameters = {}

    def modules(self):
        return self.__dict__["_modules"].values()

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        for param in self.__dict__["_parameters"].values():
            yield param
        if recurse:
            for module in self.modules():
                yield from module.parameters(recurse=True)

    def __setattr__(self, key: str, val: Any) -> None:
        if isinstance(val, Parameter):
            self.__dict__.setdefault("_parameters", {})[key] = val
        elif isinstance(val, Module):
            self.__dict__.setdefault("_modules", {})[key] = val
        else:
            super().__setattr__(key, val)

    def __getattr__(self, key: str) -> Union[Parameter, "Module"]:
        if "_parameters" in self.__dict__ and key in self.__dict__["_parameters"]:
            return self.__dict__["_parameters"][key]
        if "_modules" in self.__dict__ and key in self.__dict__["_modules"]:
            return self.__dict__["_modules"][key]
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")

class Linear(Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.weight = Parameter(randn(out_features, in_features) * np.sqrt(2. / in_features), requires_grad=True)
        if bias:
            self.bias = Parameter(as_arr(np.zeros(out_features)), requires_grad=True)
        else:
            self.bias = None

    def forward(self, x: Parameter) -> Parameter:
        output = x @ self.weight.T
        if self.bias is not None:
            output = output + self.bias
        return output

class ReLU(Module):
    def forward(self, x: Parameter) -> Parameter:
        return relu(x)

class NoGrad:
    def __enter__(self):
        global grad_tracking_enabled
        self.prev_state = grad_tracking_enabled
        grad_tracking_enabled = False

    def __exit__(self, exc_type, exc_value, traceback):
        global grad_tracking_enabled
        grad_tracking_enabled = self.prev_state



np.random.seed(0)
num_samples = 1000
input_size = 784  
num_classes = 10
X_train = Parameter(randn(num_samples, input_size), requires_grad=False)
y_train = Parameter(randint(0, num_classes, size=(num_samples,)), requires_grad=False)

class SimpleNN(Module):
    def __init__(self):
        super().__init__()
        self.fc1 = Linear(784, 128)
        self.relu = ReLU()
        self.fc2 = Linear(128, 10)

    def forward(self, x):
        x = self.fc1.forward(x)
        x = self.relu.forward(x)
        x = self.fc2.forward(x)
        return x

model = SimpleNN()

learning_rate = 0.01
num_epochs = 2  # reduced for brevity
losses = []
for epoch in range(num_epochs):
    logits = model.forward(X_train)
    loss = cross_entropy_loss(logits, y_train)
    losses.append(loss.array)
    
    loss.backward()
    for param in model.parameters():
        if param.requires_grad:
            param.array -= learning_rate * param.grad
            param.zero_grad()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.array}")



try:
    import torch
    
    
    a_np = as_arr([[1.0, 2.0],[3.0,4.0]])
    b_np = as_arr([[5.0, 6.0],[7.0,8.0]])
    
    a_param = Parameter(a_np, requires_grad=True)
    b_param = Parameter(b_np, requires_grad=True)
    
   
    out_param = a_param * b_param + a_param @ b_param
    out_param.backward()

  
    a_torch = torch.tensor(a_np, requires_grad=True)
    b_torch = torch.tensor(b_np, requires_grad=True)
    out_torch = a_torch * b_torch + a_torch.matmul(b_torch)
    out_torch.backward(torch.ones_like(out_torch))


    print("Comparing with PyTorch:")
    print("Forward result difference:", np.abs((out_param.array - out_torch.detach().numpy())).sum())
    print("a_param grad difference:", np.abs((a_param.grad - a_torch.grad.numpy())).sum())
    print("b_param grad difference:", np.abs((b_param.grad - b_torch.grad.numpy())).sum())

except ImportError:
    print("PyTorch not installed. Skipping comparison test.")


Epoch 1/2, Loss: 3.101780981301624
Epoch 2/2, Loss: 3.070237892668103
Comparing with PyTorch:
Forward result difference: 0.0
a_param grad difference: 0.0
b_param grad difference: 0.0


In [4]:
!pip install torch 

Collecting torch
  Downloading torch-2.5.1-cp312-none-macosx_11_0_arm64.whl.metadata (28 kB)
Collecting filelock (from torch)
  Downloading filelock-3.16.1-py3-none-any.whl.metadata (2.9 kB)
Collecting typing-extensions>=4.8.0 (from torch)
  Using cached typing_extensions-4.12.2-py3-none-any.whl.metadata (3.0 kB)
Collecting networkx (from torch)
  Downloading networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.4-py3-none-any.whl.metadata (2.6 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Collecting setuptools (from torch)
  Using cached setuptools-75.6.0-py3-none-any.whl.metadata (6.7 kB)
Collecting sympy==1.13.1 (from torch)
  Downloading sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy==1.13.1->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Collecting MarkupSafe>=2.0 (from jinja2->torch)
  Downloading MarkupSafe-3.