Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformers that need to see target (eg, recursive feature elimination) #874

Closed
ablaom opened this issue Dec 21, 2021 · 11 comments
Closed

Comments

@ablaom
Copy link
Member

ablaom commented Dec 21, 2021

A number of feature-reduction strategies only make sense in the context of a supervised learning task because they must consult a target variable when trained. For example, one might wants to drop features which correlate poorly with the target. In fact all but the first of sklearn's feature selectors are of this kind.

At the level of the basic model API, a transformer (or any other model) can specify any number of arguments to be used in training. So there is nothing wrong with a transformer with a fit method like

MLJModelInterface.fit(model::MyTransformer, verbosity, X, y) = ...

There is now a trait defined in MLJModelInterface to explicitly articulate the acceptable fit signatures (up to scitype). For any model type that subtypes Unsupervised this falls back to a single argument where the scitype must coincide with input_scitype(model). So for transformers that needs the target in training, you would override the trait with a declaration such as:

MLJModelInterface.fit_data_scitype(M::Type{<:MyTransformer}) =
    Tuple{input_scitype(M), target_scitype(M)}

and be sure to declare a target_scitype, just as you would for a supervised model. That should do it.

It may be that some argument checks for machines have to be tweaked in MLJBase (edit now done) but this should be very easy and essentially non-breaking.

Most happy to provide support to anyone wishing to implement such transformers.

@ablaom
Copy link
Member Author

ablaom commented Dec 21, 2021

cc @pazzo83

@pazzo83
Copy link

pazzo83 commented Dec 21, 2021

Would it make sense to create a new model type for this? It's like a supervised transformer.

@ablaom
Copy link
Member Author

ablaom commented Dec 21, 2021

I think the general consensus is a move away from types to traits. I'm afraid the discussions are a little scattered. See, for example #852 (comment).

We could add a trait for "supervised transformers" but perhaps this is unnecessary as the fit_data_scitype (together with subtyping <:Unsupervised - which ultimately might be encoded as is_supervised(mode) = false) essentially captures the whole behaviour, no?

@pazzo83
Copy link

pazzo83 commented Dec 23, 2021

So would my transformer that relies on a target subtype Unsupervised? I have tried that with overriding the definition of fit_data_scitype (as above), but I'm not sure what type to pass in where you have MyTransformer. If i create an abstract type that subtypes Unsupervised, and then have my transformer subtype that, then I still run into the check on machines:

ArgumentError: `Unsupervised` models should have one training argument, except `Static` models, which have none. Use  `machine(model, X; ...)` (usual case) or `machine(model; ...)` (static case). 

I guess I'm not familiar enough with the inner workings of the API yet to know whether that check needs to be modified, but could you expand a bit on what you mean by using traits to allow for this functionality?

@pazzo83
Copy link

pazzo83 commented Dec 23, 2021

I was able to partially get this working with the following patches (after I declared my own abstract type that my transformers subtype):

abstract type TargetTransformer <: MMI.Unsupervised end

MLJModelInterface.fit_data_scitype(M::Type{<:TargetTransformer}) =
    Tuple{input_scitype(M), target_scitype(M)}

MLJBase.check(model::TargetTransformer, args...; full=false) = MLJBase.check_supervised(model, full, args...)

MLJBase.warn_scitype(model::TargetTransformer, X, y) =
    "The scitype of `y`, in `machine(model, X, y, ...)` "*
    "is incompatible with "*
    "`model=$model`:\nscitype(y) = "*
    "$(MLJBase.elscitype(y))\ntarget_scitype(model) "*
    "= $(MLJBase.target_scitype(model))."

However, when put into a pipeline, it no longer works - it seems because it is still unsupervised the target is not getting passed through (see here: https://github.com/JuliaAI/MLJBase.jl/blob/dev/src/composition/models/pipelines.jl#L72).

This might be kind of hacky and I think you are suggesting something a bit different?

@ablaom
Copy link
Member Author

ablaom commented Dec 29, 2021

@pazzo83 Thanks for looking at this.

This isn't far from what I was imagining. Only, rather than introduce a new abstract type, I'd overload fit_data_scitype case-by-case. That is, each new model implementation for a concrete "supervised" transformer type MyTransformer includes the declaration

MLJModelInterface.fit_data_scitype(M::Type{<:MyTransformer}) =
    Tuple{input_scitype(M), target_scitype(M)}

Then, to fix the the type checking, modify the existing MLJBase.check(::Unsupervised, ...) method so as to catch transformers that legitimately require length(mach.args) == 2, because you see that length(fit_data_signature(model)) is a 2-tuple.

Ditto MLJBase.warn_scitype(::Unsupervised, ...).

This is all a little tricky as we want flexibility of design, but want also catch users' unintentional mistakes with informative errors. Warnings are better than errors here, but even warnings should be thrown only as necessary.

And we'd need to test the changes with a dummy "supervised" transformer in tests.

However, when put into a pipeline, it no longer works

No. That will require a bit more work. Nevertheless, I'm pretty sure one could include these transformers in custom composite models (exported learning networks) without issues. So they would be useful even without the pipeline enhancement.

I'd support a PR that fixes the checks without worrying about pipelines just yet. Would also be great to have an actual "supervised" transformer implementation to try this out on. Have you already started on something?

@pazzo83
Copy link

pazzo83 commented Dec 29, 2021

Thanks for the feedback! I can definitely put together a PR for this - I have some local code I've been working on so I can incorporate your feedback and go from there.

@pazzo83
Copy link

pazzo83 commented Jan 24, 2022

I've been looking at this over the last couple of days based on the feedback here: JuliaAI/MLJBase.jl#705

Would it work if we removed all the various check_* methods and simply kept this method:

function check(model::Model, args...; full=false)
    nowarns = true

    F = fit_data_scitype(model)
    (F >: Unknown || F >: Tuple{Unknown} || F >: NTuple{<:Any,Unknown}) &&
        return true

    S = Tuple{elscitype.(args)...}
    if !(S <: F)
        @warn warn_generic_scitype_mismatch(S, F)
        nowarns = false
    end
end

I got it working if I rewrote the line: S = Tuple{elscitype.(args)...} to just S = Tuple{scitype.(args)...}
Then, it just checks if the model you are trying to use matches the scitype signature it was defined to have. I think that's what we ultimately want, right?

@ablaom
Copy link
Member Author

ablaom commented Jan 25, 2022

Yes! That is what I think we should do. And it indeed looks like you found a bug with the elscitype <-> scitype business - good catch.

I would expand the return value of warn_generic_scitype_mismatch(S, F) along the lines previously suggested and copied below:

"The number and/or types of data arguments do not match what the specified model supports. Commonly, but non exclusively, supervised models are constructed using the syntax machine(model, X, y) or machine(model, X, y, w) while most other models with machine(model, X). Here X are features, y a target, and w sample or class weights. In general, data in machine(model, data...) must satisfy scitype(data) <: MLJ.fit_data_scitype(model)unless the right-hand side isUnknown`. "

Thanks for getting back go this!

@ablaom
Copy link
Member Author

ablaom commented Apr 29, 2022

Just a note that scitype checks have now (MLJBase 18.0) been relaxed to allow transformers that need a target.

@ablaom
Copy link
Member Author

ablaom commented Jun 9, 2024

Resolved.

@ablaom ablaom closed this as completed Jun 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants