In [57]:
from collections.abc import Sequence
from __future__ import annotations
from numbers import Number
from itertools import combinations
from typing import Callable
from typing import Type
from typing import Tuple
from icecream import ic
from abc import ABC, abstractmethod
import numpy as np
import graphviz
import traceback

from sklearn.datasets import load_iris
from numpy import linalg as LA

## Defining the mathematical operators as classes

In [58]:
# operator classes
class function(ABC):
    @property
    @abstractmethod
    def forward_func(self) -> Callable:
        pass
    
    @property
    @abstractmethod
    def backward_func(self) -> Callable:
        pass
    
    def get_functions(self) -> Tuple[Callable, Callable]:
        return self.forward_func, self.backward_func
    
    def __call__(self, *inner) -> expr_node: # simplifies the syntax for changing of functions
        return expr_node(self, inner)

class sin(function):
    forward_func = np.sin
    backward_func = np.cos

class cos(function):
    forward_func = np.cos
    backward_func = lambda self, x: -np.sin(x)
          
class tanh(function):
    forward_func = np.tanh
    backward_func = lambda self, x: 1 - np.square(tanh(x))

class log(function):
    forward_func = np.log
    backward_func = lambda self, x: 1/x


class multiply(function):
    
    forward_func = lambda: None # to be overwitten in constructor
    backward_func = lambda: None # to be overwitten in constructor
    
    def __init__(self,*, allow_arbitrary_many = False): # * makes allow_arbitary_many to keyword only
        if not allow_arbitrary_many: # simple case two factors
            self.forward_func = np.multiply
            self.backward_func = lambda x, y: (y, x)
        else:
            forward_func = lambda self, *x: np.prod(np.vstack(x), axis=0)
            self.backward_func = lambda self, *x: (np.prod(np.vstack(values), axis=0) for values in combinations(x, len(x) - 1))
                   
class add(function):
    
    forward_func = lambda: None # to be overwitten in constructor
    backward_func = lambda: None # to be overwitten in constructor
    
    def __init__(self,*, allow_arbitrary_many = False): # new is called before the constructer (before: "__init__")
        if not allow_arbitrary_many:
            # self.forward_func = np.add
            self.forward_func = np.add
            self.backward_func = lambda x, y: (np.full_like(x, 1), np.full_like(y, 1)) #arrays with same shape of x and y, filled with 1
        else:
            self.forward_func = lambda *x: np.sum(np.vstack(*x), axis=0)
            self.backward_func = lambda *x: (np.full_like(*x[0], 1) for _ in range(len(*x))) #backward_func = lambda *x: (np.ones(len(x[0]) if isinstance(x[0], Iterable) else 1) for _ in range(len(x))))

class power(function):
    forward_func = lambda: None
    backward_func = lambda: None

    def __init__(self):
            self.forward_func = lambda x, y: np.power(x,y)
            self.backward_func = lambda x, y: y * np.power(x, y-1)

class scalar(function):
    forward_func = np.float64
    backward_func = np.float64(0)

# class Multiply_with_Scalar(Function)

## Creating the node tree of a function

In [59]:
class expr_end_node():
    def __init__(self, value: Number | np.ndarray , grad_value: Number | np.ndarray  = 0, iname=None):
        self.value = value
        self.grad_value = grad_value
        (filename,line_number,function_name,text) = traceback.extract_stack()[-2]
        iname = text[:text.find('=')].strip()
        self.instance = iname

class expr_node():
    def __init__(self, func: function, childs: Sequence[expr_node | expr_end_node] = []):
        #self.parents = parents #expr_node
        self.childs = childs  #expr_node
        self.func = func #function 
        (filename,line_number,function_name,text) = traceback.extract_stack()[-2]
        iname = text[:text.find('=')].strip()
        self.instance = iname
        # lieber so:
        # self.forward_func, self.backward_func = func.get_functions()
        # und dann self.func entfernen

#graphical depiction of the nodes
def print_graph(node, graph, parent_id=""):
    
    node_id = str(id(node))
    if isinstance(node, expr_end_node):
        node_label = str(node.instance)
    else:
        if type(node.func).__name__ == 'scalar':
            node_label = str(node.instance)
        else:
            node_label = type(node.func).__name__
    graph.node(node_id, node_label)  
    
    if parent_id:
        graph.edge(parent_id, node_id)
        
    if hasattr(node, 'childs') and node.childs:
            for child in node.childs:
                print_graph(child, graph, node_id)#, constraint='false')

## Algorithms for forward and backward propagation

In [60]:
def forward(node):
    return node.func.forward_func(*(forward(child) for child in node.childs)) if type(node) is not expr_end_node else node.value 

# To make sure we take every derivative, we record the variables as a list in endnodefinder.
# The recursive backward function then repeats itself for every variable in the list.
# We then need backwardtotal; a function that accesses both recursive functions but doesn't recurse itself, so that we have access to constant, once-determined parameters,
# like the component list of partial derivatives or the recording parameter endvalue

# introducing the global parameter endvalue to record derivative values
endvalue = 0

def endnodefinder(node, endnodes=[]):
    if type(node) is not expr_end_node:
        if type(node.func).__name__ != 'scalar': # Terminates loop when scalar is detected; scalar is defined as a node, but does not link to an end node
            for child in node.childs:
                endnodefinder(child, endnodes)
    else:
        if node.instance not in endnodes: # to make sure we don't make a list of repeated variables
            endnodes.append(node.instance)
    return endnodes

def backward(node, variable, value = 1):
    global endvalue
    if type(node) is not expr_end_node:
        child_values = (forward(child) for child in node.childs)
        if len(node.childs) == 1:
            backward(node.childs[0], variable, value * node.func.backward_func(*child_values))
        else:
            for child, new_value in zip(node.childs, node.func.backward_func(*child_values), strict=True):
                backward(child, variable, value * new_value)
    else:
        if node.instance == variable:
            pass
        else:
            value = 0
        endvalue += value
    return endvalue

def backwardtotal(node):
    results = np.empty([1,2])
    i = 0
    variables = endnodefinder(node)
    global endvalue
    for variable in variables:
        endvalue = 0
        results[0][i] = backward(node,variable)
        i += 1
    endvalue = 0 # returning global parameter to initial value
    return results
        
        
        


## Example functions
### Depiction of the node trees of each respective function

In [61]:
x = expr_end_node(np.float64(1.2))
x1 = expr_end_node(np.float64(.5))
x2 = expr_end_node(np.float64(3.3))
w = expr_end_node(np.float64(2.8))
b = expr_end_node(np.float64(1.6))

# function 1: log(x1 * x2) * sin(x1)
function1 = multiply()(
                log()( 
                    multiply()( x1 , x2 ) 
                ) , 
                sin()(x2) 
            ) 
                     

function2 = multiply(allow_arbitrary_many=True)( 
                        x1 ,
                        x2 , 
                        add()(x1 , x2)
                       )


function3 = add()( multiply()( expr_node( np.float64(3) ) , power()(x , expr_node( np.float64(2) ) ),
                   multiply()( expr_node( np.float64(4) ), x),
                   expr_node( np.float64(2) )
                             )
                 )

function4 = tanh()(add()( multiply()(x , w),
                          b
                        )
                  )


f1graph = graphviz.Digraph('function1', comment='') 
f1graph.attr(rankdir="LR")
print_graph(function1, f1graph)
# #doctest_mark_exe() 
f1graph.render(directory='graph_out/tt', view=False)

f2graph = graphviz.Digraph('function2', comment='') 
f2graph.attr(rankdir="LR")
print_graph(function2, f2graph)
# #doctest_mark_exe() 
f2graph.render(directory='graph_out/tt', view=False)

f3graph = graphviz.Digraph('function3', comment='') 
f3graph.attr(rankdir="LR")
print_graph(function3, f3graph)
# #doctest_mark_exe() 
f3graph.render(directory='graph_out/tt', view=False)

f4graph = graphviz.Digraph('function4', comment='') 
f4graph.attr(rankdir="LR")
print_graph(function4, f4graph)
# #doctest_mark_exe() 
f4graph.render(directory='graph_out/tt', view=False)

'graph_out/tt/function4.gv.pdf'

### Taking the derivatives via chain rule

In [62]:
# manual function 1 and its derivatives
normalf1 = np.log(x1.value * x2.value) * np.sin(x2.value)
normaldf1_dx1 = np.sin(x2.value) / x1.value
normaldf1_dx2_t1 = np.sin(x2.value) / x2.value  
normaldf1_dx2_t2 = np.log(x1.value * x2.value) * np.cos(x2.value)
normaldf1_dx2 = normaldf1_dx2_t1 + normaldf1_dx2_t2

ic(normalf1)
ic(forward(function1)) # doesn't work; probably because it cannot recognize the child nodes. If it would, it would be much more convenient as we don't have to assign childs manually



ic(normaldf1_dx1)
# ic(normaldf1_dx2_t1)
# ic(normaldf1_dx2_t2)
ic(normaldf1_dx2)


# manual function 2 and its derivatives
normalf2 = x1.value*x2.value * (x1.value + x2.value)
normaldf2_dx1 = x2.value * (x1.value + x2.value) + x1.value*x2.value
normaldf2_dx2 = x1.value * (x1.value + x2.value) + x1.value*x2.value

ic(normaldf2_dx1)
ic(normaldf2_dx2)

# ic(backwardtotal(f2n1)) # error message

ic(backwardtotal(function2))

print("function 1")
print("Calculus values of x1 derivative and x2 derivative:", normaldf1_dx1, normaldf1_dx2 )
print("Compared to derivatives through chain rule:        ", backwardtotal(function1))

ic| normalf1: -0.07899514540154058
ic| forward(function1): -0.07899514540154058
ic| normaldf1_dx1: -0.3154913882864964
ic| normaldf1_dx2: -0.5423071915818244
ic| normaldf2_dx1: 14.19
ic| normaldf2_dx2: 3.55


ValueError: zip() argument 2 is shorter than argument 1