In [63]:
from collections.abc import Sequence
from __future__ import annotations
from numbers import Number
from itertools import combinations
from typing import Callable
from typing import Type
from typing import Tuple
from icecream import ic
from abc import ABC, abstractmethod
import numpy as np
import graphviz

from sklearn.datasets import load_iris
from numpy import linalg as LA

## Defining the mathematical operators as classes

In [64]:
# operator classes
class function(ABC):
    @property
    @abstractmethod
    def forward_func(self) -> Callable:
        pass
    
    @property
    @abstractmethod
    def backward_func(self) -> Callable:
        pass
    
    def get_functions(self) -> Tuple[Callable, Callable]:
        return self.forward_func, self.backward_func
    
    def __call__(self, *inner) -> expr_node: # simplifies the syntax for changing of functions
        return expr_node(self, inner)

class sin(function):
    forward_func = np.sin
    backward_func = np.cos

class cos(function):
    forward_func = np.cos
    backward_func = lambda x: -np.sin(x)
          
class tanh(function):
    forward_func = np.tanh
    backward_func = lambda x: 1 - np.square(tanh(x))

class log(function):
    forward_func = np.log
    backward_func = lambda x: 1/x


class multiply(function):
    
    forward_func = lambda: None # to be overwitten in constructor
    backward_func = lambda: None # to be overwitten in constructor
    
    def __init__(self,*, allow_arbitrary_many = False): # * makes allow_arbitary_many to keword only
        if not allow_arbitrary_many: # simple case two factors
            self.forward_func = lambda x,y: np.multiply(x,y)
            self.backward_func = lambda x, y: (y, x)
        else:
            forward_func = lambda *x: np.prod(np.vstack(x), axis=0)
            self.backward_func = lambda *x: (np.prod(np.vstack(values), axis=0) for values in combinations(x, len(x) - 1))
                   
class add(function):
    
    forward_func = lambda: None # to be overwitten in constructor
    backward_func = lambda: None # to be overwitten in constructor
    
    def __init__(self,*, allow_arbitrary_many = False): # new is clalled before the constructer (before: "__init__")
        if not allow_arbitrary_many:
            self.forward_func = np.add 
            self.backward_func = lambda x, y: (np.full_like(x, 1), np.full_like(y, 1)) #arrays with same shape of x and y, filled with 1
        else:
            self.forward_func = lambda *x: np.sum(np.vstack(x), axis=0)
            self.backward_func = lambda *x: (np.full_like(x[0], 1) for _ in range(len(x))) #backward_func = lambda *x: (np.ones(len(x[0]) if isinstance(x[0], Iterable) else 1) for _ in range(len(x))))

class power(function):
    forward_func = lambda: None
    backward_func = lambda: None

    def __init__(self):
            self.forward_func = lambda x,y: np.power(x,y)
            self.backward_func = lambda x, y: y * np.power(x,y-1)

class scalar(function):
    forward_func = np.float64
    backward_func = 0

## Creating the node tree for a function

In [65]:
class expr_end_node():
    def __init__(self, value: Number | np.ndarray , grad_value: Number | np.ndarray  = 0):
        self.value = value
        self.grad_value = grad_value
        
class expr_node():
    def __init__(self, func: function, childs: Sequence[expr_node | expr_end_node] = []):
        #self.parents = parents #expr_node
        self.childs = childs  #expr_node
        self.func = func #function 
        # lieber so:
        # self.forward_func, self.backward_func = func.get_functions()
        # und dann self.func entfernen

#graphical depiction of the nodes
def print_graph(node, graph, parent_id=""):
    
    node_id = str(id(node))
    node_label = str(node.value) if isinstance(node, expr_end_node) else type(node.func).__name__
    graph.node(node_id, node_label)  
    
    if parent_id:
        graph.edge(parent_id, node_id)
        
    if hasattr(node, 'childs') and node.childs:
            for child in node.childs:
                print_graph(child, graph, node_id)#, constraint='false')

## Forward and backward propagation

In [66]:
def forward(node):
    return node.func.forward_func(*(forward(child) for child in node.childs)) if type(node) is not expr_end_node else node.value 

"""
def forward(node):
    if type(node) is not expr_end_node:
    
        node.func.forward_func(*(forward(child) for child in node.childs))
    else:
        node.value 
"""

def backward(node, value = 1):
    if type(node) is not expr_end_node:
        child_values = (forward(child) for child in node.childs)
        if len(node.childs) == 1:
            return backward(node.childs[0], value * node.func.backward_func(*child_values))
        else:
            for child, new_value in zip(node.childs, node.func.backward_func(*child_values), strict=True):
                return backward(child, value * new_value)
    else: 
        node.grad_value += value


## Example functions
### The node trees

In [67]:
x = expr_end_node(np.float64(1.2))
x1 = expr_end_node(np.float64(.5))
x2 = expr_end_node(np.float64(3.3))
w = expr_end_node(np.float64(2.8))
b = expr_end_node(np.float64(1.6))


function1 = multiply()(log()( multiply()( x1 , x2 ) ) , 
                       sin()(x1) 
                      ) 
                     

function2 = multiply()( 
                        x1 ,
                        x2 , 
                        add()(x1 , x2)
                       )

function3 = add()( multiply()( expr_node( np.float64(3) ) , power()(x , expr_node( np.float64(2) ) ),
                   multiply()( expr_node( np.float64(4) ), x),
                   expr_node( np.float64(2) )
                             )
                 )

function4 = tanh()(add()( multiply()(x , w),
                          b
                        )
                  )


f1graph = graphviz.Digraph('function1', comment='') 
f1graph.attr(rankdir="LR")
print_graph(function1, f1graph)
# #doctest_mark_exe() 
f1graph.render(directory='graph_out/tt', view=False)

f2graph = graphviz.Digraph('function2', comment='') 
f2graph.attr(rankdir="LR")
print_graph(function2, f2graph)
# #doctest_mark_exe() 
f2graph.render(directory='graph_out/tt', view=False)

f3graph = graphviz.Digraph('function3', comment='') 
f3graph.attr(rankdir="LR")
print_graph(function3, f3graph)
# #doctest_mark_exe() 
f3graph.render(directory='graph_out/tt', view=False)

f4graph = graphviz.Digraph('function4', comment='') 
f4graph.attr(rankdir="LR")
print_graph(function4, f4graph)
# #doctest_mark_exe() 
f4graph.render(directory='graph_out/tt', view=False)

'graph_out/tt/function4.gv.pdf'

### The derivatives

In [77]:
f1n1 = expr_node(multiply())
f1n11 = expr_node(log())
f1n111 = expr_node(multiply())
f1n12 = expr_node(sin())

f1n1.childs = [f1n11, f1n12]
f1n11.childs = [f1n111]
f1n111.childs = [x1,x2]
f1n12.childs = [x2]


normalf1 = np.log(.5 * 3.3) * np.sin(3.3)
ic(normalf1)
ic(forward(f1n1))
# ic(forward(function1)) # doesn't work; probably because it cannot recognize the child nodes

# ic(backward(f1n1))

ex1 = expr_node(multiply())
ex11 = expr_node(sin())
ex111 = expr_node(multiply())
ex1.childs = [ex11, x1]
ex11.childs = [ex111]
ex111.childs = [x1,x2]

normalex1 = .5 * np.sin(3.3 * .5)
ic(forward(ex1))
ic(normalex1)
ic(backward(ex1))

e1n1 = expr_node(cos())
e1n11 = expr_node(multiply())
e1n1.childs = [e1n11]
e1n11.childs = [x1,x2]
ic(forward(e1n1))
# ic(backward(e1n1))

e2n1 = expr_node(cos())
e2n11 = expr_node(multiply())
e2n111 = expr_node(scalar())
e2n1.childs = [e2n11]
e2n11.childs = [x1, e2n111]
e2n111.childs = [expr_end_node(np.float64(2))]

ic(np.cos(1))
ic(forward(e2n1))
# ic(backward(e2n1))

ic| normalf1: -0.07899514540154058
ic| forward(f1n1): -0.07899514540154058
ic| forward(ex1): 0.49843251422695944
ic| normalex1: 0.49843251422695944
ic| backward(ex1): None
ic| forward(e1n1): -0.07912088880673386
ic| np.cos(1): 0.5403023058681398
ic| forward(e2n1): 0.5403023058681398


0.5403023058681398