In [1]:
import IJulia

# The julia kernel has built in support for Revise.jl, so this is the 
# recommended approach for long-running sessions:
# https://github.com/JuliaLang/IJulia.jl/blob/9b10fa9b879574bbf720f5285029e07758e50a5e/src/kernel.jl#L46-L51

# Users should enable revise within .julia/config/startup_ijulia.jl:
# https://timholy.github.io/Revise.jl/stable/config/#Using-Revise-automatically-within-Jupyter/IJulia-1

# clear console history
IJulia.clear_history()

fig_width = 7
fig_height = 5
fig_format = :retina
fig_dpi = 96

# no retina format type, use svg for high quality type/marks
if fig_format == :retina
  fig_format = :svg
elseif fig_format == :pdf
  fig_dpi = 96
  # Enable PDF support for IJulia
  IJulia.register_mime(MIME("application/pdf"))
end

# convert inches to pixels
fig_width = fig_width * fig_dpi
fig_height = fig_height * fig_dpi

# Intialize Plots w/ default fig width/height
try
  import Plots

  # Plots.jl doesn't support PDF output for versions < 1.28.1
  # so use png (if the DPI remains the default of 300 then set to 96)
  if (Plots._current_plots_version < v"1.28.1") & (fig_format == :pdf)
    Plots.gr(size=(fig_width, fig_height), fmt = :png, dpi = fig_dpi)
  else
    Plots.gr(size=(fig_width, fig_height), fmt = fig_format, dpi = fig_dpi)
  end
catch e
  # @warn "Plots init" exception=(e, catch_backtrace())
end

# Initialize CairoMakie with default fig width/height
try
  import CairoMakie

  # CairoMakie's display() in PDF format opens an interactive window
  # instead of saving to the ipynb file, so we don't do that.
  # https://github.com/quarto-dev/quarto-cli/issues/7548
  if fig_format == :pdf
    CairoMakie.activate!(type = "png")
  else
    CairoMakie.activate!(type = string(fig_format))
  end
  CairoMakie.update_theme!(resolution=(fig_width, fig_height))
catch e
    # @warn "CairoMakie init" exception=(e, catch_backtrace())
end
  
# Set run_path if specified
try
  run_path = raw"/Users/hirofumi48/162348.github.io/posts/2025/Comp"
  if !isempty(run_path)
    cd(run_path)
  end
catch e
  @warn "Run path init:" exception=(e, catch_backtrace())
end


# emulate old Pkg.installed beahvior, see
# https://discourse.julialang.org/t/how-to-use-pkg-dependencies-instead-of-pkg-installed/36416/9
import Pkg
function isinstalled(pkg::String)
  any(x -> x.name == pkg && x.is_direct_dep, values(Pkg.dependencies()))
end

# ojs_define
if isinstalled("JSON") && isinstalled("DataFrames")
  import JSON, DataFrames
  global function ojs_define(; kwargs...)
    convert(x) = x
    convert(x::DataFrames.AbstractDataFrame) = Tables.rows(x)
    content = Dict("contents" => [Dict("name" => k, "value" => convert(v)) for (k, v) in kwargs])
    tag = "<script type='ojs-define'>$(JSON.json(content))</script>"
    IJulia.display(MIME("text/html"), tag)
  end
elseif isinstalled("JSON")
  import JSON
  global function ojs_define(; kwargs...)
    content = Dict("contents" => [Dict("name" => k, "value" => v) for (k, v) in kwargs])
    tag = "<script type='ojs-define'>$(JSON.json(content))</script>"
    IJulia.display(MIME("text/html"), tag)
  end
else
  global function ojs_define(; kwargs...)
    @warn "JSON package not available. Please install the JSON.jl package to use ojs_define."
  end
end


# don't return kernel dependencies (b/c Revise should take care of dependencies)
nothing


In [2]:
n, p, pₑ = 200, 50, 10

using Random, StatsFuns, Distributions
β_true = vcat(randn(pₑ), zeros(p - pₑ))
X = randn(n, p)

η_true = X * β_true
π_true = logistic.(η_true)

y = rand.(Bernoulli.(π_true))
y = collect(Float64, y)

200-element Vector{Float64}:
 1.0
 1.0
 0.0
 1.0
 1.0
 1.0
 1.0
 0.0
 0.0
 0.0
 1.0
 1.0
 1.0
 ⋮
 1.0
 1.0
 1.0
 1.0
 1.0
 0.0
 0.0
 0.0
 0.0
 0.0
 1.0
 1.0

In [3]:
using PolyaGammaHybridSamplers, LinearAlgebra, MCMCChains, Dates, MCMCDiagnosticTools

function pg_logistic_gibbs(
  X::Matrix{Float64},
  y::Vector{Float64};
  n_iter::Int = 5000,
  burnin::Int = 1000,
  σ_prior::Float64 = 10.0,
)
  n, p = size(X)

  # 事前: β ~ N(0, σ_prior^2 I)
  V0_inv = (1 / σ_prior^2) * LinearAlgebra.I  # precision of prior

  # 初期値
  β = zeros(p)
  κ = y .- 0.5  # κ_i = y_i - 1/2

  # サンプル保存用
  n_save = n_iter - burnin
  β_samples = Matrix{Float64}(undef, n_save, p)

  t_start = time()
  for it in 1:n_iter
    # 1. PG 補助変数 ω_i | β のサンプル
    η = X * β
    ω = similar(η)
    for i in 1:n
      pg = PolyaGammaHybridSampler(1.0, η[i])
      ω[i] = rand(pg)
    end

    # 2. β | ω, y のサンプル (多変量ガウス)
    Ω = Diagonal(ω)
    precision = X' * Ω * X + V0_inv          # posterior precision
    cov = inv(Matrix(precision))             # posterior covariance
    m = cov * (X' * κ)                       # posterior mean (μ0=0 のため)

    # β ~ N(m, cov)
    β = rand(MvNormal(m, Symmetric(cov)))

    # burn-in 後に保存
    if it > burnin
      β_samples[it - burnin, :] .= β
    end
  end
  t_stop = time()
  runtime_sec = t_stop - t_start

  names = Symbol.("β[$i]" for i in 1:p)
  values = reshape(β_samples, n_save, p, 1)
  chain = Chains(values, names)
  chain = setinfo(chain, (
    start_time = [t_start],  # 1本チェインなら長さ1のベクトルでOK
    stop_time  = [t_stop],
  ))

  return chain, runtime_sec
end

pg_logistic_gibbs (generic function with 1 method)

In [4]:
σ_prior = 10.0
chain_pg, t_pg = pg_logistic_gibbs(X, y;
    n_iter = 6000,
    burnin = 1000,
    σ_prior = σ_prior,
)
summarize(chain_pg)



 [1m parameters [0m [1m    mean [0m [1m     std [0m [1m    mcse [0m [1m ess_bulk [0m [1m  ess_tail [0m [1m    rhat [0m [1m [0m ⋯
 [90m     Symbol [0m [90m Float64 [0m [90m Float64 [0m [90m Float64 [0m [90m  Float64 [0m [90m   Float64 [0m [90m Float64 [0m [90m [0m ⋯

        β[1]   -4.5623    1.0663    0.1010   114.9834    245.1213    1.0054    ⋯
        β[2]   -1.9332    0.7811    0.0503   243.7150    634.4626    1.0025    ⋯
        β[3]   -5.4152    1.1210    0.1077   107.9397    346.4901    1.0049    ⋯
        β[4]    5.1372    1.0597    0.0926   134.2516    330.1069    1.0007    ⋯
        β[5]   -1.1637    0.6500    0.0304   458.7530   1210.0802    1.0018    ⋯
        β[6]   -4.7075    0.8917    0.0735   147.1894    420.5086    1.0055    ⋯
        β[7]    3.8180    0.7904    0.0572   191.5226    590.3162    1.0021    ⋯
        β[8]    7.1938    1.2646    0.1233   106.3801    311.6481    1.0046    ⋯
        β[9]    4.7390    1.0691    0.0967   123.88

In [5]:
using Turing, LinearAlgebra

@model function logreg_turing(x, y, σ_prior)
    n, p = size(x)
    
    # 事前分布
    β ~ MvNormal(zeros(p), (σ_prior^2) * I)
    
    # ベクトル化した尤度（高速化）
    η = x * β
    y ~ arraydist(Bernoulli.(logistic.(η)))
end

model = logreg_turing(X, y, σ_prior)

DynamicPPL.Model{typeof(logreg_turing), (:x, :y, :σ_prior), (), (), Tuple{Matrix{Float64}, Vector{Float64}, Float64}, Tuple{}, DynamicPPL.DefaultContext}(Main.logreg_turing, (x = [-1.920710590967877 -0.4254873732592862 … -1.1712469468774718 1.3888179219627055; -0.7618220073987623 -0.73737265407344 … -0.3343958154563674 -1.4692969664995075; … ; 1.0362935802053204 0.6457359321956061 … 0.10234993997390054 0.0665273644228262; -0.5172525976397923 1.0811728531675424 … 1.6200192135754181 -0.16375074797993325], y = [1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0  …  1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], σ_prior = 10.0), NamedTuple(), DynamicPPL.DefaultContext())

In [6]:
n_samples = 200
n_adapt   = 100

chain_hmc = sample(
    model,
    NUTS(n_adapt, 0.6),
    n_samples,
)
summarize(chain_hmc)

[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mFound initial step size
[36m[1m└ [22m[39m  ϵ = 0.8


[32mSampling:   6%|██▎                                      |  ETA: 0:00:02[39m

[32mSampling:  41%|████████████████▋                        |  ETA: 0:00:01[39m

[32mSampling:  49%|███████████████████▉                     |  ETA: 0:00:01[39m

[32mSampling:  57%|███████████████████████▎                 |  ETA: 0:00:01[39m

[32mSampling:  65%|██████████████████████████▋              |  ETA: 0:00:01[39m

[32mSampling:  73%|█████████████████████████████▉           |  ETA: 0:00:00[39m

[32mSampling:  81%|█████████████████████████████████        |  ETA: 0:00:00[39m

[32mSampling:  89%|████████████████████████████████████▌    |  ETA: 0:00:00[39m

[32mSampling:  97%|███████████████████████████████████████▋ |  ETA: 0:00:00[39m

[32mSampling: 100%|█████████████████████████████████████████| Time: 0:00:01[39m




 [1m parameters [0m [1m    mean [0m [1m     std [0m [1m    mcse [0m [1m ess_bulk [0m [1m ess_tail [0m [1m    rhat [0m [1m e[0m ⋯
 [90m     Symbol [0m [90m Float64 [0m [90m Float64 [0m [90m Float64 [0m [90m  Float64 [0m [90m  Float64 [0m [90m Float64 [0m [90m  [0m ⋯

        β[1]   -4.4895    1.0807    0.1510    51.4850   115.4252    1.0412     ⋯
        β[2]   -1.7923    0.6929    0.0680   105.9273    99.2758    1.0009     ⋯
        β[3]   -5.1563    1.0966    0.1446    56.5585    97.8706    1.0116     ⋯
        β[4]    4.9742    1.1244    0.1824    41.9260   130.1523    1.0232     ⋯
        β[5]   -1.0627    0.6009    0.0423   199.4297   181.8165    0.9963     ⋯
        β[6]   -4.4956    0.7920    0.1115    48.9272    88.4193    1.0354     ⋯
        β[7]    3.6558    0.7329    0.1035    49.8661    86.5334    1.0331     ⋯
        β[8]    6.9685    1.1649    0.1969    34.7582    69.9585    1.0356     ⋯
        β[9]    4.5229    1.0392    0.1784    35.12

In [7]:
using Statistics

# 真の β との誤差
mean_hmc = vec(mean(Array(chain_hmc), dims=1))  # ここは実際のパラメータ名に合わせて調整
mean_pg = vec(mean(Array(chain_pg), dims=1))

println("‖β̂_HMC - β_true‖₂ = ", norm(mean_hmc .- β_true))
println("‖β̂_PG  - β_true‖₂ = ", norm(mean_pg  .- β_true))

# ランタイムや ESS の比較も：
ess_hmc = ess_rhat(chain_hmc)
ess_pg  = ess_rhat(chain_pg)

‖β̂_HMC - β_true‖₂ = 13.117744122347888


‖β̂_PG  - β_true‖₂ = 13.725005117014707


ESS/R-hat

 [1m parameters [0m [1m      ess [0m [1m    rhat [0m [1m ess_per_sec [0m
 [90m     Symbol [0m [90m  Float64 [0m [90m Float64 [0m [90m     Float64 [0m

        β[1]   114.9834    1.0054       93.4824
        β[2]   243.7150    1.0025      198.1423
        β[3]   107.9397    1.0049       87.7559
        β[4]   134.2516    1.0007      109.1476
        β[5]   458.7530    1.0018      372.9699
        β[6]   147.1894    1.0055      119.6662
        β[7]   191.5226    1.0021      155.7094
        β[8]   106.3801    1.0046       86.4879
        β[9]   123.8801    1.0072      100.7156
       β[10]   304.2753    1.0016      247.3783
       β[11]   231.1660    1.0013      187.9398
       β[12]   215.3553    1.0011      175.0856
       β[13]   624.5507    0.9998      507.7648
       β[14]   381.0539    1.0082      309.7999
       β[15]   167.3431    1.0012      136.0513
       β[16]   403.0752    1.0002      327.7034
       β[17]   605.9495    1.0001      492.6419
      