diff --git a/src/classification/main.jl b/src/classification/main.jl index 3c2bdfae..07bf3e1f 100644 --- a/src/classification/main.jl +++ b/src/classification/main.jl @@ -93,7 +93,7 @@ function update_pruned_impurity!( feature_importance::Vector{Float64}, ntt::Int, loss::Function=mean_squared_error, -) where {S,T<:Float64} +) where {S,T<:AbstractFloat} μl = mean(tree.left.values) nl = length(tree.left.values) μr = mean(tree.right.values) @@ -220,7 +220,7 @@ See also [`build_tree`](@ref). function prune_tree( tree::Union{Root{S,T},LeafOrNode{S,T}}, purity_thresh=1.0, - loss::Function=T <: Float64 ? mean_squared_error : util.entropy, + loss::Function=T <: AbstractFloat ? mean_squared_error : util.entropy, ) where {S,T} if purity_thresh >= 1.0 return tree @@ -293,11 +293,7 @@ function apply_tree(tree::LeafOrNode{S,T}, features::AbstractMatrix{S}) where {S for i in 1:N predictions[i] = apply_tree(tree, features[i, :]) end - if T <: Float64 - return Float64.(predictions) - else - return predictions - end + return predictions end """ @@ -343,7 +339,7 @@ end Train a random forest model, built on standard CART decision trees, using the specified `labels` (target) and `features` (patterns). Here: -- `labels` is any `AbstractVector`. If the element type is `Float64`, regression is +- `labels` is any `AbstractVector`. If the element type is `AbstractFloat`, regression is applied, and otherwise classification is applied. - `features` is any `AbstractMatrix{T}` where `T` supports ordering with `<` (unordered @@ -619,7 +615,7 @@ function apply_forest(forest::Ensemble{S,T}, features::AbstractVector{S}) where votes[i] = apply_tree(forest.trees[i], features) end - if T <: Float64 + if T <: AbstractFloat return mean(votes) else return majority_vote(votes) diff --git a/src/measures.jl b/src/measures.jl index f24653cd..6f3b3498 100644 --- a/src/measures.jl +++ b/src/measures.jl @@ -269,7 +269,7 @@ function _nfoldCV( args...; verbose, rng, -) where {T<:Float64} +) where {T<:AbstractFloat} _rng = mk_rng(rng)::Random.AbstractRNG nfolds = args[1] if nfolds < 2 @@ -361,7 +361,7 @@ function nfoldCV_tree( min_purity_increase::Float64=0.0; verbose::Bool=true, rng=Random.GLOBAL_RNG, -) where {S,T<:Float64} +) where {S,T<:AbstractFloat} _nfoldCV( :tree, labels, @@ -389,7 +389,7 @@ function nfoldCV_forest( min_purity_increase::Float64=0.0; verbose::Bool=true, rng=Random.GLOBAL_RNG, -) where {S,T<:Float64} +) where {S,T<:AbstractFloat} _nfoldCV( :forest, labels, diff --git a/src/regression/main.jl b/src/regression/main.jl index 5e4e89f7..1eaf7a4f 100644 --- a/src/regression/main.jl +++ b/src/regression/main.jl @@ -1,6 +1,6 @@ include("tree.jl") -function _convert(node::treeregressor.NodeMeta{S}, labels::Array{T}) where {S,T<:Float64} +function _convert(node::treeregressor.NodeMeta{S}, labels::Array{T}) where {S,T<:AbstractFloat} if node.is_leaf return Leaf{T}(node.label, labels[node.region]) else @@ -27,7 +27,7 @@ function build_stump( features::AbstractMatrix{S}; rng=Random.GLOBAL_RNG, impurity_importance::Bool=true, -) where {S,T<:Float64} +) where {S,T<:AbstractFloat} return build_tree(labels, features, 0, 1; rng, impurity_importance) end @@ -41,7 +41,7 @@ function build_tree( min_purity_increase=0.0; rng=Random.GLOBAL_RNG, impurity_importance::Bool=true, -) where {S,T<:Float64} +) where {S,T<:AbstractFloat} if max_depth == -1 max_depth = typemax(Int) end @@ -85,7 +85,7 @@ function build_forest( min_purity_increase=0.0; rng::Union{Integer,AbstractRNG}=Random.GLOBAL_RNG, impurity_importance::Bool=true, -) where {S,T<:Float64} +) where {S,T<:AbstractFloat} if n_trees < 1 throw("the number of trees must be >= 1") end diff --git a/src/regression/tree.jl b/src/regression/tree.jl index 5a9ae4c9..9fd39339 100644 --- a/src/regression/tree.jl +++ b/src/regression/tree.jl @@ -47,7 +47,7 @@ end # (max_depth, min_samples_split, min_purity_increase) function _split!( X::AbstractMatrix{S}, # the feature array - Y::AbstractVector{Float64}, # the label array + Y::AbstractVector{T}, # the label array W::AbstractVector{U}, node::NodeMeta{S}, # the node to split max_features::Int, # number of features to consider @@ -59,10 +59,10 @@ function _split!( # we split using samples in indX[node.region] # the two arrays below are given for optimization purposes Xf::AbstractVector{S}, - Yf::AbstractVector{Float64}, + Yf::AbstractVector{T}, Wf::AbstractVector{U}, rng::Random.AbstractRNG, -) where {S,U} +) where {S,T<:AbstractFloat,U} region = node.region n_samples = length(region) r_start = region.start - 1 @@ -245,7 +245,7 @@ end function _fit( X::AbstractMatrix{S}, - Y::AbstractVector{Float64}, + Y::AbstractVector{T}, W::AbstractVector{U}, max_features::Int, max_depth::Int, @@ -253,10 +253,10 @@ function _fit( min_samples_split::Int, min_purity_increase::Float64, rng=Random.GLOBAL_RNG::Random.AbstractRNG, -) where {S,U} +) where {S,T<:AbstractFloat,U} n_samples, n_features = size(X) - Yf = Array{Float64}(undef, n_samples) + Yf = Array{T}(undef, n_samples) Xf = Array{S}(undef, n_samples) Wf = Array{U}(undef, n_samples) @@ -293,7 +293,7 @@ end function fit(; X::AbstractMatrix{S}, - Y::AbstractVector{Float64}, + Y::AbstractVector{<:AbstractFloat}, W::Union{Nothing,AbstractVector{U}}, max_features::Int, max_depth::Int,