In [4]:
from source.feature_selection import stepwise_feature_selection, lasso, marginal_screening
from source.pipeline import make_dataset, make_pipeline, make_pipelines
import numpy as np

rng = np.random.default_rng(10)

X, y = make_dataset()
pl1 = make_pipeline(stepwise_feature_selection(X, y, 3, [1, 3, 5]))

X, y = make_dataset()
pl2 = make_pipeline(lasso(X, y, 0.08, [0.04, 0.08, 0.12]))

X, y = make_dataset()
pl3 = make_pipeline(marginal_screening(X, y, 3, [1, 3, 5]))

n, p = 200, 10
beta = np.zeros(p)
beta[:3] = 0.3
X = rng.normal(size=(n, p))
y = X @ beta + rng.normal(size=n)

pl1.tune(X, y, n_iter=3, random_state=0)
print(pl1.best_candidate, pl1.best_mse)
print(pl1.components[pl1.static_order[1]].parameters)
print()

multi_pls = make_pipelines(pl1, pl2, pl3)
multi_pls.tune(X, y, n_iters=3, random_state=0)
print(multi_pls.best_index)
print()

for pl in multi_pls.pipelines:
    print(pl.best_candidate, pl.best_mse)
    print(pl.components[pl.static_order[1]].parameters)
    print()

print(multi_pls)


{'stepwise_feature_selection_5': 3} 1.2088874635213522
3

2

{'stepwise_feature_selection_5': 3} 1.2088874635213522
3

{'lasso_5': 0.12} 1.1968544828111392
0.12

{'marginal_screening_5': 5} 1.1535464494537575
5

start -> stepwise_feature_selection_5
stepwise_feature_selection_5 -> end

start -> lasso_5
lasso_5 -> end

start -> marginal_screening_5
marginal_screening_5 -> end


In [5]:
from source.feature_selection import stepwise_feature_selection, lasso, marginal_screening
from source.pipeline import make_dataset, make_pipeline, make_pipelines
from source.model import option1, option2
import numpy as np

rng = np.random.default_rng()

n, p = 200, 10
beta = np.zeros(p)
beta[:3] = 0.3
X = rng.normal(size=(n, p))
y = X @ beta + rng.normal(size=n)

mpls = make_pipelines(option1(), option2())
print(mpls(X, y))

mpls.tune(X, y, n_iters=1, random_state=0)
print(mpls.best_index)
for pipeline in mpls.pipelines:
    print(pipeline.best_candidate, pipeline.best_mse)

print(mpls(X, y))

[([0, 1, 2], [14, 21, 24, 34, 45, 53, 76, 96, 115, 127, 128, 140, 160, 161, 182, 188, 199]), ([0, 1, 2], [14, 24, 34, 45, 53, 96, 110, 114, 115, 160, 182, 188])]
1
{} 1.1656999450093672
{} 1.16329872230803
([0, 1, 2], [14, 24, 34, 45, 53, 96, 110, 114, 115, 160, 182, 188])
