/
decision_tree.jl
240 lines (212 loc) · 6.53 KB
/
decision_tree.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
mutable struct Node{T}
feature_idx::Int
feature_val::T
value::Int
left::Node{T}
right::Node{T}
is_terminal::Bool
function Node(feature_idx, feature_val::T) where {T}
node = new{T}()
node.feature_idx = feature_idx
node.feature_val = feature_val
node.is_terminal = false
return node
end
function Node{T}(value) where {T}
node = new{T}()
node.value = value
node.is_terminal = true
return node
end
function Node(feature_idx, feature_val::T, value, is_terminal=false) where {T}
node = new{T}()
node.feature_idx = feature_idx
node.feature_val = feature_val
node.value = value
node.is_terminal = is_terminal
return node
end
function Node{T}() where T
node = new{T}()
node.is_terminal = false
return node
end
end
struct DecisionTreeContainer{T}
root::Node{T}
n_features_per_node::Int
n_classes::Int
max_depth::Int
min_node_records::Int
end
"""
feature_best_split
For a given feature search best split value.
"""
function feature_best_split(containers, X, y, n_classes, feature)
gini_before = containers.gini_before
left = containers.left
right = containers.right
lt = containers.lt
# prepare initial split
left .= 0
right .= 0
for i in axes(X, 1)
right[y[i]] += 1
end
# TODO: I leave it for now. In the future, sorting should be done before feature split,
# so allocations or unsafe arrays would be of no importance.
sort_idx = sortperm(@view X[:, feature])
ll = 1
lr = length(y) - 1
i1 = sort_idx[1]
left[y[i1]] = 1
right[y[i1]] -= 1
prev_val = X[i1, feature]
best_val = prev_val
best_impurity = 0.0
@inbounds for idx in 2:length(y)
i = sort_idx[idx]
if X[i, feature] != prev_val
prev_val = X[i, feature]
impurity = gini_impurity(gini_before, left, right, ll, lr, lt)
if impurity > best_impurity
best_impurity = impurity
best_val = prev_val
end
end
ll += 1
lr -= 1
left[y[i]] += 1
right[y[i]] -= 1
end
return (val = best_val, impurity = best_impurity)
end
function create_containers(n_classes, y)
left = zeros(Int, n_classes)
right = Vector{Int}(undef, n_classes)
lt = length(y)
for i in 1:lt
left[y[i]] += 1
end
gini_before = gini_index(left, lt)
containers = (left = left, right = right, gini_before = gini_before, lt = lt)
return containers
end
# Chooses best feature from features
function best_split(X, target, n_classes, features)
containers = create_containers(n_classes, target)
best_feature = 0
best_val = -Inf
best_impurity = -Inf
for feature in features
val, impurity = feature_best_split(containers, X, target, n_classes, feature)
if impurity > best_impurity
best_val = val
best_feature = feature
best_impurity = impurity
end
end
return (feature = best_feature, val = best_val)
end
function split_value(X, target, n_classes)
res = zeros(Int, n_classes)
for i in axes(X, 1)
res[target[i]] += 1
end
return argmax(res)
end
function get_split_indices(X, feature_idx, feature_val)
return X[:, feature_idx] .< feature_val, X[:, feature_idx] .>= feature_val
end
function is_pure(target)
return all(target[1] .== target)
end
###############################
# Node functions
###############################
function process_node(dtc::DecisionTreeContainer{T}, node, X, target,
rng = Random.GLOBAL_RNG,
features = sample(rng, 1:size(X, 2), dtc.n_features_per_node, replace = false),
depth = 1) where T
if depth > dtc.max_depth
node.is_terminal = true
node.value = split_value(X, target, dtc.n_classes)
elseif length(target) <= dtc.min_node_records
node.is_terminal = true
node.value = split_value(X, target, dtc.n_classes)
elseif is_pure(target)
node.is_terminal = true
node.value = target[1]
else
feature_idx, feature_val = best_split(X, target, dtc.n_classes, features)
node.feature_idx = feature_idx
node.feature_val = feature_val
left_ids, right_ids = get_split_indices(X, feature_idx, feature_val)
left = Node{T}()
right = Node{T}()
node.left = left
node.right = right
new_features = sample(rng, 1:size(X, 2), dtc.n_features_per_node, replace = false)
process_node(dtc, left, X[left_ids, :], target[left_ids], rng, new_features, depth + 1)
process_node(dtc, right, X[right_ids, :], target[right_ids], rng, new_features, depth + 1)
end
end
function create_tree(X, y; rng = Random.GLOBAL_RNG, max_depth = 10, min_node_records = 1,
n_features = size(X, 2))
T = eltype(X)
root = Node{T}()
n_classes = length(Set(y))
dtc = DecisionTreeContainer(root, n_features, n_classes, max_depth, min_node_records)
process_node(dtc, root, X, y, rng)
return root
end
function predict(node::Node, row)
if node.is_terminal
return node.value
else
if row[node.feature_idx] < node.feature_val
return predict(node.left, row)
else
return predict(node.right, row)
end
end
end
function predict(node::Node, X, i)
if node.is_terminal
return node.value
else
if X[i, node.feature_idx] < node.feature_val
return predict(node.left, X, i)
else
return predict(node.right, X, i)
end
end
end
function Base.:show(io::IO, node::Node, prefix = "")
if node.is_terminal
if length(prefix) > 0
new_prefix = collect(prefix)
new_prefix[end] = '↳'
new_prefix = join(new_prefix)
else
new_prefix = prefix
end
write(io, new_prefix, "[$(node.value)]", "\n")
elseif !isdefined(node, :left)
write(io, "[]\n")
else
if length(prefix) > 0
new_prefix = collect(prefix)
new_prefix[end] = '↦'
new_prefix = join(new_prefix)
else
new_prefix = prefix
end
write(io, new_prefix, "[X$(node.feature_idx) < $(node.feature_val)]", "\n")
left_prefix = prefix * "⎸"
right_prefix = prefix * " "
show(io, node.left, left_prefix)
show(io, node.right, right_prefix)
end
end