In [1]:
using StatisticalRethinking, Distributed, JLD, LinearAlgebra
using Mamba

Data

In [2]:
howell1 = CSV.read(joinpath(dirname(Base.pathof(StatisticalRethinking)), "..", "data", "Howell1.csv"), delim=';')
df = convert(DataFrame, howell1);

Use only adults

In [3]:
df2 = filter(row -> row[:age] >= 18, df);

Input data for Mamba

In [4]:
data = Dict(
  :x => convert(Array{Float64,1}, df2[:weight]),
  :y => convert(Array{Float64,1}, df2[:height])
)

Dict{Symbol,Array{Float64,1}} with 2 entries:
  :y => [151.765, 139.7, 136.525, 156.845, 145.415, 163.83, 149.225, 168.91, 14…
  :x => [47.8256, 36.4858, 31.8648, 53.0419, 41.2769, 62.9926, 38.2435, 55.48, …

Log-transformed Posterior(b0, b1, log(s2)) + Constant and Gradient Vector

In [5]:
logfgrad = function(x::DenseVector)
  b0 = x[1]
  b1 = x[2]
  logs2 = x[3]
  r = data[:y] .- b0 .- b1 .* data[:x]
  logf = (-0.5 * length(data[:y]) - 0.001) * logs2 -
           (0.5 * dot(r, r) + 0.001) / exp(logs2) -
           0.5 * b0^2 / 1000 - 0.5 * b1^2 / 1000
  grad = [
    sum(r) / exp(logs2) - b0 / 1000,
    sum(data[:x] .* r) / exp(logs2) - b1 / 1000,
    -0.5 * length(data[:y]) - 0.001 + (0.5 * dot(r, r) + 0.001) / exp(logs2)
  ]
  logf, grad
end

# MCMC Simulation with No-U-Turn Sampling

n = 5000
burnin = 1000
sim = Mamba.Chains(n, 3, start = (burnin + 1), names = ["b0", "b1", "s2"])
theta = NUTSVariate([0.0, 0.0, 0.0], logfgrad)
for i in 1:n
  sample!(theta, adapt = (i <= burnin))
  if i > burnin
    sim[i, :, 1] = [theta[1:2]; exp(theta[3])]
  end
end

Summarize draws

In [6]:
describe(sim)#-

Iterations = 1001:5000
Thinning interval = 1
Chains = 1
Samples per chain = 4000

Empirical Posterior Estimates:
       Mean        SD       Naive SE       MCSE         ESS   
b0 113.4232857 1.89784414 0.0300075506 0.1010213760  352.93484
b1   0.9150544 0.04180159 0.0006609412 0.0021835237  366.49648
s2  26.0642344 2.01613273 0.0318778574 0.0526214401 1467.95513

Quantiles:
       2.5%       25.0%       50.0%       75.0%        97.5%   
b0 109.7526687 112.2295555 113.3168959 114.7483479 117.08561452
b1   0.8346281   0.8861921   0.9165202   0.9418675   0.99674447
s2  22.2877642  24.6765413  25.9804881  27.3608793  30.30793410



*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*