-
Notifications
You must be signed in to change notification settings - Fork 11
/
reg_tree_imp.R
175 lines (143 loc) · 6.07 KB
/
reg_tree_imp.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Regression Tree ---------------------------------------------------------
# This is the splitting criterion we minimize (SSE [Sum Of Squared Errors]):
# $SSE = \sum_{i \in S_1} (y_i - \bar(y)1)^2 + \sum_{i \in S_2} (y_i - \bar(y)2)^2$
sse_var <- function(x, y) {
splits <- sort(unique(x))
sse <- c()
for (i in seq_along(splits)) {
sp <- splits[i]
sse[i] <- sum((y[x < sp] - mean(y[x < sp]))^2) + sum((y[x >= sp] - mean(y[x >= sp]))^2)
}
split_at <- splits[which.min(sse)]
return(c(sse = min(sse), split = split_at))
}
#' reg_tree
#' Fits a simple regression tree with SSE splitting criterion. The estimator function
#' is the mean.
#'
#' @param formula an object of class formula
#' @param data a data.frame or matrix
#' @param minsize a numeric value indicating the minimum size of observations
#' in a leaf
#'
#' @return \itemize{
#' \item tree - the tree object containing all splitting rules and observations
#' \item imp - returns the feature importance
#' \item fit - our fitted values, i.e. X %*% theta
#' \item formula - the underlying formula
#' \item data - the underlying data
#' }
#' @export
#'
#' @examples # Complete runthrough see: www.github.com/andrebleier/cheapml
reg_tree_imp <- function(formula, data, minsize) {
# coerce to data.frame
data <- as.data.frame(data)
# handle formula
formula <- terms.formula(formula)
# get the design matrix
X <- model.matrix(formula, data)
# extract target
y <- data[, as.character(formula)[2]]
# initialize while loop
do_splits <- TRUE
# create output data.frame with splitting rules and observations
tree_info <- data.frame(NODE = 1, NOBS = nrow(data), FILTER = NA, TERMINAL = "SPLIT",
IMP_GINI = NA, SPLIT = NA, stringsAsFactors = FALSE)
# keep splitting until there are only leafs left
while(do_splits) {
# which parents have to be splitted
to_calculate <- which(tree_info$TERMINAL == "SPLIT")
for (j in to_calculate) {
# handle root node
if (!is.na(tree_info[j, "FILTER"])) {
# subset data according to the filter
this_data <- subset(data, eval(parse(text = tree_info[j, "FILTER"])))
# get the design matrix
X <- model.matrix(formula, this_data)
} else {
this_data <- data
}
# estimate splitting criteria
splitting <- apply(X, MARGIN = 2, FUN = sse_var, y = this_data[, all.vars(formula)[1]])
# get the min SSE
tmp_splitter <- which.min(splitting[1,])
# define maxnode
mn <- max(tree_info$NODE)
# paste filter rules
current_filter <- c(paste(names(tmp_splitter), ">=",
splitting[2,tmp_splitter]),
paste(names(tmp_splitter), "<",
splitting[2,tmp_splitter]))
# Error handling! check if the splitting rule has already been invoked
split_here <- !sapply(current_filter,
FUN = function(x,y) any(grepl(x, x = y)),
y = tree_info$FILTER)
# append the splitting rules
if (!is.na(tree_info[j, "FILTER"])) {
current_filter <- paste(tree_info[j, "FILTER"],
current_filter, sep = " & ")
}
# calculate metrics within the children
metr <- lapply(current_filter,
FUN = function(i, x, data, formula) {
df <- subset(x = x, subset = eval(parse(text = i)))
nobs <- nrow(df)
w <- nobs/nrow(data)
y <- df[, all.vars(formula)[1]]
imp <- mean((y - mean(y, na.rm = TRUE))^2)
return(c(nobs, w*imp))
},
x = this_data, data = data, formula = formula)
# extract relevant information
current_nobs <- sapply(metr, function(x) x[[1]])
imp_sum_child <- sum(sapply(metr, function(x) x[[2]]))
current_y <- this_data[, all.vars(formula)[1]]
imp_parent <- nrow(this_data)/nrow(data) * mean((current_y-mean(current_y))^2)
imp_gini <- imp_parent - imp_sum_child
# insufficient minsize for split
if (any(current_nobs <= minsize)) {
split_here <- rep(FALSE, 2)
}
# create children data frame
children <- data.frame(NODE = c(mn+1, mn+2),
NOBS = current_nobs,
FILTER = current_filter,
TERMINAL = rep("SPLIT", 2),
IMP_GINI = NA,
SPLIT = NA,
row.names = NULL)[split_here,]
# overwrite state of current node, add gini importance and split variable
tree_info[j, "TERMINAL"] <- ifelse(all(!split_here), "LEAF", "PARENT")
tree_info[j, "IMP_GINI"] <- imp_gini
if (tree_info[j, "TERMINAL"] == "PARENT") {
tree_info[j, "SPLIT"] <- names(tmp_splitter)
}
# bind everything
tree_info <- rbind(tree_info, children)
# check if there are any open splits left
do_splits <- !all(tree_info$TERMINAL != "SPLIT")
} # end for
} # end while
# calculate fitted values
leafs <- tree_info[tree_info$TERMINAL == "LEAF", ]
fitted <- c()
for (i in seq_len(nrow(leafs))) {
# extract index
ind <- as.numeric(rownames(subset(data, eval(parse(text = leafs[i, "FILTER"])))))
# estimator is the mean y value of the leaf
fitted[ind] <- mean(y[ind])
}
# calculate feature importance
imp <- tree_info[, c("SPLIT", "IMP_GINI")]
if (!all(is.na(imp$SPLIT))) {
imp <- aggregate(IMP_GINI ~ SPLIT, FUN = function(x, all) sum(x, na.rm = T)/sum(all, na.rm = T),
data = imp, all = imp$IMP_GINI)
}
# rename to importance
names(imp) <- c("FEATURES", "IMPORTANCE")
imp <- imp[order(imp$IMPORTANCE, decreasing = TRUE),]
# return everything
return(list(tree = tree_info, fit = fitted, formula = formula,
importance = imp, data = data))
}