Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
87a9385
feat: add free-element Cholesky gradient engine for GGM NUTS
MaartenMarsman Mar 19, 2026
250aea8
feat: integrate NUTS sampler for GGM model
MaartenMarsman Mar 19, 2026
3645fdf
feat: add GGM NUTS correctness tests and fix gradient inversion
MaartenMarsman Mar 19, 2026
c3f0f98
feat: RATTLE constrained HMC and include_edge for sparse fixed graphs
MaartenMarsman Mar 20, 2026
9026029
feat: dense mass, gradient optimization, and mass-weighted RATTLE pro…
MaartenMarsman Mar 21, 2026
2a97542
refactor: remove null-space sampler and include_edge parameter
MaartenMarsman Mar 21, 2026
a1d6d02
Fix
MaartenMarsman Mar 21, 2026
51e8b5f
fix: split RATTLE projections, re-tune step size for constraints
MaartenMarsman Mar 22, 2026
0f7ec18
perf: PCG warm-start and profiling instrumentation
MaartenMarsman Mar 22, 2026
650b88f
perf: move semantics in NUTS tree and gradient workspace reuse
MaartenMarsman Mar 22, 2026
71227a0
perf: use triangular BLAS dispatch for gradient matrix multiply
MaartenMarsman Mar 22, 2026
7d0eff1
style: fix semicolon lint violations in RATTLE NUTS tests
MaartenMarsman Mar 22, 2026
54677af
refactor: remove profiling instrumentation
MaartenMarsman Mar 22, 2026
d82a1a8
refactor: remove step_size from user-facing output
MaartenMarsman Mar 22, 2026
c9c5328
feat: enable HMC for GGM with RATTLE constraint support
MaartenMarsman Mar 22, 2026
cd0675a
test: add HMC + GGM integration and RATTLE correctness tests
MaartenMarsman Mar 22, 2026
09ca7e7
feat: NUTS sampler for mixed MRF (Cholesky parameterization)
MaartenMarsman Mar 23, 2026
aa30bf7
fix: correct discrete slab prior scale in mixed MRF edge selection
MaartenMarsman Mar 23, 2026
4dd963c
feat: RATTLE constrained HMC infrastructure for mixed MRF
MaartenMarsman Mar 23, 2026
3eec4a8
test: mixed MRF leapfrog tests and mass-weighted projection wrappers
MaartenMarsman Mar 25, 2026
00c7e1f
fix: use indicator array for slab filtering in summarize_slab
MaartenMarsman Mar 25, 2026
1538d12
tests: add scaling and stress diagnostics for NUTS sampler
MaartenMarsman Mar 25, 2026
a920a64
fix: MH prior/Jacobian bugs in diagonal and off-diagonal updates
MaartenMarsman Mar 25, 2026
ee793b1
tests: relax mixed MRF PIP threshold to 0.20 for edge selection condi…
MaartenMarsman Mar 25, 2026
dee30e5
refactor: remove redundant C++ constructor defaults; use MY_LOG in SB…
MaartenMarsman Mar 25, 2026
2db21bc
Merge remote-tracking branch 'origin/main' into feature/ggm-nuts
MaartenMarsman Mar 25, 2026
f295395
fix: add MASS back to Suggests; style: apply bgms_style to R and test…
MaartenMarsman Mar 25, 2026
26eed09
fix: restore <<- in test warning handler broken by styler
MaartenMarsman Mar 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ docs/*
/inst/doc/
dev/
/paper/

# testthat problem artefacts
tests/testthat/_problems/
tests/testthat/testthat-problems.rds
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Suggests:
coda,
covr,
knitr,
MASS,
parallel,
qgraph,
rmarkdown,
Expand Down
52 changes: 50 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,34 @@ rcpp_ieee754_log <- function(x) {
.Call(`_bgms_rcpp_ieee754_log`, x)
}

ggm_test_logp_and_gradient <- function(theta, suf_stat, n, edge_indicators, pairwise_scale) {
.Call(`_bgms_ggm_test_logp_and_gradient`, theta, suf_stat, n, edge_indicators, pairwise_scale)
}

ggm_test_forward_map <- function(theta, edge_indicators) {
.Call(`_bgms_ggm_test_forward_map`, theta, edge_indicators)
}

ggm_test_project_position <- function(x, edge_indicators, inv_mass_in = NULL) {
.Call(`_bgms_ggm_test_project_position`, x, edge_indicators, inv_mass_in)
}

ggm_test_get_full_position <- function(Phi, edge_indicators) {
.Call(`_bgms_ggm_test_get_full_position`, Phi, edge_indicators)
}

ggm_test_logp_and_gradient_full <- function(x, suf_stat, n, edge_indicators, pairwise_scale) {
.Call(`_bgms_ggm_test_logp_and_gradient_full`, x, suf_stat, n, edge_indicators, pairwise_scale)
}

ggm_test_project_momentum <- function(r, x, edge_indicators, inv_mass_in = NULL) {
.Call(`_bgms_ggm_test_project_momentum`, r, x, edge_indicators, inv_mass_in)
}

ggm_test_leapfrog_constrained <- function(x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, inv_mass_in = NULL) {
.Call(`_bgms_ggm_test_leapfrog_constrained`, x0, r0, step_size, n_steps, suf_stat, n, edge_indicators, pairwise_scale, inv_mass_in)
}

.compute_ess_cpp <- function(array3d) {
.Call(`_bgms_compute_ess_cpp`, array3d)
}
Expand All @@ -29,6 +57,26 @@ rcpp_ieee754_log <- function(x) {
.Call(`_bgms_compute_indicator_ess_cpp`, array3d)
}

mixed_test_logp_and_gradient <- function(params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pseudolikelihood, pairwise_scale) {
.Call(`_bgms_mixed_test_logp_and_gradient`, params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pseudolikelihood, pairwise_scale)
}

mixed_test_logp_and_gradient_full <- function(params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pseudolikelihood, pairwise_scale) {
.Call(`_bgms_mixed_test_logp_and_gradient_full`, params, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pseudolikelihood, pairwise_scale)
}

mixed_test_project_position <- function(x, inv_mass, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pseudolikelihood, pairwise_scale) {
.Call(`_bgms_mixed_test_project_position`, x, inv_mass, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pseudolikelihood, pairwise_scale)
}

mixed_test_project_momentum <- function(r, x, inv_mass, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pseudolikelihood, pairwise_scale) {
.Call(`_bgms_mixed_test_project_momentum`, r, x, inv_mass, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pseudolikelihood, pairwise_scale)
}

mixed_test_leapfrog_constrained <- function(x0, r0, step_size, n_steps, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pseudolikelihood, pairwise_scale, inv_mass_in = NULL) {
.Call(`_bgms_mixed_test_leapfrog_constrained`, x0, r0, step_size, n_steps, discrete_observations, continuous_observations, num_categories, is_ordinal_variable, baseline_category, edge_indicators, pseudolikelihood, pairwise_scale, inv_mass_in)
}

compute_conditional_ggm <- function(observations, predict_vars, precision) {
.Call(`_bgms_compute_conditional_ggm`, observations, predict_vars, precision)
}
Expand Down Expand Up @@ -69,8 +117,8 @@ run_mixed_simulation_parallel <- function(mux_samples, disc_samples, muy_samples
.Call(`_bgms_run_mixed_simulation_parallel`, mux_samples, disc_samples, muy_samples, cont_samples, cross_samples, draw_indices, num_states, p, q, num_categories, variable_type_r, baseline_category, iter, nThreads, seed, progress_type)
}

sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, na_impute = FALSE, missing_index_nullable = NULL) {
.Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, na_impute, missing_index_nullable)
sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, na_impute = FALSE, missing_index_nullable = NULL) {
.Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable)
}

sample_mixed_mrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, sampler_type = "mh", target_acceptance = 0.80, max_tree_depth = 10L, num_leapfrogs = 100L, na_impute = FALSE, missing_index_discrete_nullable = NULL, missing_index_continuous_nullable = NULL) {
Expand Down
4 changes: 2 additions & 2 deletions R/bgm.R
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,8 @@
#' with Robbins–Monro proposal adaptation.}
#' \item{"hamiltonian-mc"}{Hamiltonian Monte Carlo with fixed path length
#' (number of leapfrog steps set by \code{hmc_num_leapfrogs}).}
#' \item{"nuts"}{The No-U-Turn Sampler, an adaptive form of HMC with
#' dynamically chosen trajectory lengths.}
#' \item{"nuts"}{The No-U-Turn Sampler with RATTLE constrained integration
#' for Gaussian models with edge selection.}
#' }
#' Default: \code{"nuts"}.
#'
Expand Down
18 changes: 8 additions & 10 deletions R/bgm_spec.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,6 @@ validate_bgm_spec = function(spec) {
if(!isTRUE(spec$variables$is_continuous)) {
stop("bgm_spec: model_type = 'ggm' requires is_continuous = TRUE.")
}
if(spec$sampler$update_method != "adaptive-metropolis") {
stop("bgm_spec: model_type = 'ggm' requires update_method = 'adaptive-metropolis'.")
}
}

# Compare invariants
Expand Down Expand Up @@ -215,7 +212,7 @@ validate_bgm_spec = function(spec) {
if(length(spec$data$num_categories) != spec$data$num_discrete) {
stop("bgm_spec: num_categories length doesn't match num_discrete.")
}
allowed = c("adaptive-metropolis", "hybrid-nuts")
allowed = c("adaptive-metropolis", "nuts", "hybrid-nuts")
if(!(spec$sampler$update_method %in% allowed)) {
stop(
"bgm_spec: model_type = 'mixed_mrf' requires update_method in ",
Expand Down Expand Up @@ -273,7 +270,8 @@ bgm_spec = function(x,
difference_probability = 0.5,
# Sampler
update_method = c(
"nuts", "adaptive-metropolis",
"nuts",
"adaptive-metropolis",
"hamiltonian-mc"
),
target_accept = NULL,
Expand Down Expand Up @@ -345,11 +343,11 @@ bgm_spec = function(x,
verbose = verbose
)

# Mixed MRF: remap "nuts" to the hybrid sampler that uses NUTS for the
# unconstrained block and component-wise MH for the SPD-constrained
# continuous precision.
if(is_mixed && sampler$update_method == "nuts") {
sampler$update_method = "hybrid-nuts"
# Mixed MRF: remap "hybrid-nuts" to "nuts" — the full NUTS sampler now
# handles all parameters including the continuous precision via Cholesky
# parameterization. Keep "hybrid-nuts" as a recognized alias.
if(is_mixed && sampler$update_method == "hybrid-nuts") {
sampler$update_method = "nuts"
}

# --- Build by model type ----------------------------------------------------
Expand Down
5 changes: 4 additions & 1 deletion R/build_output.R
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ build_output_bgm = function(spec, raw) {
if(!is.null(chain$allocation_samples)) {
res$allocations = t(chain$allocation_samples)
}
if(!is.null(chain$treedepth)) res[["treedepth__"]] = chain$treedepth
if(!is.null(chain$divergent)) res[["divergent__"]] = chain$divergent
if(!is.null(chain$energy)) res[["energy__"]] = chain$energy
res
})
} else {
Expand Down Expand Up @@ -771,7 +774,7 @@ build_output_mixed_mrf = function(spec, raw) {
)

# --- NUTS diagnostics -------------------------------------------------------
if(s$update_method == "hybrid-nuts") {
if(s$update_method %in% c("nuts", "hybrid-nuts")) {
results$nuts_diag = summarize_nuts_diagnostics(
raw,
nuts_max_depth = s$nuts_max_depth
Expand Down
14 changes: 9 additions & 5 deletions R/mcmc_summary.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ summarize_indicator = function(fit, component = c("indicator_samples"), param_na
}

# Summarize slab values where indicators are 1
summarize_slab = function(fit, component = c("pairwise_samples"), param_names = NULL, array3d = NULL) {
summarize_slab = function(fit, component = c("pairwise_samples"), param_names = NULL, array3d = NULL, array3d_ind = NULL) {
component = match.arg(component) # Add options later
if(is.null(array3d)) array3d = combine_chains(fit, component)
nparam = dim(array3d)[3]
Expand All @@ -190,8 +190,12 @@ summarize_slab = function(fit, component = c("pairwise_samples"), param_names =
for(j in seq_len(nparam)) {
draws = array3d[, , j]
vec = as.vector(draws)
nonzero = vec != 0
vec = vec[nonzero]
if(!is.null(array3d_ind)) {
selected = as.vector(array3d_ind[, , j]) == 1
} else {
selected = vec != 0
}
vec = vec[selected]
n_total = length(vec)

if(n_total >= 1) {
Expand Down Expand Up @@ -230,7 +234,7 @@ summarize_pair = function(fit,
if(is.null(array3d_id)) array3d_id = combine_chains(fit, indicator_component)
if(is.null(array3d_pw)) array3d_pw = combine_chains(fit, slab_component)
if(is.null(summ_ind)) summ_ind = summarize_indicator(fit, component = indicator_component, array3d = array3d_id)
if(is.null(summ_slab)) summ_slab = summarize_slab(fit, component = slab_component, array3d = array3d_pw)
if(is.null(summ_slab)) summ_slab = summarize_slab(fit, component = slab_component, array3d = array3d_pw, array3d_ind = array3d_id)
nparam = nrow(summ_ind)

# EAP = indicator_mean * slab_mean.
Expand Down Expand Up @@ -277,7 +281,7 @@ summarize_fit = function(fit, edge_selection = FALSE) {

# Compute indicator and slab summaries once
ind_summary = summarize_indicator(fit, component = "indicator_samples", array3d = array3d_ind)
slab_summary = summarize_slab(fit, component = "pairwise_samples", array3d = array3d_pw)
slab_summary = summarize_slab(fit, component = "pairwise_samples", array3d = array3d_pw, array3d_ind = array3d_ind)

all_selected = ind_summary$mean == 1

Expand Down
15 changes: 15 additions & 0 deletions R/run_sampler.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ run_sampler = function(spec) {
stop("Unknown model_type: ", spec$model_type)
)

# Check for chain-level errors
chain_errors = vapply(raw, function(ch) isTRUE(ch$error), logical(1L))
if(all(chain_errors)) {
msgs = vapply(raw, function(ch) ch$error_msg %||% "unknown error", character(1L))
stop("All chains failed. First error: ", msgs[1L])
}
if(any(chain_errors)) {
n_fail = sum(chain_errors)
warning(n_fail, " of ", length(raw), " chain(s) failed and will be dropped.")
raw = raw[!chain_errors]
}

# Check for user interrupt across all chains
userInterrupt = any(vapply(raw, `[[`, logical(1L), "userInterrupt"))
attr(raw, "userInterrupt") = userInterrupt
Expand Down Expand Up @@ -69,6 +81,7 @@ run_sampler_ggm = function(spec) {
no_warmup = s$warmup,
no_chains = s$chains,
edge_selection = p$edge_selection,
sampler_type = s$update_method,
seed = s$seed,
no_threads = s$cores,
progress_type = s$progress_type,
Expand All @@ -79,6 +92,8 @@ run_sampler_ggm = function(spec) {
beta_bernoulli_beta_between = bb_beta_between,
dirichlet_alpha = p$dirichlet_alpha,
lambda = p$lambda,
target_acceptance = s$target_accept,
max_tree_depth = s$nuts_max_depth,
na_impute = m$na_impute,
missing_index_nullable = m$missing_index
)
Expand Down
22 changes: 10 additions & 12 deletions R/validate_sampler.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ progress_type_from_display_progress = function(display_progress = c("per-chain",
# @param cores Integer: number of CPU cores.
# @param seed Integer or NULL.
# @param display_progress Character or logical: progress display mode.
# @param is_continuous Logical: TRUE for GGM model. Forces adaptive-metropolis.
# @param is_continuous Logical: TRUE for GGM model.
# @param edge_selection Logical: affects warmup warning tiers.
# @param verbose Logical: whether to emit warmup warnings.
#
Expand Down Expand Up @@ -107,19 +107,17 @@ validate_sampler = function(update_method,
choices = c("nuts", "adaptive-metropolis", "hamiltonian-mc")
)

# --- GGM guard: force adaptive-metropolis -----------------------------------
if(is_continuous) {
if(user_chose_method && update_method %in% c("nuts", "hamiltonian-mc")) {
stop(paste0(
"The Gaussian model (variable_type = 'continuous') only supports ",
"update_method = 'adaptive-metropolis'. ",
"Got '", update_method, "'."
))
}
update_method = "adaptive-metropolis"
# --- target_accept ----------------------------------------------------------
if(is_continuous && edge_selection && update_method == "hamiltonian-mc") {
warning(
"hamiltonian-mc with edge selection on a GGM uses constrained ",
"integration (RATTLE), which can be numerically fragile with a ",
"fixed trajectory length. Consider using 'nuts' instead, which ",
"adapts trajectory length and avoids degenerate regions.",
call. = FALSE
)
}

# --- target_accept ----------------------------------------------------------
if(!is.null(target_accept)) {
target_accept = min(target_accept, 1 - sqrt(.Machine$double.eps))
target_accept = max(target_accept, 0 + sqrt(.Machine$double.eps))
Expand Down
4 changes: 2 additions & 2 deletions man/bgm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading