Skip to content

Commit

Permalink
improve computation time
Browse files Browse the repository at this point in the history
  • Loading branch information
Anthony Devaux authored and Anthony Devaux committed Jan 5, 2024
1 parent 58b8dbf commit 4aabf18
Show file tree
Hide file tree
Showing 10 changed files with 255 additions and 947 deletions.
431 changes: 184 additions & 247 deletions R/DynTree.R

Large diffs are not rendered by default.

32 changes: 20 additions & 12 deletions R/OOB_rfshape.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,23 @@ OOB.rfshape <- function(rf, Longitudinal = NULL, Numeric = NULL, Factor = NULL,
oob <- setdiff(Y$id,BOOT)
if (is.element(indiv, oob)== TRUE){

pred.node <- tryCatch(pred.MMT(rf$rf[,t], Longitudinal = Longitudinal_courant,
pred_node <- tryCatch(pred.MMT(rf$rf[,t], Longitudinal = Longitudinal_courant,
Numeric = Numeric_courant, Factor = Factor_courant,
timeVar = timeVar),
error = function(e) return(NA))

if (is.na(pred.node)){
pred_node_chr <- as.character(pred_node)

if (is.na(pred_node_chr)){
pred.mat[t,] <- NA
}else{
if (IBS.min == 0){
pi_t <- rf$rf[,t]$Y_pred[[pred.node]][[as.character(cause)]]$traj[which(allTimes%in%allTimes_IBS)]
pi_t <- rf$rf[,t]$Y_pred[[pred_node_chr]][[as.character(cause)]]$traj[which(allTimes%in%allTimes_IBS)]
pred.mat[t,] <- pi_t
}else{
pi_t <- rf$rf[,t]$Y_pred[[pred.node]][[as.character(cause)]]$traj[which(allTimes%in%allTimes_IBS)] # pi(t)
pi_s <- rf$rf[,t]$Y_pred[[pred.node]][[as.character(cause)]]$traj[sum(allTimes<IBS.min)] # pi(s)
s_s <- 1 - sum(unlist(lapply(rf$rf[,t]$Y_pred[[pred.node]], FUN = function(x){
pi_t <- rf$rf[,t]$Y_pred[[pred_node_chr]][[as.character(cause)]]$traj[which(allTimes%in%allTimes_IBS)] # pi(t)
pi_s <- rf$rf[,t]$Y_pred[[pred_node_chr]][[as.character(cause)]]$traj[sum(allTimes<IBS.min)] # pi(s)
s_s <- 1 - sum(unlist(lapply(rf$rf[,t]$Y_pred[[pred_node_chr]], FUN = function(x){
return(x$traj[sum(allTimes<IBS.min)])
}))) # s(s)
pred.mat[t,] <- (pi_t - pi_s)/s_s # P(S<T<S+t|T>S)
Expand Down Expand Up @@ -201,12 +203,15 @@ OOB.rfshape <- function(rf, Longitudinal = NULL, Numeric = NULL, Factor = NULL,
Factor_courant <- list(type="Factor", X=Factor$X[w_XFactor,, drop=FALSE], id=Factor$id[w_XFactor])
}

pred.node <- tryCatch(pred.MMT(rf$rf[,t], Longitudinal = Longitudinal_courant,
pred_node <- tryCatch(pred.MMT(rf$rf[,t], Longitudinal = Longitudinal_courant,
Numeric = Numeric_courant, Factor = Factor_courant,
timeVar = timeVar),
error = function(e) return(NA))
pred_courant[t] <- ifelse(!is.null(rf$rf[,t]$Y_pred[[pred.node]]),
rf$rf[,t]$Y_pred[[pred.node]], NA)

pred_node_chr <- as.character(pred_node)

pred_courant[t] <- ifelse(!is.null(rf$rf[,t]$Y_pred[[pred_node_chr]]),
rf$rf[,t]$Y_pred[[pred_node_chr]], NA)
}
}
if (all(is.na(pred_courant))){
Expand Down Expand Up @@ -268,12 +273,15 @@ OOB.rfshape <- function(rf, Longitudinal = NULL, Numeric = NULL, Factor = NULL,
Factor_courant <- list(type="Factor", X=Factor$X[w_XFactor,, drop=FALSE], id=Factor$id[w_XFactor])
}

pred.node <- tryCatch(pred.MMT(rf$rf[,t], Longitudinal = Longitudinal_courant,
pred_node <- tryCatch(pred.MMT(rf$rf[,t], Longitudinal = Longitudinal_courant,
Numeric = Numeric_courant, Factor = Factor_courant,
timeVar = timeVar),
error = function(e) return(NA))
pred_courant[t] <- ifelse(!is.null(rf$rf[,t]$Y_pred[[pred.node]]),
rf$rf[,t]$Y_pred[[pred.node]], NA)

pred_node_chr <- as.character(pred_node)

pred_courant[t] <- ifelse(!is.null(rf$rf[,t]$Y_pred[[pred_node_chr]]),
rf$rf[,t]$Y_pred[[pred_node_chr]], NA)

}
}
Expand Down
16 changes: 10 additions & 6 deletions R/OOB_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ OOB.tree <- function(tree, Longitudinal = NULL, Numeric = NULL, Factor = NULL, Y
timeVar = timeVar),
error = function(e) return(NA)) # handle permutation issue

if (is.na(pred_current)){
pred_current_chr <- as.character(pred_current)

if (is.na(pred_current_chr)){

xerror[which(OOB_IBS==i)] <- NA

Expand All @@ -101,11 +103,11 @@ OOB.tree <- function(tree, Longitudinal = NULL, Numeric = NULL, Factor = NULL, Y

# CIF
if (IBS.min == 0){
pi_t <- tree$Y_pred[[pred_current]][[as.character(cause)]]$traj[which(allTimes%in%allTimes_IBS)]
pi_t <- tree$Y_pred[[pred_current_chr]][[as.character(cause)]]$traj[which(allTimes%in%allTimes_IBS)]
}else{
pi_t <- tree$Y_pred[[pred_current]][[as.character(cause)]]$traj[which(allTimes%in%allTimes_IBS)] # pi(t)
pi_s <- tree$Y_pred[[pred_current]][[as.character(cause)]]$traj[sum(allTimes<IBS.min)] # pi(s)
s_s <- 1 - sum(unlist(lapply(tree$Y_pred[[pred_current]], FUN = function(x){
pi_t <- tree$Y_pred[[pred_current_chr]][[as.character(cause)]]$traj[which(allTimes%in%allTimes_IBS)] # pi(t)
pi_s <- tree$Y_pred[[pred_current_chr]][[as.character(cause)]]$traj[sum(allTimes<IBS.min)] # pi(s)
s_s <- 1 - sum(unlist(lapply(tree$Y_pred[[pred_current_chr]], FUN = function(x){
return(x$traj[sum(allTimes<IBS.min)])
}))) # s(s)
pi_t <- (pi_t - pi_s)/s_s # P(S<T<S+t|T>S)
Expand Down Expand Up @@ -151,7 +153,9 @@ OOB.tree <- function(tree, Longitudinal = NULL, Numeric = NULL, Factor = NULL, Y
pred_current <- pred.MMT(tree, Longitudinal = Longitudinal_current, Numeric = Numeric_current,
Factor = Factor_current, timeVar = timeVar)

pred <- unlist(sapply(pred_current, FUN = function(x) {
pred_current_chr <- as.character(pred_current)

pred <- unlist(sapply(pred_current_chr, FUN = function(x) {
ifelse(!is.null(tree$Y_pred[[x]]), tree$Y_pred[[x]], NA)
}))

Expand Down
14 changes: 8 additions & 6 deletions R/pred_MMT.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pred.MMT <- function(tree, Longitudinal=NULL, Numeric=NULL, Factor=NULL,

pred <- rep(NA,length(id.pred))

for (i in 1:length(id.pred)){
pred <- sapply(seq_along(id.pred), FUN = function(i){

if (is.element("Longitudinal",Inputs)==TRUE) wLongitudinal <- which(Longitudinal$id==id.pred[i])
if (is.element("Numeric",Inputs)==TRUE) wNumeric <- which(Numeric$id==id.pred[i])
Expand All @@ -34,8 +34,8 @@ pred.MMT <- function(tree, Longitudinal=NULL, Numeric=NULL, Factor=NULL,
var.split.sum <- as.numeric(as.character(tree$V_split[which(tree$V_split[,2]==current_node),4]))
threshold <- as.numeric(as.character(tree$V_split[which(tree$V_split[,2]==current_node),5]))

meanG <- tree$hist_nodes[[2*current_node]]
meanD <- tree$hist_nodes[[2*current_node+1]]
meanG <- tree$hist_nodes[[as.character(2*current_node)]]
meanD <- tree$hist_nodes[[as.character(2*current_node+1)]]

if (type=="longitudinal"){

Expand All @@ -48,7 +48,7 @@ pred.MMT <- function(tree, Longitudinal=NULL, Numeric=NULL, Factor=NULL,
colnames(data_model)[which(colnames(data_model)=="time")] <- timeVar
data_model <- data_model[,c("id",model_var)]

RE <- predRE(tree$model_param[[current_node]][[1]],
RE <- predRE(tree$model_param[[as.character(current_node)]][[1]],
X$model[[var.split]], data_model)$bi

######################
Expand Down Expand Up @@ -96,8 +96,10 @@ pred.MMT <- function(tree, Longitudinal=NULL, Numeric=NULL, Factor=NULL,

}

pred[i] <- current_node
return(current_node)

})

}
return(pred)

}
4 changes: 2 additions & 2 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ predict.DynForest <- function(object,

i.leaf <- pred_leaf[t,][indiv]

pred_leaf_indiv <- object$rf[,t]$Y_pred[[i.leaf]][[cause]]$traj[id.predTimes]
pred_leaf_indiv <- object$rf[,t]$Y_pred[[as.character(i.leaf)]][[cause]]$traj[id.predTimes]

if (!is.null(pred_leaf_indiv)){
pred[[cause]][[indiv]][t,] <- pred_leaf_indiv
Expand All @@ -261,7 +261,7 @@ predict.DynForest <- function(object,

i.leaf <- pred_leaf[t,indiv]

pred_leaf_indiv <- object$rf[,t]$Y_pred[[i.leaf]]
pred_leaf_indiv <- object$rf[,t]$Y_pred[[as.character(i.leaf)]]

if (!is.null(pred_leaf_indiv)){
pred[t,indiv] <- pred_leaf_indiv
Expand Down
Loading

0 comments on commit 4aabf18

Please sign in to comment.