In [None]:
from abc import ABC, abstractmethod
import math
from queue import Queue
from collections import defaultdict
import itertools

from utils.parser import parse_expression

# Operations

In [None]:
class Operator(ABC):
    @abstractmethod
    def f(self, x, y=None) -> float:
        """
        Evaluation of the function at (x, y)
        """
        raise NotImplementedError()
        return 0.0

    @abstractmethod
    def df(self, x, y=None) -> list:
        """
        Evaluation of the gradient of the function at (x, y)
        """
        raise NotImplementedError()
        return [0.0]

In [None]:
class Add(Operator):
    def f(self, x: float, y: float) -> float:
        return x + y

    def df(self, x: float, y: float) -> list[float, float]:
        return [1, 1]


class Sub(Operator):
    def f(self, x: float, y: float) -> float:
        ...

    def df(self, x: float, y: float) -> list[float, float]:
        ...


class Mul(Operator):
    def f(self, x: float, y: float) -> float:
        ...

    def df(self, x: float, y: float) -> list[float, float]:
        ...


class Div(Operator):
    def f(self, x: float, y: float) -> float:
        ...

    def df(self, x: float, y: float) -> list[float, float]:
        ...


class Pow(Operator):
    def f(self, x, y):
        return x**y

    def df(self, x, y):
        if x <= 0:  ## Can't take df/dy if x is negative
            return [y * (x ** (y - 1))]
        else:
            return [y * (x ** (y - 1)), (x**y) * math.log(x)]


class Exp(Operator):
    def f(self, x: float, y: float = None) -> float:
        ...

    def df(self, x: float, y: float = None) -> list[float]:
        ...


class Log(Operator):
    def f(self, x: float, y: float = None) -> float:
        ...

    def df(self, x: float = None) -> list[float]:
        ...

# Computational graph

In [None]:
class Node:
    def __init__(self, id: str):
        self.id = id
        self.value = None
        self.inputs = []
        self.outputs = []

    def __str__(self):
        return f"(value = {self.value}, id = {self.id}, in_nodes: {str([str(in_node) for in_node in self.inputs])})"

    def __repr__(self):
        return self.__str__()


class SymbolNode(Node):
    def __init__(self, id: str):
        super().__init__(id)


class ValueNode(Node):
    def __init__(self, id: str, value):
        super().__init__(id)
        self.value = value


class OperandNode(Node):
    def __init__(self, id: str, operand, inputs: list):
        super().__init__(id)
        self.operand = operand
        self.inputs = inputs

    def __str__(self):
        return f"(operand = {self.operand}, " + super().__str__()[1:]


class Graph:
    def build_graph(self, infix: list, params: set, variables: set):
        ...

    def __init__(self, infix: list, params: list, variables: list):
        # Create the nodes for the params and the variables
        self.params = params
        self.param_nodes = {}
        self.var_nodes = {}
        for param in params:
            self.param_nodes[param] = SymbolNode(param)
        for var in variables:
            self.var_nodes[var] = SymbolNode(var)
        self.id_generator = map(str, itertools.count(0))
        self.out_node = self.build_graph(infix, set(params), set(variables))

    def __str__(self):
        return str(self.out_node)

    def __repr__(self):
        return self.__str__()

## Test Computational graph

In [None]:
expressions = ["exp(y) - (x * 2)", "log(x) / x + y", "1 + 1"]
expected_graphs = [
    """(operand = -, value = None, id = 0, in_nodes: ["(operand = exp, value = None, id = 1, in_nodes: ['(value = None, id = y, in_nodes: [])'])", "(operand = *, value = None, id = 2, in_nodes: ['(value = None, id = x, in_nodes: [])', '(value = 2, id = 3, in_nodes: [])'])"])""",
    r"""(operand = +, value = None, id = 0, in_nodes: ['(operand = /, value = None, id = 1, in_nodes: ["(operand = log, value = None, id = 2, in_nodes: [\'(value = None, id = x, in_nodes: [])\'])", \'(value = None, id = x, in_nodes: [])\'])', '(value = None, id = y, in_nodes: [])'])""",
    """(operand = +, value = None, id = 0, in_nodes: ['(value = 1, id = 1, in_nodes: [])', '(value = 1, id = 2, in_nodes: [])'])""",
]
for expression, expected_g in zip(expressions, expected_graphs):
    infix = parse_expression(expression)
    g = Graph(infix, params=["y"], variables=["x"])
    assert str(g) == expected_g

In [None]:
g = Graph(parse_expression("1 + 1"), params=["y"], variables=["x"])
print(g)

# Evaluate function and gradient

In [None]:
class Model:
    def __init__(self, graph: Graph, init_params: dict = None):
        self.graph = graph
        self.map_operands = {
            "+": Add(),
            "-": Sub(),
            "*": Mul(),
            "/": Div(),
            "^": Pow(),
            "exp": Exp(),
            "log": Log(),
        }
        self.param_values = {param_id: 20.0 for param_id in graph.param_nodes}
        self.derivatives = defaultdict(int)

    def resolve_forward(self, node: Node, var_values: dict):
        ...

    def forward(self, var_values: dict):
        return self.resolve_forward(self.graph.out_node, var_values)

    def resolve_backward(self):
        ...

    def backward(self):
        self.resolve_backward()

    def optimize_step(self, lr=0.1):
        ...
        # Reset derivatives for next iteration
        self.derivatives = defaultdict(int)

# Training with Gradient descent

In [None]:
def train(num_epochs, model):
    ...