In [None]:
from abc import ABC, abstractmethod
import math
from queue import Queue
from collections import defaultdict
import itertools

from utils.parser import parse_expression

# Operations

In [None]:
class Operator(ABC):
    @abstractmethod
    def f(self, x, y=None) -> float:
        """
        Evaluation of the function at (x, y)
        """
        raise NotImplementedError()
        return 0.0

    @abstractmethod
    def df(self, x, y=None) -> list:
        """
        Evaluation of the gradient of the function at (x, y)
        """
        raise NotImplementedError()
        return [0.0]

In [None]:
class Add(Operator):
    def f(self, x: float, y: float) -> float:
        return x + y

    def df(self, x: float, y: float) -> list[float, float]:
        return [1, 1]


class Sub(Operator):
    def f(self, x: float, y: float) -> float:
        return x - y

    def df(self, x: float, y: float) -> list[float, float]:
        return [1, -1]


class Mul(Operator):
    def f(self, x: float, y: float) -> float:
        return x * y

    def df(self, x: float, y: float) -> list[float, float]:
        return [y, x]


class Div(Operator):
    def f(self, x: float, y: float) -> float:
        return x / y

    def df(self, x: float, y: float) -> list[float, float]:
        return [1 / y, -x / (y**2)]


class Pow(Operator):
    def f(self, x, y):
        return x**y

    def df(self, x, y):
        if x <= 0:  ## Can't take df/dy if x is negative
            return [y * (x ** (y - 1))]
        else:
            return [y * (x ** (y - 1)), (x**y) * math.log(x)]


class Exp(Operator):
    def f(self, x: float, y: float = None) -> float:
        return math.exp(x)

    def df(self, x: float, y: float = None) -> list[float]:
        return [math.exp(x)]


class Log(Operator):
    def f(self, x: float, y: float = None) -> float:
        return math.log(x)

    def df(self, x: float = None) -> list[float]:
        return [1 / x]

# Computational graph

In [None]:
class Node:
    def __init__(self, id: str):
        self.id = id
        self.value = None
        self.inputs = []
        self.outputs = []

    def __str__(self):
        return f"(value = {self.value}, id = {self.id}, in_nodes: {str([str(in_node) for in_node in self.inputs])})"

    def __repr__(self):
        return self.__str__()


class SymbolNode(Node):
    def __init__(self, id: str):
        super().__init__(id)


class ValueNode(Node):
    def __init__(self, id: str, value):
        super().__init__(id)
        self.value = value


class OperandNode(Node):
    def __init__(self, id: str, operand, inputs: list):
        super().__init__(id)
        self.operand = operand
        self.inputs = inputs

    def __str__(self):
        return f"(operand = {self.operand}, " + super().__str__()[1:]


class Graph:
    def build_graph(self, infix: list, params: set, variables: set):
        if not isinstance(infix, list):
            symbol = infix
            if symbol in params:
                node = self.param_nodes[symbol]
            elif symbol in variables:
                node = self.var_nodes[symbol]
            else:
                node = ValueNode(next(self.id_generator), symbol)
        elif len(infix) == 2:
            operand = infix[0]
            node = OperandNode(
                next(self.id_generator),
                operand,
                [self.build_graph(infix[1], params, variables)],
            )
        elif len(infix) == 3:
            operand = infix[1]
            node = OperandNode(
                next(self.id_generator),
                operand,
                [
                    self.build_graph(infix[0], params, variables),
                    self.build_graph(infix[2], params, variables),
                ],
            )
            for in_node in node.inputs:
                in_node.outputs.append(node)
        return node

    def __init__(self, infix: list, params: list, variables: list):
        # Create the nodes for the params and the variables
        self.params = params
        self.param_nodes = {}
        self.var_nodes = {}
        for param in params:
            self.param_nodes[param] = SymbolNode(param)
        for var in variables:
            self.var_nodes[var] = SymbolNode(var)
        self.id_generator = map(str, itertools.count(0))
        self.out_node = self.build_graph(infix, set(params), set(variables))

    def __str__(self):
        return str(self.out_node)

    def __repr__(self):
        return self.__str__()

## Test Computational graph

In [None]:
expressions = ["exp(y) - (x * 2)", "log(x) / x + y", "1 + 1"]
expected_graphs = [
    """(operand = -, value = None, id = 0, in_nodes: ["(operand = exp, value = None, id = 1, in_nodes: ['(value = None, id = y, in_nodes: [])'])", "(operand = *, value = None, id = 2, in_nodes: ['(value = None, id = x, in_nodes: [])', '(value = 2, id = 3, in_nodes: [])'])"])""",
    r"""(operand = +, value = None, id = 0, in_nodes: ['(operand = /, value = None, id = 1, in_nodes: ["(operand = log, value = None, id = 2, in_nodes: [\'(value = None, id = x, in_nodes: [])\'])", \'(value = None, id = x, in_nodes: [])\'])', '(value = None, id = y, in_nodes: [])'])""",
    """(operand = +, value = None, id = 0, in_nodes: ['(value = 1, id = 1, in_nodes: [])', '(value = 1, id = 2, in_nodes: [])'])""",
]
for expression, expected_g in zip(expressions, expected_graphs):
    infix = parse_expression(expression)
    g = Graph(infix, params=["y"], variables=["x"])
    assert str(g) == expected_g

In [None]:
g = Graph(parse_expression("1 + 1"), params=["y"], variables=["x"])
print(g)

(operand = +, value = None, id = 0, in_nodes: ['(value = 1, id = 1, in_nodes: [])', '(value = 1, id = 2, in_nodes: [])'])


# Evaluate function and gradient

In [None]:
class Evaluator:
    def __init__(self, graph: Graph, init_params: dict = None):
        self.graph = graph
        self.map_operands = {
            "+": Add(),
            "-": Sub(),
            "*": Mul(),
            "/": Div(),
            "^": Pow(),
            "exp": Exp(),
            "log": Log(),
        }
        self.param_values = {param_id: 20.0 for param_id in graph.param_nodes}
        self.derivatives = defaultdict(int)

    def resolve_forward(self, node: Node, var_values: dict):
        if node.id in var_values.keys():
            node.value = var_values[node.id]
        elif node.id in self.param_values.keys():
            node.value = self.param_values[node.id]
        elif isinstance(node, ValueNode):
            node.value = node.value
        else:
            input_values = [self.resolve_forward(n, var_values) for n in node.inputs]
            node.value = self.map_operands[node.operand].f(*input_values)
        return node.value

    def forward(self, var_values: dict):
        return self.resolve_forward(self.graph.out_node, var_values)

    def resolve_backward(self):
        queue = Queue()
        queue.put(self.graph.out_node)
        visited = set()
        self.derivatives[self.graph.out_node.id] = 1
        while not queue.empty():
            node = queue.get()
            if node.id not in visited and isinstance(node, OperandNode):
                input_values = [in_node.value for in_node in node.inputs]
                gradient = self.map_operands[node.operand].df(*input_values)
                for in_node, partial_derivative in zip(node.inputs, gradient):
                    self.derivatives[in_node.id] += (
                        partial_derivative * self.derivatives[node.id]
                    )
                    queue.put(in_node)

                visited.add(self.graph.out_node.id)

    def backward(self):
        self.resolve_backward()

    def optimize_step(self, lr=0.1):
        for param in self.graph.params:
            self.param_values[param] -= lr * self.derivatives[param]
        self.derivatives = defaultdict(int)

In [None]:
g = Graph(parse_expression("(x-50) ^ 2"), params=["x"], variables=[])

In [None]:
e = Evaluator(g)

In [None]:
e.forward({})

900.0

In [None]:
e.backward()

In [None]:
e.derivatives

defaultdict(int, {'0': 1, '1': -60.0, 'x': -60.0, '2': 60.0})

In [None]:
print(e.forward({}), e.param_values["x"])
for epoch in range(1000):
    e.forward({})
    e.backward()
    # print(e.forward({}), e.param_values["x"], e.derivatives["x"])
    e.optimize_step(lr=0.01)
print(e.forward({}), e.param_values["x"])

900.0 20.0
2.345295211941157e-15 49.99999995157175
