# Tree search

## Code

In [1]:
# -*- coding: utf-8 -*-
"""
Editor de Spyder

Este es un archivo temporal.
"""

import typing
import math
import collections
import dataclasses
import abc
import itertools

T = typing.TypeVar("T", int, float)
Vector = typing.Tuple[T, T]
VectorComponent = typing.List[T]
Vertex = typing.List[typing.Tuple[T, T]]

def check_axis_intersect(that: Vector, z: T):
    x, y = that
    return y >= z >= x


def is_collision(that: Vertex, other: VectorComponent) -> bool:
    if len(that) != len(other):
        raise ValueError("Invalid arguments, different length of values")

    return all([check_axis_intersect(t, z) for t, z in zip(that, other)])


@dataclasses.dataclass
class VectorNode:
    data: VectorComponent
    value: T

    def __gt__(self, other):
        return self.value > other.value

    def __str__(self):
        return f"{self.data}, {self.value}"

    def __iter__(self):
        return iter([*self.data, self.value])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.vertex[item]

    def __setitem__(self, key, value):
        self.vertex[key] = value


def calc_distance(x: T, y: T):
    return math.fabs(y - x) / 2


def create_axis(_array: Vertex):
    all_data = []

    for arr in _array:
        d = calc_distance(*arr)

        all_data.append(((x, x + d), (x + d, y)))

    return all_data


def create_axis_node(arrays):
    return itertools.product(arrays)


@dataclasses.dataclass
class Node:
    vertex: Vertex
    nodes: list

    def append(self, node: VectorComponent):
        self.nodes.append(node)

    def is_collide(self, node: VectorComponent):
        return is_collision(self.vertex, node)

    def __len__(self):
        return len(self.nodes)

    def __iter__(self):
        return iter(self.nodes)

    def __hash__(self):
        _hash = hash(tuple(self.vertex))
        
        return _hash


@dataclasses.dataclass
class NodeValue:
    vertex: Vertex
    node: VectorNode

    def __gt__(self, other: VectorNode):
        return self.value > other.value

    @property
    def value(self):
        return self.node.value


class NodeContainerInterface(abc.ABC):
    @property
    @abc.abstractmethod
    def axis(self) -> Vertex: ...

    @property
    @abc.abstractmethod
    def children(self): ...

    @property
    @abc.abstractmethod
    def node(self) -> Node: ...

    @property
    @abc.abstractmethod
    def is_parent(self) -> bool: ...

    @abc.abstractmethod
    def insert(self, verx: VectorComponent): ...

    @abc.abstractmethod
    def sort(self): ...

    @abc.abstractmethod
    def __iter__(self): ...


class TreeNode:
    def __init__(self, axis: Vertex, limit_divisions=1):
        self.__children = {}
        self.__node = Node(vertex=axis, nodes=[])
        self.__axis = axis
        self.__limit_divisions = limit_divisions

    def __hash__(self):
        return hash(self.__node)
        
    @property
    def axis(self):
        return [*self.__axis]

    @property
    def children(self):
        return self.__children

    @property
    def node(self) -> Node:
        return self.__node

    @property
    def is_parent(self):
        return len(self.children) != 0

    def insert(self, verx: VectorComponent):
        return self._insert_recursive(verx)

    def sort(self):
        return self._get_iter_child_recursive()

    def __iter__(self):
        return iter(self._get_iter_child_recursive())
        
    def _insert_recursive(self, verx: VectorComponent):
        def create_vertex(verx):
            root_axis = collections.deque()
            
            for axis, c in zip(self.axis, verx):
                x, y = axis
                d = calc_distance(x, y)

                # print(axis, c)
                
                if  c > x and c < (x + d):
                    root_axis.append((x, x + d))
                else:
                    root_axis.append((x + d, y))

            axis = list(root_axis)
            # print(axis)

            return axis

        tree = TreeNode(create_vertex(verx), self.__limit_divisions - 1)
        tree_key = hash(tree)
        
        if tree_key in self.__children:
            tree = self.__children[tree_key]
        else:
            self.__children[tree_key] = tree
        
        if self.__limit_divisions > 0:
            return tree.insert(verx)
        else:
            if not tree.node.is_collide(verx):
                raise "Vertex no collide"
            
            tree.node.append(verx)
        
        return tree

    def _get_iter_child_recursive(self):
        def get_iter_child(root, nodes=None):
            if nodes is None:
                nodes = []

            for _, child in root.children.items():
                node = child.node
                
                if child.is_parent:
                    get_iter_child(child, nodes)
                else:
                    if len(child.node) > 0:
                        nodes.append(child.node)

            return nodes

        return get_iter_child(self, [])


In [2]:
axis = (-1, 1), (-1, 1)
tree = TreeNode(axis)

tree.insert((.1, .1))
tree.insert((.1, .2))

tree.sort()

[Node(vertex=[(0.0, 0.5), (0.0, 0.5)], nodes=[(0.1, 0.1), (0.1, 0.2)])]

## Initialize and config data

In [8]:
import sympy as sp
import numpy as np

N_COMPONENTS = 1000

syms = sp.symbols(f'x1:{N_COMPONENTS + 1}')

ALL_SYMS = syms
LIMITS_RANGE = -3, 3

def get_n_syms(_range=(-1, 1)):
    return [_range for _ in range(len(syms))]

In [9]:
import operator
import functools

x1 = syms[0]
x2 = syms[1]
x3 = syms[2]

eq = sum(syms)
# eq = eq.subs(x1, 1/x1**2).subs(x2, x2**2).subs(x3, x3**3)
# eq = functools.reduce(operator.mul, syms, 1)**2 / 100

OBJ_FUNC = sp.lambdify(syms, eq, modules='numpy')

eq

x1 + x10 + x100 + x1000 + x101 + x102 + x103 + x104 + x105 + x106 + x107 + x108 + x109 + x11 + x110 + x111 + x112 + x113 + x114 + x115 + x116 + x117 + x118 + x119 + x12 + x120 + x121 + x122 + x123 + x124 + x125 + x126 + x127 + x128 + x129 + x13 + x130 + x131 + x132 + x133 + x134 + x135 + x136 + x137 + x138 + x139 + x14 + x140 + x141 + x142 + x143 + x144 + x145 + x146 + x147 + x148 + x149 + x15 + x150 + x151 + x152 + x153 + x154 + x155 + x156 + x157 + x158 + x159 + x16 + x160 + x161 + x162 + x163 + x164 + x165 + x166 + x167 + x168 + x169 + x17 + x170 + x171 + x172 + x173 + x174 + x175 + x176 + x177 + x178 + x179 + x18 + x180 + x181 + x182 + x183 + x184 + x185 + x186 + x187 + x188 + x189 + x19 + x190 + x191 + x192 + x193 + x194 + x195 + x196 + x197 + x198 + x199 + x2 + x20 + x200 + x201 + x202 + x203 + x204 + x205 + x206 + x207 + x208 + x209 + x21 + x210 + x211 + x212 + x213 + x214 + x215 + x216 + x217 + x218 + x219 + x22 + x220 + x221 + x222 + x223 + x224 + x225 + x226 + x227 + x228 + x

## Execution

In [None]:
import collections
import functools
import itertools

import pandas as pd

def tree_search_algorithm(f, axis, n_it=10, n_p=5, n_limit=0):
    tree = TreeNode(axis)

    all_it = collections.deque()
    all_it_route = collections.deque()

    for it in range(n_it):
        local_min_point = collections.deque()

        for _ in range(n_p):
            data = [np.random.uniform(*_) for _ in tree.axis]
            value = f(*data)

            tree.insert(VectorNode(data, value))

        for container in tree.sort():
            min_node = min(container)

            local_min_point.append(
                NodeValue(container.vertex, min_node)
            )

        if len(local_min_point) <= 0:
            continue

        local_min_node = min(local_min_point)

        all_it.append(list(min(local_min_point).node))
        all_it_route.append(local_min_node.vertex)

        tree = TreeNode(local_min_node.vertex, 2)

    return all_it, all_it_route


all_it, all_it_route = tree_search_algorithm(
    OBJ_FUNC, [(-3, 3) for _ in range(len(syms))], 25, 25
)


In [None]:
pd.DataFrame(all_it_route)

In [7]:
pd.DataFrame(all_it)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,291,292,293,294,295,296,297,298,299,300
0,-1.414307,2.631321,-2.951846,2.827119,0.075694,-2.765889,-0.95136,1.136446,2.375882,2.841543,...,2.064789,0.14775,-0.969467,-1.2554,0.121373,-1.393173,-1.975903,-2.042187,-2.053925,-24.457547
1,-1.114223,1.924698,-2.666278,2.039895,1.428463,-1.846359,-1.045324,1.361377,1.81562,2.567393,...,1.554758,0.372002,-0.804109,-0.947499,0.49923,-0.15437,-2.18298,-1.549952,-2.07464,-20.733058
2,-1.107772,1.892494,-2.745222,2.060148,1.474537,-1.853,-1.084973,1.485995,1.701676,2.448376,...,1.576371,0.266072,-0.869732,-1.093814,0.549501,-0.009528,-2.1769,-1.580525,-2.078639,-21.684722
3,-1.123478,1.886647,-2.750151,2.058543,1.459924,-1.861435,-1.084539,1.491671,1.689957,2.450226,...,1.575375,0.259438,-0.884439,-1.089373,0.556843,-0.013116,-2.166661,-1.593198,-2.074341,-21.87573
4,-1.123524,1.883819,-2.750877,2.057095,1.459479,-1.860962,-1.086881,1.49263,1.689748,2.45027,...,1.575016,0.258697,-0.883567,-1.087574,0.557769,-0.013388,-2.167854,-1.592083,-2.076689,-21.905779
