Skip to content

Latest release breaks MCMCDiagnosticTools #863

@devmotion

Description

@devmotion

It seems MLJBase 0.21, and in particular the changes to predict for Pipelines, break MCMCDiagnosticTools since the output of predict is not always the desired predictions but sometimes also a tuple of predictions and something else now (see, e.g., TuringLang/MCMCDiagnosticTools.jl#44).

using MLJBase
using MLJXGBoostInterface
using DataFrames
using Test

iris = DataFrame(load_iris())
y, X = unpack(iris, ==(:target))
train, test = partition(eachindex(y), 0.7)

XGBoostClassifier = @load XGBoostClassifier pkg=XGBoost

model1 = XGBoostClassifier()
fitresult, _ = fit(model1, 0, X[train, :], y[train])
@test predict(model1, fitresult, X[test, :]) isa UnivariateFiniteVector

model2 = Pipeline(XGBoostClassifier(); operation=predict_mode)
fitresult, _ = fit(model2, 0, X[train, :], y[train])
@test predict(model2, fitresult, X[test, :]) isa AbstractVector{<:MLJBase.CategoricalValue}

This example works with MLJBase < 0.21 but in MLJBase 0.21 the last line returns a tuple where the first element are the desired predictions.

Some more details about our use case:
In MCMCDiagnosticTools we do want to keep dependencies minimal and hence depend only on MLJModelInterface. Therefore we use fit and predict instead of fit! and predict! with machines of models.
Additionally, models are provided by users and depending on whether the output of predict is a vector of distributions or not, we dispatch internally to one of two algorithms (https://github.com/TuringLang/MCMCDiagnosticTools.jl/blob/5e3232f500d10bf352ced88530e3f07e3d132ab3/src/rstar.jl#L21-L24).
Users are supposed to provide classifiers with deterministic predictions if they want to use one algorithm, and with probabilistic predictions if they want to use the other algorithm.
Hence the ability for users to change a probabilistic classifier to a deterministic one is crucial if they want to use e.g. XGBoost with the algorithm for deterministic predictions.
In previous versions of MLJ and MLJBase this was possible by defining a Pipeline with e.g. predict_mode as operation but as the example shows this is not possible anymore with MLJBase 0.21 it seems.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions