In [1]:
import IJulia

# The julia kernel has built in support for Revise.jl, so this is the 
# recommended approach for long-running sessions:
# https://github.com/JuliaLang/IJulia.jl/blob/9b10fa9b879574bbf720f5285029e07758e50a5e/src/kernel.jl#L46-L51

# Users should enable revise within .julia/config/startup_ijulia.jl:
# https://timholy.github.io/Revise.jl/stable/config/#Using-Revise-automatically-within-Jupyter/IJulia-1

# clear console history
IJulia.clear_history()

fig_width = 7
fig_height = 5
fig_format = :retina
fig_dpi = 96

# no retina format type, use svg for high quality type/marks
if fig_format == :retina
  fig_format = :svg
elseif fig_format == :pdf
  fig_dpi = 96
  # Enable PDF support for IJulia
  IJulia.register_mime(MIME("application/pdf"))
end

# convert inches to pixels
fig_width = fig_width * fig_dpi
fig_height = fig_height * fig_dpi

# Intialize Plots w/ default fig width/height
try
  import Plots

  # Plots.jl doesn't support PDF output for versions < 1.28.1
  # so use png (if the DPI remains the default of 300 then set to 96)
  if (Plots._current_plots_version < v"1.28.1") & (fig_format == :pdf)
    Plots.gr(size=(fig_width, fig_height), fmt = :png, dpi = fig_dpi)
  else
    Plots.gr(size=(fig_width, fig_height), fmt = fig_format, dpi = fig_dpi)
  end
catch e
  # @warn "Plots init" exception=(e, catch_backtrace())
end

# Initialize CairoMakie with default fig width/height
try
  import CairoMakie

  # CairoMakie's display() in PDF format opens an interactive window
  # instead of saving to the ipynb file, so we don't do that.
  # https://github.com/quarto-dev/quarto-cli/issues/7548
  if fig_format == :pdf
    CairoMakie.activate!(type = "png")
  else
    CairoMakie.activate!(type = string(fig_format))
  end
  CairoMakie.update_theme!(resolution=(fig_width, fig_height))
catch e
    # @warn "CairoMakie init" exception=(e, catch_backtrace())
end
  
# Set run_path if specified
try
  run_path = raw"/Users/hirofumi48/162348.github.io/posts/2025/Comp"
  if !isempty(run_path)
    cd(run_path)
  end
catch e
  @warn "Run path init:" exception=(e, catch_backtrace())
end


# emulate old Pkg.installed beahvior, see
# https://discourse.julialang.org/t/how-to-use-pkg-dependencies-instead-of-pkg-installed/36416/9
import Pkg
function isinstalled(pkg::String)
  any(x -> x.name == pkg && x.is_direct_dep, values(Pkg.dependencies()))
end

# ojs_define
if isinstalled("JSON") && isinstalled("DataFrames")
  import JSON, DataFrames
  global function ojs_define(; kwargs...)
    convert(x) = x
    convert(x::DataFrames.AbstractDataFrame) = Tables.rows(x)
    content = Dict("contents" => [Dict("name" => k, "value" => convert(v)) for (k, v) in kwargs])
    tag = "<script type='ojs-define'>$(JSON.json(content))</script>"
    IJulia.display(MIME("text/html"), tag)
  end
elseif isinstalled("JSON")
  import JSON
  global function ojs_define(; kwargs...)
    content = Dict("contents" => [Dict("name" => k, "value" => v) for (k, v) in kwargs])
    tag = "<script type='ojs-define'>$(JSON.json(content))</script>"
    IJulia.display(MIME("text/html"), tag)
  end
else
  global function ojs_define(; kwargs...)
    @warn "JSON package not available. Please install the JSON.jl package to use ojs_define."
  end
end


# don't return kernel dependencies (b/c Revise should take care of dependencies)
nothing


In [2]:
n, p, pₑ = 200, 50, 10

using Random, StatsFuns, Distributions
β_true = vcat(randn(pₑ), zeros(p - pₑ))
X = randn(n, p)

η_true = X * β_true
π_true = logistic.(η_true)

y = rand.(Bernoulli.(π_true))
y = collect(Float64, y)

200-element Vector{Float64}:
 0.0
 0.0
 0.0
 1.0
 0.0
 1.0
 0.0
 1.0
 0.0
 0.0
 0.0
 1.0
 0.0
 ⋮
 0.0
 1.0
 1.0
 1.0
 0.0
 1.0
 0.0
 1.0
 1.0
 0.0
 0.0
 0.0

In [3]:
using PolyaGammaHybridSamplers, LinearAlgebra, MCMCChains

function pg_logistic_gibbs(
    X::Matrix{Float64},
    y::Vector{Float64};
    n_iter::Int = 5000,
    burnin::Int = 1000,
    σ_prior::Float64 = 10.0,
)
    n, p = size(X)

    # 事前: β ~ N(0, σ_prior^2 I)
    V0_inv = (1 / σ_prior^2) * LinearAlgebra.I  # precision of prior

    # 初期値
    β = zeros(p)
    κ = y .- 0.5  # κ_i = y_i - 1/2

    # サンプル保存用
    n_save = n_iter - burnin
    β_samples = Matrix{Float64}(undef, n_save, p)

    runtime_sec = @elapsed begin
        for it in 1:n_iter
            # 1. PG 補助変数 ω_i | β のサンプル
            η = X * β
            ω = similar(η)
            for i in 1:n
                pg = PolyaGammaHybridSampler(1.0, η[i])
                ω[i] = rand(pg)
            end

            # 2. β | ω, y のサンプル (多変量ガウス)
            Ω = Diagonal(ω)
            precision = X' * Ω * X + V0_inv          # posterior precision
            cov = inv(Matrix(precision))             # posterior covariance
            m = cov * (X' * κ)                       # posterior mean (μ0=0 のため)

            # β ~ N(m, cov)
            β = rand(MvNormal(m, Symmetric(cov)))

            # burn-in 後に保存
            if it > burnin
                β_samples[it - burnin, :] .= β
            end
        end
    end

    names = Symbol.("β[$i]" for i in 1:p)
    values = reshape(β_samples, n_save, p, 1)
    chain = Chains(values, names)
    ess_table = ess(chain; duration = _ -> runtime_sec)

    return chain, ess_table, runtime_sec
end

pg_logistic_gibbs (generic function with 1 method)

In [4]:
σ_prior = 10.0
chain_pg, ess_table, t_pg = pg_logistic_gibbs(X, y;
    n_iter = 6000,
    burnin = 1000,
    σ_prior = σ_prior,
)
summarize(chain_pg)
println(ess_table)

ESS (

50 x 3)


In [5]:
using Turing, LinearAlgebra

@model function logreg_turing(x, y, σ_prior)
    n, p = size(x)
    
    # 事前分布
    β ~ MvNormal(zeros(p), (σ_prior^2) * I)
    
    # ベクトル化した尤度（高速化）
    η = x * β
    y ~ arraydist(Bernoulli.(logistic.(η)))
end

model = logreg_turing(X, y, σ_prior)

DynamicPPL.Model{typeof(logreg_turing), (:x, :y, :σ_prior), (), (), Tuple{Matrix{Float64}, Vector{Float64}, Float64}, Tuple{}, DynamicPPL.DefaultContext}(Main.logreg_turing, (x = [-1.1187056682900378 -1.1269865616249308 … 0.43482916584479453 0.760907455586308; 0.9438931355932827 -1.1384020276548183 … 0.24349764269051588 2.5284229677281154; … ; -0.4296626465151024 1.1139798295491257 … -0.22198717758198588 0.39574695654685094; 0.3343798853171577 0.7150406437142572 … 0.3004970463094861 0.11768215498703054], y = [0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0  …  1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0], σ_prior = 10.0), NamedTuple(), DynamicPPL.DefaultContext())

In [6]:
n_samples = 200
n_adapt   = 100

chain_hmc = sample(
    model,
    NUTS(n_adapt, 0.6),
    n_samples,
)
summarize(chain_hmc)

[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mFound initial step size
[36m[1m└ [22m[39m  ϵ = 0.8


[32mSampling:   7%|██▊                                      |  ETA: 0:00:01[39m

[32mSampling:  66%|██████████████████████████▉              |  ETA: 0:00:01[39m

[32mSampling:  74%|██████████████████████████████▏          |  ETA: 0:00:01[39m

[32mSampling:  81%|█████████████████████████████████▎       |  ETA: 0:00:00[39m

[32mSampling:  89%|████████████████████████████████████▍    |  ETA: 0:00:00[39m

[32mSampling:  96%|███████████████████████████████████████▎ |  ETA: 0:00:00[39m

[32mSampling: 100%|█████████████████████████████████████████| Time: 0:00:01[39m




 [1m parameters [0m [1m    mean [0m [1m     std [0m [1m    mcse [0m [1m ess_bulk [0m [1m ess_tail [0m [1m    rhat [0m [1m e[0m ⋯
 [90m     Symbol [0m [90m Float64 [0m [90m Float64 [0m [90m Float64 [0m [90m  Float64 [0m [90m  Float64 [0m [90m Float64 [0m [90m  [0m ⋯

        β[1]    8.1527    1.7256    0.2720    42.1212    64.3104    1.0115     ⋯
        β[2]   -3.1124    0.8229    0.1102    55.1294   104.4818    1.0050     ⋯
        β[3]    2.9031    1.1173    0.1602    44.4416   115.8658    1.0059     ⋯
        β[4]    2.2817    0.8007    0.1123    50.3272   104.1194    1.0305     ⋯
        β[5]   -1.9668    0.7441    0.0973    59.7141    62.8135    1.0044     ⋯
        β[6]    6.7309    1.4341    0.2385    37.7861    93.8319    1.0101     ⋯
        β[7]    4.3373    0.9459    0.1102    75.9565   112.3159    1.0193     ⋯
        β[8]   -0.8715    0.8061    0.0859    87.0212   152.8143    1.0040     ⋯
        β[9]    7.6997    1.8669    0.3282    33.30

In [7]:
using Statistics

# 真の β との誤差
mean_hmc = vec(mean(Array(chain_hmc), dims=1))  # ここは実際のパラメータ名に合わせて調整
mean_pg = vec(mean(Array(chain_pg), dims=1))

println("‖β̂_HMC - β_true‖₂ = ", norm(mean_hmc .- β_true))
println("‖β̂_PG  - β_true‖₂ = ", norm(mean_pg  .- β_true))

# ランタイムや ESS の比較も：
ess_hmc = ess_rhat(chain_hmc)
ess_pg  = ess_rhat(chain_pg)

‖β̂_HMC - β_true‖₂ = 16.27084740860733


‖β̂_PG  - β_true‖₂ = 15.63124429533065


ESS/R-hat

 [1m parameters [0m [1m      ess [0m [1m    rhat [0m [1m ess_per_sec [0m
 [90m     Symbol [0m [90m  Float64 [0m [90m Float64 [0m [90m     Missing [0m

        β[1]    93.8193    1.0105       missing
        β[2]   147.3975    1.0126       missing
        β[3]   209.1936    1.0010       missing
        β[4]   123.9315    1.0087       missing
        β[5]   193.6608    1.0050       missing
        β[6]    93.0069    1.0075       missing
        β[7]   136.0048    1.0102       missing
        β[8]   315.1963    1.0020       missing
        β[9]    93.6536    1.0103       missing
       β[10]   374.9009    1.0030       missing
       β[11]   489.9509    1.0002       missing
       β[12]   217.1356    1.0024       missing
       β[13]   526.3917    1.0014       missing
       β[14]   496.5095    1.0013       missing
       β[15]   410.8508    1.0042       missing
       β[16]   198.7431    1.0270       missing
       β[17]   227.4674    1.0022       missing
      