Skip to content

Commit

Permalink
Improve speed of covariate balance computation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Admin_mschuemi authored and Admin_mschuemi committed Nov 6, 2023
1 parent 105e3bc commit 19ae189
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Changes:
- Added the `generalizability_max_sdm` and `generalizabiltiy_diagnostic` fields to the `cm_diagnostics_summary` table.
- Added the `mean_before`, `mean_after`, `target_std_diff`, `comparator_std_diff`, and `target_comparator_std_diff` fields to both the `cm_covariate_balance` and `cm_shared_covariate_balance` tables.

7. Improve speed of covariate balance computation.


Bugfixes:

Expand Down
12 changes: 7 additions & 5 deletions R/Balance.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ computeMeansPerGroup <- function(cohorts, cohortMethodData, covariateFilter) {
stratumSize <- cohorts %>%
group_by(.data$stratumId, .data$treatment) %>%
count() %>%
ungroup()
ungroup() %>%
collect()
}

useWeighting <- (hasStrata && any(stratumSize %>% pull(.data$n) > 1)) ||
Expand All @@ -71,19 +72,20 @@ computeMeansPerGroup <- function(cohorts, cohortMethodData, covariateFilter) {
# Variable strata sizes detected: weigh by size of strata set
w <- stratumSize %>%
mutate(weight = 1 / .data$n) %>%
inner_join(cohorts, by = c("stratumId", "treatment")) %>%
inner_join(cohorts, by = c("stratumId", "treatment"), copy = TRUE) %>%
select("rowId", "treatment", "weight")
# Overall weight is for computing mean and SD across T and C
overallW <- stratumSize %>%
group_by(.data$stratumId) %>%
summarise(weight = 1 / sum(.data$n, na.rm = TRUE)) %>%
ungroup() %>%
inner_join(cohorts, by = c("stratumId")) %>%
inner_join(cohorts, by = c("stratumId"), copy = TRUE) %>%
select("rowId", "weight")
} else {
w <- cohorts %>%
mutate(weight = .data$iptw) %>%
select("rowId", "treatment", "weight")
select("rowId", "treatment", "weight") %>%
collect()
overallW <- w
}
# Normalize so sum(weight) == 1 per treatment arm:
Expand Down Expand Up @@ -338,7 +340,7 @@ computeCovariateBalance <- function(population,
on.exit(cohortMethodData$tempCohortsAfterMatching <- NULL, add = TRUE)

beforeMatching <- computeMeansPerGroup(cohortMethodData$tempCohorts, cohortMethodData, covariateFilter)
afterMatching <- computeMeansPerGroup(cohortMethodData$tempCohortsAfterMatching, cohortMethodData, covariateFilter)
afterMatching <- computeMeansPerGroup(cohorts = cohortMethodData$tempCohortsAfterMatching, cohortMethodData, covariateFilter)

beforeMatching <- beforeMatching %>%
select("covariateId",
Expand Down

0 comments on commit 19ae189

Please sign in to comment.