In [1]:
from importlib import reload
from time import sleep

import numpy as np
import polars as pl
from plotly import express as px
from plotly import io as pio

import cool_search as cool




## 1D

In [2]:
def f(x):
    """1D curve, with a clear minimum
    - slightly slow
    """
    sleep(0.1)
    return np.sqrt((x - 2) ** 2 + 1) + np.sin(x / 2)


X = np.linspace(-10, 10, 300)
Y = f(X)
px.line(x=X, y=Y)


In [3]:
reload(cool)
search = cool.CoolSearch(
    f,
    {"x": (-10, 10)},
)
display(search.samples)
print(search)

print("example grid:")
display(search.get_grid(5))
search.grid_search(20)

px.scatter(
    search.samples,
    x="x",
    y="score",
    color="runtime",
    title="current samples",
)

x,score,runtime
f32,f64,f64


1 dimensional search
  - has 0 samples
example grid:


x
f32
-10.0
-5.0
0.0
5.0
10.0


Searching 20 new parameter points
Total runtime: 2.0094 s + overhead: 0.0299 s.


# 2D

In [4]:
reload(cool)
def g(x, y):
    """2D function, with a clear minimum.
    - slightly slow"""
    sleep(0.1)
    return np.sqrt((((x+1)**2 + y**2) - 2) ** 2 + 1) + 3 * np.sin((y) ** 2 / 2)


search = cool.CoolSearch(
    g,
    dict.fromkeys(["x", "y"], (-10, 10)),
)

grid = search.get_random_samples(50)


px.scatter(
    grid,
    x="x",
    y="y",
    color=g(grid[:, 0], grid[:, 1]),
    title=f"GT evaluated on {len(grid)} points"
)


In [5]:
search.grid_search(target_runtime=7,verbose=2,etr_update_step=10)
min_points = search.samples.filter(pl.col("score") == pl.col("score").min())
print(f"mean of {len(min_points)} minimum points:")

px.scatter(
    search.samples,
    x="x",
    y="y",
    color="score",
    # size="runtime",
    title=f"grid search of {len(search.samples)} points"
)


No previous samples. Running 1 initial evaluation
choose 8 steps
  -> maximum 64 samples
Searching 63 new parameter points
Estimated runtime: 6.3194 s.
Total runtime: 6.3234 s + overhead: 0.0029 s.
mean of 2 minimum points:


In [7]:
marginals = search.marginals()

marginals["y"]

y,mean,std,median
f32,f64,f64,f64
-10.0,141.073802,42.369421,130.053436
-7.142857,93.989613,42.367928,82.968959
-4.285714,60.954113,42.362297,49.930669
-1.428571,46.577296,42.238597,35.450064
1.428571,46.577296,42.238597,35.450064
4.285714,60.954113,42.362297,49.930669
7.142857,93.989613,42.367928,82.968959
10.0,141.073802,42.369421,130.053436


In [None]:
reload(cool)
search = cool.CoolSearch(
    g,
    {
        "x": (-10, 10),
        "y": (-5, 5),
    },
    {
        "x": "float",
        "y": "int",
    },
)
print(search)

search.get_grid(50)

In [None]:
rng = np.random.default_rng()
rng.integers(1,5)