From 92eafb5d081d77ba7619aa5c0c31994bf8662bf1 Mon Sep 17 00:00:00 2001 From: Samuel Jenness Date: Sun, 19 Apr 2026 15:22:12 -0400 Subject: [PATCH] Add backward-compatibility validation harness for joint g-comp refactor Sets up inst/validation/ as the pre/post regression harness for the joint g-computation refactor (issues #61-#65). The harness captures golden snapshots of build_netparams() + build_netstats() output on main-era code, then diffs against the refactor branch to verify that method = "existing" (or the legacy default) reproduces prior behavior byte-identically. Contents: - README.md documents the capture/compare workflow. - validate_backward_compat.R provides capture_snapshot() and compare_to_snapshot() entry points. Iterates PARAM_SETS covering Atlanta+race, national no-geog, and Atlanta no-race. Uses a fixed seed so stochastic bits of build_netstats() are reproducible, and strips additive fields (e.g. \$joint_model) before comparison. - epimodelhiv_template_ref/ pins verbatim copies of the downstream consumer scripts from EpiModelHIV-Template/R/A-networks/ so the backward-compat contract is explicit and does not drift silently. - netstats_contract.md distills exactly which netstats fields the template ERGM specs read. - snapshots/*.rds is gitignored (large, local). --- .gitignore | 1 + inst/validation/README.md | 72 ++++++ .../epimodelhiv_template_ref/initialize.R | 53 +++++ .../epimodelhiv_template_ref/model_casl.R | 43 ++++ .../epimodelhiv_template_ref/model_main.R | 43 ++++ .../epimodelhiv_template_ref/model_ooff.R | 41 ++++ inst/validation/netstats_contract.md | 60 +++++ inst/validation/snapshots/.gitkeep | 0 inst/validation/validate_backward_compat.R | 220 ++++++++++++++++++ 9 files changed, 533 insertions(+) create mode 100644 inst/validation/README.md create mode 100644 inst/validation/epimodelhiv_template_ref/initialize.R create mode 100644 inst/validation/epimodelhiv_template_ref/model_casl.R create mode 100644 inst/validation/epimodelhiv_template_ref/model_main.R create mode 100644 inst/validation/epimodelhiv_template_ref/model_ooff.R create mode 100644 inst/validation/netstats_contract.md create mode 100644 inst/validation/snapshots/.gitkeep create mode 100644 inst/validation/validate_backward_compat.R diff --git a/.gitignore b/.gitignore index 0443d44..73c9330 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ .DS_Store vignettes/*.html vignettes/*.R +inst/validation/snapshots/*.rds diff --git a/inst/validation/README.md b/inst/validation/README.md new file mode 100644 index 0000000..326f877 --- /dev/null +++ b/inst/validation/README.md @@ -0,0 +1,72 @@ +# Validation Infrastructure + +This directory supports the joint g-computation refactor (issues +[#61](https://github.com/EpiModel/ARTnet/issues/61)–[#65](https://github.com/EpiModel/ARTnet/issues/65)) +by giving us two things a standard `testthat` suite cannot: + +1. A **byte-for-byte reference snapshot** of the output that + `build_netparams()` and `build_netstats()` produce on the pre-refactor + `main` branch, captured once and compared against on every subsequent + commit. +2. A pinned copy of the **downstream consumer** code + (`EpiModelHIV-Template/R/A-networks/`) so we always know exactly which + fields of `netstats` must remain stable — no guessing. + +## Files + +- `epimodelhiv_template_ref/` — verbatim copies of + `~/git/EpiModelHIV-Template/R/A-networks/{initialize,model_main,model_casl,model_ooff}.R` + taken on 2026-04-19. These are the ERGM specifications that consume + `netstats`; they define the backward-compatibility contract. Do not edit + unless the upstream template changes. +- `netstats_contract.md` — distilled list of exactly which `netstats` fields + the template scripts read, by layer. +- `validate_backward_compat.R` — `capture_snapshot()` and + `compare_to_snapshot()` functions. Run `capture_snapshot()` on pre-refactor + `main`; run `compare_to_snapshot()` on the refactor branch with + `method = "existing"` (or equivalent default) and expect zero diffs. +- `snapshots/` — created on first capture run. `.gitignore`d by default; + the captured `.rds` files are large and should not be checked in. + +## Workflow + +### Step A — Before starting the refactor (pre-capture) + +On the pre-refactor `main` branch, with `ARTnetData` installed: + +```r +devtools::load_all() # from the ARTnet repo root +source(system.file("validation/validate_backward_compat.R", package = "ARTnet")) +capture_snapshot() # writes inst/validation/snapshots/*.rds +``` + +This saves one snapshot per parameter set (see the `PARAM_SETS` list in +`validate_backward_compat.R`). Commit the snapshot files only if they are +small enough; otherwise keep them locally and rely on a hash digest that +**is** committed. + +### Step B — During/after the refactor (compare) + +On the refactor branch, with the new joint-GLM code in place: + +```r +devtools::load_all() +source(system.file("validation/validate_backward_compat.R", package = "ARTnet")) +compare_to_snapshot(method = "existing") +``` + +The call must report `ALL MATCH` before the PR is considered mergeable. +Any field-level diff is a backward-compatibility regression. + +## Why not just `testthat::expect_equal()`? + +Two reasons: +1. These runs require `ARTnetData` (private) and take minutes to execute — + they do not belong in CI. +2. `testthat` snapshots are text-based and don't roundtrip well for deeply + nested lists containing S3 objects (`glm`, `lm`, `dissolution_coefs`). + `saveRDS()` + `all.equal()` is the simplest reliable approach here. + +Unit tests for individual joint-GLM behaviors (convergence, marginal +recovery, coefficient sanity — see CLAUDE.md §4.5) should still live in +`tests/testthat/` as normal. diff --git a/inst/validation/epimodelhiv_template_ref/initialize.R b/inst/validation/epimodelhiv_template_ref/initialize.R new file mode 100644 index 0000000..397db73 --- /dev/null +++ b/inst/validation/epimodelhiv_template_ref/initialize.R @@ -0,0 +1,53 @@ +## REFERENCE COPY (2026-04-19) of EpiModelHIV-Template/R/A-networks/initialize.R +## DO NOT EDIT — this exists to pin the downstream consumer contract. +## If the upstream file changes, refresh this copy and update +## `inst/validation/netstats_contract.md`. + +## Initialize the ARTnet data objects and the networks to be fitted +## +## This script should not be run directly. But `sourced` by `1-estimation.R` + +if (system.file(package = "ARTnetData") == "") { + message( + "=================================================================\n", + "You are currently using the example population provided by ARTnet\n", + "Install ARTnetData to get all the features.\n", + "Follow the instructions at the link below to get access to it.\n", + "https://github.com/EpiModel/ARTnet/tree/main?tab=readme-ov-file#artnetdata-dependency\n", + "=================================================================\n" + ) + + epistats <- readRDS(system.file("epistats-example.rds", package = "ARTnet")) + netstats <- readRDS(system.file("netstats-example.rds", package = "ARTnet")) +} else { + epistats <- build_epistats( + geog.lvl = "city", + geog.cat = "Atlanta", + init.hiv.prev = c(0.33, 0.137, 0.084), + race = TRUE, + time.unit = time_unit + ) + + netparams <- build_netparams( + epistats = epistats, + smooth.main.dur = TRUE + ) + + netstats <- build_netstats( + epistats, + netparams, + expect.mort = 0.000478213, + network.size = networks_size + ) +} + + +nw <- EpiModel::network_initialize(netstats$demog$num) +nw_main <- EpiModel::set_vertex_attribute( + nw, + names(netstats$attr), + netstats$attr +) + +nw_casl <- nw_main +nw_inst <- nw_main diff --git a/inst/validation/epimodelhiv_template_ref/model_casl.R b/inst/validation/epimodelhiv_template_ref/model_casl.R new file mode 100644 index 0000000..18dbd43 --- /dev/null +++ b/inst/validation/epimodelhiv_template_ref/model_casl.R @@ -0,0 +1,43 @@ +## REFERENCE COPY (2026-04-19) of EpiModelHIV-Template/R/A-networks/model_casl.R +## DO NOT EDIT — pins the ERGM specification consuming netstats$casl. + +## Define and fit the *casual* network model +## +## This script should not be run directly. But `sourced` by `1-estimation.R` + +# Formula +model_casl <- ~ edges + + nodematch("age.grp", diff = TRUE) + + nodefactor("age.grp", levels = -5) + + nodematch("race", diff = FALSE) + + nodefactor("race", levels = -1) + + nodefactor("deg.main", levels = -3) + + concurrent + + degrange(from = 4) + + nodematch("role.class", diff = TRUE, levels = c(1, 2)) + +# Target Stats +netstats_casl <- c( + edges = netstats$casl$edges, + nodematch_age.grp = netstats$casl$nodematch_age.grp, + nodefactor_age.grp = netstats$casl$nodefactor_age.grp[-5], + nodematch_race = netstats$casl$nodematch_race_diffF, + nodefactor_race = netstats$casl$nodefactor_race[-1], + nodefactor_deg.main = netstats$casl$nodefactor_deg.main[-3], + concurrent = netstats$casl$concurrent, + degrange = 0, + nodematch_role.class = c(0, 0) +) |> unname() + +# Fit model +fit_casl <- EpiModel::netest( + nw_casl, + formation = model_casl, + target.stats = netstats_casl, + coef.diss = netstats$casl$diss.byage, + set.control.ergm = control_ergm, + verbose = FALSE +) |> trim_netest() + +# Keep only the necessary objects +rm(model_casl, netstats_casl) diff --git a/inst/validation/epimodelhiv_template_ref/model_main.R b/inst/validation/epimodelhiv_template_ref/model_main.R new file mode 100644 index 0000000..012243c --- /dev/null +++ b/inst/validation/epimodelhiv_template_ref/model_main.R @@ -0,0 +1,43 @@ +## REFERENCE COPY (2026-04-19) of EpiModelHIV-Template/R/A-networks/model_main.R +## DO NOT EDIT — pins the ERGM specification consuming netstats$main. + +## Define and fit the *main* network model +## +## This script should not be run directly. But `sourced` by `1-estimation.R` + +# Formula +model_main <- ~ edges + + nodematch("age.grp", diff = TRUE) + + nodefactor("age.grp", levels = -1) + + nodematch("race", diff = FALSE) + + nodefactor("race", levels = -1) + + nodefactor("deg.casl", levels = -1) + + concurrent + + degrange(from = 3) + + nodematch("role.class", diff = TRUE, levels = c(1, 2)) + +# Target Stats +netstats_main <- c( + edges = netstats$main$edges, + nodematch_age.grp = netstats$main$nodematch_age.grp, + nodefactor_age.grp = netstats$main$nodefactor_age.grp[-1], + nodematch_race = netstats$main$nodematch_race_diffF, + nodefactor_race = netstats$main$nodefactor_race[-1], + nodefactor_deg.casl = netstats$main$nodefactor_deg.casl[-1], + concurrent = netstats$main$concurrent, + degrange = 0, + nodematch_role.class = c(0, 0) +) |> unname() + +# Fit model +fit_main <- EpiModel::netest( + nw_main, + formation = model_main, + target.stats = netstats_main, + coef.diss = netstats$main$diss.byage, + set.control.ergm = control_ergm, + verbose = FALSE +) |> EpiModel::trim_netest() + +# Keep only the necessary objects +rm(model_main, netstats_main) diff --git a/inst/validation/epimodelhiv_template_ref/model_ooff.R b/inst/validation/epimodelhiv_template_ref/model_ooff.R new file mode 100644 index 0000000..f5bc21e --- /dev/null +++ b/inst/validation/epimodelhiv_template_ref/model_ooff.R @@ -0,0 +1,41 @@ +## REFERENCE COPY (2026-04-19) of EpiModelHIV-Template/R/A-networks/model_ooff.R +## DO NOT EDIT — pins the ERGM specification consuming netstats$inst. + +## Define and fit the *one-off* network model +## +## This script should not be run directly. But `sourced` by `1-estimation.R` + +# Formula +model_ooff <- ~ edges + + nodematch("age.grp", diff = FALSE) + + nodefactor("age.grp", levels = -1) + + nodematch("race", diff = FALSE) + + nodefactor("race", levels = -1) + + nodefactor("risk.grp", levels = -5) + + nodefactor("deg.tot", levels = -1) + + nodematch("role.class", diff = TRUE, levels = c(1, 2)) + +# Target Stats +netstats_ooff <- c( + edges = netstats$inst$edges, + nodematch_age.grp = sum(netstats$inst$nodematch_age.grp), + nodefactor_age.grp = netstats$inst$nodefactor_age.grp[-1], + nodematch_race = netstats$inst$nodematch_race_diffF, + nodefactor_race = netstats$inst$nodefactor_race[-1], + nodefactor_risk.grp = netstats$inst$nodefactor_risk.grp[-5], + nodefactor_deg.tot = netstats$inst$nodefactor_deg.tot[-1], + nodematch_role.class = c(0, 0) +) |> unname() + +# Fit model +fit_ooff <- EpiModel::netest( + nw_inst, + formation = model_ooff, + target.stats = netstats_ooff, + coef.diss = dissolution_coefs(~ offset(edges), 1), + set.control.ergm = control_ergm, + verbose = FALSE +) |> trim_netest() + +# Keep only the necessary objects +rm(model_ooff, netstats_ooff) diff --git a/inst/validation/netstats_contract.md b/inst/validation/netstats_contract.md new file mode 100644 index 0000000..a040226 --- /dev/null +++ b/inst/validation/netstats_contract.md @@ -0,0 +1,60 @@ +# netstats Backward-Compatibility Contract + +The `netstats` object returned by `build_netstats()` is consumed by the +ERGM estimation scripts in `EpiModelHIV-Template/R/A-networks/` (verbatim +copy pinned in `epimodelhiv_template_ref/`). These are the fields the +template reads — they must remain byte-identical under +`method = "existing"` (or whatever we name the legacy flag). + +Snapshot taken 2026-04-19 against EpiModelHIV-Template@main. + +## `initialize.R` +- `netstats$demog$num` — network size (scalar integer) +- `netstats$attr` — named list of vertex attributes (age, sqrt.age, + age.grp, active.sex, race, deg.casl, deg.main, deg.tot, risk.grp, + role.class, diag.status). The attribute vectors are constructed via + `sample()` / `apportion_lr()` / `rbinom()` and depend on RNG state. + Validation must `set.seed()` before comparison. + +## `model_main.R` +- `netstats$main$edges` +- `netstats$main$nodematch_age.grp` (vector, one per age group) +- `netstats$main$nodefactor_age.grp` (vector, one per age group) +- `netstats$main$nodematch_race_diffF` (scalar) +- `netstats$main$nodefactor_race` (vector, one per race group) +- `netstats$main$nodefactor_deg.casl` (vector, one per deg.casl level) +- `netstats$main$concurrent` (scalar) +- `netstats$main$diss.byage` — `dissolution_coefs` S3 object + +## `model_casl.R` +- `netstats$casl$edges` +- `netstats$casl$nodematch_age.grp` +- `netstats$casl$nodefactor_age.grp` +- `netstats$casl$nodematch_race_diffF` +- `netstats$casl$nodefactor_race` +- `netstats$casl$nodefactor_deg.main` +- `netstats$casl$concurrent` +- `netstats$casl$diss.byage` + +## `model_ooff.R` +- `netstats$inst$edges` +- `netstats$inst$nodematch_age.grp` +- `netstats$inst$nodefactor_age.grp` +- `netstats$inst$nodematch_race_diffF` +- `netstats$inst$nodefactor_race` +- `netstats$inst$nodefactor_risk.grp` +- `netstats$inst$nodefactor_deg.tot` + +## Not directly consumed but still part of the contract +Anything else currently in `netstats$*` — `nodematch_race`, +`absdiff_age`, `absdiff_sqrt.age`, etc. — is also part of the contract +by default because the package ships it publicly. The validation script +does a full-object diff rather than checking only the fields above. + +## `netparams` contract +The validation also captures `netparams` whole (the input to +`build_netstats`). Joint models are *additive* outputs (new +`$joint_model` fields), so: +- existing `netparams$main$*`, `$casl$*`, `$inst$*`, `$all$*` fields must be + byte-identical under `method = "existing"`; +- new fields (`$joint_model`) are ignored during comparison. diff --git a/inst/validation/snapshots/.gitkeep b/inst/validation/snapshots/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/inst/validation/validate_backward_compat.R b/inst/validation/validate_backward_compat.R new file mode 100644 index 0000000..f12ea92 --- /dev/null +++ b/inst/validation/validate_backward_compat.R @@ -0,0 +1,220 @@ +## Backward-compatibility validation harness for the joint g-computation +## refactor (issues #61-#65). See inst/validation/README.md for the workflow. +## +## Two entry points: +## capture_snapshot() - run on pre-refactor `main` to save reference +## compare_to_snapshot(...) - run on refactor branch to diff against reference +## +## Both functions iterate over the parameter sets in PARAM_SETS (edit below to +## add coverage). Each set defines the args to build_epistats / build_netparams +## / build_netstats. Keep them small — this is a regression harness, not a +## simulation. + +# ---- Configuration ----------------------------------------------------------- + +# Where snapshots are stored, relative to the package install (or the repo +# root when using devtools::load_all()). +.snapshot_dir <- function() { + candidates <- c( + # Dev mode: ARTnet repo root / inst / validation / snapshots + file.path(getwd(), "inst", "validation", "snapshots"), + # Installed package + system.file("validation", "snapshots", package = "ARTnet") + ) + hit <- candidates[nzchar(candidates) & dir.exists(dirname(candidates))] + if (length(hit) == 0) { + stop("Cannot locate inst/validation/. Run from the ARTnet repo root or ", + "install the package.") + } + dir <- hit[1] + if (!dir.exists(dir)) dir.create(dir, recursive = TRUE) + dir +} + +# Parameter sets to cover. Add more as edge cases surface. +# Each entry: a list with $name (snapshot key), $epistats (args to build_epistats), +# $netparams (args to build_netparams), $netstats (args to build_netstats). +PARAM_SETS <- list( + list( + name = "atlanta_default", + epistats = list( + geog.lvl = "city", + geog.cat = "Atlanta", + init.hiv.prev = c(0.33, 0.137, 0.084), + race = TRUE, + time.unit = 7 + ), + netparams = list(smooth.main.dur = TRUE), + netstats = list(expect.mort = 0.000478213, network.size = 5000) + ), + list( + name = "national_no_geog", + epistats = list(race = TRUE, time.unit = 7), + netparams = list(smooth.main.dur = TRUE), + netstats = list(expect.mort = 0.000478213, network.size = 5000) + ), + list( + name = "atlanta_no_race", + epistats = list( + geog.lvl = "city", + geog.cat = "Atlanta", + init.hiv.prev = c(0.33, 0.137, 0.084), + race = FALSE, + time.unit = 7 + ), + netparams = list(smooth.main.dur = TRUE), + netstats = list(expect.mort = 0.000478213, network.size = 5000) + ) +) + +# Fixed seed so the stochastic bits of build_netstats (sample/rbinom/runif) +# are reproducible across runs. +.VALIDATION_SEED <- 20260419L + + +# ---- Utilities --------------------------------------------------------------- + +.require_artnetdata <- function() { + if (system.file(package = "ARTnetData") == "") { + stop("ARTnetData not installed; validation cannot run. ", + "See https://github.com/EpiModel/ARTnet#artnetdata-dependency") + } +} + +# Strip fields that are new/additive under the refactor so they don't cause +# spurious diffs against a pre-refactor snapshot. Extend this list as new +# fields are added (e.g., $joint_model). +.strip_additive <- function(netparams) { + for (layer in c("main", "casl", "inst", "all")) { + if (is.null(netparams[[layer]])) next + netparams[[layer]]$joint_model <- NULL + } + netparams +} + +# Run one parameter set end-to-end. `netparams_extra` lets the caller pass +# e.g. method = "existing" post-refactor without affecting the pre-capture. +.run_one <- function(set, netparams_extra = list()) { + set.seed(.VALIDATION_SEED) + epistats <- do.call(ARTnet::build_epistats, set$epistats) + + netparams_args <- c(list(epistats = epistats), set$netparams, netparams_extra) + netparams <- do.call(ARTnet::build_netparams, netparams_args) + + netstats_args <- c(list(epistats = epistats, netparams = netparams), + set$netstats) + netstats <- do.call(ARTnet::build_netstats, netstats_args) + + list(netparams = netparams, netstats = netstats) +} + + +# ---- Public entry points ----------------------------------------------------- + +#' Capture golden-reference snapshots from the current code +#' +#' Call once on the pre-refactor `main` branch. Saves one `.rds` per entry in +#' `PARAM_SETS` under `inst/validation/snapshots/`. +#' +#' @param overwrite If TRUE, overwrite existing snapshot files. +capture_snapshot <- function(overwrite = FALSE) { + .require_artnetdata() + dir <- .snapshot_dir() + + for (set in PARAM_SETS) { + path <- file.path(dir, paste0(set$name, ".rds")) + if (file.exists(path) && !overwrite) { + message("SKIP (exists): ", path, " -- pass overwrite = TRUE to replace") + next + } + message("CAPTURE: ", set$name, " -> ", path) + result <- .run_one(set) + saveRDS(result, path) + } + invisible(TRUE) +} + + +#' Compare current code output against the captured snapshots +#' +#' Call on the refactor branch. Reports per-parameter-set diffs between the +#' current code's output and the snapshot saved by `capture_snapshot()`. The +#' joint g-comp refactor should pass this with zero diffs when the legacy +#' code path is selected (e.g. `method = "existing"`). +#' +#' @param ... Passed as additional args to `build_netparams()`. For the +#' refactor branch this will typically be `method = "existing"` once the +#' arg exists; on the pre-refactor branch leave empty. +#' @param tolerance Numeric tolerance for `all.equal()`. Default 0 (exact). +#' @return Invisibly: TRUE iff all sets match; FALSE otherwise. +compare_to_snapshot <- function(..., tolerance = 0) { + .require_artnetdata() + dir <- .snapshot_dir() + netparams_extra <- list(...) + + overall_ok <- TRUE + for (set in PARAM_SETS) { + path <- file.path(dir, paste0(set$name, ".rds")) + if (!file.exists(path)) { + warning("No snapshot for ", set$name, " at ", path, + " -- did you forget to run capture_snapshot()?") + overall_ok <- FALSE + next + } + message("COMPARE: ", set$name) + ref <- readRDS(path) + cur <- .run_one(set, netparams_extra = netparams_extra) + + np_ref <- .strip_additive(ref$netparams) + np_cur <- .strip_additive(cur$netparams) + np_diff <- all.equal(np_ref, np_cur, tolerance = tolerance) + ns_diff <- all.equal(ref$netstats, cur$netstats, tolerance = tolerance) + + np_ok <- isTRUE(np_diff) + ns_ok <- isTRUE(ns_diff) + if (np_ok && ns_ok) { + message(" OK (netparams + netstats identical)") + } else { + overall_ok <- FALSE + if (!np_ok) { + message(" FAIL netparams:") + message(paste(" ", np_diff, collapse = "\n")) + } + if (!ns_ok) { + message(" FAIL netstats:") + message(paste(" ", ns_diff, collapse = "\n")) + } + } + } + + if (overall_ok) { + message("\n==============================") + message("ALL MATCH (", length(PARAM_SETS), " parameter sets)") + message("==============================") + } else { + message("\n==============================") + message("REGRESSION DETECTED -- see diffs above") + message("==============================") + } + invisible(overall_ok) +} + + +#' Show which snapshots currently exist on disk. +list_snapshots <- function() { + dir <- .snapshot_dir() + files <- list.files(dir, pattern = "\\.rds$", full.names = TRUE) + if (length(files) == 0) { + message("No snapshots in ", dir) + return(invisible(character(0))) + } + info <- file.info(files) + out <- data.frame( + name = basename(files), + size_kb = round(info$size / 1024, 1), + mtime = info$mtime, + row.names = NULL + ) + print(out) + invisible(files) +}