In [None]:
library("aldvmm")

In [None]:
# you need to include these variables in your dataset: 
# utility, age, gender, pcs, mcs, kdcs, effect, burden, symptom, 
# as well as their square_effect and interation effect.
data <- read.csv("your_path.csv")

data$kdcs <- (data$effect + data$burden + data$symptom) / 3

# 预处理平方项，确保数据完整
data$pcs_sq <- data$pcs^2
data$mcs_sq <- data$mcs^2
data$kdcs_sq <- data$kdcs^2
data$effect_sq <- data$effect^2
data$burden_sq <- data$burden^2
data$symptom_sq <- data$symptom^2

# 交互项
data$pcs_kdcs <- data$pcs * data$kdcs
data$mcs_kdcs <- data$mcs * data$kdcs
data$pcs_mcs <- data$pcs * data$mcs
data$pcs_effect <- data$pcs * data$effect
data$pcs_burden <- data$pcs * data$burden
data$pcs_symptom <- data$pcs * data$symptom
data$mcs_effect <- data$mcs * data$effect
data$mcs_burden <- data$mcs * data$burden
data$mcs_symptom <- data$mcs * data$symptom
data$effect_burden <- data$effect * data$burden
data$effect_symptom <- data$effect * data$symptom
data$burden_symptom <- data$burden * data$symptom


head(data)

In [63]:
# mapping KDQoL-36 to any generic QoL instrument using ALDVMM or OLS
k_fold_cross_aldvmm <- function(data, ncp = 2, k = 10, thre = 3, psiv = c(-0.391, 0.9555), method = 'aldvmm') {
  
  n <- nrow(data)
  fold_size <- floor(n / k)
  data <- data[sample(n), ]

  if (method=='aldvmm'){
    model_specs <- list(
    "5_main" = utility ~ age + gender + pcs + mcs + kdcs | 1,
    "7_main" = utility ~ age + gender + pcs + mcs + effect + burden + symptom | 1,
    "5_sq"   = utility ~ age + gender + pcs + mcs + kdcs + pcs_sq + mcs_sq + kdcs_sq | 1,
    "7_sq"   = utility ~ age + gender + pcs + mcs + effect + burden + symptom + pcs_sq + mcs_sq + effect_sq + burden_sq + symptom_sq | 1,
    "5_inter"= utility ~ age + gender + pcs + mcs + kdcs + pcs_sq + mcs_sq + kdcs_sq + pcs_kdcs + mcs_kdcs + pcs_mcs | 1,
    "7_inter"= utility ~ pcs + mcs + effect + burden + symptom + age + gender + pcs_sq + mcs_sq + effect_sq + burden_sq + symptom_sq +
                 pcs_mcs + pcs_effect + pcs_burden + pcs_symptom + mcs_effect + mcs_burden + mcs_symptom + effect_burden + effect_symptom + burden_symptom | 1
  )
  }
  else if (method=='ols'){
    model_specs <- list(
    "5_main" = utility ~ age + gender + pcs + mcs + kdcs,
    "7_main" = utility ~ age + gender + pcs + mcs + effect + burden + symptom,
    "5_sq"   = utility ~ age + gender + pcs + mcs + kdcs + pcs_sq + mcs_sq + kdcs_sq,
    "7_sq"   = utility ~ age + gender + pcs + mcs + effect + burden + symptom + pcs_sq + mcs_sq + effect_sq + burden_sq + symptom_sq,
    "5_inter"= utility ~ age + gender + pcs + mcs + kdcs + pcs_sq + mcs_sq + kdcs_sq + pcs_kdcs + mcs_kdcs + pcs_mcs,
    "7_inter"= utility ~ pcs + mcs + effect + burden + symptom + age + gender + pcs_sq + mcs_sq + effect_sq + burden_sq + symptom_sq +
                 pcs_mcs + pcs_effect + pcs_burden + pcs_symptom + mcs_effect + mcs_burden + mcs_symptom + effect_burden + effect_symptom + burden_symptom
  )
  }

  

  results <- list()
  for (key in names(model_specs)) {
    results[[key]] <- list(me = c(), mae = c(), rmse = c())
  }

  for (i in 1:k) {
    start <- (i - 1) * fold_size + 1
    end <- min(i * fold_size, n)
    
    test_data <- data[start:end, ]
    train_data <- data[-c(start:end), ]
    l <- nrow(test_data)
    
    if (method == 'aldvmm'){
        for (key in names(model_specs)) {
        formula <- model_specs[[key]]
        model <- tryCatch(
            aldvmm(model_specs[[key]], data = train_data, psi = psiv, ncmp = ncp),
            error = function(e) return(NULL) 
        )
        if (!is.null(model)) {
            predicted <- predict(model, test_data)$yhat # ncp >= 2
            if (ncp == 1) {
                vars_temp <- all.vars(model_specs[[key]])[-1]
                data_temp <- test_data[, vars_temp]
                data_temp <- cbind(rep(1, l), data_temp)
                coe <- coef(model)
                coe <- coe[-length(coe)]
                predicted <- as.matrix(data_temp) %*% coe
            } # ncp == 1
            predicted[predicted > 1] <- 1 

            me <- mean(test_data$utility - predicted, na.rm = TRUE)
            mae <- mean(abs(test_data$utility - predicted), na.rm = TRUE)
            rmse <- sqrt(mean((test_data$utility - predicted)^2, na.rm = TRUE))

            results[[key]]$me <- c(results[[key]]$me, me)
            results[[key]]$mae <- c(results[[key]]$mae, mae)
            results[[key]]$rmse <- c(results[[key]]$rmse, rmse)
            }   
        }
    }
    else if (method == 'ols') {
        for (key in names(model_specs)) {
        model <- tryCatch(
            lm(model_specs[[key]], data = train_data),
            error = function(e) return(NULL) 
        )
        if (!is.null(model)) {
            predicted <- predict(model, test_data)
            predicted[predicted > 1] <- 1 

            me <- mean(test_data$utility - predicted, na.rm = TRUE)
            mae <- mean(abs(test_data$utility - predicted), na.rm = TRUE)
            rmse <- sqrt(mean((test_data$utility - predicted)^2, na.rm = TRUE))

            results[[key]]$me <- c(results[[key]]$me, me)
            results[[key]]$mae <- c(results[[key]]$mae, mae)
            results[[key]]$rmse <- c(results[[key]]$rmse, rmse)
            }   
        }
    }
  }

df_results <- do.call(rbind, lapply(names(results), function(name) {
  me_values <- results[[name]]$me
  mae_values <- results[[name]]$mae
  rmse_values <- results[[name]]$rmse

  if (length(me_values) < thre || length(mae_values) < thre || length(rmse_values) < thre) {
    return(data.frame(Model = name, ME = NA, MAE = NA, RMSE = NA))
  } else {
    return(data.frame(
      Model = name,
      ME = mean(me_values, na.rm = TRUE),
      MAE = mean(mae_values, na.rm = TRUE),
      RMSE = mean(rmse_values, na.rm = TRUE)
    ))
  }
}))

}

In [64]:
res <- k_fold_cross_aldvmm(data, method = 'ols')

In [None]:
# testing function with my local dataset
res

Model,ME,MAE,RMSE
<chr>,<dbl>,<dbl>,<dbl>
5_main,0.016262577,0.1809562,0.2450289
7_main,0.015564778,0.1776993,0.2425556
5_sq,0.007612875,0.163263,0.231156
7_sq,0.007617405,0.1616168,0.2297086
5_inter,0.00835206,0.1636894,0.2275157
7_inter,0.009451354,0.1661717,0.23099
