diff --git a/Project.toml b/Project.toml index 3a1070bf..2b90f50b 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +StatsModelComparisons = "854dedd9-9477-4a25-907d-7fd989bfdd01" TableTraits = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" @@ -40,6 +41,7 @@ PrettyTables = "0.9, 0.10, 0.11" RecipesBase = "0.7, 0.8, 1.0" SpecialFunctions = "^0.8, 0.9, 0.10, 1.0" StatsBase = "0.32, 0.33" +StatsModelComparisons = "0.1.1" TableTraits = "0.4, 1" Tables = "1" julia = "1" diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl index 2d4ea079..8269b232 100644 --- a/src/MCMCChains.jl +++ b/src/MCMCChains.jl @@ -24,6 +24,8 @@ import Random import Serialization import Statistics: std, cor, mean, var, mean! +using StatsModelComparisons + export Chains, chains, chainscat export setrange, resetrange export set_section, get_params, sections, sort_sections, setinfo @@ -34,7 +36,7 @@ export summarize # Export diagnostics functions export discretediag, gelmandiag, gewekediag, heideldiag, rafterydiag -export hpd, ess +export hpd, ess, dic export rstar diff --git a/src/modelstats.jl b/src/modelstats.jl index a8e6a732..b3f95dce 100644 --- a/src/modelstats.jl +++ b/src/modelstats.jl @@ -1,49 +1,13 @@ -export dic - #################### Posterior Statistics #################### """ - dic(chain::Chains, logpdf::Function) -> (DIC, pD) + dic(chain::Chains, loglik::Symbol) -Compute the deviance information criterion (DIC). -(Smaller is better) +Compute the deviance information criterion (DIC) from `chain` on posterior log likelihood samples specified by parameter name `loglik`. Note: DIC assumes that the posterior distribution is approx. multivariate Gaussian and tends to select overfitted models. - -## Returns: -* `DIC`: The calculated deviance information criterion -* `pD`: The effective number of parameters - -## Usage: - -``` -chn ... # sampling results -lpfun = function f(chain::Chains) # function to compute the logpdf values - niter, nparams, nchains = size(chain) - lp = zeros(niter + nchains) # resulting logpdf values - for i = 1:nparams - lp += map(p -> logpdf( ... , x), Array(chain[:,i,:])) - end - return lp -end - -DIC, pD = dic(chn, lpfun) -``` - """ -function dic(chain::Chains, logpdf::Function) - - # expectation of each parameter - Eθ = reshape(mean(Array(chain), dims = [1,3]), 1,:,1) - Echain = Chains(Eθ) - EθD = -2*mean(logpdf(Echain)) - - D = -2*logpdf(chain) - ED = mean(D) - - pD = 2*(ED - EθD) - - DIC = EθD + pD - - return DIC, pD +function StatsModelComparisons.dic(chain::Chains, loglik::Symbol) + lps = Array(chain[:, loglik, :]) + return dic(vec(lps)) end diff --git a/test/modelstats_test.jl b/test/modelstats_test.jl index e8c21963..d9d63ca8 100644 --- a/test/modelstats_test.jl +++ b/test/modelstats_test.jl @@ -2,43 +2,11 @@ using Test, Random using MCMCChains using Distributions -## Test Chain -# Define the experiment -n_iter = 4000 -n_name = 3 -n_chain = 2 - -# observations -Random.seed!(1234) -x = map(i -> i + 2e-1*randn(), 1:n_name) - -# some sample experiment results -Random.seed!(1234) -val1 = 0.5*randn(n_iter, n_name, n_chain) .+ x' -val2 = 2*randn(n_iter, n_name, n_chain) .+ x' - -# construct a Chains object -chn1 = Chains(val1) -chn2 = Chains(val2) - -lpfun = function f(chain::Chains) - - p1 = Array(chain[:,1,:]) - p2 = Array(chain[:,2,:]) - p3 = Array(chain[:,3,:]) - - lp = map(p -> logpdf(Normal(p), x[1]), p1) - lp += map(p -> logpdf(Normal(p), x[2]), p2) - lp += map(p -> logpdf(Normal(p), x[3]), p3) - - return lp -end - @testset "deviance information criterion" begin - - DIC1, pD1 = dic(chn1, lpfun) - DIC2, pD2 = dic(chn2, lpfun) - - @test DIC1 < DIC2 - @test pD1 < pD2 + # Ensure function at least runs + chain = Chains(rand(100, 2, 1), [:a, :b]) + val = dic(chain, :a) + @test isa(val, Float64) + # Should fail if variable does not exist + @test_throws ArgumentError dic(chain, :c) end