In [1]:
from importlib import reload

import polars as pl
from plotly import graph_objects as go

import coolsearch.models as cmodel
import coolsearch.search as csearch
import coolsearch.plotting as cplt

reload(csearch)
reload(cmodel)

cplt.set_plotly_template()

SEED = 137

## classification problem


In [2]:
from sklearn import datasets, model_selection
from sklearn.metrics import accuracy_score, classification_report
from sklearn.tree import DecisionTreeClassifier

X, y = datasets.make_classification(10000, 10, random_state=SEED)
X_train, X_val, y_train, y_val = model_selection.train_test_split(
    X, y, random_state=SEED
)

In [3]:
clf = DecisionTreeClassifier(random_state=SEED)
clf.fit(X_train, y_train)

print(classification_report(y_val, clf.predict(X_val), digits=4))

              precision    recall  f1-score   support

           0     0.8724    0.9072    0.8894      1228
           1     0.9068    0.8719    0.8890      1272

    accuracy                         0.8892      2500
   macro avg     0.8896    0.8895    0.8892      2500
weighted avg     0.8899    0.8892    0.8892      2500



## Grid search DT-model


In [4]:
param_range = dict(
    max_depth=(1, 30), min_samples_leaf=(1, 50), min_samples_split=(2, 8)
)
param_types = dict.fromkeys(param_range.keys(), "int")

search = csearch.CoolSearch.model_validate(
    DecisionTreeClassifier(random_state=SEED),
    param_range,
    param_types,
    data=(X_train, X_val, y_train, y_val),
    loss_fn=accuracy_score,
    invert=True,
)

print(search)
_ = search.grid_search(steps=5)
search.samples.sort(pl.col("score"))[0, :]

3 dimensional search
  - has 0 samples
Searching 125 new parameter points
Total runtime: 5.7422 s + overhead: 0.0149 s.


max_depth,min_samples_leaf,min_samples_split,score,runtime
i64,i64,i64,f64,f64
8,50,2,-0.9208,0.045897


### Compare to sklearn


## Visualizations


In [5]:
s = search.samples

go.Figure(
    go.Scatter(
        x=s["max_depth"],
        y=s["min_samples_leaf"],
        marker_color=(s["score"]),
        mode="markers",
    )
)

### polynomials & marginals


In [6]:
polymod = search.model_poly(4)
polyval = polymod.predict(search.get_grid(100))

margpoly = {}
for k in polymod.features:
    margpoly[k] = (
        polyval.group_by(k)
        .agg(
            pl.col("y_pred").mean().alias("mean"),
        )
        .sort(k)
    )

125 samples
35 poly features
coefficients: [-8.52278920e-01 -1.71021857e-03  3.97781218e-05  6.86113797e-05
 -6.66666667e-06 -1.66789937e-02 -9.68492913e-05  1.75740474e-05
 -1.40256192e-06  1.81941237e-03 -1.78241863e-06 -4.29554528e-09
 -7.15967637e-05  6.53002292e-08  9.51889865e-07 -1.17077506e-03
  1.52228011e-04 -2.93940626e-05  2.14447995e-06 -1.88854490e-04
  2.32369036e-06  1.06125098e-07  3.08922339e-07 -3.23777160e-08
  5.20533160e-08  1.04080812e-04 -1.35117348e-06 -7.29544822e-08
  5.19117002e-06 -3.47746897e-08 -3.48609348e-08 -3.03930252e-06
  2.68338122e-08 -3.93524199e-08  2.69170269e-08], residuals: [0.00026569]
3 -> 35 features:


In [11]:
marg = search.marginals()
margpoly = polymod.poly_marginals()

# TODO Marginal plot with confint shade
for k in marg.keys():
    go.Figure(
        [
            go.Scatter(x=marg[k][k], y=marg[k]["mean"], name="marginal"),
            go.Scatter(x=margpoly[k][k], y=margpoly[k]["mean"], name="poly"),
        ]
    ).update_layout(title=k).show()

3 -> 35 features:
