**NOTE: this was an initial trial of PolyHoot implementation, this version is depricated**

In [None]:
import os, sys, json
sys.path.append(os.path.abspath(".."))

from velopix.hyperParameterFramework._optimizers import BaseOptimizer
from velopix.hyperParameterFramework import TrackFollowingPipeline, GraphDFSPipeline, SearchByTripletTriePipeline

In [None]:
from typing import Any, Optional, List, Tuple, Callable
import math

# --- 1. HOO tree for n-dimensional continuous actions ---
class HOO_Node:
    def __init__(
        self,
        region: List[Tuple[float, float]],  # list of (low, high) per dimension
        depth: int,
        c: float,
        alpha: float
    ):
        self.region = region                 # hyperrectangle bounds
        self.depth = depth                   # depth in the HOO tree
        self.c = c                           # polynomial bonus coefficient
        self.alpha = alpha                   # bonus exponent
        self.n = 0                           # number of times sampled
        self.mean = 0.0                      # empirical mean reward
        self.B = float('inf')                # optimistic bound
        self.left: Optional['HOO_Node'] = None
        self.right: Optional['HOO_Node'] = None

    def split(self) -> None:
        # Split along one dimension, cycling through dims by depth
        d = len(self.region)
        dim = self.depth % d
        lo, hi = self.region[dim]
        mid = (lo + hi) / 2.0
        # create left and right region lists
        left_region = list(self.region)
        right_region = list(self.region)
        left_region[dim] = (lo, mid)
        right_region[dim] = (mid, hi)
        self.left = HOO_Node(left_region, self.depth + 1, self.c, self.alpha)
        self.right = HOO_Node(right_region, self.depth + 1, self.c, self.alpha)

    def center(self) -> Tuple[float, ...]:
        # center of each interval
        return tuple((lo + hi) / 2.0 for lo, hi in self.region)


class HOO_Tree:
    def __init__(
        self,
        full_region: List[Tuple[float, float]],
        c: float,
        alpha: float
    ):
        self.root = HOO_Node(full_region, depth=0, c=c, alpha=alpha)

    def query(self) -> Tuple[Tuple[float, ...], List[HOO_Node]]:
        """
        Traverse the HOO tree with B-values, splitting as needed.
        Returns (action_vector, path_of_nodes).
        """
        path: List[HOO_Node] = []
        node: HOO_Node = self.root
        while True:
            path.append(node)
            if node.n == 0:
                return node.center(), path
            if node.left is None:
                node.split()
            left, right = node.left, node.right  # type: ignore
            node = left if left.B >= right.B else right

    def update(self, path: List[HOO_Node], reward: float) -> None:
        """
        Update counts, means, and B-values along visited path.
        """
        # update n and mean
        for node in path:
            node.n += 1
            node.mean += (reward - node.mean) / node.n
        # bottom-up B-value recomputation
        for node in reversed(path):
            bonus = node.c * (node.n ** -node.alpha)
            U = node.mean + bonus
            if node.left and node.right:
                node.B = min(U, max(node.left.B, node.right.B))
            else:
                node.B = U


# --- 2. Node class with embedded HOO tree ---
class Node:
    def __init__(
        self,
        state: Any,
        action_space: List[Tuple[float, float]],
        c: float,
        alpha: float,
        last_action: Optional[Tuple[float, ...]] = None
    ) -> None:
        self.state = state
        self.last_action = last_action       # action vector from parent
        self.win_value = 0.0
        self.visits = 0
        self.parent: Optional['Node'] = None
        self.children: List['Node'] = []
        self.expanded = False
        # HOO tree for n-dimensional actions
        self.hoo = HOO_Tree(action_space, c=c, alpha=alpha)

    def add_child(self, child: 'Node') -> None:
        self.children.append(child)
        child.parent = self


# --- 3. POLY-HOOT MCTS implementation ---
class PolyHootMCTS:
    def __init__(
        self,
        root: Node,
        max_depth: int,
        rollout_fn: Callable[[Any], float],
        env_step_fn: Callable[[Any, Tuple[float, ...]], Any]
    ):
        self.root = root
        self.max_depth = max_depth
        self.rollout = rollout_fn
        self.env_step = env_step_fn

    def simulate(self, n_simulations: int) -> None:
        for _ in range(n_simulations):
            node = self.root
            hoo_paths: List[List[HOO_Node]] = []
            actions: List[Tuple[float, ...]] = []
            depth = 0

            # --- SELECTION & EXPANSION ---
            while node.expanded and depth < self.max_depth:
                action, path = node.hoo.query()
                next_state = self.env_step(node.state, action)
                hoo_paths.append(path)
                actions.append(action)

                child = self._find_child(node, action)
                if child is None:
                    child = Node(
                        state=next_state,
                        action_space=node.hoo.root.region,
                        c=node.hoo.root.c,
                        alpha=node.hoo.root.alpha,
                        last_action=action
                    )
                    node.add_child(child)
                node = child
                depth += 1

            node.expanded = True

            # --- SIMULATION / ROLLOUT ---
            reward = self.rollout(node.state)

            # --- BACKPROPAGATION ---
            # MCTS value backup
            temp = node
            while temp is not None:
                temp.visits += 1
                temp.win_value += reward
                temp = temp.parent

            # HOO updates
            cur = self.root
            for path, action in zip(hoo_paths, actions):
                cur.hoo.update(path, reward)
                cur = self._find_child(cur, action)  # type: ignore

    def best_action(self) -> Tuple[float, ...]:
        node = self.root.hoo.root
        while node.left and node.right:  # type: ignore
            left, right = node.left, node.right  # type: ignore
            node = left if left.B >= right.B else right
        return node.center()

    def _find_child(
        self,
        node: Node,
        action: Tuple[float, ...],
        eps: float = 1e-6
    ) -> Optional[Node]:
        # match child by Euclidean distance on last_action
        for child in node.children:
            if child.last_action is not None:
                dist = math.dist(child.last_action, action)
                if dist < eps:
                    return child
        return None

# --- Usage example ---
#
# full_region = [(-1.0, 1.0) for _ in range(n_dims)]
# root = Node(initial_state, full_region, c=1.0, alpha=0.5)
# mcts = PolyHootMCTS(root, max_depth=5, rollout_fn=my_rollout, env_step_fn=my_env_step)
# mcts.simulate(1000)
# best = mcts.best_action()
# print("Best action vector:", best)
