/
rstar.jl
96 lines (71 loc) · 3.99 KB
/
rstar.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
rstar([rng ,] classif::Supervised, chains::Chains; kwargs...)
rstar([rng ,] classif::Supervised, x::AbstractMatrix, y::AbstractVector; kwargs...)
Compute the R* convergence diagnostic of MCMC.
This implementation is an adaption of Algorithm 1 & 2, described in [Lambert & Vehtari]. Note that the correctness of the statistic depends on the convergence of the classifier used internally in the statistic. You can track if the training of the classifier converged by inspection of the printed RMSE values from the XGBoost backend. To adjust the number of iterations used to train the classifier set `niter` accordingly.
# Keyword Arguments
* `subset = 0.8` ... Subset used to train the classifier, i.e. 0.8 implies 80% of the samples are used.
* `iterations = 10` ... Number of iterations used to estimate the statistic. If the classifier is not probabilistic, i.e. does not return class probabilities, it is advisable to use a value of one.
* `verbosity = 0` ... Verbosity level used during fitting of the classifier.
# Usage
```julia
using MLJ, MLJModels
# You need to load MLJBase and the respective package your are using for classification first.
# Select a classifier to compute the Rstar statistic.
# For example the XGBoost classifier.
classif = @load XGBoostClassifier()
# Compute 100 samples of the R* statistic using sampling from according to the prediction probabilities.
Rs = rstar(classif, chn, iterations = 20)
# estimate Rstar
R = mean(Rs)
# visualize distribution
histogram(Rs)
```
## References:
[Lambert & Vehtari] Ben Lambert and Aki Vehtari. "R∗: A robust MCMC convergence diagnostic with uncertainty using gradient-boostined machines." Arxiv 2020.
"""
function rstar(rng::Random.AbstractRNG, classif::MLJModelInterface.Supervised, x::AbstractMatrix, y::AbstractVector{Int}; iterations = 10, subset = 0.8, verbosity = 0)
size(x,1) != length(y) && throw(DimensionMismatch())
iterations >= 1 && ArgumentError("Number of iterations has to be positive!")
if iterations > 1 && classif isa MLJModelInterface.Deterministic
@warn("Classifier is not a probabilistic classifier but number of iterations is > 1.")
elseif iterations == 1 && classif isa MLJModelInterface.Probabilistic
@warn("Classifier is probabilistic but number of iterations is equal to one.")
end
N = length(y)
K = length(unique(y))
# randomly sub-select training and testing set
Ntrain = round(Int, N*subset)
Ntest = N - Ntrain
ids = Random.randperm(rng, N)
train_ids = view(ids, 1:Ntrain)
test_ids = view(ids, (Ntrain+1):N)
# train classifier using XGBoost
fitresult, _ = MLJModelInterface.fit(classif, verbosity, Tables.table(x[train_ids,:]), MLJModelInterface.categorical(y[train_ids]))
xtest = Tables.table(x[test_ids,:])
ytest = view(y, test_ids)
Rstats = map(i -> K*rstar_score(rng, classif, fitresult, xtest, ytest), 1:iterations)
return Rstats
end
function rstar(classif::MLJModelInterface.Supervised, x::AbstractMatrix, y::AbstractVector{Int}; kwargs...)
rstar(Random.GLOBAL_RNG, classif, x, y; kwargs...)
end
function rstar(classif::MLJModelInterface.Supervised, chn::Chains; kwargs...)
return rstar(Random.GLOBAL_RNG, classif, chn; kwargs...)
end
function rstar(rng::Random.AbstractRNG, classif::MLJModelInterface.Supervised, chn::Chains; kwargs...)
nchains = size(chn, 3)
nchains <= 1 && throw(DimensionMismatch())
# collect data
x = Array(chn)
y = repeat(chains(chn); inner = size(chn,1))
return rstar(rng, classif, x, y; kwargs...)
end
function rstar_score(rng::Random.AbstractRNG, classif::MLJModelInterface.Probabilistic, fitresult, xtest, ytest)
pred = get.(rand.(Ref(rng), MLJModelInterface.predict(classif, fitresult, xtest)))
return mean(((p,y),) -> p == y, zip(pred, ytest))
end
function rstar_score(rng::Random.AbstractRNG, classif::MLJModelInterface.Deterministic, fitresult, xtest, ytest)
pred = MLJModelInterface.predict(classif, fitresult, xtest)
return mean(((p,y),) -> p == y, zip(pred, ytest))
end