# Tree search

## Code

In [1]:
import typing
import math
import collections
import dataclasses
import abc
import itertools

T = typing.TypeVar("T", int, float)
Vector = typing.Tuple[T, T]
VectorComponent = typing.List[T]
Vertex = typing.List[typing.Tuple[T, T]]

def product_combination(iterable):
    first = iterable[0]

    for r in itertools.product(first, *iterable[1:]):
        yield r


def check_axis_intersect(that: Vector, z: T):
    x, y = that
    return y >= z >= x


def is_collision(that: Vertex, other: VectorComponent) -> bool:
    if len(that) != len(other):
        raise ValueError("Invalid arguments, different length of values")

    return all([check_axis_intersect(t, z) for t, z in zip(that, other)])


@dataclasses.dataclass
class VectorNode:
    data: VectorComponent
    value: T

    def __gt__(self, other):
        return self.value > other.value

    def __str__(self):
        return f"{self.data}, {self.value}"

    def __iter__(self):
        return iter([*self.data, self.value])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.vertex[item]

    def __setitem__(self, key, value):
        self.vertex[key] = value


def create_axis(_array: Vertex):
    all_data = []

    for arr in _array:
        x, y = arr
        d = math.fabs(y - x) / 2

        all_data.append(((x, x + d), (x + d, y)))

    return [*product_combination(all_data)]


@dataclasses.dataclass
class Node:
    vertex: Vertex
    nodes: list

    def append(self, node: VectorComponent):
        self.nodes.append(node)

    def is_collide(self, node: VectorComponent):
        return is_collision(self.vertex, node)

    def __len__(self):
        return len(self.nodes)

    def __iter__(self):
        return iter(self.nodes)


@dataclasses.dataclass
class NodeValue:
    vertex: Vertex
    node: VectorNode

    def __gt__(self, other: VectorNode):
        return self.value > other.value

    @property
    def value(self):
        return self.node.value


class NodeContainerInterface(abc.ABC):
    @property
    @abc.abstractmethod
    def axis(self) -> Vertex: ...

    @property
    @abc.abstractmethod
    def children(self): ...

    @property
    @abc.abstractmethod
    def node(self) -> Node: ...

    @property
    @abc.abstractmethod
    def is_parent(self) -> bool: ...

    @abc.abstractmethod
    def insert(self, verx: VectorComponent): ...

    @abc.abstractmethod
    def sort(self): ...

    @abc.abstractmethod
    def __iter__(self): ...


class NodeContainer:
    def __init__(self, data: Vertex, limit_divisions=1):
        self.__node = Node(vertex=data, nodes=[])
        self.__children = None
        self.__axis = data

        if limit_divisions > 0:
            self.__children = [
                NodeContainer(vertices, limit_divisions - 1) for vertices in create_axis(data)
            ]

    @property
    def axis(self):
        return [*self.__axis]

    @property
    def children(self) -> typing.List[NodeContainerInterface]:
        if self.__children:
            return [*self.__children]

        return []

    @property
    def node(self) -> Node:
        return self.__node

    @property
    def is_parent(self):
        return len(self.children) != 0

    def insert(self, verx: VectorComponent):
        return self._insert_loop(verx)

    def _insert_loop(self, verx: VectorComponent):
        children = collections.deque(self.children)

        while True:
            child = children.pop()
            node = child.node

            if node.is_collide(verx):
                if child.is_parent:
                    children = child.children

                    continue
                else:
                    node.append(verx)
                    return child

            if len(children) <= 0:
                return None

    def _insert_recursive(self, verx: VectorComponent):
        for child in self.children:
            node = child.node

            if node.is_collide(verx):
                if child.is_parent:
                    return child.insert(verx)
                else:
                    node.append(verx)
                    return child

        return None

    def sort(self):
        return self._get_iter_child_recursive()

    def __iter__(self):
        return iter(self._get_iter_child_recursive())

    def _get_iter_child_recursive(self):
        def get_iter_child(root, nodes=None):
            if nodes is None:
                nodes = []

            for child in root.children:
                node = child.node
                if child.is_parent:
                    get_iter_child(child, nodes)
                else:
                    if len(child.node) > 0:
                        nodes.append(child.node)

            return nodes

        return get_iter_child(self, [])

    def _get_iter_child_loop(self):
        pass


class TreeNode:
    def __init__(self):
        pass


## Execution

In [2]:
import collections
import functools
import itertools

import pandas as pd

def tree_search_algorithm(f, axis, n_it=10, n_p=5, n_limit=0):
    tree = NodeContainer(axis, 2)

    all_it = collections.deque()
    all_it_route = collections.deque()

    for it in range(n_it):
        local_min_point = collections.deque()

        for _ in range(n_p):
            data = np.array([np.random.uniform(*_) for _ in tree.axis])            
            value = f(data)

            tree.insert(VectorNode(data, value))

        for container in tree.sort():
            min_node = min(container)

            local_min_point.append(
                NodeValue(container.vertex, min_node)
            )

        if len(local_min_point) <= 0:
            continue

        local_min_node = min(local_min_point)

        all_it.append(list(min(local_min_point).node))
        all_it_route.append(local_min_node.vertex)

        tree = NodeContainer(local_min_node.vertex, 2)

    return all_it, all_it_route


## Initialize and config data

In [3]:
M_COMPONENTS = 5
M_RANGE = -3, 3

In [4]:
import operator
import functools

def _my_func(inputs):
    return functools.reduce(operator.add, inputs, 1)


In [5]:
M_OBJ_FUNC = _my_func

In [6]:
import sympy as sp
import numpy as np
import operator
import functools

syms = sp.symbols(f'x1:{M_COMPONENTS +1 }')

ALL_SYMS = syms

def create_eq0(syms):
    x1 = syms[0]
    x2 = syms[1]
    x3 = syms[2]
    x4 = syms[3]
    x5 = syms[4]
    
    eq = x1/x2 + x3**2 + sp.cos(x4) + sp.exp(x5**2 + x4**2)
    
    return eq

def create_eq1(syms):
    x1 = syms[0]
    x2 = syms[1]
    x3 = syms[2]

    eq = sum(syms)
    eq = eq.subs(x1, 1/x1**2).subs(x2, x2**2).subs(x3, x3**3)
    
    return eq

def create_eq2(syms):
    eq = functools.reduce(operator.mul, syms, 1)**2 / 100
    
    return eq


In [7]:
eq = create_eq0(syms)
M_OBJ_FUNC = sp.lambdify([syms], eq, modules='numpy')
M_COMPONENTS = len(syms)

eq

x1/x2 + x3**2 + exp(x4**2 + x5**2) + cos(x4)

## Results

In [8]:
m_range = np.tile(np.array([M_RANGE]), (M_COMPONENTS, 1))
all_it, all_it_route = tree_search_algorithm(
    M_OBJ_FUNC, m_range, 10, 25
)

In [9]:
pd.DataFrame(all_it_route)

Unnamed: 0,0,1,2,3,4
0,"(1.5, 3)","(-1.5, 0.0)","(-3, -1.5)","(0.0, 1.5)","(-1.5, 0.0)"
1,"(1.875, 2.25)","(-0.375, 0.0)","(-2.25, -1.875)","(0.375, 0.75)","(-1.5, -1.125)"
2,"(1.875, 1.96875)","(-0.09375, 0.0)","(-1.96875, -1.875)","(0.65625, 0.75)","(-1.3125, -1.21875)"
3,"(1.921875, 1.9453125)","(-0.0234375, 0.0)","(-1.921875, -1.8984375)","(0.65625, 0.6796875)","(-1.265625, -1.2421875)"
4,"(1.927734375, 1.93359375)","(-0.005859375, 0.0)","(-1.916015625, -1.91015625)","(0.673828125, 0.6796875)","(-1.265625, -1.259765625)"
5,"(1.9306640625, 1.93212890625)","(-0.00146484375, 0.0)","(-1.91455078125, -1.9130859375)","(0.6767578125, 0.67822265625)","(-1.26123046875, -1.259765625)"
6,"(1.931396484375, 1.9317626953125)","(-0.0003662109375, 0.0)","(-1.913818359375, -1.9134521484375)","(0.6771240234375, 0.677490234375)","(-1.2601318359375, -1.259765625)"
7,"(1.931396484375, 1.931488037109375)","(-9.1552734375e-05, 0.0)","(-1.91363525390625, -1.913543701171875)","(0.677215576171875, 0.67730712890625)","(-1.2601318359375, -1.260040283203125)"
8,"(1.9314422607421875, 1.9314651489257812)","(-2.288818359375e-05, 0.0)","(-1.9135894775390625, -1.9135665893554688)","(0.6772842407226562, 0.67730712890625)","(-1.2600631713867188, -1.260040283203125)"
9,"(1.9314537048339844, 1.9314594268798828)","(-5.7220458984375e-06, 0.0)","(-1.9135723114013672, -1.9135665893554688)","(0.6772899627685547, 0.6772956848144531)","(-1.2600517272949219, -1.2600460052490234)"


In [10]:
pd.DataFrame(all_it)

Unnamed: 0,0,1,2,3,4,5
0,1.560065,-0.1222364,-1.587477,0.621389,-0.441947,-7.640922
1,1.884419,-0.003130985,-1.918944,0.672676,-1.134492,-591.7022
2,1.89199,-0.01414261,-1.938869,0.744119,-1.258438,-120.8073
3,1.936026,-0.003586514,-1.899405,0.658781,-1.256298,-527.9284
4,1.932531,-0.0002454146,-1.915515,0.676209,-1.26485,-7862.285
5,1.931615,-0.0002370323,-1.913917,0.677649,-1.2606,-8136.966
6,1.931623,-2.905171e-05,-1.913811,0.6774,-1.260107,-66476.94
7,1.931482,-2.172364e-06,-1.91363,0.677287,-1.260058,-889103.3
8,1.931451,-5.846553e-06,-1.91357,0.677293,-1.260049,-330345.1
9,1.931459,-2.791037e-07,-1.913568,0.677295,-1.260047,-6920208.0
