Skip to content

Commit

Permalink
Merge 1075d57 into 2f9c57c
Browse files Browse the repository at this point in the history
  • Loading branch information
ludoro committed Jul 8, 2020
2 parents 2f9c57c + 1075d57 commit 7556d32
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 0 deletions.
174 changes: 174 additions & 0 deletions src/Earth.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
using LinearAlgebra

mutable struct EarthSurrogate{X,Y,L,U,B,C,P,M,N,R,G,I} <: AbstractSurrogate
x::X
y::Y
lb::L
ub::U
basis::B
coeff::C
penalty::P
n_min_terms::M
n_max_terms::N
rel_res_error::R
rel_GCV::G
intercept::I
end

#1D
_hinge(x::Number,knot::Number) = max(0,x-knot)
_hinge_mirror(x::Number,knot::Number) = max(0,knot-x)

#ND
#inside arr_hing I have functions like g(x) = x -> _hinge(x,5.0) or g(x) = one(x)
_product_hinge(val,arr_hing) = prod([arr_hing[i](val[i]) for i = 1:length(val)])

function _coeff_1d(x,y,basis)
n = length(x)
d = length(basis)
X = zeros(eltype(x[1]),n,d)
@inbounds for i = 1:n
for j = 1:d
X[i,j] = basis[j](x[i])
end
end
return (X'*X)\(X'*y)
end

function _forward_pass_1d(x,y,n_max_terms,rel_res_error)
n = length(x)
basis = Array{Function}(undef,0)
current_sse = +Inf
intercept = sum([y[i] for i =1:length(y)])/length(y)
num_terms = 0
for var_i in x
#Add or not add the knot var_i?
new_basis = copy(basis)
#select best new pair
hinge1 = x-> _hinge(x,var_i)
hinge2 = x-> _hinge_mirror(x,var_i)
push!(new_basis,hinge1)
push!(new_basis,hinge2)
#find coefficients
d = length(new_basis)
X = zeros(eltype(x[1]),n,d)
@inbounds for i = 1:n
for j = 1:d
X[i,j] = new_basis[j](x[i])
end
end
if (cond(X'*X) > 1e8)
condition_number = false
new_sse = +Inf
else
condition_number = true
coeff = (X'*X)\(X'*y)
new_sse = zero(y[1])
d = length(new_basis)
for i = 1:n
val_i = sum(coeff[j]*new_basis[j](x[i]) for j = 1:d) + intercept
new_sse = new_sse + (y[i]-val_i)^2
end
end
if ( (new_sse < current_sse) && (abs(current_sse - new_sse) >= rel_res_error) && condition_number)
#Add the hinge function to the basis
num_terms = num_terms+1
push!(basis,hinge1)
push!(basis,hinge2)
current_sse = new_sse
end
if (num_terms > n_max_terms)
break
end
end
return basis
end

function _backward_pass_1d(x,y,n_min_terms,basis,penalty,rel_GCV)
n = length(x)
d = length(basis)
intercept = sum([y[i] for i =1:length(y)])/length(y)
coeff = _coeff_1d(x,y,basis)
sse = zero(y[1])
for i = 1:n
val_i = sum(coeff[j]*basis[j](x[i]) for j = 1:d) + intercept
sse = sse + (y[i]-val_i)^2
end
current_gcv = sse/(n*(1-d/n)^2)
num_terms = d
while (num_terms > n_min_terms)
#Basis-> select worst performing element-> eliminate it
if num_terms <= 1
break
end
found_new_to_eliminate = false
for i = 1:num_terms
current_basis = copy(basis)
#remove i-esim element from current basis
deleteat!(current_basis,i)
coef = _coeff_1d(x,y,current_basis)
new_sse = zero(y[i])
for i = 1:n
val_i = sum(coeff[j]*basis[j](x[i]) for j = 1:d) + intercept
new_sse = new_sse + (y[i]-val_i)^2
end
i_gcv = new_sse/(n*(1-d/n)^2)
if i_gcv < current_gcv
basis_to_remove = i
new_gcv = i_gcv
found_new_to_eliminate = true
end
end
if !found_new_to_eliminate
break
end
if abs(current_gcv-new_gcv) < rel_GCV
break
else
num_terms = num_terms-1
deleteat!(basis,basis_to_remove)
end
end
return basis
end


function EarthSurrogate(x,y,lb::Number,ub::Number; penalty::Number = 2.0, n_min_terms::Int = 2, n_max_terms::Int = 10, rel_res_error::Number = 1e-2, rel_GCV::Number = 1e-2)
intercept = sum([y[i] for i =1:length(y)])/length(y)
basis_after_forward = _forward_pass_1d(x,y,n_max_terms,rel_res_error)
basis = _backward_pass_1d(x,y,n_min_terms,basis_after_forward,penalty,rel_GCV)
coeff = _coeff_1d(x,y,basis)
return EarthSurrogate(x,y,lb,ub,basis,coeff,penalty,n_min_terms,n_max_terms,rel_res_error,rel_GCV,intercept)
end

function (earth::EarthSurrogate)(val::Number)
return sum([earth.coeff[i]*earth.basis[i](val) for i = 1:length(earth.coeff)])+earth.intercept
end

function EarthSurrogate(x,y,lb,ub; penalty::Number = 2.0, n_min_terms::Int = 2, n_max_terms::Int = 10, rel_res_error::Number = 1e-2, rel_GCV::Number = 1e-2)
return EarthSurrogate(x,y,lb,ub,1,2,3,4,5,6,7,10)
end


function add_point!(earth::EarthSurrogate,x_new,y_new)
if length(earth.x[1]) == 1
#1D
earth.x = vcat(earth.x,x_new)
earth.y = vcat(earth.y,y_new)
earth.intercept = sum([earth.y[i] for i =1:length(earth.y)])/length(earth.y)
basis_after_forward = _forward_pass_1d(earth.x,earth.y,earth.n_max_terms,earth.rel_res_error)
earth.basis = _backward_pass_1d(earth.x,earth.y,earth.n_min_terms,basis_after_forward,earth.penalty,earth.rel_GCV)
earth.coeff = _coeff_1d(earth.x,earth.y,earth.basis)
nothing
else
#ND
earth.x = vcat(earth.x,x_new)
earth.y = vcat(earth.y,y_new)
#earth.intercept =
#earth.basis =
#earth.coeff =

end


end
2 changes: 2 additions & 0 deletions src/Surrogates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ include("Wendland.jl")
include("MOE.jl")
include("VariableFidelity.jl")
include("PolynomialChaos.jl")
include("Earth.jl")

current_surrogates = ["Kriging","LinearSurrogate","LobacheskySurrogate","NeuralSurrogate",
"RadialBasis","RandomForestSurrogate","SecondOrderPolynomialSurrogate","Wendland"]
Expand All @@ -44,5 +45,6 @@ export WendlandStructure
export MOE
export VariableFidelitySurrogate
export PolynomialChaosSurrogate
export EarthSurrogate

end
12 changes: 12 additions & 0 deletions test/earth.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using Surrogates


lb = 0.0
ub = 5.0
n = 20
x = sample(n,lb,ub,SobolSample())
f = x->2*x+x^2
y = f.(x)
my_ear1d = EarthSurrogate(x,y,lb,ub)
val = my_ear1d(3.0)
add_point!(my_ear1d,6.0,48.0)
3 changes: 3 additions & 0 deletions test/optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ y = objective_function.(x)
my_poly1d = PolynomialChaosSurrogate(x,y,lb,ub)
surrogate_optimize(objective_function,SRBF(),a,b,my_poly1d,LowDiscrepancySample(2))

my_earth1d = EarthSurrogate(x,y,lb,ub)
surrogate_optimize(objective_function,SRBF(),a,b,my_earth1d,LowDiscrepancySample(2))

##### ND #####
objective_function_ND = z -> 3*norm(z)+1
lb = [1.0,1.0]
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ using Surrogates
@testset "Wendland" begin include("Wendland.jl") end
@testset "MOE" begin include("MOE.jl") end
@testset "VariableFidelity" begin include("VariableFidelity.jl") end
@testset "Earth" begin include("earth.jl") end

0 comments on commit 7556d32

Please sign in to comment.