Skip to content

Commit

Permalink
Merge branch 'master' into serialize
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Sep 17, 2020
2 parents fe7fba1 + f356e99 commit dbe2491
Show file tree
Hide file tree
Showing 11 changed files with 212 additions and 45 deletions.
11 changes: 7 additions & 4 deletions Project.toml
Expand Up @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "Chain types and utility functions for MCMC simulations."
version = "4.0.3"
version = "4.2.1"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -14,11 +14,11 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -34,10 +34,10 @@ Compat = "2.2, 3"
Distributions = "0.21, 0.22, 0.23"
Formatting = "0.4"
IteratorInterfaceExtensions = "0.1.1, 1"
MLJModelInterface = "0.3.5"
NaturalSort = "1"
PrettyTables = "0.9"
RecipesBase = "0.7, 0.8, 1.0"
Requires = "0.5, 1.0"
SpecialFunctions = "^0.8, 0.9, 0.10"
StatsBase = "0.32, 0.33"
TableTraits = "0.4, 1"
Expand All @@ -49,9 +49,12 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
XGBoost = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"

[targets]
test = ["DataFrames", "FFTW", "KernelDensity", "Logging", "StatsPlots", "Test", "UnicodePlots"]
test = ["DataFrames", "FFTW", "KernelDensity", "Logging", "StatsPlots", "Test", "UnicodePlots", "MLJ", "MLJModels", "XGBoost"]
68 changes: 48 additions & 20 deletions README.md
@@ -1,6 +1,5 @@
# MCMCChains.jl
[![Build Status](https://travis-ci.org/TuringLang/MCMCChains.jl.svg?branch=master)](https://travis-ci.org/TuringLang/MCMCChains.jl)
[![Build status](https://ci.appveyor.com/api/projects/status/1av8osv0099nqw8m/branch/master?svg=true)](https://ci.appveyor.com/project/trappmartin/mcmcchain-jl/branch/master)
[![Coverage Status](https://coveralls.io/repos/github/TuringLang/MCMCChains.jl/badge.svg?branch=master)](https://coveralls.io/github/TuringLang/MCMCChains.jl?branch=master)

Implementation of Julia types for summarizing MCMC simulations and utility functions for diagnostics and visualizations.
Expand Down Expand Up @@ -120,35 +119,38 @@ chn2 = set_section(chn, Dict(:internals => ["d", "e"]))

Any parameters not assigned will be placed into `:parameters`.

Calling `show(chn)` provides the following output:
Calling `display(chn)` provides the following output:

```julia
Log evidence = 0.0
Chains MCMC chain (500×5×2 Array{Float64,3}):

Iterations = 1:500
Thinning interval = 1
Chains = 1, 2, 3
Chains = 1, 2
Samples per chain = 500
parameters = c, b, a
parameters = a, b, c
internals = d, e

Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Float64 Float64 Float64

Empirical Posterior Estimates
────────────────────────────────────
parameters
Mean SD Naive SE MCSE ESS
a 0.5169 0.2920 0.0075 0.0066 500
b 0.4891 0.2929 0.0076 0.0070 500
c 0.5102 0.2840 0.0073 0.0068 500
a 0.4930 0.2906 0.0092 0.0095 1044.0585 1.0030
b 0.5148 0.2875 0.0091 0.0087 992.1013 0.9984
c 0.5046 0.2899 0.0092 0.0087 922.6449 0.9987

Quantiles
────────────────────────────────────
parameters
2.5% 25.0% 50.0% 75.0% 97.5%
a 0.0001 0.2620 0.5314 0.7774 0.9978
b 0.0001 0.2290 0.4972 0.7365 0.9998
c 0.0004 0.2739 0.5137 0.7498 0.9997
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64

a 0.0232 0.2405 0.4836 0.7530 0.9687
b 0.0176 0.2781 0.5289 0.7605 0.9742
c 0.0258 0.2493 0.5071 0.7537 0.9754
```

Note that only `a`, `b`, and `c` are being shown. You can explicity show the `:internals`
section by calling `describe(chn; sections = :internals)` or all variables with
Note that only `a`, `b`, and `c` are being shown. You can explicity retrieve
an array of the summary statistics and the quantiles of the `:internals` section by
calling `describe(chn; sections = :internals)`, or of all variables with
`describe(chn; sections = nothing)`. Many functions such as `plot` or `gelmandiag`
support the `sections` keyword argument.

Expand Down Expand Up @@ -223,6 +225,32 @@ heideldiag(c::Chains; alpha=0.05, eps=0.1, etype=:imse)
rafterydiag(c::Chains; q=0.025, r=0.005, s=0.95, eps=0.001)
```

#### Rstar Diagnostic
Rstar diagnostic described in [https://arxiv.org/pdf/2003.07900.pdf](https://arxiv.org/pdf/2003.07900.pdf).
Note that the use requires MLJ and MLJModels to be installed.

Usage:

```julia
using MLJ, MLJModels

chn ... # sampling results of multiple chains

# select classifier used to compute the diagnostic
classif = @load XGBoostClassifier

# estimate diagnostic
Rs = rstar(classif, chn)
R = mean(Rs)

# visualize distribution
using Plots
histogram(Rs)
```

See `? rstar` for more details.


### Model Selection
#### Deviance Information Criterion (DIC)
```julia
Expand Down
5 changes: 4 additions & 1 deletion src/MCMCChains.jl
Expand Up @@ -12,7 +12,7 @@ using SpecialFunctions
using Formatting
import StatsBase: autocov, counts, sem, AbstractWeights,
autocor, describe, quantile, sample, summarystats, cov
using Requires
import MLJModelInterface
import NaturalSort
import PrettyTables
import Tables
Expand All @@ -36,6 +36,8 @@ export summarize
export discretediag, gelmandiag, gewekediag, heideldiag, rafterydiag
export hpd, ess

export rstar

export ESSMethod, FFTESSMethod, BDAESSMethod

"""
Expand Down Expand Up @@ -73,6 +75,7 @@ include("stats.jl")
include("modelstats.jl")
include("plot.jl")
include("tables.jl")
include("rstar.jl")

# deprecations
# TODO: Remove dependency on Serialization if this deprecation is removed
Expand Down
5 changes: 2 additions & 3 deletions src/mcse.jl
Expand Up @@ -11,9 +11,8 @@ function mcse_bm(x::Vector{<:Real}; size::Integer=100)
n = length(x)
m = div(n, size)
if m < 2
throw(
ArgumentError("iterations are < $(2 * size) and batch size is > $(div(n, 2))")
)
@debug "iterations are < $(2 * size) and batch size is > $(div(n, 2))"
return missing
end
mbar = [mean(x[i * size .+ (1:size)]) for i in 0:(m - 1)]
return sem(mbar)
Expand Down
96 changes: 96 additions & 0 deletions src/rstar.jl
@@ -0,0 +1,96 @@
"""
rstar([rng ,] classif::Supervised, chains::Chains; kwargs...)
rstar([rng ,] classif::Supervised, x::AbstractMatrix, y::AbstractVector; kwargs...)
Compute the R* convergence diagnostic of MCMC.
This implementation is an adaption of Algorithm 1 & 2, described in [Lambert & Vehtari]. Note that the correctness of the statistic depends on the convergence of the classifier used internally in the statistic. You can track if the training of the classifier converged by inspection of the printed RMSE values from the XGBoost backend. To adjust the number of iterations used to train the classifier set `niter` accordingly.
# Keyword Arguments
* `subset = 0.8` ... Subset used to train the classifier, i.e. 0.8 implies 80% of the samples are used.
* `iterations = 10` ... Number of iterations used to estimate the statistic. If the classifier is not probabilistic, i.e. does not return class probabilities, it is advisable to use a value of one.
* `verbosity = 0` ... Verbosity level used during fitting of the classifier.
# Usage
```julia
using MLJ, MLJModels
# You need to load MLJBase and the respective package your are using for classification first.
# Select a classifier to compute the Rstar statistic.
# For example the XGBoost classifier.
classif = @load XGBoostClassifier()
# Compute 100 samples of the R* statistic using sampling from according to the prediction probabilities.
Rs = rstar(classif, chn, iterations = 20)
# estimate Rstar
R = mean(Rs)
# visualize distribution
histogram(Rs)
```
## References:
[Lambert & Vehtari] Ben Lambert and Aki Vehtari. "R∗: A robust MCMC convergence diagnostic with uncertainty using gradient-boostined machines." Arxiv 2020.
"""
function rstar(rng::Random.AbstractRNG, classif::MLJModelInterface.Supervised, x::AbstractMatrix, y::AbstractVector{Int}; iterations = 10, subset = 0.8, verbosity = 0)

size(x,1) != length(y) && throw(DimensionMismatch())
iterations >= 1 && ArgumentError("Number of iterations has to be positive!")

if iterations > 1 && classif isa MLJModelInterface.Deterministic
@warn("Classifier is not a probabilistic classifier but number of iterations is > 1.")
elseif iterations == 1 && classif isa MLJModelInterface.Probabilistic
@warn("Classifier is probabilistic but number of iterations is equal to one.")
end

N = length(y)
K = length(unique(y))

# randomly sub-select training and testing set
Ntrain = round(Int, N*subset)
Ntest = N - Ntrain

ids = Random.randperm(rng, N)
train_ids = view(ids, 1:Ntrain)
test_ids = view(ids, (Ntrain+1):N)

# train classifier using XGBoost
fitresult, _ = MLJModelInterface.fit(classif, verbosity, Tables.table(x[train_ids,:]), MLJModelInterface.categorical(y[train_ids]))

xtest = Tables.table(x[test_ids,:])
ytest = view(y, test_ids)

Rstats = map(i -> K*rstar_score(rng, classif, fitresult, xtest, ytest), 1:iterations)
return Rstats
end

function rstar(classif::MLJModelInterface.Supervised, x::AbstractMatrix, y::AbstractVector{Int}; kwargs...)
rstar(Random.GLOBAL_RNG, classif, x, y; kwargs...)
end

function rstar(classif::MLJModelInterface.Supervised, chn::Chains; kwargs...)
return rstar(Random.GLOBAL_RNG, classif, chn; kwargs...)
end

function rstar(rng::Random.AbstractRNG, classif::MLJModelInterface.Supervised, chn::Chains; kwargs...)
nchains = size(chn, 3)
nchains <= 1 && throw(DimensionMismatch())

# collect data
x = Array(chn)
y = repeat(chains(chn); inner = size(chn,1))

return rstar(rng, classif, x, y; kwargs...)
end

function rstar_score(rng::Random.AbstractRNG, classif::MLJModelInterface.Probabilistic, fitresult, xtest, ytest)
pred = get.(rand.(Ref(rng), MLJModelInterface.predict(classif, fitresult, xtest)))
return mean(((p,y),) -> p == y, zip(pred, ytest))
end

function rstar_score(rng::Random.AbstractRNG, classif::MLJModelInterface.Deterministic, fitresult, xtest, ytest)
pred = MLJModelInterface.predict(classif, fitresult, xtest)
return mean(((p,y),) -> p == y, zip(pred, ytest))
end
7 changes: 1 addition & 6 deletions src/stats.jl
Expand Up @@ -248,13 +248,8 @@ function summarystats(
etype = :bm,
kwargs...
)
# Make some functions.
df_mcse(x) = length(x) < 200 ?
missing :
mcse(cskip(x), etype; kwargs...)

# Store everything.
funs = [meancskip, stdcskip, semcskip, df_mcse]
funs = [meancskip, stdcskip, semcskip, x -> mcse(cskip(x), etype; kwargs...)]
func_names = [:mean, :std, :naive_se, :mcse]

# Subset the chain.
Expand Down
5 changes: 2 additions & 3 deletions src/summarize.jl
Expand Up @@ -106,16 +106,15 @@ end
function Base.convert(::Type{Array}, c::C) where C<:ChainDataFrame
T = promote_eltype_namedtuple_tail(c.nt)
arr = Array{T, 2}(undef, c.nrows, c.ncols - 1)

for (i, k) in enumerate(Iterators.drop(keys(c.nt), 1))
arr[:, i] = c.nt[k]
end

return arr
end

Base.convert(::Type{Array{ChainDataFrame,1}}, cs::Array{ChainDataFrame,1}) = cs
function Base.convert(::Type{Array}, cs::Array{C,1}) where C<:ChainDataFrame
function Base.convert(::Type{Array}, cs::Array{ChainDataFrame{T},1}) where T<:NamedTuple
return mapreduce((x, y) -> cat(x, y; dims = Val(3)), cs) do c
reshape(convert(Array, c), Val(3))
end
Expand Down
5 changes: 3 additions & 2 deletions test/diagnostic_tests.jl
Expand Up @@ -82,7 +82,7 @@ end
end

@testset "function tests" begin
tchain = Chains(rand(n_iter, n_name, n_chain), ["a", "b", "c"], Dict(:internals => ["c"]))
tchain = Chains(rand(niter, nparams, nchains), ["a", "b", "c"], Dict(:internals => ["c"]))

# the following tests only check if the function calls work!
@test MCMCChains.diag_all(rand(50, 2), :weiss, 1, 1, 1) != nothing
Expand Down Expand Up @@ -137,9 +137,10 @@ end
end

@testset "sorting" begin
chn_unsorted = Chains(rand(100,3,1), ["2", "1", "3"])
chn_unsorted = Chains(rand(100, nparams, 1), ["2", "1", "3"])
chn_sorted = sort(chn_unsorted)

@test names(chn_sorted) == Symbol.([1, 2, 3])
@test names(chn_unsorted) == Symbol.([2, 1, 3])
end

10 changes: 5 additions & 5 deletions test/ess_tests.jl
Expand Up @@ -8,7 +8,7 @@ using Test
@testset "copy and split" begin
# check a matrix with even number of rows
x = rand(50, 20)

# check incompatible sizes
@test_throws DimensionMismatch MCMCChains.copyto_split!(similar(x, 25, 20), x)
@test_throws DimensionMismatch MCMCChains.copyto_split!(similar(x, 50, 40), x)
Expand Down Expand Up @@ -41,15 +41,15 @@ end
ess_standard2, rhat_standard2 = MCMCChains.ess_rhat(x; method = ESSMethod())
ess_fft, rhat_fft = MCMCChains.ess_rhat(x; method = FFTESSMethod())
ess_bda, rhat_bda = MCMCChains.ess_rhat(x; method = BDAESSMethod())

# check that we get (roughly) the same results
@test ess_standard == ess_standard2
@test ess_standard ess_fft
@test rhat_standard == rhat_standard2 == rhat_fft == rhat_bda

# check that the estimates are reasonable
@test all(x -> isapprox(x, 100_000; atol = 2_500), ess_standard)
@test all(x -> isapprox(x, 100_000; atol = 2_500), ess_bda)
@test all(x -> isapprox(x, 100_000; atol = 5_000), ess_standard)
@test all(x -> isapprox(x, 100_000; atol = 5_000), ess_bda)
@test all(x -> isapprox(x, 1; atol = 0.1), rhat_standard)

# BDA method fluctuates more
Expand Down Expand Up @@ -104,4 +104,4 @@ end
@test ismissing(ess_df[:,2][1]) # since min(maxlag, niter - 1) = 0
@test ismissing(ess_df[:,3][1])
end
end
end

0 comments on commit dbe2491

Please sign in to comment.