Skip to content

Commit

Permalink
Merge 81116bf into 5955f82
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Feb 27, 2023
2 parents 5955f82 + 81116bf commit e1b3f3f
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,7 +1,7 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.3.0"
version = "0.3.1"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
15 changes: 9 additions & 6 deletions src/rstar.jl
Expand Up @@ -43,18 +43,21 @@ function rstar(

xtable = _astable(x)
ycategorical = MMI.categorical(ysplit)
xdata, ydata = MMI.reformat(classifier, xtable, ycategorical)

# train classifier on training data
xtrain, ytrain = MMI.selectrows(classifier, train_ids, xdata, ydata)
fitresult, _ = MMI.fit(classifier, verbosity, xtrain, ytrain)
data = MMI.reformat(classifier, xtable, ycategorical)
train_data = MMI.selectrows(classifier, train_ids, data...)
fitresult, _ = MMI.fit(classifier, verbosity, train_data...)

# compute predictions on test data
xtest, = MMI.selectrows(classifier, test_ids, xdata)
ytest = ycategorical[test_ids]
predictions = _predict(classifier, fitresult, xtest)
# we exploit that MLJ demands that
# reformat(model, args...)[1] = reformat(model, args[1])
# (https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/#Implementing-a-data-front-end)
test_data = MMI.selectrows(classifier, test_ids, data[1])
predictions = _predict(classifier, fitresult, test_data...)

# compute statistic
ytest = ycategorical[test_ids]
result = _rstar(MMI.scitype(predictions), predictions, ytest)

return result
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Expand Up @@ -6,6 +6,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
Expand All @@ -25,6 +26,7 @@ FFTW = "1.1"
LogDensityProblems = "0.12, 1, 2"
LogExpFunctions = "0.3"
MLJBase = "0.19, 0.20, 0.21"
MLJDecisionTreeInterface = "0.3"
MLJIteration = "0.5"
MLJLIBSVMInterface = "0.2"
MLJModels = "0.16"
Expand Down
3 changes: 3 additions & 0 deletions test/rstar.jl
Expand Up @@ -3,6 +3,7 @@ using MCMCDiagnosticTools
using Distributions
using EvoTrees
using MLJBase: MLJBase, Pipeline, predict_mode
using MLJDecisionTreeInterface
using MLJLIBSVMInterface
using MLJModels
using MLJXGBoostInterface
Expand All @@ -26,6 +27,7 @@ end
classifiers = (
EvoTreeClassifier(; nrounds=100, eta=0.3),
Pipeline(EvoTreeClassifier(; nrounds=100, eta=0.3); operation=predict_mode),
DecisionTreeClassifier(),
SVC(),
XGBoostClassifiers...,
)
Expand Down Expand Up @@ -131,6 +133,7 @@ end
Pipeline(
EvoTreeClassifier(; rng=rng, nrounds=100, eta=0.3); operation=predict_mode
),
DecisionTreeClassifier(; rng=rng),
SVC(),
XGBoostClassifiers...,
)
Expand Down

0 comments on commit e1b3f3f

Please sign in to comment.