In [1]:
import IJulia

# The julia kernel has built in support for Revise.jl, so this is the 
# recommended approach for long-running sessions:
# https://github.com/JuliaLang/IJulia.jl/blob/9b10fa9b879574bbf720f5285029e07758e50a5e/src/kernel.jl#L46-L51

# Users should enable revise within .julia/config/startup_ijulia.jl:
# https://timholy.github.io/Revise.jl/stable/config/#Using-Revise-automatically-within-Jupyter/IJulia-1

# clear console history
IJulia.clear_history()

fig_width = 7
fig_height = 5
fig_format = :retina
fig_dpi = 96

# no retina format type, use svg for high quality type/marks
if fig_format == :retina
  fig_format = :svg
elseif fig_format == :pdf
  fig_dpi = 96
  # Enable PDF support for IJulia
  IJulia.register_mime(MIME("application/pdf"))
end

# convert inches to pixels
fig_width = fig_width * fig_dpi
fig_height = fig_height * fig_dpi

# Intialize Plots w/ default fig width/height
try
  import Plots

  # Plots.jl doesn't support PDF output for versions < 1.28.1
  # so use png (if the DPI remains the default of 300 then set to 96)
  if (Plots._current_plots_version < v"1.28.1") & (fig_format == :pdf)
    Plots.gr(size=(fig_width, fig_height), fmt = :png, dpi = fig_dpi)
  else
    Plots.gr(size=(fig_width, fig_height), fmt = fig_format, dpi = fig_dpi)
  end
catch e
  # @warn "Plots init" exception=(e, catch_backtrace())
end

# Initialize CairoMakie with default fig width/height
try
  import CairoMakie

  # CairoMakie's display() in PDF format opens an interactive window
  # instead of saving to the ipynb file, so we don't do that.
  # https://github.com/quarto-dev/quarto-cli/issues/7548
  if fig_format == :pdf
    CairoMakie.activate!(type = "png")
  else
    CairoMakie.activate!(type = string(fig_format))
  end
  CairoMakie.update_theme!(resolution=(fig_width, fig_height))
catch e
    # @warn "CairoMakie init" exception=(e, catch_backtrace())
end
  
# Set run_path if specified
try
  run_path = raw"/Users/hirofumi48/162348.github.io/posts/2025/Comp"
  if !isempty(run_path)
    cd(run_path)
  end
catch e
  @warn "Run path init:" exception=(e, catch_backtrace())
end


# emulate old Pkg.installed beahvior, see
# https://discourse.julialang.org/t/how-to-use-pkg-dependencies-instead-of-pkg-installed/36416/9
import Pkg
function isinstalled(pkg::String)
  any(x -> x.name == pkg && x.is_direct_dep, values(Pkg.dependencies()))
end

# ojs_define
if isinstalled("JSON") && isinstalled("DataFrames")
  import JSON, DataFrames
  global function ojs_define(; kwargs...)
    convert(x) = x
    convert(x::DataFrames.AbstractDataFrame) = Tables.rows(x)
    content = Dict("contents" => [Dict("name" => k, "value" => convert(v)) for (k, v) in kwargs])
    tag = "<script type='ojs-define'>$(JSON.json(content))</script>"
    IJulia.display(MIME("text/html"), tag)
  end
elseif isinstalled("JSON")
  import JSON
  global function ojs_define(; kwargs...)
    content = Dict("contents" => [Dict("name" => k, "value" => v) for (k, v) in kwargs])
    tag = "<script type='ojs-define'>$(JSON.json(content))</script>"
    IJulia.display(MIME("text/html"), tag)
  end
else
  global function ojs_define(; kwargs...)
    @warn "JSON package not available. Please install the JSON.jl package to use ojs_define."
  end
end


# don't return kernel dependencies (b/c Revise should take care of dependencies)
nothing


In [2]:
#| output: false
n, p, pₑ = 200, 50, 10

using Random, StatsFuns, Distributions
β_true = vcat(randn(pₑ), zeros(p - pₑ))
X = randn(n, p)

η_true = X * β_true
π_true = logistic.(η_true)

y = rand.(Bernoulli.(π_true))
y = collect(Float64, y)

200-element Vector{Float64}:
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 1.0
 0.0
 1.0
 1.0
 0.0
 1.0
 ⋮
 0.0
 0.0
 1.0
 0.0
 1.0
 1.0
 0.0
 0.0
 1.0
 0.0
 0.0
 1.0

In [3]:
#| output: false
using PolyaGammaHybridSamplers, LinearAlgebra, MCMCChains, Dates, MCMCDiagnosticTools

function pg_logistic_gibbs(
  X::Matrix{Float64},
  y::Vector{Float64};
  n_iter::Int = 5000,
  burnin::Int = 1000,
  σ_prior::Float64 = 10.0,
)
  n, p = size(X)

  # 事前: β ~ N(0, σ_prior^2 I)
  V0_inv = (1 / σ_prior^2) * LinearAlgebra.I  # precision of prior

  # 初期値
  β = zeros(p)
  κ = y .- 0.5  # κ_i = y_i - 1/2

  # サンプル保存用
  n_save = n_iter - burnin
  β_samples = Matrix{Float64}(undef, n_save, p)

  t_start = time()
  for it in 1:n_iter
    # 1. PG 補助変数 ω_i | β のサンプル
    η = X * β
    ω = similar(η)
    for i in 1:n
      pg = PolyaGammaHybridSampler(1.0, η[i])
      ω[i] = rand(pg)
    end

    # 2. β | ω, y のサンプル (多変量ガウス)
    Ω = Diagonal(ω)
    precision = X' * Ω * X + V0_inv          # posterior precision
    cov = inv(Matrix(precision))             # posterior covariance
    m = cov * (X' * κ)                       # posterior mean (μ0=0 のため)

    # β ~ N(m, cov)
    β = rand(MvNormal(m, Symmetric(cov)))

    # burn-in 後に保存
    if it > burnin
      β_samples[it - burnin, :] .= β
    end
  end
  t_stop = time()
  runtime_sec = t_stop - t_start

  names = Symbol.("β[$i]" for i in 1:p)
  values = reshape(β_samples, n_save, p, 1)
  chain = Chains(values, names)
  chain = setinfo(chain, (
    start_time = [t_start],  # 1本チェインなら長さ1のベクトルでOK
    stop_time  = [t_stop],
  ))

  return chain, runtime_sec
end

pg_logistic_gibbs (generic function with 1 method)

In [4]:
σ_prior = 10.0
chain_pg, t_pg = pg_logistic_gibbs(X, y;
    n_iter = 6000,
    burnin = 1000,
    σ_prior = σ_prior,
)
summarize(chain_pg)



 [1m parameters [0m [1m     mean [0m [1m     std [0m [1m    mcse [0m [1m ess_bulk [0m [1m ess_tail [0m [1m    rhat [0m [1m [0m ⋯
 [90m     Symbol [0m [90m  Float64 [0m [90m Float64 [0m [90m Float64 [0m [90m  Float64 [0m [90m  Float64 [0m [90m Float64 [0m [90m [0m ⋯

        β[1]   -10.0344    2.0678    0.3056    45.4967   156.5714    1.0300    ⋯
        β[2]     5.9530    1.5507    0.2234    48.1918   296.8149    1.0450    ⋯
        β[3]     9.3878    2.0021    0.3129    41.5495   244.8472    1.0427    ⋯
        β[4]     5.0165    1.1829    0.1281    83.2698   267.7183    1.0158    ⋯
        β[5]     5.2013    1.7377    0.2494    47.9734   262.7303    1.0356    ⋯
        β[6]   -10.0213    1.9507    0.2854    45.3661   249.9768    1.0406    ⋯
        β[7]    -7.1532    1.7091    0.2154    62.9335   170.7253    1.0309    ⋯
        β[8]    -9.7942    1.9694    0.2451    66.0353   230.7863    1.0391    ⋯
        β[9]    -1.6939    0.9746    0.0618   250.6

In [5]:
#| output: false
using Turing, LinearAlgebra

@model function logreg_turing(x, y, σ_prior)
    n, p = size(x)
    
    # 事前分布
    β ~ MvNormal(zeros(p), (σ_prior^2) * I)
    
    # ベクトル化した尤度（高速化）
    η = x * β
    y ~ arraydist(Bernoulli.(logistic.(η)))
end

model = logreg_turing(X, y, σ_prior)

DynamicPPL.Model{typeof(logreg_turing), (:x, :y, :σ_prior), (), (), Tuple{Matrix{Float64}, Vector{Float64}, Float64}, Tuple{}, DynamicPPL.DefaultContext}(Main.logreg_turing, (x = [0.4768785849580158 -0.8568831005204172 … -0.3230897118074397 1.5131801356517873; 0.6586469284439959 -0.016144067105212836 … -1.8455990854935465 1.7381473442455606; … ; 2.844280173095508 1.4299437496416676 … 0.6923680078774691 0.6867362220667922; 0.3869496582192621 1.5465078871871303 … 0.8465924039417897 0.3739235402854992], y = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0  …  1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0], σ_prior = 10.0), NamedTuple(), DynamicPPL.DefaultContext())

In [6]:
#| output: false
n_samples = 200
n_adapt   = 100

chain_hmc = sample(
    model,
    NUTS(n_adapt, 0.6),
    n_samples,
)

[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mFound initial step size
[36m[1m└ [22m[39m  ϵ = 0.8


[32mSampling:   7%|██▊                                      |  ETA: 0:00:02[39m

[32mSampling:  17%|██████▊                                  |  ETA: 0:00:03[39m

[32mSampling:  23%|█████████▍                               |  ETA: 0:00:03[39m

[32mSampling:  28%|███████████▍                             |  ETA: 0:00:02[39m

[32mSampling:  32%|█████████████▏                           |  ETA: 0:00:02[39m

[32mSampling:  36%|██████████████▊                          |  ETA: 0:00:02[39m

[32mSampling:  40%|████████████████▍                        |  ETA: 0:00:02[39m

[32mSampling:  45%|██████████████████▎                      |  ETA: 0:00:02[39m

[32mSampling:  50%|████████████████████▍                    |  ETA: 0:00:01[39m

[32mSampling:  55%|██████████████████████▍                  |  ETA: 0:00:01[39m

[32mSampling:  60%|████████████████████████▍                |  ETA: 0:00:01[39m

[32mSampling:  65%|██████████████████████████▌              |  ETA: 0:00:01[39m

[32mSampling:  70%|████████████████████████████▌            |  ETA: 0:00:01[39m

[32mSampling:  75%|██████████████████████████████▋          |  ETA: 0:00:01[39m

[32mSampling:  80%|████████████████████████████████▋        |  ETA: 0:00:01[39m

[32mSampling:  85%|██████████████████████████████████▋      |  ETA: 0:00:00[39m

[32mSampling:  89%|████████████████████████████████████▌    |  ETA: 0:00:00[39m

[32mSampling:  93%|██████████████████████████████████████▏  |  ETA: 0:00:00[39m

[32mSampling:  97%|███████████████████████████████████████▊ |  ETA: 0:00:00[39m

[32mSampling: 100%|█████████████████████████████████████████| Time: 0:00:02[39m


Chains MCMC chain (200×64×1 Array{Float64, 3}):

Iterations        = 101:1:300
Number of chains  = 1
Samples per chain = 200
Wall duration     = 7.45 seconds
Compute duration  = 7.45 seconds
parameters        = β[1], β[2], β[3], β[4], β[5], β[6], β[7], β[8], β[9], β[10], β[11], β[12], β[13], β[14], β[15], β[16], β[17], β[18], β[19], β[20], β[21], β[22], β[23], β[24], β[25], β[26], β[27], β[28], β[29], β[30], β[31], β[32], β[33], β[34], β[35], β[36], β[37], β[38], β[39], β[40], β[41], β[42], β[43], β[44], β[45], β[46], β[47], β[48], β[49], β[50]
internals         = n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, lp, logprior, loglikelihood

Use `describe(chains)` for summary statistics and quantiles.


In [7]:
summarize(chain_hmc)



 [1m parameters [0m [1m     mean [0m [1m     std [0m [1m    mcse [0m [1m ess_bulk [0m [1m ess_tail [0m [1m    rhat [0m [1m [0m ⋯
 [90m     Symbol [0m [90m  Float64 [0m [90m Float64 [0m [90m Float64 [0m [90m  Float64 [0m [90m  Float64 [0m [90m Float64 [0m [90m [0m ⋯

        β[1]   -10.8624    2.1855    0.2986    52.0479   116.3597    1.0056    ⋯
        β[2]     6.0145    1.5147    0.1829    69.0917   149.0770    1.0133    ⋯
        β[3]    10.0608    2.0227    0.2828    51.3087    76.7798    0.9971    ⋯
        β[4]     5.2886    1.3070    0.1490    77.5819   116.5805    1.0057    ⋯
        β[5]     4.9963    1.8020    0.1836    95.4826   132.5623    1.0199    ⋯
        β[6]   -10.5341    2.0171    0.2802    49.6085    67.7908    1.0055    ⋯
        β[7]    -7.8511    1.8391    0.2302    64.6170    97.7058    1.0011    ⋯
        β[8]   -10.2871    1.9171    0.2564    55.4474   103.4417    0.9987    ⋯
        β[9]    -1.7836    0.9526    0.0805   148.3

In [8]:
using Statistics

# 真の β との誤差
mean_hmc = vec(mean(Array(chain_hmc), dims=1))  # ここは実際のパラメータ名に合わせて調整
mean_pg = vec(mean(Array(chain_pg), dims=1))

println("‖β̂_HMC - β_true‖₂ = ", norm(mean_hmc .- β_true))
println("‖β̂_PG  - β_true‖₂ = ", norm(mean_pg  .- β_true))

# ランタイムや ESS の比較も：
ess_hmc = ess_rhat(chain_hmc)
ess_pg  = ess_rhat(chain_pg)

‖β̂_HMC - β_true‖₂ = 22.859266028123894


‖β̂_PG  - β_true‖₂ = 21.436798077254295


ESS/R-hat

 [1m parameters [0m [1m      ess [0m [1m    rhat [0m [1m ess_per_sec [0m
 [90m     Symbol [0m [90m  Float64 [0m [90m Float64 [0m [90m     Float64 [0m

        β[1]    45.4967    1.0300       40.3339
        β[2]    48.1918    1.0450       42.7232
        β[3]    41.5495    1.0427       36.8346
        β[4]    83.2698    1.0158       73.8207
        β[5]    47.9734    1.0356       42.5296
        β[6]    45.3661    1.0406       40.2182
        β[7]    62.9335    1.0309       55.7921
        β[8]    66.0353    1.0391       58.5420
        β[9]   250.6044    1.0030      222.1670
       β[10]   205.6471    1.0022      182.3112
       β[11]   220.4787    1.0080      195.4598
       β[12]   228.6893    1.0005      202.7388
       β[13]   136.4203    1.0014      120.9399
       β[14]   180.1687    1.0022      159.7241
       β[15]   278.7472    1.0012      247.1163
       β[16]   102.9826    1.0169       91.2966
       β[17]   260.7968    1.0013      231.2029
      