In [None]:
from importlib import reload

import polars as pl
from plotly import graph_objects as go

import coolsearch.models as cmodel
import coolsearch.search as csearch
import coolsearch.plotting as cplt

reload(csearch)
reload(cmodel)

cplt.set_plotly_template()

SEED = 137

## classification problem


In [None]:
from sklearn import datasets, model_selection
from sklearn.metrics import accuracy_score, classification_report
from sklearn.tree import DecisionTreeClassifier

X, y = datasets.make_classification(1000, 5, random_state=SEED)
X_train, X_val, y_train, y_val = model_selection.train_test_split(
    X, y, random_state=SEED
)

In [None]:
clf = DecisionTreeClassifier(random_state=SEED)
clf.fit(X_train, y_train)

print(classification_report(y_val, clf.predict(X_val), digits=4))

## Grid search DT-model


In [None]:
param_range = dict(
    max_depth=(1, 30), min_samples_leaf=(1, 50), min_samples_split=(2, 8)
)
param_types = dict.fromkeys(param_range.keys(), "int")

search = csearch.CoolSearch.model_validate(
    DecisionTreeClassifier(random_state=SEED),
    param_range,
    param_types,
    data=(X_train, X_val, y_train, y_val),
    loss_fn=accuracy_score,
    invert=True,
)

_ = search.grid_search(steps=20)
print(search)
display(search.samples.sort(pl.col("score"))[0, :])

### Compare to sklearn


## Visualizations


In [None]:
s = search.samples

go.Figure(
    go.Scatter(
        x=s["max_depth"],
        y=s["min_samples_leaf"],
        marker_color=(s["score"]),
        mode="markers",
    )
)

### polynomials & marginals


In [None]:
polymod = search.model_poly(4)
polyval = polymod.predict(search.get_grid(100))

margpoly = {}
for k in polymod.features:
    margpoly[k] = (
        polyval.group_by(k)
        .agg(
            pl.col("y_pred").mean().alias("mean"),
        )
        .sort(k)
    )

In [None]:
marg = search.marginals()
margpoly = polymod.poly_marginals()


In [None]:
def marg_plot(
    marginal: pl.DataFrame,
    mean=True,
    median=False,
    intervals=False,
):
    """Plot a marginal distribution"""

    shade = "rgba(255,200,200,0.2)"
    stat_cols = ["mean", "std", "median"]

    feat = next(col for col in marginal.columns if col not in stat_cols)
    fig = go.Figure(
        layout=dict(
            title=feat,
            yaxis_title="value",
            margin=dict(t=50, l=20, r=10, b=10),
            width=400,
            height=200,
        ),
    )

    mu = marginal["mean"]
    std = marginal["std"]
    # TODO PROPER INTERVALS
    upper = mu + std
    lower = mu - std

    if mean:
        fig.add_trace(go.Scatter(x=marginal[feat], y=mu, name="mean"))
    if median:
        fig.add_trace(go.Scatter(x=marginal[feat], y=marginal["median"], name="median"))

    if intervals:
        fig.add_traces(
            [
                go.Scatter(
                    x=marginal[feat],
                    y=lower,
                    name="lower",
                    mode="lines",
                    line_color=shade,
                ),
                go.Scatter(
                    x=marginal[feat],
                    y=upper,
                    name="upper",
                    mode="lines",
                    fill="tonexty",
                    line_color=shade,
                    fillcolor=shade,
                ),
            ]
        )
    return fig


for k in marg.keys():
    marg_plot(margpoly[k], mean=True, median=True, intervals=True).show()