Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deserialisation fails for wrappers like TunedModel when atomic model overloads save/restore #1099

Closed
4 tasks done
ablaom opened this issue Mar 3, 2024 · 2 comments
Closed
4 tasks done
Labels
bug Something isn't working

Comments

@ablaom
Copy link
Member

ablaom commented Mar 3, 2024

XGBoost.jl models have non-persistent fitresults which means that they cannot be directly serialised. That's not a problem, because such models can overload MLJModelInterface's save and restore. However, it has been reported that XGBoost models wrapped in TunedModel don't deserialise properly. Here is a MWE:

First, we define a supervised model, EphemeralRegressor, with an ephemeral fitresult. For this model we overload save/restore to ensure deserialization works, provided you use the correct API.

using Statistics, MLJBase, Test, MLJTuning, Serialization, StatisticalMeasures
import MLJModelInterface

# define a model with non-persistent fitresult:
thing = []
struct EphemeralRegressor <: Deterministic end
function MLJModelInterface.fit(::EphemeralRegressor, verbosity, X, y)
    # if I serialize/deserialized `thing` then `view` below changes:
    view = objectid(thing)
    fitresult = (thing, view, mean(y))
    return fitresult, nothing, NamedTuple()
end
function MLJModelInterface.predict(::EphemeralRegressor, fitresult, X)
    thing, view, μ = fitresult
    return view == objectid(thing) ? fill(μ, nrows(X)) :
        throw(ErrorException("dead fitresult"))
end
function MLJModelInterface.save(::EphemeralRegressor, fitresult)
    thing, _, μ = fitresult
    return (thing, μ)
end
function MLJModelInterface.restore(::EphemeralRegressor, serialized_fitresult)
    thing, μ = serialized_fitresult
    view = objectid(thing)
    return (thing, view, μ)
end

# EphemeralRegressor cannot be directly serialized:
X, y = (; x = rand(3)), fill(42.0, 3)
model = EphemeralRegressor()
mach = machine(model, X, y) |> fit!
io = IOBuffer()
serialize(io, mach)
seekstart(io)
mach2 = deserialize(io)
@test_throws ErrorException("dead fitresult") predict(mach2, 42)

# But it can be serialized/deserialized using correct API:
io = IOBuffer()
serialize(io, serializable(mach))
seekstart(io)
mach2 = restore!(deserialize(io))
@test MLJBase.predict(mach2, (; x = rand(2)))  fill(42.0, 2)

But wrapping this model using TunedModel leads to deserialization failure:

tmodel = TunedModel(
    models=fill(EphemeralRegressor(), 2),
    measure = l2,
)
mach = machine(tmodel, X, y) |> fit!
io = IOBuffer()
serialize(io, serializable(mach))
seekstart(io)
mach2 = restore!(deserialize(io))
MLJBase.predict(mach2, (; x = rand(2)))
# ERROR: dead fitresult
# Stacktrace:
#  [1] predict(::EphemeralRegressor, fitresult::Tuple{Vector{Any}, UInt64, Float64}, X::@NamedTuple{x::Vector{Float64}})                                                            
#    @ Main ./REPL[7]:3
#
# < truncated trace >

The remedy is to properly "forward" the save/restore methods of the atomic models. We can exclude any wrapper model implemented as NetworkComposite (ie, using learning networks) as they already overload save and restore properly.

To do (waiting on review):

@OkonSamuel
Copy link
Member

OkonSamuel commented Mar 5, 2024

@ablaom Seems like I would also have to overload this for RFE model I'm working on. Maybe we should also add a note in the MLJ docs concerning this

@ablaom
Copy link
Member Author

ablaom commented Mar 6, 2024

yep

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants