Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 122 additions & 13 deletions R/NetStats.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,64 @@
# Aggregate synth-population stratum-level durations under joint_lm. Per-ego
# predicted log(duration) from the joint_lm fit, marginalized over partner-race
# uncertainty using joint_nm_race_model when race = TRUE, then exponentiated
# and median-aggregated within each (same.age.grp x index.age.grp) stratum. The
# returned vector matches the row layout of `netparams$<layer>$durs.<layer>.byage`:
# nonmatch first (same.age.grp = 0), then matched-within-age-group 1..N, plus
# an optional deterministic "post-cessation" row when sex.cess.mod = TRUE.
#
# Used in `build_netstats(method = "joint")` to override the within-ARTnet
# stratum medians that build_netparams emits under duration.method = "joint_lm",
# so dissolution offsets reflect the synthetic target population's joint
# attribute distribution rather than ARTnet's. See issue #73.
#
# Returns NULL if joint_dur_model is NULL (caller falls back to netparams values).
.aggregate_synth_byage_durations <- function(joint_dur_model,
joint_nm_race_model,
synth_data,
n_age_grps,
sex_cess_extra_row = FALSE) {
if (is.null(joint_dur_model)) return(NULL)

if (!is.null(joint_nm_race_model)) {
p_same_race <- predict(joint_nm_race_model, newdata = synth_data,
type = "response")
} else {
p_same_race <- rep(0, nrow(synth_data))
}

predict_stratum_median <- function(same_age, age_grp_select) {
sel <- if (is.na(age_grp_select)) {
rep(TRUE, nrow(synth_data))
} else {
synth_data$age.grp == age_grp_select
}
if (!any(sel)) return(NA_real_)
sub <- synth_data[sel, , drop = FALSE]
sub$same.age.grp <- as.integer(same_age)

sub$same.race <- 0L
pred_log_0 <- predict(joint_dur_model, newdata = sub)
sub$same.race <- 1L
pred_log_1 <- predict(joint_dur_model, newdata = sub)

p <- p_same_race[sel]
dur_marg <- p * exp(pred_log_1) + (1 - p) * exp(pred_log_0)
median(dur_marg, na.rm = TRUE)
}

medians <- c(predict_stratum_median(0, NA),
vapply(seq_len(n_age_grps),
function(k) predict_stratum_median(1, k),
numeric(1)))

mean_dur_adj <- ifelse(is.na(medians) | medians <= 0,
NA_real_,
1 / (1 - 2^(-1 / medians)))

if (isTRUE(sex_cess_extra_row)) mean_dur_adj <- c(mean_dur_adj, 1)
mean_dur_adj
}


#' Calculate Network Target Statistics
#'
Expand Down Expand Up @@ -34,9 +95,14 @@
#' (`netparams$<layer>$joint_nm_{age,race}_model` and
#' `netparams$<layer>$joint_absdiff_{age,sqrtage}_model`). Under `"joint"`, edges and all
#' nodefactor target stats are internally consistent by construction
#' (`sum(nodefactor_<attr>) = 2 * edges`), so `edges.avg` has no effect. Dissolution
#' coefficients and `nodefactor_risk.grp` still use the univariate marginals — the
#' duration refactor lives on #63 and will be addressed in a follow-up PR.
#' (`sum(nodefactor_<attr>) = 2 * edges`), so `edges.avg` has no effect. The
#' `diss.byage` dissolution offset is computed from synth-aggregated stratum-level
#' durations when `build_netparams(..., duration.method = "joint_lm")` was used:
#' per-ego predicted log(duration) from `joint_duration_model`, marginalized over
#' partner-race uncertainty via `joint_nm_race_model`, then median-aggregated within
#' stratum. `nodefactor_risk.grp` and `diss.homog` still use the within-ARTnet
#' univariate / aggregated values — those are not consumed by the standard
#' EpiModelHIV-Template dissolution offset.
#' @param browser If `TRUE`, run `build_netparams` in interactive browser mode.
#'
#' @details
Expand Down Expand Up @@ -409,6 +475,12 @@ build_netstats <- function(epistats, netparams,
# construction (sum(nodefactor_<attr>) == 2 * edges). Inactive-age nodes
# (sex.cess.mod) are zeroed out so they contribute nothing to any layer's
# edges or nodefactor counts.
# Initialize synth-aggregated duration overrides; populated below under
# method = "joint" + duration.method = "joint_lm". When NULL, the dissolution
# offsets fall back to the within-ARTnet values from build_netparams.
synth_dur_main_byage <- NULL
synth_dur_casl_byage <- NULL

if (method == "joint") {
synth <- data.frame(
age.grp = out$attr$age.grp,
Expand All @@ -419,6 +491,11 @@ build_netstats <- function(epistats, netparams,
geogYN = 1L
)
synth$deg.tot3 <- pmin(out$attr$deg.tot, 3)
# Alias for the joint duration model, which uses index.age.grp on the RHS
# rather than age.grp (a leftover from the partnership-level fit on lmain
# in build_netparams). Joint_nm_*_model uses age.grp; we keep both columns
# available so predict() works for both families of models.
synth$index.age.grp <- synth$age.grp

pred_deg_main <- predict(netparams$main$joint_model, newdata = synth, type = "response")
pred_deg_casl <- predict(netparams$casl$joint_model, newdata = synth, type = "response")
Expand Down Expand Up @@ -465,6 +542,29 @@ build_netstats <- function(epistats, netparams,
# dyad predictions are multiplied by pred_deg downstream, so
# zeroing pred_deg above already suppresses their contribution.
}

# Synth-aggregated stratum durations (#73). Override the within-ARTnet
# joint_lm aggregation that build_netparams emits with synth-population
# aggregation. Only fires when duration.method = "joint_lm" was used in
# build_netparams (otherwise joint_duration_model is NULL).
n_age_grps_main <-
nrow(netparams$main$durs.main.byage) - 1L - as.integer(sex.cess.mod)
n_age_grps_casl <-
nrow(netparams$casl$durs.casl.byage) - 1L - as.integer(sex.cess.mod)
synth_dur_main_byage <- .aggregate_synth_byage_durations(
joint_dur_model = netparams$main$joint_duration_model,
joint_nm_race_model = netparams$main$joint_nm_race_model,
synth_data = synth,
n_age_grps = n_age_grps_main,
sex_cess_extra_row = sex.cess.mod
)
synth_dur_casl_byage <- .aggregate_synth_byage_durations(
joint_dur_model = netparams$casl$joint_duration_model,
joint_nm_race_model = netparams$casl$joint_nm_race_model,
synth_data = synth,
n_age_grps = n_age_grps_casl,
sex_cess_extra_row = sex.cess.mod
)
}


Expand Down Expand Up @@ -574,14 +674,21 @@ build_netstats <- function(epistats, netparams,
function(h) sum(pred_deg_main[out$attr$diag.status == h]), numeric(1))
}

# Dissolution (identical under both methods) -----------------------------
# Dissolution -----------------------------------------------------------
# diss.byage uses synth-aggregated durations under method = "joint" +
# duration.method = "joint_lm" (#73). diss.homog still uses the within-
# ARTnet aggregation from build_netparams; it is not consumed by
# EpiModelHIV-Template's tergm offset, and its synth analog can be added
# later without changing the byage interface that matters for production.
out$main$diss.homog <- dissolution_coefs(dissolution = ~offset(edges),
duration = netparams$main$durs.main.homog$mean.dur.adj,
d.rate = expect.mort)
out$main$diss.byage <- dissolution_coefs(dissolution = ~offset(edges) +
offset(nodematch("age.grp", diff = TRUE)),
duration = netparams$main$durs.main.byage$mean.dur.adj,
d.rate = expect.mort)
out$main$diss.byage <- dissolution_coefs(
dissolution = ~offset(edges) + offset(nodematch("age.grp", diff = TRUE)),
duration = if (!is.null(synth_dur_main_byage)) synth_dur_main_byage
else netparams$main$durs.main.byage$mean.dur.adj,
d.rate = expect.mort
)



Expand Down Expand Up @@ -680,14 +787,16 @@ build_netstats <- function(epistats, netparams,
function(h) sum(pred_deg_casl[out$attr$diag.status == h]), numeric(1))
}

# Dissolution (identical under both methods)
# Dissolution (see note on diss.byage / diss.homog at the main layer block)
out$casl$diss.homog <- dissolution_coefs(dissolution = ~offset(edges),
duration = netparams$casl$durs.casl.homog$mean.dur.adj,
d.rate = expect.mort)
out$casl$diss.byage <- dissolution_coefs(dissolution = ~offset(edges) +
offset(nodematch("age.grp", diff = TRUE)),
duration = netparams$casl$durs.casl.byage$mean.dur.adj,
d.rate = expect.mort)
out$casl$diss.byage <- dissolution_coefs(
dissolution = ~offset(edges) + offset(nodematch("age.grp", diff = TRUE)),
duration = if (!is.null(synth_dur_casl_byage)) synth_dur_casl_byage
else netparams$casl$durs.casl.byage$mean.dur.adj,
d.rate = expect.mort
)


# One-Time Model ----------------------------------------------------------
Expand Down
11 changes: 8 additions & 3 deletions man/build_netstats.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

126 changes: 126 additions & 0 deletions tests/testthat/test-duration-gcomp-synth.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Tests for the synth-aggregated duration g-computation (#73). Under
# method = "joint" + duration.method = "joint_lm", build_netstats overrides
# the within-ARTnet stratum medians from build_netparams with synth-aggregated
# medians: predict joint_lm log-duration per synth ego, marginalize over
# partner-race uncertainty via joint_nm_race_model, take median per stratum.

skip_without_artnetdata <- function() {
testthat::skip_if(system.file(package = "ARTnetData") == "",
"ARTnetData not installed")
}

build_setup <- function(race = TRUE, dur_method = "joint_lm",
netparams_method = "joint", race.prop = NULL,
sex.cess = FALSE) {
set.seed(20260420L)
age_args <- if (sex.cess) {
list(age.limits = c(15, 100), age.sexual.cessation = 65)
} else list()
epistats <- do.call(build_epistats, c(list(
geog.lvl = "city", geog.cat = "Atlanta",
init.hiv.prev = c(0.33, 0.137, 0.084),
race = race, time.unit = 7
), age_args))
set.seed(20260420L)
np <- build_netparams(epistats, smooth.main.dur = TRUE,
method = netparams_method,
duration.method = dur_method)
set.seed(20260420L)
young.prop <- if (sex.cess) 0.99 else NULL
ns <- build_netstats(epistats, np,
expect.mort = 0.000478213, network.size = 5000,
race.prop = race.prop, young.prop = young.prop,
method = if (netparams_method == "joint") "joint" else "existing")
list(epistats = epistats, netparams = np, netstats = ns)
}


test_that("synth override fires under joint + joint_lm and differs from netparams", {
skip_without_artnetdata()
obj <- build_setup()
netparams_main <- obj$netparams$main$durs.main.byage$mean.dur.adj
diss_main <- obj$netstats$main$diss.byage$duration
# Shapes match
expect_length(diss_main, length(netparams_main))
# At least one stratum's value diverges (synth attribute distribution
# differs from ARTnet's), confirming the override fires.
expect_true(any(abs(diss_main - netparams_main) > 0.5))
})


test_that("synth override does not fire under duration.method = 'empirical'", {
skip_without_artnetdata()
obj <- build_setup(dur_method = "empirical")
expect_null(obj$netparams$main$joint_duration_model)
netparams_main <- obj$netparams$main$durs.main.byage$mean.dur.adj
diss_main <- obj$netstats$main$diss.byage$duration
# Without joint_duration_model, dissolution_coefs falls back to the
# netparams values exactly.
expect_equal(diss_main, netparams_main)
})


test_that("synth override does not fire under method = 'existing'", {
skip_without_artnetdata()
obj <- build_setup(netparams_method = "existing", dur_method = "joint_lm")
netparams_main <- obj$netparams$main$durs.main.byage$mean.dur.adj
diss_main <- obj$netstats$main$diss.byage$duration
# Even though joint_duration_model exists in netparams, build_netstats
# under method = "existing" doesn't construct synth predictions and so
# uses the within-ARTnet aggregation directly.
expect_equal(diss_main, netparams_main)
})


test_that("synth-aggregated durations diverge under shifted race.prop", {
skip_without_artnetdata()
default_run <- build_setup() # ARTnetData::race.dist Atlanta
shifted_run <- build_setup(race.prop = c(0.35, 0.25, 0.40))
d_default <- default_run$netstats$casl$diss.byage$duration
d_shifted <- shifted_run$netstats$casl$diss.byage$duration
# The casl joint_lm has stronger race-related effects than main; we
# expect at least one stratum to diverge by > 1% under the population
# shift.
rel_diff <- abs(d_default - d_shifted) / d_default
expect_true(any(rel_diff > 0.01),
info = paste("max relative diff =",
sprintf("%.4f", max(rel_diff, na.rm = TRUE))))
})


test_that("sex.cess.mod preserves the deterministic post-cessation row", {
skip_without_artnetdata()
obj <- build_setup(sex.cess = TRUE)
d_main <- obj$netstats$main$diss.byage$duration
d_casl <- obj$netstats$casl$diss.byage$duration
# Last row should equal 1 (deterministic dissolution after sexual cessation)
expect_equal(d_main[length(d_main)], 1)
expect_equal(d_casl[length(d_casl)], 1)
# Length matches netparams shape (1 nonmatch + N age-grps + 1 dead row)
expect_equal(length(d_main),
nrow(obj$netparams$main$durs.main.byage))
})


test_that("dissolution_coefs object is well-formed under override", {
skip_without_artnetdata()
obj <- build_setup()
for (layer in c("main", "casl")) {
diss <- obj$netstats[[layer]]$diss.byage
expect_s3_class(diss, "disscoef")
expect_true(all(is.finite(diss$coef.diss)))
# No NaN / Inf in d.rate adjustment
expect_true(all(is.finite(diss$d.rate)))
}
})


test_that("race = FALSE skips the partner-race marginalization gracefully", {
skip_without_artnetdata()
obj <- build_setup(race = FALSE)
# joint_nm_race_model is not fit when race = FALSE; the helper should
# treat partner-race probability as 0 and predict only at same.race = 0.
diss <- obj$netstats$main$diss.byage$duration
expect_true(all(is.finite(diss)))
expect_true(all(diss > 0))
})
Loading