forked from NSAPH/GPSmatching
/
estimate_gps.R
executable file
·200 lines (172 loc) · 7.08 KB
/
estimate_gps.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
#' @title
#' Estimate generalized propensity score (GPS) values
#'
#' @description
#' Estimates GPS value for each observation using normal or kernel
#' approaches.
#'
#'
#' @param w A data frame of observed continuous exposure variable. Including
#' `id` and `w` columns.
#' @param c A data frame of observed covariates variable. Also includes `id`
#' column.
#' @param gps_density Model type which is used for estimating GPS value,
#' including `normal` (default) and `kernel`.
#' @param params Includes list of parameters that are used internally. Unrelated
#' parameters will be ignored.
#' @param sl_lib A vector of prediction algorithms.
#' @param nthread An integer value that represents the number threads to be used
#' in a shared memory system.
#' @param ... Additional arguments passed to the model.
#'
#' @return
#' The function returns a S3 object. Including the following:
#' - `dataset `: `id`, `w`, `gps`
#' - e_gps_pred
#' - e_gps_std_pred
#' - w_resid
#' - gps_mx (min and max of gps)
#' - w_mx (min and max of w).
#' - used_params
#'
#' @note
#' If \code{internal.use} is set to be FALSE, only original data set + GPS will
#' be returned.
#'
#' The outcome variable is not used in estimating the GPS value. However, it is
#' used in compiling the data set with GPS values.
#'
#'
#' @export
#'
#' @examples
#' \donttest{
#' m_d <- generate_syn_data(sample_size = 100)
#' data_with_gps <- estimate_gps(m_d[, c("id", "w")],
#' m_d[, c("id", "cf1", "cf2", "cf3",
#' "cf4", "cf5", "cf6")],
#' gps_density = "normal",
#' params = list(xgb_max_depth = c(3,4,5),
#' xgb_nrounds=c(10,20,30,40,50,60)),
#' nthread = 1,
#' sl_lib = c("m_xgboost")
#' )
#'}
estimate_gps <- function(w,
c,
gps_density = "normal",
params = list(),
sl_lib = c("m_xgboost"),
nthread = 1,
...) {
start_time <- proc.time()
# Check passed arguments -----------------------------------------------------
check_args_estimate_gps(gps_density, ...)
id_exist_w <- any(colnames(w) %in% "id")
if (!id_exist_w) stop("w should include id column.")
id_exist_c <- any(colnames(c) %in% "id")
if (!id_exist_c) stop("c should include id column.")
dot_args <- list(...)
arg_names <- names(dot_args)
for (i in arg_names){
assign(i, unlist(dot_args[i], use.names = FALSE))
}
# Check if data has missing value(s) -----------------------------------------
if (sum(is.na(w)) > 0){
logger::log_warn("Vector w has {sum(is.na(w))} missing values.")
}
if (sum(is.na(c)) > 0){
logger::log_warn(
"Confounders data.frame (c) has {sum(is.na(c))} missing values.")
}
# Generate SL wrapper library for each type of prediction algorithms ---------
sl_lib_internal = NULL
used_params <- list()
for (item in sl_lib){
wrapper_generated_res <- gen_wrap_sl_lib(lib_name = item, params,
nthread = nthread)
if (wrapper_generated_res[[1]]){
sl_lib_internal <- c(sl_lib_internal, paste(item, "_internal", sep=""))
used_params <- c(used_params, wrapper_generated_res[[2]])
} else {
sl_lib_internal <- c(sl_lib_internal, item)
}
}
merged_data <- merge(w, c, by = "id")
if (nrow(merged_data) == 0){
stop(paste0("Merged data length is 0.",
" Make sure that w and c belong to the same observations, ",
" or partially include same observations."))
}
exposure_col <- Filter(function(x) !(x %in% c("id")), colnames(w))
covariate_cols <- Filter(function(x) !(x %in% c("id")), colnames(c))
if (gps_density == "normal"){
e_gps <- train_it(target = merged_data[,c(exposure_col)],
input = merged_data[, covariate_cols],
sl_lib_internal = sl_lib_internal,
...)
e_gps_pred <- e_gps$SL.predict
e_gps_std_pred <- stats::sd(merged_data[,c(exposure_col)] - e_gps_pred)
w_resid <- compute_resid(merged_data[,c(exposure_col)],
e_gps_pred,
e_gps_std_pred)
gps <- stats::dnorm(merged_data[,c(exposure_col)],
mean = e_gps_pred,
sd = e_gps_std_pred)
} else if (gps_density == "kernel"){
e_gps <- train_it(target = merged_data[,c(exposure_col)],
input = merged_data[, covariate_cols],
sl_lib_internal = sl_lib_internal, ...)
e_gps_pred <- e_gps$SL.predict
e_gps_std <- train_it(target = abs(merged_data[,c(exposure_col)] - e_gps_pred),
input = merged_data[, covariate_cols],
sl_lib_internal = sl_lib_internal, ...)
e_gps_std_pred <- e_gps_std$SL.predict
w_resid <- compute_resid(merged_data[,c(exposure_col)],
e_gps_pred,e_gps_std_pred)
gps <- compute_density(w_resid, w_resid)
} else {
logger::log_error("Code should nevet get here. Doublecheck check_arguments.")
stop(paste("Invalide gps_density: ", gps_density,
". Use normal or kernel."))
}
w_mx <- compute_min_max(merged_data[,c(exposure_col)])
gps_mx <- compute_min_max(gps)
merged_data$gps <- gps
# Drop covariates
merged_data[covariate_cols] <- NULL
dataset <- merged_data
dataset$e_gps_pred <- e_gps_pred
if (length(e_gps_std_pred) == 1){
e_gps_std_pred <- rep(e_gps_std_pred, nrow(dataset))
}
dataset$e_gps_std_pred <- e_gps_std_pred
dataset$w_resid <- w_resid
# Logging for debugging purposes
logger::log_debug("Min Max of treatment: {paste(w_mx, collapse = ', ')}")
logger::log_debug("Min Max of gps: {paste(gps_mx, collapse = ', ')}")
logger::log_debug("Weights for the select libraries in predicting e_gps:",
" {paste(names(e_gps$coef), collapse = ', ')}",
" {paste(e_gps$coef, collapse = ', ')}",
" | Overal Risk: {sum(e_gps$coef * e_gps$cvRisk)/length(e_gps$coef)}")
logger::log_debug("Wall clock time to estimate e_gps:",
" {e_gps$times$everything[3]} seconds.")
if (gps_density == "kernel"){
logger::log_debug("Weights for the select libraries in predicting residuals:",
" {paste(names(e_gps_std$coef), collapse = ', ')}",
" {paste(e_gps_std$coef, collapse = ', ')} | Overal risk:",
" {sum(e_gps_std$coef * e_gps_std$cvRisk)/length(e_gps_std$coef)}")
logger::log_debug("Wall clock time to estimate residuals:",
" {e_gps_std$times$everything[3]} seconds.")
}
end_time <- proc.time()
logger::log_debug("Wall clock time to run estimate_gps function: ",
" {(end_time - start_time)[[3]]} seconds.")
result <- list()
class(result) <- "cgps_gps"
result$dataset <- dataset
result$used_params <- used_params
result$gps_mx <- gps_mx
result$w_mx <- w_mx
invisible(result)
}