Skip to content

Commit

Permalink
Float64 replaced by AbstractFloat for regression
Browse files Browse the repository at this point in the history
  • Loading branch information
xinadi committed Oct 13, 2023
1 parent 605e4d4 commit 7de046f
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 23 deletions.
14 changes: 5 additions & 9 deletions src/classification/main.jl
Expand Up @@ -93,7 +93,7 @@ function update_pruned_impurity!(
feature_importance::Vector{Float64},
ntt::Int,
loss::Function=mean_squared_error,
) where {S,T<:Float64}
) where {S,T<:AbstractFloat}
μl = mean(tree.left.values)
nl = length(tree.left.values)
μr = mean(tree.right.values)
Expand Down Expand Up @@ -220,7 +220,7 @@ See also [`build_tree`](@ref).
function prune_tree(
tree::Union{Root{S,T},LeafOrNode{S,T}},
purity_thresh=1.0,
loss::Function=T <: Float64 ? mean_squared_error : util.entropy,
loss::Function=T <: AbstractFloat ? mean_squared_error : util.entropy,
) where {S,T}
if purity_thresh >= 1.0
return tree
Expand Down Expand Up @@ -293,11 +293,7 @@ function apply_tree(tree::LeafOrNode{S,T}, features::AbstractMatrix{S}) where {S
for i in 1:N
predictions[i] = apply_tree(tree, features[i, :])
end
if T <: Float64
return Float64.(predictions)
else
return predictions
end
return predictions
end

"""
Expand Down Expand Up @@ -343,7 +339,7 @@ end
Train a random forest model, built on standard CART decision trees, using the specified
`labels` (target) and `features` (patterns). Here:
- `labels` is any `AbstractVector`. If the element type is `Float64`, regression is
- `labels` is any `AbstractVector`. If the element type is `AbstractFloat`, regression is
applied, and otherwise classification is applied.
- `features` is any `AbstractMatrix{T}` where `T` supports ordering with `<` (unordered
Expand Down Expand Up @@ -619,7 +615,7 @@ function apply_forest(forest::Ensemble{S,T}, features::AbstractVector{S}) where
votes[i] = apply_tree(forest.trees[i], features)
end

if T <: Float64
if T <: AbstractFloat
return mean(votes)
else
return majority_vote(votes)
Expand Down
6 changes: 3 additions & 3 deletions src/measures.jl
Expand Up @@ -269,7 +269,7 @@ function _nfoldCV(
args...;
verbose,
rng,
) where {T<:Float64}
) where {T<:AbstractFloat}
_rng = mk_rng(rng)::Random.AbstractRNG
nfolds = args[1]
if nfolds < 2
Expand Down Expand Up @@ -361,7 +361,7 @@ function nfoldCV_tree(
min_purity_increase::Float64=0.0;
verbose::Bool=true,
rng=Random.GLOBAL_RNG,
) where {S,T<:Float64}
) where {S,T<:AbstractFloat}
_nfoldCV(
:tree,
labels,
Expand Down Expand Up @@ -389,7 +389,7 @@ function nfoldCV_forest(
min_purity_increase::Float64=0.0;
verbose::Bool=true,
rng=Random.GLOBAL_RNG,
) where {S,T<:Float64}
) where {S,T<:AbstractFloat}
_nfoldCV(
:forest,
labels,
Expand Down
8 changes: 4 additions & 4 deletions src/regression/main.jl
@@ -1,6 +1,6 @@
include("tree.jl")

function _convert(node::treeregressor.NodeMeta{S}, labels::Array{T}) where {S,T<:Float64}
function _convert(node::treeregressor.NodeMeta{S}, labels::Array{T}) where {S,T<:AbstractFloat}
if node.is_leaf
return Leaf{T}(node.label, labels[node.region])
else
Expand All @@ -27,7 +27,7 @@ function build_stump(
features::AbstractMatrix{S};
rng=Random.GLOBAL_RNG,
impurity_importance::Bool=true,
) where {S,T<:Float64}
) where {S,T<:AbstractFloat}
return build_tree(labels, features, 0, 1; rng, impurity_importance)
end

Expand All @@ -41,7 +41,7 @@ function build_tree(
min_purity_increase=0.0;
rng=Random.GLOBAL_RNG,
impurity_importance::Bool=true,
) where {S,T<:Float64}
) where {S,T<:AbstractFloat}
if max_depth == -1
max_depth = typemax(Int)
end
Expand Down Expand Up @@ -85,7 +85,7 @@ function build_forest(
min_purity_increase=0.0;
rng::Union{Integer,AbstractRNG}=Random.GLOBAL_RNG,
impurity_importance::Bool=true,
) where {S,T<:Float64}
) where {S,T<:AbstractFloat}
if n_trees < 1
throw("the number of trees must be >= 1")
end
Expand Down
14 changes: 7 additions & 7 deletions src/regression/tree.jl
Expand Up @@ -47,7 +47,7 @@ end
# (max_depth, min_samples_split, min_purity_increase)
function _split!(
X::AbstractMatrix{S}, # the feature array
Y::AbstractVector{Float64}, # the label array
Y::AbstractVector{T}, # the label array
W::AbstractVector{U},
node::NodeMeta{S}, # the node to split
max_features::Int, # number of features to consider
Expand All @@ -59,10 +59,10 @@ function _split!(
# we split using samples in indX[node.region]
# the two arrays below are given for optimization purposes
Xf::AbstractVector{S},
Yf::AbstractVector{Float64},
Yf::AbstractVector{T},
Wf::AbstractVector{U},
rng::Random.AbstractRNG,
) where {S,U}
) where {S,T<:AbstractFloat,U}
region = node.region
n_samples = length(region)
r_start = region.start - 1
Expand Down Expand Up @@ -245,18 +245,18 @@ end

function _fit(
X::AbstractMatrix{S},
Y::AbstractVector{Float64},
Y::AbstractVector{T},
W::AbstractVector{U},
max_features::Int,
max_depth::Int,
min_samples_leaf::Int,
min_samples_split::Int,
min_purity_increase::Float64,
rng=Random.GLOBAL_RNG::Random.AbstractRNG,
) where {S,U}
) where {S,T<:AbstractFloat,U}
n_samples, n_features = size(X)

Yf = Array{Float64}(undef, n_samples)
Yf = Array{T}(undef, n_samples)
Xf = Array{S}(undef, n_samples)
Wf = Array{U}(undef, n_samples)

Expand Down Expand Up @@ -293,7 +293,7 @@ end

function fit(;
X::AbstractMatrix{S},
Y::AbstractVector{Float64},
Y::AbstractVector{<:AbstractFloat},
W::Union{Nothing,AbstractVector{U}},
max_features::Int,
max_depth::Int,
Expand Down

0 comments on commit 7de046f

Please sign in to comment.