Skip to content

Commit

Permalink
mars completed
Browse files Browse the repository at this point in the history
  • Loading branch information
ludoro committed Jul 9, 2020
1 parent ced0ea4 commit 02ef59d
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 29 deletions.
107 changes: 81 additions & 26 deletions src/Earth.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using LinearAlgebra

mutable struct EarthSurrogate{X,Y,L,U,B,C,P,M,N,R,G,I} <: AbstractSurrogate
mutable struct EarthSurrogate{X,Y,L,U,B,C,P,M,N,R,G,I,T} <: AbstractSurrogate
x::X
y::Y
lb::L
Expand All @@ -13,6 +13,7 @@ mutable struct EarthSurrogate{X,Y,L,U,B,C,P,M,N,R,G,I} <: AbstractSurrogate
rel_res_error::R
rel_GCV::G
intercept::I
maxiters::T
end


Expand All @@ -31,14 +32,15 @@ mutable struct EarthSurrogate{X,Y,L,U,B,C,P,M,N,R,G,I} <: AbstractSurrogate
return (X'*X)\(X'*y)
end

function _forward_pass_1d(x,y,n_max_terms,rel_res_error)
function _forward_pass_1d(x,y,n_max_terms,rel_res_error,maxiters)
n = length(x)
basis = Array{Function}(undef,0)
current_sse = +Inf
intercept = sum([y[i] for i =1:length(y)])/length(y)
num_terms = 0
pos_of_knot = 0
while num_terms < n_max_terms
iters = 0
while num_terms < n_max_terms || iters > maxiters
#Look for best addition:
new_addition = false
for i = 1:length(x)
Expand Down Expand Up @@ -79,6 +81,7 @@ mutable struct EarthSurrogate{X,Y,L,U,B,C,P,M,N,R,G,I} <: AbstractSurrogate
new_addition = true
end
end
iters = iters+1
if new_addition
best_hinge1 = z-> _hinge(z,x[pos_of_knot])
best_hinge2 = z-> _hinge_mirror(z,x[pos_of_knot])
Expand All @@ -105,25 +108,27 @@ mutable struct EarthSurrogate{X,Y,L,U,B,C,P,M,N,R,G,I} <: AbstractSurrogate
val_i = sum(coeff[j]*basis[j](x[i]) for j = 1:d) + intercept
sse = sse + (y[i]-val_i)^2
end
current_gcv = sse/(n*(1-d/n)^2)
effect_num_params = d + penalty*(d-1)/2
current_gcv = sse/(n*(1-effect_num_params/n)^2)
num_terms = d
while (num_terms > n_min_terms)
#Basis-> select worst performing element-> eliminate it
if num_terms <= 1
break
end
found_new_to_eliminate = false
for i = 1:num_terms
for i = 1:n
current_basis = copy(basis)
#remove i-esim element from current basis
deleteat!(current_basis,i)
coef = _coeff_1d(x,y,current_basis)
new_sse = zero(y[i])
for i = 1:n
val_i = sum(coeff[j]*basis[j](x[i]) for j = 1:d) + intercept
new_sse = new_sse + (y[i]-val_i)^2
for a = 1:n
val_a = sum(coeff[j]*basis[j](x[a]) for j = 1:num_terms) + intercept
new_sse = new_sse + (y[a]-val_a)^2
end
i_gcv = new_sse/(n*(1-d/n)^2)
effect_num_params = num_terms + penalty*(num_terms-1)/2
i_gcv = new_sse/(n*(1-effect_num_params/n)^2)
if i_gcv < current_gcv
basis_to_remove = i
new_gcv = i_gcv
Expand All @@ -144,12 +149,12 @@ mutable struct EarthSurrogate{X,Y,L,U,B,C,P,M,N,R,G,I} <: AbstractSurrogate
end


function EarthSurrogate(x,y,lb::Number,ub::Number; penalty::Number = 2.0, n_min_terms::Int = 2, n_max_terms::Int = 10, rel_res_error::Number = 1e-2, rel_GCV::Number = 1e-2)
function EarthSurrogate(x,y,lb::Number,ub::Number; penalty::Number = 2.0, n_min_terms::Int = 2, n_max_terms::Int = 10, rel_res_error::Number = 1e-2, rel_GCV::Number = 1e-2, maxiters = 100)
intercept = sum([y[i] for i =1:length(y)])/length(y)
basis_after_forward = _forward_pass_1d(x,y,n_max_terms,rel_res_error)
basis_after_forward = _forward_pass_1d(x,y,n_max_terms,rel_res_error,maxiters)
basis = _backward_pass_1d(x,y,n_min_terms,basis_after_forward,penalty,rel_GCV)
coeff = _coeff_1d(x,y,basis)
return EarthSurrogate(x,y,lb,ub,basis,coeff,penalty,n_min_terms,n_max_terms,rel_res_error,rel_GCV,intercept)
return EarthSurrogate(x,y,lb,ub,basis,coeff,penalty,n_min_terms,n_max_terms,rel_res_error,rel_GCV,intercept,maxiters)
end

function (earth::EarthSurrogate)(val::Number)
Expand All @@ -159,7 +164,7 @@ end

#ND
#inside arr_hing I have functions like g(x) = x -> _hinge(x,5.0) or g(x) = one(x)
_product_hinge(val,arr_hing) = prod([arr_hing[i](val[i]) for i = 1:length(val)])
#_product_hinge(val,arr_hing) = prod([arr_hing[i](val[i]) for i = 1:length(val)])

function _coeff_nd(x,y,basis)
n = length(x)
Expand All @@ -174,7 +179,7 @@ function _coeff_nd(x,y,basis)
return (X'*X)\(X'*y)
end

function _forward_pass_nd(x,y,n_max_terms,rel_res_error)
function _forward_pass_nd(x,y,n_max_terms,rel_res_error,maxiters)
n = length(x)
basis = Array{Array{Function,1}}(undef,0) #ex of basis: push!(x,[x->1.0,x->1.0,x->_hinge(x,knot)]) so basis is of the form arr_hing and then I can pass my val
current_sse = +Inf
Expand All @@ -184,7 +189,8 @@ function _forward_pass_nd(x,y,n_max_terms,rel_res_error)
d = length(x[1])
best_hinge1 = best_hinge2 = [x->one(eltype(x[1])) for j = 1:d]
new_addition = false
while num_terms <= n_max_terms
iters = 0
while num_terms <= n_max_terms || iters > maxiters
current_basis = copy(basis)
new_addition = false
for i = 1:n
Expand Down Expand Up @@ -231,7 +237,7 @@ function _forward_pass_nd(x,y,n_max_terms,rel_res_error)
end
end
end
#exit for
iters = iters+1
if new_addition
push!(basis,best_hinge1)
push!(basis,best_hinge2)
Expand All @@ -247,21 +253,70 @@ function _forward_pass_nd(x,y,n_max_terms,rel_res_error)
end


function _backward_pass_nd(x,y,n_min_terms,basis_after_forward,penalty,rel_GCV)

function _backward_pass_nd(x,y,n_min_terms,basis,penalty,rel_GCV)
n = length(x)
d = length(x[1])
base_len = length(basis)
intercept = sum([y[i] for i =1:length(y)])/length(y)
coeff = _coeff_nd(x,y,basis)
sse = zero(y[1])
for a = 1:n
val_a = sum(coeff[b]*prod([basis[b][c](x[a][c]) for c = 1:d]) for b = 1:base_len) + intercept
sse = sse + (y[a]-val_a)^2
end
effect_num_params = base_len + penalty*(base_len-1)/2
current_gcv = sse/(n*(1-effect_num_params/n)^2)
num_terms = base_len
new_gcv = +Inf
basis_to_remove = 0
while (num_terms > n_min_terms)
#Basis-> select worst performing element-> eliminate it
if num_terms <= 1
break
end
found_new_to_eliminate = false
for i = 1:num_terms
current_basis = copy(basis)
#remove i-esim element from current basis
deleteat!(current_basis,i)
coef = _coeff_nd(x,y,current_basis)
new_sse = zero(y[i])
current_base_len = num_terms-1
for a = 1:n
val_a = sum(coeff[b]*prod([current_basis[b][c](x[a][c]) for c = 1:d]) for b = 1:current_base_len) + intercept
new_sse = new_sse + (y[a]-val_a)^2
end
curr_effect_num_params = current_base_len + penalty*(current_base_len - 1)/2
i_gcv = new_sse/(n*(1-curr_effect_num_params/n)^2)
if i_gcv < current_gcv
basis_to_remove = i
new_gcv = i_gcv
found_new_to_eliminate = true
end
end
if !found_new_to_eliminate
break
elseif abs(current_gcv-new_gcv) < rel_GCV
break
else
num_terms = num_terms-1
deleteat!(basis,basis_to_remove)
end
end
return basis
end


function EarthSurrogate(x,y,lb,ub; penalty::Number = 2.0, n_min_terms::Int = 2, n_max_terms::Int = 10, rel_res_error::Number = 1e-2, rel_GCV::Number = 1e-2)
function EarthSurrogate(x,y,lb,ub; penalty::Number = 2.0, n_min_terms::Int = 2, n_max_terms::Int = 10, rel_res_error::Number = 1e-2, rel_GCV::Number = 1e-2, maxiters = 100)
intercept = sum([y[i] for i =1:length(y)])/length(y)
basis_after_forward = _forward_pass_nd(x,y,n_max_terms,rel_res_error)
#basis = _backward_pass_nd(x,y,n_min_terms,basis_after_forward,penalty,rel_GCV)
coeff = _coeff_nd(x,y,basis_after_forward)
return EarthSurrogate(x,y,lb,ub,basis_after_forward,coeff,penalty,n_min_terms,n_max_terms,rel_res_error,rel_GCV,intercept)
basis_after_forward = _forward_pass_nd(x,y,n_max_terms,rel_res_error,maxiters)
basis = _backward_pass_nd(x,y,n_min_terms,basis_after_forward,penalty,rel_GCV)
coeff = _coeff_nd(x,y,basis)
return EarthSurrogate(x,y,lb,ub,basis,coeff,penalty,n_min_terms,n_max_terms,rel_res_error,rel_GCV,intercept,maxiters)
end

function (earth::EarthSurrogate)(val)
return sum([earth.coeff[i]*prod([basis[i][j](val[j]) for j = 1:length(val)]) for i = 1:length(earth.coeff)]) + earth.intercept
return sum([earth.coeff[i]*prod([earth.basis[i][j](val[j]) for j = 1:length(val)]) for i = 1:length(earth.coeff)]) + earth.intercept
end


Expand All @@ -274,7 +329,7 @@ function add_point!(earth::EarthSurrogate,x_new,y_new)
earth.x = vcat(earth.x,x_new)
earth.y = vcat(earth.y,y_new)
earth.intercept = sum([earth.y[i] for i =1:length(earth.y)])/length(earth.y)
basis_after_forward = _forward_pass_1d(earth.x,earth.y,earth.n_max_terms,earth.rel_res_error)
basis_after_forward = _forward_pass_1d(earth.x,earth.y,earth.n_max_terms,earth.rel_res_error,earth.maxiters)
earth.basis = _backward_pass_1d(earth.x,earth.y,earth.n_min_terms,basis_after_forward,earth.penalty,earth.rel_GCV)
earth.coeff = _coeff_1d(earth.x,earth.y,earth.basis)
nothing
Expand All @@ -283,7 +338,7 @@ function add_point!(earth::EarthSurrogate,x_new,y_new)
earth.x = vcat(earth.x,x_new)
earth.y = vcat(earth.y,y_new)
earth.intercept = sum([earth.y[i] for i =1:length(earth.y)])/length(earth.y)
basis_after_forward = _forward_pass_nd(earth.x,earth.y,earth.n_max_terms,earth.rel_res_error)
basis_after_forward = _forward_pass_nd(earth.x,earth.y,earth.n_max_terms,earth.rel_res_error,earth.maxiters)
earth.basis = _backward_pass_nd(earth.x,earth.y,earth.n_min_terms,basis_after_forward,earth.penalty,earth.rel_GCV)
earth.coeff = _coeff_nd(earth.x,earth.y,earth.basis)
nothing
Expand Down
4 changes: 2 additions & 2 deletions test/earth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ x = sample(n,lb,ub,SobolSample())
f = x -> x[1]*x[2] + x[1]
y = f.(x)
my_earnd = EarthSurrogate(x,y,lb,ub)
#val = my_earnd((2.0,2.0))
#add_point!(my_earnd,(2.0,2.0),6.0)
val = my_earnd((2.0,2.0))
add_point!(my_earnd,(2.0,2.0),6.0)
3 changes: 2 additions & 1 deletion test/optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ y = obj_ND.(x)
my_second_order_poly_ND = SecondOrderPolynomialSurrogate(x,y,lb,ub)
surrogate_optimize(obj_ND,SRBF(),lb,ub,my_second_order_poly_ND,SobolSample(),maxiters=15)


obj_ND = x -> log(x[1])*exp(x[2])
x = sample(40,lb,ub,UniformSample())
y = obj_ND.(x)
Expand Down Expand Up @@ -233,6 +232,8 @@ surrogate_optimize(objective_function_ND,DYCORS(),lb,ub,my_rad_DYCORSN,UniformSa
my_wend_ND = Wendland(x,y,lb,ub)
surrogate_optimize(objective_function_ND,DYCORS(),lb,ub,my_wend_ND,UniformSample(),maxiters=30)

my_earthND = EarthSurrogate(x,y,lb,ub)
surrogate_optimize(obj_ND,SRBF(),lb,ub,my_earthND,SobolSample(),maxiters = 15)

### SOP ###
# 1D
Expand Down

0 comments on commit 02ef59d

Please sign in to comment.