Skip to content
Permalink
Browse files

update vivo

  • Loading branch information...
kozaka93 committed May 25, 2019
1 parent 46ce5c0 commit e7b542cff00f9586764525b99dae6d6b42948ee6
@@ -1,6 +1,6 @@
Package: LocalVariableImportanceViaOscillations
Package: vivo
Title: Local variable importance via oscillations of Ceteris Paribus profiles
Version: 0.0.0.9000
Version: 0.0.1
Authors@R: person("Anna", "Kozak", email = "anna1993kozak@gmail.com", role = c("aut", "cre"))
Description: Calculates a local variable importance via oscillations of Ceteris Paribus profiles.
Depends: R (>= 3.0)
@@ -1,7 +1,7 @@
# Generated by roxygen2: do not edit by hand

S3method(plot,local_importance)
export(CalculateWeight)
export(LocalVariableImportance)
export(calculate_weight)
export(local_variable_importance)
import(ggplot2)
import(ingredients)

This file was deleted.

This file was deleted.

@@ -0,0 +1,47 @@
#' Calculated empirical density and weight based on variable split.
#'
#' This function calculate an empirical density of raw data based on variable split from Ceteris Paribus profiles. Then calculated weight for values generated by ingredients::ceteris_paribus().
#'
#' @param profiles data.frame generated by ingredients::ceteris_paribus()
#' @param data data.frame with raw data to model
#' @param variable_split list generated by ingredients::calculate_variable_split()
#'
#' @return Return an weight based on empirical density.
#'
#' @examples
#' \dontrun{
#'
#' split <- ingredients::calculate_variable_split(data, variables = colnames(data))
#' calculate_weight(profiles, data, variable_split = split)
#' }
#'
#' @export
#'
calculate_weight <- function(profiles, data, variable_split){
if (!(c("ceteris_paribus_explainer") %in% class(profiles)))
stop("The CalculateWeight() function requires an object created with ceteris_paribus() function.")
if (!(c("list") %in% class(variable_split)))
stop("The CalculateWeight() function requires an object created with calculate_variable_split() function.")
if (!(c("data.frame") %in% class(data)))
stop("The CalculateWeight() function requires a data.frame.")
cut_range <- lapply(unique(profiles$`_vname_`), function(x){
data.frame(table(cut(data[, as.vector(as.character(x))],
unique(c(min(data[, as.vector(as.character(x))]),
variable_split[[as.character(x)]],
max(data[, as.vector(as.character(x))]))), include.lowest = TRUE))/nrow(data))})
weight_range <- lapply(unique(profiles$`_vname_`), function(x){
data.frame("Var1" = cut(profiles[profiles$`_vname_` == x, as.vector(as.character(x))],
unique(c(min(data[, as.vector(as.character(x))]),
variable_split[[as.character(x)]],
max(data[, as.vector(as.character(x))]))), include.lowest = TRUE),
"Value" = profiles[profiles$`_vname_` == x, as.vector(as.character(x))])
})
names(cut_range) <- as.vector(unique(profiles$`_vname_`))
names(weight_range) <- as.vector(unique(profiles$`_vname_`))
weight <- lapply(as.vector(unique(profiles$`_vname_`)), function(x){
unname(unlist(dplyr::left_join(weight_range[[x]], cut_range[[x]], by = "Var1")["Freq"]))
})
names(weight) <- as.vector(unique(profiles$`_vname_`))
weight
}

@@ -0,0 +1,92 @@
#' Local Variable Importance measure based on Ceteris Paribus profiles.
#'
#' This function calculate local importance measure in eight variants. We obtain eight variants measure through the possible options of three parameters such as `absolute_deviation`, `point` and `density`.
#'
#' @param profiles data.frame generated by ingredients::ceteris_paribus()
#' @param data data.frame with raw data to model
#' @param absolute_deviation logical parameter, if `absolute_deviation = TRUE` then measue is calculated as absolute deviation, else is calculated as a root from average squares
#' @param point logical parameter, if `point = TRUE` then measure is calculated as a distance from f(x), else measure is calculated as a distance from average profiles
#' @param density logical parameter, if `density = TRUE` then measure is weighted based on the density of variable, else is not weighted
#'
#' @return A list of the class 'local_variable_importance'.
#' It's a list with calculated local variable importance measure.
#' @examples
#' \dontrun{
#'
#' local_variable_importance(profiles, data, absolute_deviation = TRUE, point = TRUE, density = FALSE)
#'
#' }
#'
#' @export
#'


local_variable_importance <- function(profiles, data, absolute_deviation = TRUE, point = TRUE, density = TRUE){
if (!(c("ceteris_paribus_explainer") %in% class(profiles)))
stop("The LocalVariableImportance() function requires an object created with ceteris_paribus() function.")
if (!c("data.frame") %in% class(data))
stop("The LocalVariableImportance() function requires a data.frame.")

avg_yhat <- lapply(unique(profiles$`_vname_`), function(x){
mean(profiles$`_yhat_`[profiles$`_vname_` == x])
})
names(avg_yhat) <- unique(profiles$`_vname_`)

variable_split <- ingredients::calculate_variable_split(data, variables = colnames(data))

weight <- vivo::calculate_weight(profiles, data, variable_split = variable_split)

obs <- attr(profiles, "observations")


if(absolute_deviation == TRUE){
if(point == TRUE){
if(density == TRUE){
result <- unlist(lapply(unique(profiles$`_vname_`), function(m){
sum(abs(weight[[m]] * (profiles[profiles$`_vname_` == m, "_yhat_"] - unlist(unname(obs["_yhat_"])))))
}))
}else{
result <- unlist(lapply(unique(profiles$`_vname_`), function(w){
sum(abs((profiles[profiles$`_vname_` == w, "_yhat_"] - unlist(unname(obs["_yhat_"])))))
}))
}
}else{
if(density == TRUE){
result <- unlist(lapply(unique(profiles$`_vname_`), function(m){
sum(abs(weight[[m]] *(profiles[profiles$`_vname_` == m, "_yhat_"] - avg_yhat[[m]])))
}))
}else{
result <- unlist(lapply(unique(profiles$`_vname_`), function(w){
sum(abs((profiles[profiles$`_vname_` == w, "_yhat_"] - avg_yhat[[w]])))
}))
}
}
}else{
if(point == TRUE){
if(density == TRUE){
result <- unlist(lapply(unique(profiles$`_vname_`), function(m){
sqrt(sum((weight[[m]] * (profiles[profiles$`_vname_` == m, "_yhat_"] - unlist(unname(obs["_yhat_"]))))^2)/length(profiles[profiles$`_vname_` == m, "_yhat_"]))
}))
}else{
result <- unlist(lapply(unique(profiles$`_vname_`), function(w){
sqrt(sum((profiles[profiles$`_vname_` == w, "_yhat_"] - unlist(unname(obs["_yhat_"])))^2)/length(profiles[profiles$`_vname_` == w, "_yhat_"]))
}))
}
}else{
if(density == TRUE){
result <- unlist(lapply(unique(profiles$`_vname_`), function(m){
sqrt(sum((weight[[m]] * (profiles[profiles$`_vname_` == m, "_yhat_"] - avg_yhat[[m]]))^2)/length(profiles[profiles$`_vname_` == m, "_yhat_"]))
}))
}else{
result <- unlist(lapply(unique(profiles$`_vname_`), function(w){
sqrt(sum((profiles[profiles$`_vname_` == w, "_yhat_"] - avg_yhat[[w]])^2)/(length(profiles[profiles$`_vname_` == w, "_yhat_"])))
}))
}
}
}

lvivo = data.frame(variable_name = unique(profiles$`_vname_`), measure = result)
class(lvivo) = c("local_importance", "data.frame")
lvivo
}

@@ -2,13 +2,13 @@
#'
#' Function plot.local_importance plots local importance measure based on Ceteris Paribus profiles.
#'
#' @param x object returned from `LocalVariableImportance()` function
#' @param x object returned from `local_variable_importance()` function
#' @param ... other parameters
#' @return a ggplot2 object
#'
#' @examples
#' \dontrun{
#' measure <- LocalVariableImportance(cp, df, absolute_deviation = TRUE, point = TRUE, density = FALSE)
#' measure <- local_variable_importance(profiles, data, absolute_deviation = TRUE, point = TRUE, density = FALSE)
#' plot(measure)
#' }
#'

0 comments on commit e7b542c

Please sign in to comment.
You can’t perform that action at this time.