In [None]:
from importlib import reload
from time import sleep

import numpy as np
import polars as pl
from plotly import express as px
from plotly import io as pio

import cool_search as cool

plot_temp = pio.templates["plotly_dark"]
plot_temp.layout.width = 400
plot_temp.layout.height = 300
plot_temp.layout.autosize = False
pio.templates.default = plot_temp


## 1D

In [None]:
def f(x):
    """1D curve, with a clear minimum
    - slightly slow
    """
    sleep(0.1)
    return np.sqrt((x - 2) ** 2 + 1) + np.sin(x / 2)


X = np.linspace(-10, 10, 300)
Y = f(X)
px.line(x=X, y=Y)


In [None]:
reload(cool)
search = cool.CoolSearch(
    f,
    {"x": (-10, 10)},
)
display(search.samples)
print(search)

print("example grid:")
display(search.get_grid(5))
search.grid_search(20)

px.scatter(
    search.samples,
    x="x",
    y="score",
    color="runtime",
    title="current samples",
)

# 2D

In [None]:
reload(cool)
def g(x, y):
    """2D function, with a clear minimum.
    - slightly slow"""
    sleep(0.1)
    return np.sqrt((((x+1)**2 + y**2) - 2) ** 2 + 1) + 3 * np.sin((y) ** 2 / 2)


search = cool.CoolSearch(
    g,
    dict.fromkeys(["x", "y"], (-10, 10)),
)

grid = search.get_random_samples(50)


px.scatter(
    grid,
    x="x",
    y="y",
    color=g(grid[:, 0], grid[:, 1]),
    title=f"GT evaluated on {len(grid)} points"
)


In [None]:
search.grid_search(target_runtime=7,verbose=2,etr_update_step=10)
min_points = search.samples.filter(pl.col("score") == pl.col("score").min())
print(f"mean of {len(min_points)} minimum points:")

px.scatter(
    search.samples,
    x="x",
    y="y",
    color="score",
    # size="runtime",
    title=f"grid search of {len(search.samples)} points"
)


In [None]:
marginals = search.marginal_distr()

marginals["y"]

In [None]:
reload(cool)
search = cool.CoolSearch(
    g,
    {
        "x": (-10, 10),
        "y": (-5, 5),
    },
    {
        "x": "float",
        "y": "int",
    },
)
print(search)

search.get_grid(50)

In [None]:
rng = np.random.default_rng()
rng.integers(1,5)