Skip to content

Commit

Permalink
Speeding up balance computation. Fixing issues caused by new dplyr
Browse files Browse the repository at this point in the history
  • Loading branch information
schuemie committed Nov 6, 2023
1 parent e174be3 commit cf8ade5
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
5 changes: 3 additions & 2 deletions R/Balance.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ computeMeansPerGroup <- function(cohorts, cohortMethodData, covariateFilter) {
stratumSize <- cohorts %>%
group_by(.data$stratumId, .data$treatment) %>%
count() %>%
ungroup()
ungroup() %>%
collect()
}

useWeighting <- (hasStrata && any(stratumSize %>% pull(.data$n) > 1)) ||
Expand All @@ -68,7 +69,7 @@ computeMeansPerGroup <- function(cohorts, cohortMethodData, covariateFilter) {
# Variable strata sizes detected: weigh by size of strata set
w <- stratumSize %>%
mutate(weight = 1 / .data$n) %>%
inner_join(cohorts, by = c("stratumId", "treatment")) %>%
inner_join(cohorts, by = c("stratumId", "treatment"), copy = TRUE) %>%
select("rowId", "treatment", "weight")
} else {
w <- cohorts %>%
Expand Down
4 changes: 2 additions & 2 deletions R/StudyPopulation.R
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,14 @@ createStudyPopulation <- function(cohortMethodData,
.data$daysToEvent > -priorOutcomeLookback &
outcomes$daysToEvent < outcomes$daysToCohortEnd + riskWindowStart
) %>%
select("rowId")
pull("rowId")
} else {
priorOutcomeRowIds <- outcomes %>%
filter(
.data$daysToEvent > -priorOutcomeLookback &
.data$daysToEvent < riskWindowStart
) %>%
select("rowId")
pull("rowId")
}
population <- population %>%
filter(!(.data$rowId %in% priorOutcomeRowIds))
Expand Down
5 changes: 4 additions & 1 deletion tests/testthat/test-parameterSweep.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ test_that("Create study population functions", {
minDaysAtRisk = 1
)
expect_true(all(studyPop$timeAtRisk > 0))
peopleWithPriorOutcomes <- cohortMethodData$outcomes$rowId[cohortMethodData$outcomes$daysToEvent < 0]
peopleWithPriorOutcomes <- cohortMethodData$outcomes %>%
filter(outcomeId == 194133 & daysToEvent < 0) %>%
distinct(rowId) %>%
pull()
expect_false(any(peopleWithPriorOutcomes %in% studyPop$rowId))

aTable <- getAttritionTable(studyPop)
Expand Down

0 comments on commit cf8ade5

Please sign in to comment.