In [None]:
import os 
import sys
import json
sys.path.append(os.path.abspath(".."))

from velopix_wrappers.parameter_optimisers import optimiserBase
from velopix_wrappers.velopix_pipeline import TrackFollowingPipeline, GraphDFSPipeline, SearchByTripletTriePipeline

In [None]:
from collections.abc import Callable, Generator, Iterable
from itertools import product
from typing import Any, TypeAlias, cast
import random

ITypes: TypeAlias = type[bool] | type[int] | type[float]
Input: TypeAlias = bool | int | float
Axis: TypeAlias = tuple[str, Input]
Point: TypeAlias = tuple[Axis, ...]
Voxel: TypeAlias = tuple[int, ...]
VoxelSamples: TypeAlias = dict[Voxel, frozenset[Point]]
SampleValues: TypeAlias = dict[Point, float]

def _point_to_dict(point: Point) -> dict[str, Any]:
    return {axis: value for axis, value in point}

class GridSearch(optimiserBase):
    num_splits: int
    sample_size: int

    voxels: set[Voxel]
    samples: VoxelSamples
    values: SampleValues

    test_points: Generator[dict[str, Any], None, None]
    last_point: dict[str, Any]
    stop_flag: bool = False

    def __init__(self, num_splits: int = 10, sample_size: int = 10, objective: str = "min"):
        super().__init__(Objective=objective)
        self.num_splits = num_splits
        self.sample_size = sample_size

    def init(self) -> dict[str, Any]:
        """
        Initializes the optimization process by setting an initial parameter map.
        """
        self.stop_flag = False
        self.voxels = set()
        self.samples = {}
        self.values = {}
        axes_indices: dict[str, tuple[int, ...]] = {}
        sample_at_index: dict[str, Callable[[int], Input]] = {}
        axis_ordering = {axis: i for i, (axis, _) in enumerate(cast(Iterable[tuple[str, Any]], self._algorithm.get_config().items()))}

        for axis_name, (axis_type, _) in self._algorithm.get_config().items(): # type: ignore
            if axis_type == bool:
                sample_at_index[axis_name] = lambda i: random.choice([False, True])
                axes_indices[axis_name] = (0, 1)
            elif axis_type in (float, int):
                low, high = cast(tuple[Input, Input], self._algorithm._bounds().get(axis)) # type: ignore
                if axis_type == float:
                    axes_indices[axis_name] = tuple(i for i in range(self.num_splits))
                    sample_at_index[axis_name] = lambda i: random.uniform(
                        low + i * (high - low) / self.num_splits, low + (i + 1) * (high - low) / self.num_splits
                    )
                elif axis_type == int:
                    axes_indices[axis_name] = tuple(i for i in range(max(1, int(high - low))))
                    sample_at_index[axis_name] = lambda i: random.randint(
                        int(low + i * (high - low)) // self.num_splits, int(low + (i + 1) * (high - low)) // self.num_splits
                    )
                else:
                    raise NotImplementedError(f"Unsupported type: {type}")

        ordered_axes_names = sorted(axes_indices.keys(), key=lambda a: axis_ordering[a])
        ordered_axes_indices = tuple(axes_indices[axis] for axis in ordered_axes_names)

        for voxel in product(*ordered_axes_indices):
            voxel = tuple(voxel)
            self.voxels.add(voxel)
            sample = set[Point]()
            for _ in range(self.sample_size):
                point = tuple((axis, sample_at_index[axis](voxel[i])) for i, axis in enumerate(ordered_axes_names))
                sample.add(point)
            self.samples[voxel] = frozenset(sample)

        self.test_points = iter(_point_to_dict(point) for voxel in self.voxels for point in self.samples[voxel])
        return self._forward()

    def next(self) -> dict[str, Any]:
        """
        Generates the next parameter map by slightly modifying existing values.
        """
        return self._forward()


    def objective_func(self) -> float:
        """
        Converts the results of an experiment into a numeric score.
        In this example, we simulate a loss function that we aim to minimize.
        """
        ...

    def is_finished(self) -> bool:
        """
        Determines if the optimization process is finished.
        In this case, it stops after `max_iterations` iterations.
        """
        return self.stop_flag
    
    def _forward(self) -> dict[str, Any]:
        try:
            self.last_point = next(self.test_points)
        except StopIteration:
            self.stop_flag = True
        finally:
            return self.last_point