Skip to content

Commit

Permalink
Merge e40102d into 33971f3
Browse files Browse the repository at this point in the history
  • Loading branch information
ludoro committed May 19, 2020
2 parents 33971f3 + e40102d commit 05f8dc5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/RandomForestSurrogate.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using XGBoost
using XGBoost
mutable struct RandomForestSurrogate{X,Y,B,L,U,N} <: AbstractSurrogate
x::X
y::Y
Expand All @@ -8,7 +8,7 @@ mutable struct RandomForestSurrogate{X,Y,B,L,U,N} <: AbstractSurrogate
num_round::N
end

function RandomForestSurrogate(x,y,lb::Number,ub::Number,num_round)
function RandomForestSurrogate(x,y,lb::Number,ub::Number; num_round::Int = 1)
bst = xgboost(reshape(x,length(x),1), num_round, label = y)
RandomForestSurrogate(x,y,bst,lb,ub,num_round)
end
Expand All @@ -23,7 +23,7 @@ RandomForestSurrogate(x,y,lb,ub,num_round)
Build Random forest surrogate. num_round is the number of trees.
"""
function RandomForestSurrogate(x,y,lb,ub,num_round)
function RandomForestSurrogate(x,y,lb,ub;num_round::Int = 1)
X = Array{Float64,2}(undef,length(x),length(x[1]))
for j = 1:length(x)
X[j,:] = vec(collect(x[j]))
Expand Down
2 changes: 1 addition & 1 deletion test/optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ x = sample(5,lb,ub,SobolSample())
objective_function_ND = z -> 3*norm(z)+1
y = objective_function_ND.(x)
num_round = 2
my_forest_ND_SRBF = RandomForestSurrogate(x,y,lb,ub,num_round)
my_forest_ND_SRBF = RandomForestSurrogate(x,y,lb,ub,num_round=2)
surrogate_optimize(objective_function_ND,SRBF(),lb,ub,my_forest_ND_SRBF,SobolSample(),maxiters=15)

#Inverse distance surrogate
Expand Down
6 changes: 4 additions & 2 deletions test/random_forest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ y = obj_1D.(x)
a = 0.0
b = 10.0
num_round = 2
my_forest_1D = RandomForestSurrogate(x,y,a,b,num_round)
my_forest_1D = RandomForestSurrogate(x,y,a,b,num_round = 2)
my_forest_kwarg = RandomForestSurrogate(x,y,a,b)
val = my_forest_1D(3.5)
add_point!(my_forest_1D,6.0,19.0)
add_point!(my_forest_1D,[7.0,8.0],obj_1D.([7.0,8.0]))
Expand All @@ -18,7 +19,8 @@ ub = [10.0,10.0,10.0]
x = sample(5,lb,ub,SobolSample())
obj_ND = x -> x[1] * x[2]^2 * x[3]
y = obj_ND.(x)
my_forest_ND = RandomForestSurrogate(x,y,lb,ub,num_round)
my_forest_ND = RandomForestSurrogate(x,y,lb,ub,num_round = 2)
my_forest_kwarg = RandomForestSurrogate(x,y,lb,ub)
val = my_forest_ND((1.0,1.0,1.0))
add_point!(my_forest_ND,(1.0,1.0,1.0),1.0)
add_point!(my_forest_ND,[(1.2,1.2,1.0),(1.5,1.5,1.0)],[1.728,3.375])

0 comments on commit 05f8dc5

Please sign in to comment.