-
Notifications
You must be signed in to change notification settings - Fork 43
Description
It seems MLJBase 0.21, and in particular the changes to predict for Pipelines, break MCMCDiagnosticTools since the output of predict is not always the desired predictions but sometimes also a tuple of predictions and something else now (see, e.g., TuringLang/MCMCDiagnosticTools.jl#44).
using MLJBase
using MLJXGBoostInterface
using DataFrames
using Test
iris = DataFrame(load_iris())
y, X = unpack(iris, ==(:target))
train, test = partition(eachindex(y), 0.7)
XGBoostClassifier = @load XGBoostClassifier pkg=XGBoost
model1 = XGBoostClassifier()
fitresult, _ = fit(model1, 0, X[train, :], y[train])
@test predict(model1, fitresult, X[test, :]) isa UnivariateFiniteVector
model2 = Pipeline(XGBoostClassifier(); operation=predict_mode)
fitresult, _ = fit(model2, 0, X[train, :], y[train])
@test predict(model2, fitresult, X[test, :]) isa AbstractVector{<:MLJBase.CategoricalValue}This example works with MLJBase < 0.21 but in MLJBase 0.21 the last line returns a tuple where the first element are the desired predictions.
Some more details about our use case:
In MCMCDiagnosticTools we do want to keep dependencies minimal and hence depend only on MLJModelInterface. Therefore we use fit and predict instead of fit! and predict! with machines of models.
Additionally, models are provided by users and depending on whether the output of predict is a vector of distributions or not, we dispatch internally to one of two algorithms (https://github.com/TuringLang/MCMCDiagnosticTools.jl/blob/5e3232f500d10bf352ced88530e3f07e3d132ab3/src/rstar.jl#L21-L24).
Users are supposed to provide classifiers with deterministic predictions if they want to use one algorithm, and with probabilistic predictions if they want to use the other algorithm.
Hence the ability for users to change a probabilistic classifier to a deterministic one is crucial if they want to use e.g. XGBoost with the algorithm for deterministic predictions.
In previous versions of MLJ and MLJBase this was possible by defining a Pipeline with e.g. predict_mode as operation but as the example shows this is not possible anymore with MLJBase 0.21 it seems.