In [1]:
import numpy as np
from collections import defaultdict, deque
from typing import Callable, Any, DefaultDict, Tuple, Optional
from typing import Any, Callable, Iterable, Iterator, TypeAlias, DefaultDict
from typing import Union, Iterator, Any
grad_tracking_enabled = True
Arr : np.ndarray


In [2]:
def wrap_forward_function(func): #function should  be from tensors to tensors 
    def new_function(*args, **kwargs):
        arguments = tuple([a for a in args]) #we expect args to be tensors 
        result = func(*args, **kwargs)
        requires_grad = grad_tracking_enabled and any([a.requires_grad for a in args])
        if requires_grad:
            result.parents = arguments
        return result
    return new_function


In [3]:
class Parameter:
    def __init__(
        self,
        array: Optional[np.ndarray] = None,
        requires_grad: bool = False,
        parents: Optional[Tuple['Parameter', ...]] = None,
        func: Optional[Callable] = None,
        kwargs: Optional[dict] = None,
    ):
        self.array = array if array is not None else np.array(0.0)
        self.grad: Optional[np.ndarray] = None
        self.requires_grad = requires_grad
        self.parents = parents if parents is not None else ()
        self.func = func
        self.kwargs = kwargs if kwargs is not None else {}
    def __add__(self, other: 'Parameter') -> 'Parameter':
        return sum_(self, other)
    def __mul__(self, other: 'Parameter') -> 'Parameter':
        return multiply(self, other)
    def __rmul__(self, other: 'Parameter') -> 'Parameter':
        return multiply(other, self)
    def __matmul__(self, other: 'Parameter') -> 'Parameter':
        return matmul(self, other)
    def __rmatmul__(self, other: 'Parameter') -> 'Parameter':
        return matmul(other, self)
    @property
    def T(self) -> 'Parameter':
        return Parameter(array=self.array.T, requires_grad=self.requires_grad)
    def __repr__(self) -> str:
        return f"Parameter(shape={self.array.shape}, requires_grad={self.requires_grad})"
    def backward(self, grad: Optional[np.ndarray] = None):
        if not self.requires_grad:
            return
        if grad is None:
            grad = np.ones_like(self.array)
        self.grad = grad if self.grad is None else self.grad + grad
        topo_order = self._topological_sort()
        for tensor in reversed(topo_order):
            if tensor.func is None:
                continue  
            for idx, parent in enumerate(tensor.parents):
                backward_func = lookup.get_backward_function(tensor.func, idx)
                parent_grad = backward_func(*tensor.parents, tensor.grad)

                if parent.requires_grad:
                    if parent.grad is None:
                        parent.grad = parent_grad
                    else:
                        parent.grad += parent_grad

    def _topological_sort(self) -> list:
        visited = set()
        topo_order = []
        def dfs(tensor: 'Parameter'):
            if tensor in visited:
                return
            visited.add(tensor)
            for parent in tensor.parents:
                dfs(parent)
            topo_order.append(tensor)
        dfs(self)
        return topo_order


In [4]:
@wrap_forward_function
def multiply(x, y):
    return Tensor(x.array * y.array, parents = (x, y))

@wrap_forward_function
def matmul(x, y):
    return Tensor(x.array @ y.array, parents = (x, y))

@wrap_forward_function
def sum_(x, y):
    return Tensor(x.array + y.array, parents=(x, y), requires_grad=True)

@wrap_forward_function
def log(x):
    return Tensor(np.log(x.array))
@wrap_forward_function
def eq(x, y):
    return Tensor(np.equal(x.array, y.array))
@wrap_forward_function
def sum(x, dim=None, keepdim=False):
    return Tensor(np.sum(x.array, axis=dim, keepdims=keepdim), parents=(x,), kwargs={'dim': dim, 'keepdim': keepdim}, requires_grad=True)


In [5]:
class BackwardLookupTable:
    def __init__(self):
        self.table: DefaultDict[Callable, Dict[int, Callable]] = defaultdict(dict)

    def add_element(self, forward_function: Callable, position: int, backward_func: Callable):
        self.table[forward_function][position] = backward_func

    def get_backward_function(self, forward_function: Callable, position: int) -> Callable:
        return self.table[forward_function][position]

   

def multiply_back0(argument0, argument1, grad_out):
    return  argument1.array * grad_out
def multiply_back1(argument0, argument1, grad_out):
    return  argument0.array * grad_out
def matmul_back0(argument0, argument1, grad_out): 
    return grad_out @ argument1.T.array
def matmul_back1(argument0, argument1, grad_out): 
    return  argument0.T.array @ grad_out 
def sum_back0(argument0, argument1, grad_out):
    return grad_out
def sum_back1(argument0, argument1, grad_out):
    return grad_out
def sum_back(argument, grad_out):
    dim = argument.kwargs.get('dim', None)
    keepdim = argument.kwargs.get('keepdim', False)
    original_shape = argument.array.shape
    grad_shape = grad_out.shape
    if dim is None:
        grad = np.full(original_shape, grad_out)
    else:
        if isinstance(dim, int):
            dim = (dim,)
        else:
            dim = tuple(dim)
        if not keepdim:
            grad_out = np.expand_dims(grad_out, axis=dim)
        grad = np.ones(original_shape) * grad_out
    return grad



lookup = BackwardLookupTable()
lookup.add_element(multiply, 0, multiply_back0)
lookup.add_element(multiply, 1, multiply_back1)
lookup.add_element(matmul, 0, matmul_back0)
lookup.add_element(matmul, 1, matmul_back1)
lookup.add_element(sum_, 0, sum_back0)
lookup.add_element(sum_, 1, sum_back1)
lookup.add_element(sum, 0, sum_back)

In [6]:
class Module:
    _modules: dict[str, "Module"]
    _parameters: dict[str, Parameter]

    def __init__(self):
        self._modules = {}
        self._parameters = {}
    
    def modules(self):
        '''Return the direct child modules of this module.'''
        return self.__dict__["_modules"].values()
    
    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        '''
        Return an iterator over Module parameters.
        recurse: if True, the iterator includes parameters of submodules, recursively.
        '''
        parameters_list = list(self.__dict__["_parameters"].values())
        if recurse:
            for mod in self.modules():
                parameters_list.extend(list(mod.parameters(recurse=True)))
        return iter(parameters_list)
    
    def __setattr__(self, key: str, val: Any) -> None:
        '''
        If val is a Parameter or Module, store it in the appropriate _parameters or _modules dict.
        Otherwise, call __setattr__ from the superclass.
        '''
        if isinstance(val, Parameter):
            self.__dict__["_parameters"][key] = val
        elif isinstance(val, Module):
            self.__dict__["_modules"][key] = val
        else:
            super().__setattr__(key, val)
    
    def __getattr__(self, key: str) -> Union[Parameter, "Module"]:  # Changed this line
        '''
        If key is in _parameters or _modules, return the corresponding value.
        Otherwise, raise KeyError.
        '''
        if key in self.__dict__["_parameters"]:
            return self.__dict__["_parameters"][key]
        if key in self.__dict__["_modules"]:
            return self.__dict__["_modules"][key]
        raise KeyError(key)
