In [None]:
from importlib import reload

import polars as pl
from plotly import graph_objects as go

import coolsearch.models as cmodel
import coolsearch.search as csearch
import coolsearch.plotting as cplt

reload(csearch)
reload(cmodel)

cplt.set_plotly_template()

SEED = 137

## classification problem


In [None]:
from sklearn import datasets, model_selection
from sklearn.metrics import accuracy_score, classification_report
from sklearn.tree import DecisionTreeClassifier

X, y = datasets.make_classification(10000, 10, random_state=SEED)
X_train, X_val, y_train, y_val = model_selection.train_test_split(
    X, y, random_state=SEED
)

In [None]:
clf = DecisionTreeClassifier(random_state=SEED)
clf.fit(X_train, y_train)

print(classification_report(y_val, clf.predict(X_val), digits=4))

## Grid search DT-model


In [None]:
param_range = dict(
    max_depth=(1, 30), min_samples_leaf=(1, 50), min_samples_split=(2, 8)
)
param_types = dict.fromkeys(param_range.keys(), "int")

search = csearch.CoolSearch.model_validate(
    DecisionTreeClassifier(random_state=SEED),
    param_range,
    param_types,
    data=(X_train, X_val, y_train, y_val),
    loss_fn=accuracy_score,
    invert=True,
)

print(search)
_ = search.grid_search(steps=5)
search.samples.sort(pl.col("score"))[0, :]

### Compare to sklearn


## Visualizations


In [None]:
s = search.samples

go.Figure(
    go.Scatter(
        x=s["max_depth"],
        y=s["min_samples_leaf"],
        marker_color=(s["score"]),
        mode="markers",
    )
)

### polynomials & marginals


In [None]:
polymod = search.model_poly(4)
polyval = polymod.predict(search.get_grid(100))

margpoly = {}
for k in polymod.features:
    margpoly[k] = (
        polyval.group_by(k)
        .agg(
            pl.col("y_pred").mean().alias("mean"),
        )
        .sort(k)
    )

In [None]:
marg = search.marginals()
margpoly = polymod.poly_marginals()

# TODO Marginal plot with confint shade
for k in marg.keys():
    go.Figure(
        [
            go.Scatter(x=marg[k][k], y=marg[k]["mean"], name="marginal"),
            go.Scatter(x=margpoly[k][k], y=margpoly[k]["mean"], name="poly"),
        ]
    ).update_layout(title=k).show()