In [1]:
import numpy as np
from collections import defaultdict, deque
from typing import Callable, Any, DefaultDict, Tuple, Optional

# ----------------------------
# Backward Lookup Table
# ----------------------------
class BackwardLookupTable:
    def __init__(self):
        # Maps a forward function to a dictionary that maps input positions to backward functions
        self.table: DefaultDict[Callable, Dict[int, Callable]] = defaultdict(dict)

    def add_element(self, forward_function: Callable, position: int, backward_func: Callable):
        """
        Registers a backward function for a given forward function and input position.
        """
        self.table[forward_function][position] = backward_func

    def get_backward_function(self, forward_function: Callable, position: int) -> Callable:
        """
        Retrieves the backward function for a given forward function and input position.
        """
        return self.table[forward_function][position]

# Initialize the backward lookup table
lookup = BackwardLookupTable()

# ----------------------------
# Backward Functions
# ----------------------------
def multiply_back0(arg0: 'Parameter', arg1: 'Parameter', grad_out: np.ndarray) -> np.ndarray:
    """
    Backward function for the first input of element-wise multiplication.
    Gradient w.r. to arg0 is arg1.array * grad_out
    """
    return arg1.array * grad_out

def multiply_back1(arg0: 'Parameter', arg1: 'Parameter', grad_out: np.ndarray) -> np.ndarray:
    """
    Backward function for the second input of element-wise multiplication.
    Gradient w.r. to arg1 is arg0.array * grad_out
    """
    return arg0.array * grad_out

def matmul_back0(arg0: 'Parameter', arg1: 'Parameter', grad_out: np.ndarray) -> np.ndarray:
    """
    Backward function for the first input of matrix multiplication.
    Gradient w.r. to arg0 is grad_out @ arg1.array.T
    """
    return grad_out @ arg1.array.T

def matmul_back1(arg0: 'Parameter', arg1: 'Parameter', grad_out: np.ndarray) -> np.ndarray:
    """
    Backward function for the second input of matrix multiplication.
    Gradient w.r. to arg1 is arg0.array.T @ grad_out
    """
    return arg0.array.T @ grad_out

def sum_back(arg: 'Parameter', grad_out: np.ndarray) -> np.ndarray:
    """
    Backward function for the sum operation over a single tensor.
    It broadcasts grad_out to the shape of the original tensor.
    """
    dim = arg.kwargs.get('dim', None)
    keepdim = arg.kwargs.get('keepdim', False)
    
    original_shape = arg.parents[0].array.shape  # Assuming single parent
    if dim is None:
        # Sum over all elements; grad_out is a scalar
        grad = np.full(original_shape, grad_out)
    else:
        if isinstance(dim, int):
            dim = (dim,)
        else:
            dim = tuple(dim)
        
        if not keepdim:
            # Expand dimensions for broadcasting
            grad_out = grad_out
            for d in sorted(dim):
                grad_out = np.expand_dims(grad_out, axis=d)
        
        # Broadcast the gradient to the original shape
        grad = np.ones(original_shape) * grad_out
    
    return grad

def sum_back0(arg0: 'Parameter', arg1: 'Parameter', grad_out: np.ndarray) -> np.ndarray:
    """
    Backward function for the first input of addition.
    Gradient w.r. to arg0 is grad_out.
    """
    return grad_out

def sum_back1(arg0: 'Parameter', arg1: 'Parameter', grad_out: np.ndarray) -> np.ndarray:
    """
    Backward function for the second input of addition.
    Gradient w.r. to arg1 is grad_out.
    """
    return grad_out

# ----------------------------
# Forward Operations with Decorator
# ----------------------------
def wrap_forward_function(func: Callable) -> Callable:
    """
    Decorator to wrap forward functions to handle gradient tracking.
    Sets the parents of the result tensor if gradients are required.
    """
    def new_function(*args: 'Parameter', **kwargs: Any) -> 'Parameter':
        # Execute the forward function
        result = func(*args, **kwargs)
        
        # Determine if gradients need to be tracked
        requires_grad = any([arg.requires_grad for arg in args])
        if requires_grad:
            result.parents = args
            result.func = func
            result.kwargs = kwargs
        return result
    return new_function

@wrap_forward_function
def multiply(x: 'Parameter', y: 'Parameter') -> 'Parameter':
    """
    Element-wise multiplication of two tensors.
    """
    return Parameter(array=x.array * y.array, requires_grad=True)

@wrap_forward_function
def matmul(x: 'Parameter', y: 'Parameter') -> 'Parameter':
    """
    Matrix multiplication of two tensors.
    """
    return Parameter(array=x.array @ y.array, requires_grad=True)

@wrap_forward_function
def sum_(x: 'Parameter', y: 'Parameter') -> 'Parameter':
    """
    Element-wise addition of two tensors.
    """
    return Parameter(array=x.array + y.array, requires_grad=True)

@wrap_forward_function
def sum(x: 'Parameter', dim: Optional[int] = None, keepdim: bool = False) -> 'Parameter':
    """
    Sum of a tensor along specified dimensions.
    """
    return Parameter(array=np.sum(x.array, axis=dim, keepdims=keepdim), requires_grad=True)

# ----------------------------
# Register Backward Functions
# ----------------------------
# Register backward functions for element-wise multiplication
lookup.add_element(multiply, 0, multiply_back0)
lookup.add_element(multiply, 1, multiply_back1)

# Register backward functions for matrix multiplication
lookup.add_element(matmul, 0, matmul_back0)
lookup.add_element(matmul, 1, matmul_back1)

# Register backward functions for addition of two tensors
lookup.add_element(sum_, 0, sum_back0)
lookup.add_element(sum_, 1, sum_back1)

# Register backward function for sum over a single tensor
lookup.add_element(sum, 0, sum_back)

# ----------------------------
# Parameter Class
# ----------------------------
class Parameter:
    def __init__(
        self,
        array: Optional[np.ndarray] = None,
        requires_grad: bool = False,
        parents: Optional[Tuple['Parameter', ...]] = None,
        func: Optional[Callable] = None,
        kwargs: Optional[dict] = None,
    ):
        """
        Initializes a Parameter.
        
        :param array: The underlying NumPy array.
        :param requires_grad: Flag indicating whether to track gradients.
        :param parents: Parent tensors involved in creating this tensor.
        :param func: The forward function that created this tensor.
        :param kwargs: Additional keyword arguments passed to the forward function.
        """
        self.array = array if array is not None else np.array(0.0)
        self.grad: Optional[np.ndarray] = None
        self.requires_grad = requires_grad
        self.parents = parents if parents is not None else ()
        self.func = func
        self.kwargs = kwargs if kwargs is not None else {}
    
    def __add__(self, other: 'Parameter') -> 'Parameter':
        return sum_(self, other)

    def __mul__(self, other: 'Parameter') -> 'Parameter':
        return multiply(self, other)

    def __rmul__(self, other: 'Parameter') -> 'Parameter':
        return multiply(other, self)

    def __matmul__(self, other: 'Parameter') -> 'Parameter':
        return matmul(self, other)

    def __rmatmul__(self, other: 'Parameter') -> 'Parameter':
        return matmul(other, self)

    @property
    def T(self) -> 'Parameter':
        """
        Transpose of the tensor.
        """
        return Parameter(array=self.array.T, requires_grad=self.requires_grad)

    def __repr__(self) -> str:
        return f"Parameter(shape={self.array.shape}, requires_grad={self.requires_grad})"

    def backward(self, grad: Optional[np.ndarray] = None):
        """
        Initiates the backward pass to compute gradients.
        
        :param grad: Gradient of the loss with respect to this tensor.
                     If None, it is assumed to be a scalar tensor with grad 1.
        """
        if not self.requires_grad:
            return

        if grad is None:
            # If grad is not provided, assume this tensor is a scalar
            grad = np.ones_like(self.array)
        
        # Initialize the gradient
        self.grad = grad if self.grad is None else self.grad + grad

        # Perform topological sort
        topo_order = self._topological_sort()

        # Traverse the graph in reverse topological order
        for tensor in reversed(topo_order):
            if tensor.func is None:
                continue  # Skip leaf tensors

            # Iterate over all parents and compute their gradients
            for idx, parent in enumerate(tensor.parents):
                backward_func = lookup.get_backward_function(tensor.func, idx)
                parent_grad = backward_func(*tensor.parents, tensor.grad)

                if parent.requires_grad:
                    if parent.grad is None:
                        parent.grad = parent_grad
                    else:
                        parent.grad += parent_grad

    def _topological_sort(self) -> list:
        """
        Performs a topological sort of the computation graph.
        
        :return: List of tensors in topologically sorted order.
        """
        visited = set()
        topo_order = []

        def dfs(tensor: 'Parameter'):
            if tensor in visited:
                return
            visited.add(tensor)
            for parent in tensor.parents:
                dfs(parent)
            topo_order.append(tensor)

        dfs(self)
        return topo_order
