-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More arrows #307
Comments
I thought this could work with
MVP that doesnt work
|
Here are examples of inserting static transformers into composite models. Due to a bug I discovered in preparing this, you will need MLJBase 0.11.10 just released. It would be awesome if you could use part of these to improve the documentation and tutorials regarding static transformers #393 . Note: For now, at least, I recommend that you always explicitly do # March 4th
using MLJ, DataFrames, Random
MLJ.color_off() # hide
@load RidgeRegressor pkg=MultivariateStats
Random.seed!(5) # for reproducibility
x1 = rand(300)
x2 = rand(300)
x3 = rand(300)
y = exp.(x1 - x2 -2x3 + 0.1*rand(300))
X = DataFrame(x1=x1, x2=x2, x3=x3)
using MLJBase
f(X)=insertcols!(copy(X), 1, p=X[:,:x1] .* X[:,:x2])
# Solution 1 (simplest): Just use @pipeline
comp = @pipeline Comp(f, std=Standardizer(),
rgs=RidgeRegressor(),
target=UnivariateBoxCoxTransformer())
e = evaluate(comp, X, y, measure=mae, resampling=CV())
# What if your transformer has parameters? Then you need a static
# transformer:
mutable struct MyTransformer <: Static
ftr::Symbol
end
MLJBase.transform(transf::MyTransformer, verbosity, X) =
insertcols!(copy(X), 1, p=X[:,transf.ftr] .* X[:,:x2])
comp2 = @pipeline Comp2(transf=MyTransformer(:x3), std=Standardizer(),
rgs=RidgeRegressor(),
target=UnivariateBoxCoxTransformer())
comp2.transf.ftr = :x1 # change the parameter
e2 = evaluate(comp2, X, y, measure=mae, resampling=CV())
@assert e2.measurement[1] ≈ e.measurement[1]
# Solution 2: Using learning network:
Xs = source(X)
ys = source(y, kind=:target)
ridge = RidgeRegressor()
# overload your function for nodes:
f(X::AbstractNode) = node(f, X)
W = Xs |> f |> Standardizer()
z = ys |> UnivariateBoxCoxTransformer()
zhat = (W, z) |> ridge
yhat = zhat |> inverse_transform(z)
# # or, without arrow syntax:
# X2 = f(Xs)
# W = transform(machine(Standardizer(), X2), X2)
# box_mach = machine(UnivariateBoxCoxTransformer(), ys)
# z = transform(box_mach, ys)
# ridge_mach = machine(ridge, W, z)
# zhat = predict(ridge_mach, W)
# yhat = inverse_transform(box_mach, zhat)
comp3 = @from_network Comp3(rgs=ridge) <= yhat
e3 = evaluate(comp3, X, y, measure=mae, resampling=CV())
@assert e2.measurement[1] ≈ e.measurement[1]
# Or if you need parameters for your static transformer:
inserter = MyTransformer(:x3)
W = Xs |> inserter |> Standardizer()
z = ys |> UnivariateBoxCoxTransformer()
zhat = (W, z) |> ridge
yhat = zhat |> inverse_transform(z)
# # or, without arrow syntax:
# inserter_mach = machine(inserter)
# X2 = transform(inserter_mach, Xs)
# W = transform(machine(Standardizer(), X2), X2)
# box_mach = machine(UnivariateBoxCoxTransformer(), ys)
# z = transform(box_mach, ys)
# ridge_mach = machine(ridge, W, z)
# zhat = predict(ridge_mach, W)
yhat = inverse_transform(box_mach, zhat)
comp4 = @from_network Comp4(transf=inserter, rgs=ridge) <= yhat
comp4.transf.ftr = :x1 # change the parameter
e4 = evaluate(comp4, X, y, measure=mae, resampling=CV())
@assert e4.measurement[1] ≈ e.measurement[1]
|
This is to keep track of things that can be added to the arrow syntax.
stick
source
where neededthis
works fine. for the supervised case it'd be nice to do something like
that doesn't work because
y
is not recognised well; this howeveris fine but of course ugly, that should be an easy fix in the arrow syntax definition that when passed a tuple, if it's not a node, then slap a node on it.
arrow on hcat
In stacking, it can be nice to hcat the output of nodes and then feed the result in a later layer, currently this doesn't fully work because the
hcat
will not lead to a table. The easy way out is just to use theMLJBase.table
if it receives data that is a matrix.The text was updated successfully, but these errors were encountered: