In [1]:
from importlib import reload

import polars as pl
from plotly import graph_objects as go

import coolsearch.models as cmodel
import coolsearch.search as csearch
import coolsearch.plotting as cplt

reload(csearch)
reload(cmodel)

cplt.set_plotly_template()

SEED = 137

## classification problem


In [2]:
from sklearn import datasets, model_selection
from sklearn.metrics import accuracy_score, classification_report
from sklearn.tree import DecisionTreeClassifier

X, y = datasets.make_classification(1000, 5, random_state=SEED)
X_train, X_val, y_train, y_val = model_selection.train_test_split(
    X, y, random_state=SEED
)

In [3]:
clf = DecisionTreeClassifier(random_state=SEED)
clf.fit(X_train, y_train)

print(classification_report(y_val, clf.predict(X_val), digits=4))

              precision    recall  f1-score   support

           0     0.9322    0.9091    0.9205       121
           1     0.9167    0.9380    0.9272       129

    accuracy                         0.9240       250
   macro avg     0.9244    0.9235    0.9239       250
weighted avg     0.9242    0.9240    0.9240       250



## Grid search DT-model


In [4]:
param_range = dict(
    max_depth=(1, 30), min_samples_leaf=(1, 50), min_samples_split=(2, 8)
)
param_types = dict.fromkeys(param_range.keys(), "int")

search = csearch.CoolSearch.model_validate(
    DecisionTreeClassifier(random_state=SEED),
    param_range,
    param_types,
    data=(X_train, X_val, y_train, y_val),
    loss_fn=accuracy_score,
    invert=True,
)

_ = search.grid_search(steps=20)
print(search)
display(search.samples.sort(pl.col("score"))[0, :])

Searching 2800 new parameter points
Total runtime: 4.6714 s + overhead: 0.0900 s.
3 dimensional search
  - has 2800 samples


max_depth,min_samples_leaf,min_samples_split,score,runtime
i64,i64,i64,f64,f64
5,6,2,-0.932,0.00195


### Compare to sklearn


## Visualizations


In [5]:
s = search.samples

go.Figure(
    go.Scatter(
        x=s["max_depth"],
        y=s["min_samples_leaf"],
        marker_color=(s["score"]),
        mode="markers",
    )
)

### polynomials & marginals


In [6]:
polymod = search.model_poly(4)
polyval = polymod.predict(search.get_grid(100))

margpoly = {}
for k in polymod.features:
    margpoly[k] = (
        polyval.group_by(k)
        .agg(
            pl.col("y_pred").mean().alias("mean"),
        )
        .sort(k)
    )

2800 samples
35 poly features
coefficients: [-8.96208608e-01 -8.06383341e-04  3.42926288e-04 -5.92740108e-05
  3.10606061e-06 -8.94894609e-03 -8.73571468e-06 -5.54399367e-06
 -5.15707122e-08  9.40809186e-04  7.70184262e-07  7.00344988e-08
 -3.84046930e-05 -6.18853776e-09  5.32607196e-07  1.35181483e-03
  2.05860464e-06  8.31056858e-06 -6.04992635e-08 -5.27444477e-05
  1.98085210e-06  1.21084748e-07  1.53758233e-06 -3.44149453e-08
  4.51456755e-09 -1.54681434e-04 -2.42829377e-06 -1.46487700e-07
  6.08957750e-07 -3.41306672e-08 -3.16145576e-08  5.54645149e-06
  4.89457815e-08  9.04037366e-09 -6.01861015e-08], residuals: [0.04057754]
3 -> 35 features:


In [7]:
marg = search.marginals()
margpoly = polymod.poly_marginals()


3 -> 35 features:


In [8]:
def marg_plot(
    marginal: pl.DataFrame,
    mean=True,
    median=False,
    intervals=False,
):
    """Plot a marginal distribution"""

    shade = "rgba(255,200,200,0.2)"
    stat_cols = ["mean", "std", "median"]

    feat = next(col for col in marginal.columns if col not in stat_cols)
    fig = go.Figure(
        layout=dict(
            title=feat,
            yaxis_title="value",
            margin=dict(t=50, l=20, r=10, b=10),
            width=400,
            height=200,
        ),
    )

    mu = marginal["mean"]
    std = marginal["std"]
    # TODO PROPER INTERVALS
    upper = mu + std
    lower = mu - std

    if mean:
        fig.add_trace(go.Scatter(x=marginal[feat], y=mu, name="mean"))
    if median:
        fig.add_trace(go.Scatter(x=marginal[feat], y=marginal["median"], name="median"))

    if intervals:
        fig.add_traces(
            [
                go.Scatter(
                    x=marginal[feat],
                    y=lower,
                    name="lower",
                    mode="lines",
                    line_color=shade,
                ),
                go.Scatter(
                    x=marginal[feat],
                    y=upper,
                    name="upper",
                    mode="lines",
                    fill="tonexty",
                    line_color=shade,
                    fillcolor=shade,
                ),
            ]
        )
    return fig


for k in marg.keys():
    marg_plot(margpoly[k], mean=True, median=True, intervals=True).show()

In [23]:
import numpy as np

max_cols = 3


def r(n):
    return (n - 1) // max_cols + 1


nn = np.arange(1, 9)
print(nn)
print(r(nn))


[1 2 3 4 5 6 7 8]
[1 1 1 2 2 2 3 3]
