diff --git a/R/gs_power_rd.R b/R/gs_power_rd.R index 4bfc3910..f3b0b730 100644 --- a/R/gs_power_rd.R +++ b/R/gs_power_rd.R @@ -343,9 +343,18 @@ gs_power_rd <- function( select(analysis, n, rd, rd0, theta1, theta0, info, info0, info_frac, info_frac0) ) + # Get input parameters to output ---- + input <- list( + p_c = p_c, p_e = p_e, n = n, rd0 = rd0, ratio = ratio, weight = weight, + upper = upper, lower = lower, upar = upar, lpar = lpar, + info_scale = info_scale, binding = binding, test_upper = test_upper, + test_lower = test_lower, r = r, tol = tol + ) + ans <- structure( list( design = "rd", + input = input, bound = bound |> filter(!is.infinite(z)), analysis = analysis ), diff --git a/R/to_integer.R b/R/to_integer.R index 8d19467b..10fb011f 100644 --- a/R/to_integer.R +++ b/R/to_integer.R @@ -576,18 +576,18 @@ to_integer.gs_design <- function(x, round_up_final = TRUE, ratio = x$input$ratio x_new$analysis$n <- round(x_new$analysis$n) if (!is_rd) x_new$analysis$event <- round(x_new$analysis$event) - # Add attributes to x_new to identify whether it is a gs_design_ahr orbject or gs_power_ahr object - if ("analysis_time" %in% names(x$input) && "info_frac" %in% names(x$input) && "ahr" %in% class(x)) { + # Add attributes to x_new to identify whether it is a gs_design_ahr object or gs_power_ahr object + if ("analysis_time" %in% names(x$input) && "info_frac" %in% names(x$input) && is_ahr) { attr(x_new, 'uninteger_is_from') <- "gs_design_ahr" - } else if ("analysis_time" %in% names(x$input) && "event" %in% names(x$input) && "ahr" %in% class(x)) { + } else if ("analysis_time" %in% names(x$input) && "event" %in% names(x$input) && is_ahr) { attr(x_new, 'uninteger_is_from') <- "gs_power_ahr" - } else if ("analysis_time" %in% names(x$input) && "info_frac" %in% names(x$input) && "wlr" %in% class(x)) { + } else if ("analysis_time" %in% names(x$input) && "info_frac" %in% names(x$input) && is_wlr) { attr(x_new, 'uninteger_is_from') <- "gs_design_wlr" - } else if ("analysis_time" %in% names(x$input) && "event" %in% names(x$input) && "wlr" %in% class(x)) { + } else if ("analysis_time" %in% names(x$input) && "event" %in% names(x$input) && is_wlr) { attr(x_new, 'uninteger_is_from') <- "gs_power_wlr" - } else if (!("n" %in% names(x$input)) && "rd" %in% class(x)) { + } else if (!("n" %in% names(x$input)) && is_rd) { attr(x_new, 'uninteger_is_from') <- "gs_design_rd" - } else if ("n" %in% names(x$input) && "rd" %in% class(x)) { + } else if ("n" %in% names(x$input) && is_rd) { attr(x_new, 'uninteger_is_from') <- "gs_power_rd" } diff --git a/tests/testthat/test-developer-to_integer.R b/tests/testthat/test-developer-to_integer.R index ea023047..33e79b8c 100644 --- a/tests/testthat/test-developer-to_integer.R +++ b/tests/testthat/test-developer-to_integer.R @@ -279,3 +279,21 @@ test_that("verify the crossing prob of a MB design at IA1 under null", { expect_equal((x$bounds |> filter(bound == "upper", analysis == 1))$probability0, sfLDOF(alpha = .025, t = x$analysis$info_frac0)$spend[1]) }) + +test_that("The attribute `uninteger_is_from` matches the input design object", { + for (design_func in c("gs_design_ahr", "gs_design_rd", "gs_design_wlr")) { + x <- get(design_func)() |> to_integer() + expect_identical(attr(x, "uninteger_is_from"), design_func) + } + + lpar <- list(sf = gsDesign::sfLDOF, total_spend = 0.1) + for (power_func in c("gs_power_ahr", "gs_power_rd", "gs_power_wlr")) { + if (power_func == "gs_power_rd") { + x <- get(power_func)() + } else { + x <- get(power_func)(lpar = lpar) + } + x <- to_integer(x) + expect_identical(attr(x, "uninteger_is_from"), power_func) + } +})