Skip to content

Commit

Permalink
find_MAP(MH) will track n0
Browse files Browse the repository at this point in the history
  • Loading branch information
PrzeChoj committed Jun 5, 2024
1 parent 5b981ea commit 0478686
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 11 deletions.
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# gips 1.2.2.9000

### Update to functions

- `find_MAP(optimizer = "MH")` tracks the `n0` along the optimization.


# gips 1.2.2

### Bugfix:
Expand Down
19 changes: 14 additions & 5 deletions R/find_MAP.R
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ find_MAP <- function(g, max_iter = NA, optimizer = NA,
n0 <- max(structure_constants[["r"]] * structure_constants[["d"]] / structure_constants[["k"]])
if (attr(g, "was_mean_estimated")) { # correction for estimating the mean
n0 <- n0 + 1
attr(gips_optimized, "optimization_info")[["all_n0"]] <- attr(gips_optimized, "optimization_info")[["all_n0"]] + 1 # when all_n0 is NA, all_n0 + 1 is also an NA
}
if (n0 > number_of_observations) {
rlang::warn(c(
Expand Down Expand Up @@ -432,13 +433,15 @@ Metropolis_Hastings_optimizer <- function(S,

acceptance <- rep(FALSE, max_iter)
log_posteriori_values <- rep(0, max_iter)
all_n0 <- rep(0, max_iter)
if (save_all_perms) {
visited_perms <- list()
visited_perms[[1]] <- start_perm
} else {
visited_perms <- NA
}
current_perm <- start_perm
all_n0[1] <- get_n0_from_perm(current_perm, was_mean_estimated = FALSE) # was_mean_estimated will be corrected in find_MAP()

if (show_progress_bar) {
progressBar <- utils::txtProgressBar(min = 0, max = max_iter, initial = 1)
Expand Down Expand Up @@ -469,6 +472,7 @@ Metropolis_Hastings_optimizer <- function(S,
}
log_posteriori_values[i + 1] <- goal_function_perm_proposal
acceptance[i] <- TRUE
all_n0[i+1] <- get_n0_from_perm(current_perm, was_mean_estimated = FALSE) # was_mean_estimated will be corrected in find_MAP()

if (found_perm_log_posteriori < log_posteriori_values[i + 1]) {
found_perm_log_posteriori <- log_posteriori_values[i + 1]
Expand All @@ -479,6 +483,7 @@ Metropolis_Hastings_optimizer <- function(S,
visited_perms[[i + 1]] <- current_perm
}
log_posteriori_values[i + 1] <- log_posteriori_values[i]
all_n0[i+1] <- all_n0[i]
}
}

Expand Down Expand Up @@ -509,7 +514,8 @@ Metropolis_Hastings_optimizer <- function(S,
"did_converge" = NULL,
"best_perm_log_posteriori" = found_perm_log_posteriori,
"optimization_time" = NA,
"whole_optimization_time" = NA
"whole_optimization_time" = NA,
"all_n0" = all_n0
)


Expand Down Expand Up @@ -664,7 +670,8 @@ hill_climbing_optimizer <- function(S,
"did_converge" = did_converge,
"best_perm_log_posteriori" = goal_function_best_logvalues[iteration],
"optimization_time" = NA,
"whole_optimization_time" = NA
"whole_optimization_time" = NA,
"all_n0" = NA
)


Expand Down Expand Up @@ -717,7 +724,7 @@ brute_force_optimizer <- function(
iterations_to_perform <-
if ((3 <= perm_size) && (perm_size <= 9)) {
# Only the generators are interesting for us:
# perm_group_generators are calculated only for up to perm_size = 9
# We precalculated perm_group_generators only for up to perm_size = 9
# See ISSUE#21 for more information
OEIS_A051625[perm_size]
} else {
Expand Down Expand Up @@ -799,7 +806,8 @@ brute_force_optimizer <- function(
"did_converge" = TRUE,
"best_perm_log_posteriori" = log_posteriori_values[which.max(log_posteriori_values)],
"optimization_time" = NA,
"whole_optimization_time" = NA
"whole_optimization_time" = NA,
"all_n0" = NA
)


Expand Down Expand Up @@ -871,7 +879,8 @@ combine_gips <- function(g1, g2, show_progress_bar = FALSE) {
"did_converge" = optimization_info2[["did_converge"]],
"best_perm_log_posteriori" = max(optimization_info1[["best_perm_log_posteriori"]], optimization_info2[["best_perm_log_posteriori"]]),
"optimization_time" = c(optimization_info1[["optimization_time"]], optimization_info2[["optimization_time"]]),
"whole_optimization_time" = optimization_info1[["whole_optimization_time"]] + optimization_info2[["whole_optimization_time"]]
"whole_optimization_time" = optimization_info1[["whole_optimization_time"]] + optimization_info2[["whole_optimization_time"]],
"all_n0" = c(optimization_info1[["all_n0"]], optimization_info2[["all_n0"]])
)

if (optimization_info1[["best_perm_log_posteriori"]] > optimization_info2[["best_perm_log_posteriori"]]) {
Expand Down
2 changes: 1 addition & 1 deletion R/gips_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ validate_gips <- function(g) {
}

if (is.list(optimization_info)) { # Validate the `optimization_info` after the optimization
legal_fields <- c("original_perm", "acceptance_rate", "log_posteriori_values", "visited_perms", "start_perm", "last_perm", "last_perm_log_posteriori", "iterations_performed", "optimization_algorithm_used", "post_probabilities", "did_converge", "best_perm_log_posteriori", "optimization_time", "whole_optimization_time")
legal_fields <- c("original_perm", "acceptance_rate", "log_posteriori_values", "visited_perms", "start_perm", "last_perm", "last_perm_log_posteriori", "iterations_performed", "optimization_algorithm_used", "post_probabilities", "did_converge", "best_perm_log_posteriori", "optimization_time", "whole_optimization_time", "all_n0")

lacking_fields <- setdiff(legal_fields, names(optimization_info))
illegal_fields <- setdiff(names(optimization_info), legal_fields)
Expand Down
11 changes: 6 additions & 5 deletions tests/testthat/test-gips_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -214,24 +214,24 @@ test_that("Properly validate the gips class with no optimization or after a sing
attr(g_err, "optimization_info")[["non_existing"]] <- "test"
expect_error(
validate_gips(g_err),
"You have a list of 15 elements."
"You have a list of 16 elements."
)

g_err <- g2
attr(g_err, "optimization_info")[["acceptance_rate"]] <- NULL
expect_error(
validate_gips(g_err),
"You have a list of 13 elements."
"You have a list of 14 elements."
)

g_err <- g2
attr(g_err, "optimization_info")[["non_existing"]] <- "test"
attr(g_err, "optimization_info")[["acceptance_rate"]] <- NULL
expect_error(
validate_gips(g_err),
"You have a list of 14 elements."
"You have a list of 15 elements."
)
# this one showed an error that one have the list of 13 elements, which is actually expected, but the names of the fields are not expected.
# this one showed an error that one have the list of proper number of elements, which is actually expected, but the names of the fields are not expected.

g_err <- g2
attr(g_err, "optimization_info")[["acceptance_rate"]] <- -0.1
Expand Down Expand Up @@ -1202,7 +1202,8 @@ test_that("summary.gips() works", {
optimization_algorithm_used = "Metropolis_Hastings", post_probabilities = NULL,
did_converge = NULL, best_perm_log_posteriori = -16.0120977148862,
optimization_time = structure(0.00564193725585938, class = "difftime", units = "secs"),
whole_optimization_time = structure(0.00564193725585938, class = "difftime", units = "secs")
whole_optimization_time = structure(0.00564193725585938, class = "difftime", units = "secs"),
all_n0 = c(2, 3, 2)
), class = "gips")

expect_equal(
Expand Down

0 comments on commit 0478686

Please sign in to comment.