Skip to content

Commit

Permalink
speed-up survival
Browse files Browse the repository at this point in the history
  • Loading branch information
anthonydevaux committed Jan 4, 2024
1 parent 6e8fb86 commit 58b8dbf
Show file tree
Hide file tree
Showing 9 changed files with 453 additions and 39 deletions.
55 changes: 26 additions & 29 deletions R/DynTree_surv.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ DynTree_surv <- function(Y, Longitudinal = NULL, Numeric = NULL, Factor = NULL,

for (current_node in current_nodes){

current_node_chr <- as.character(current_node)

# mtry predictors
set.seed(seed+p*current_node)
set.seed(seed+p*which(current_node==current_nodes))
mtry_pred <- sample(type_pred, mtry)
mtry_type_pred <- unique(mtry_pred)

Expand Down Expand Up @@ -118,7 +120,7 @@ DynTree_surv <- function(Y, Longitudinal = NULL, Numeric = NULL, Factor = NULL,
model_init <- getParamMM(current_node = current_node, markers = colnames(Longitudinal_current$X),
params = model_init)
}else{
model_init[[current_node]] <- lapply(Longitudinal$model, FUN = function(x) x$init.param)
model_init[[current_node_chr]] <- lapply(Longitudinal$model, FUN = function(x) x$init.param)
}

}
Expand All @@ -138,9 +140,8 @@ DynTree_surv <- function(Y, Longitudinal = NULL, Numeric = NULL, Factor = NULL,
# Try best split on mtry predictors
if (is.element("Factor", mtry_type_pred)){

leaf_split_Factor <- var_split_surv(X = Factor_current, Y = Y_current,
timeVar = timeVar,
cause = cause, nodesize = nodesize)
leaf_split_Factor <- var_split_factor(X = Factor_current, Y = Y_current,
cause = cause, nodesize = nodesize)

if (leaf_split_Factor$Pure==FALSE){
F_SPLIT <- rbind(F_SPLIT,
Expand All @@ -151,28 +152,27 @@ DynTree_surv <- function(Y, Longitudinal = NULL, Numeric = NULL, Factor = NULL,

if (is.element("Longitudinal", mtry_type_pred)){

leaf_split_Longitudinal <- var_split_surv(X = Longitudinal_current, Y = Y_current,
leaf_split_Longitudinal <- var_split_long(X = Longitudinal_current, Y = Y_current,
timeVar = timeVar,
nsplit_option = nsplit_option,
cause = cause, nodesize = nodesize,
init = model_init[[current_node]])
init = model_init[[current_node_chr]])

if (leaf_split_Longitudinal$Pure==FALSE){
model_init[[current_node]] <- leaf_split_Longitudinal$init # update initial values at current node
model_init[[current_node_chr]] <- leaf_split_Longitudinal$init # update initial values at current node
F_SPLIT <- rbind(F_SPLIT,
data.frame(TYPE = "Longitudinal", Impurity = leaf_split_Longitudinal$impur,
stringsAsFactors = FALSE))

conv_issue[[current_node]] <- leaf_split_Longitudinal$conv_issue
conv_issue[[current_node_chr]] <- leaf_split_Longitudinal$conv_issue
}
}

if (is.element("Numeric", mtry_type_pred)){

leaf_split_Numeric <- var_split_surv(X = Numeric_current, Y = Y_current,
timeVar = timeVar,
nsplit_option = nsplit_option,
cause = cause, nodesize = nodesize)
leaf_split_Numeric <- var_split_num(X = Numeric_current, Y = Y_current,
nsplit_option = nsplit_option,
cause = cause, nodesize = nodesize)

if (leaf_split_Numeric$Pure==FALSE){
F_SPLIT <- rbind(F_SPLIT,
Expand Down Expand Up @@ -219,7 +219,7 @@ DynTree_surv <- function(Y, Longitudinal = NULL, Numeric = NULL, Factor = NULL,
threshold = leaf_split$threshold, N = N_current,
Nevent = Nevent_current, stringsAsFactors = FALSE))

model_param[[current_node]] <- leaf_split$model_param
model_param[[current_node_chr]] <- leaf_split$model_param

w_left <- which(X_boot$id%in%left_id)
wY_left <- which(Y_boot$id%in%left_id)
Expand Down Expand Up @@ -257,8 +257,8 @@ DynTree_surv <- function(Y, Longitudinal = NULL, Numeric = NULL, Factor = NULL,
meanFd <- mean(X_boot$X[w_right, best_pred])
}

hist_nodes[[2*current_node]] <- meanFg
hist_nodes[[2*current_node+1]] <- meanFd
hist_nodes[[as.character(2*current_node)]] <- meanFg
hist_nodes[[as.character(2*current_node+1)]] <- meanFd

}else{

Expand Down Expand Up @@ -286,39 +286,36 @@ DynTree_surv <- function(Y, Longitudinal = NULL, Numeric = NULL, Factor = NULL,
rownames(V_split) <- seq(nrow(V_split))
}

for (q in unique(na.omit(id_leaves))){
for (q in sort(unique(na.omit(id_nodes)))){

w <- which(id_leaves == q)
w <- which(id_nodes == q)

datasurv <- data.frame(time_event = Y_boot$Y[w][,1], event = Y_boot$Y[w][,2])
fit <- prodlim(Hist(time_event, event)~1, data = datasurv,
type = "risk")

if (is.null(fit$cuminc)){
pred <- list()
current.cause <- as.character(unique(sort(datasurv$event))[-1])

if (length(current.cause)==0){
current.cause <- "1"
}

pred[[current.cause]] <- data.frame(times=fit$time, traj=1-fit$surv) # 1-KM

if (is.null(pred[[as.character(cause)]])){
if (all(unique(datasurv$event)==0)){ # case with no event
pred[[as.character(cause)]] <- data.frame(times=fit$time, traj = 0) # no event => no risk
}else{ # case with event no matter which one
u_current_causes <- unique(datasurv$event)
current_cause <- u_current_causes[u_current_causes!=0] # keep leaf cause
pred <- list(data.frame(times=fit$time, traj=1-fit$surv)) # 1-KM
names(pred) <- as.character(current_cause)
}

}else{
pred <- lapply(fit$cuminc, FUN = function(x) return(data.frame(times=fit$time, traj=x))) # CIF Aalen-Johansen
}

Y_pred[[q]] <- lapply(pred, function(x){
Y_pred[[as.character(q)]] <- lapply(pred, function(x){
combine_times(pred = x, newtimes = unique(Y$Y[,1]), type = "risk")
})

}

return(list(leaves = id_leaves, idY = Y_boot$id, Ytype = Y_boot$type, V_split = V_split,
return(list(leaves = id_nodes, idY = Y_boot$id, Ytype = Y_boot$type, V_split = V_split,
hist_nodes = hist_nodes, Y_pred = Y_pred, Y = Y, boot = id_boot, conv_issue = conv_issue,
model_param = model_param))

Expand Down
13 changes: 7 additions & 6 deletions R/getParamMM.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@
#' @keywords internal
getParamMM <- function(current_node, markers, params){

starting_node <- current_node
params[[starting_node]] <- rep(list(NA), length(markers))
names(params[[starting_node]]) <- markers
starting_node_chr <- as.character(current_node)
params[[starting_node_chr]] <- rep(list(NA), length(markers))
names(params[[starting_node_chr]]) <- markers

while (current_node > 1 & length(markers) > 0) { # get initial values from upper nodes
current_node <- current_node %/% 2
current_node_chr <- as.character(current_node)

current_node_marker_indices <- match(names(params[[current_node]]), markers, nomatch = 0)
current_node_marker <- names(params[[current_node]])[current_node_marker_indices > 0]
current_node_marker_indices <- match(names(params[[current_node_chr]]), markers, nomatch = 0)
current_node_marker <- names(params[[current_node_chr]])[current_node_marker_indices > 0]

if (length(current_node_marker) > 0) {
params[[starting_node]][current_node_marker] <- params[[current_node]][current_node_marker]
params[[starting_node_chr]][current_node_marker] <- params[[current_node_chr]][current_node_marker]
markers <- setdiff(markers, current_node_marker)
}
}
Expand Down
8 changes: 4 additions & 4 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ plot.DynForest <- function(x, tree = NULL, nodes = NULL, id = NULL, max_tree = N
if (!all(inherits(nodes, "numeric"))){
stop("'nodes' should be a numeric object containing the tree identifier!")
}
if (any(nodes>length(x$rf[,tree]$Y_pred))){
if (!all(nodes%in%names(x$rf[,tree]$Y_pred))){
stop("One selected node do not have CIF! Please verify the 'nodes' identifiers!")
}
if (any(sapply(nodes, FUN = function(node) is.null(x$rf[,tree]$Y_pred[[node]])))){
if (any(sapply(nodes, FUN = function(node) is.null(x$rf[,tree]$Y_pred[[as.character(node)]])))){
stop("One selected node do not have CIF! Please verify the 'nodes' identifiers!")
}
}else{
Expand All @@ -151,7 +151,7 @@ plot.DynForest <- function(x, tree = NULL, nodes = NULL, id = NULL, max_tree = N
# data transformation for ggplot2
CIFs_nodes_list <- lapply(nodes, FUN = function(node){

CIFs_node <- x$rf[,tree]$Y_pred[[node]]
CIFs_node <- x$rf[,tree]$Y_pred[[as.character(node)]]

CIFs_node_list <- lapply(names(CIFs_node), FUN = function(y){

Expand Down Expand Up @@ -198,7 +198,7 @@ plot.DynForest <- function(x, tree = NULL, nodes = NULL, id = NULL, max_tree = N
next()
}

CIFs_node <- x$rf[,tree_id]$Y_pred[[tree_node]]
CIFs_node <- x$rf[,tree_id]$Y_pred[[as.character(tree_node)]]

CIFs_node_list <- lapply(names(CIFs_node), FUN = function(y){

Expand Down
58 changes: 58 additions & 0 deletions R/var_split_factor.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#' Split function to build the two daughter nodes from factor predictor
#'
#' @param X Input data
#' @param Y Outcome data
#' @param cause (Only with competing events) Number indicates the event of interest.
#' @param nodesize Minimal number of subjects required in both child nodes to split. Cannot be smaller than 1.
#'
#' @keywords internal
var_split_factor <- function(X, Y, cause = 1, nodesize = 1){

X_ncol <- ncol(X$X)
split_var <- vector("list", X_ncol)
impur_var <- rep(Inf, X_ncol)
Pure <- FALSE

for (i in 1:X_ncol){

if (length(unique(X$X[,i]))>1){

L <- Fact.partitions(X$X[,i],X$id)

# Find best partition
split_list <- lapply(seq_along(L), FUN = function(x){

split <- rep(2,length(X$id))
split[which(X$id%in%L[[x]])] <- 1

if ((length(unique(split))>1)&(all(table(split)>=nodesize))){
# Evaluate the partition
impur <- impurity_split(Y, split, cause = cause)$impur
}else{
impur <- Inf
}

return(list(split = split, impur = impur))

})

partition_impur <- unlist(lapply(split_list, function(x) return(x$impur)))

if (any(partition_impur!=Inf)){
best_part <- which.min(partition_impur)
split_var[[i]] <- split_list[[best_part]]$split
impur_var[i] <- split_list[[best_part]]$impur
}
}
}

if (all(impur_var==Inf)){
return(list(Pure=TRUE))
}

var_split <- which.min(impur_var)

return(list(split = split_var[[var_split]], impur = min(impur_var),
variable = var_split, variable_summary = NA, threshold = NA,
Pure = Pure))
}
Loading

0 comments on commit 58b8dbf

Please sign in to comment.