Skip to content

Commit

Permalink
ND forward tested
Browse files Browse the repository at this point in the history
  • Loading branch information
ludoro committed Jul 9, 2020
1 parent ddf63fe commit ced0ea4
Showing 1 changed file with 75 additions and 18 deletions.
93 changes: 75 additions & 18 deletions src/Earth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ mutable struct EarthSurrogate{X,Y,L,U,B,C,P,M,N,R,G,I} <: AbstractSurrogate
best_hinge2 = z-> _hinge_mirror(z,x[pos_of_knot])
push!(basis,best_hinge1)
push!(basis,best_hinge2)
num_terms = num_terms + 1
else
break
end
Expand Down Expand Up @@ -162,49 +163,105 @@ _product_hinge(val,arr_hing) = prod([arr_hing[i](val[i]) for i = 1:length(val)])

function _coeff_nd(x,y,basis)
n = length(x)
d = length(basis)
X = zeros(eltype(x[1]),n,d)
@inbounds for i = 1:n
for j = 1:d
X[i,j] = basis[j](x[i])
base_len = length(basis)
d = length(x[1])
X = zeros(eltype(x[1]),n,base_len)
@inbounds for a = 1:n
for b = 1:base_len
X[a,b] = prod([basis[b][c](x[a][c]) for c = 1:d])
end
end
return (X'*X)\(X'*y)
end

function _forward_pass_nd(x,y,n_max_terms,rel_res_error)
n = length(x)
max_interactions = 2
basis = Array{Array{Function,1}}(undef,0) #ex of basis: push!(x,[x->1.0,x->1.0,x->_hinge(x,knot)])

basis = Array{Array{Function,1}}(undef,0) #ex of basis: push!(x,[x->1.0,x->1.0,x->_hinge(x,knot)]) so basis is of the form arr_hing and then I can pass my val
current_sse = +Inf
const_1 = x->1.0
intercept = sum([y[i] for i =1:length(y)])/length(y)
num_terms = 0
#3D: x y z, fix knot of x, check if you want to add it alone, or in combination with one node of y or one node of z, do this for every node
#So first add all nodes that are alone
d = length(x[1])
best_hinge1 = best_hinge2 = [x->one(eltype(x[1])) for j = 1:d]
new_addition = false
while num_terms <= n_max_terms
current_basis = copy(basis)
new_addition = false
for i = 1:n
for j = 1:d
for k = 1:n
for l = 1:d
new_basis = copy(basis)
#model with interaction between x[i][j] and x[k][l]
new_hinge1 = [w != j ? x->one(eltype(x[1])) : z->_hinge(z,x[i][j]) for w = 1:d]
new_hinge2 = [w != l ? x->one(eltype(x[1])) : z->_hinge_mirror(z,x[k][l]) for w = 1:d]
push!(new_basis,new_hinge1)
push!(new_basis,new_hinge2)

#build the model, find sse and check if it's better than current best, if so save the params
bas_len = length(new_basis)
X = zeros(eltype(x[1]),n,bas_len)
@inbounds for a = 1:n
for b = 1:bas_len
X[a,b] = prod([new_basis[b][c](x[a][c]) for c = 1:d])
end
end
if (cond(X'*X) > 1e8)
condition_number = false
new_sse = +Inf
else
condition_number = true
coeff = (X'*X)\(X'*y)
new_sse = zero(y[1])
for a = 1:n
val_a = sum(coeff[b]*prod([new_basis[b][c](x[a][c]) for c = 1:d]) for b = 1:bas_len) + intercept
new_sse = new_sse + (y[a]-val_a)^2
end
end
#is the i-esim the best?
if ( (new_sse < current_sse) && (abs(current_sse - new_sse) >= rel_res_error) && condition_number)
#Add the hinge function to the basis
best_hinge1 = new_hinge1
best_hinge2 = new_hinge2
current_sse = new_sse
new_addition = true
end
end

#Then check 2 by 2 interactions
end
end
end
#exit for
if new_addition
push!(basis,best_hinge1)
push!(basis,best_hinge2)
num_terms = num_terms + 1
else
break
end
end
if (length(basis) ==0)
throw("Earth surrogate did not add any term, just the intercept. It is advised to double check the parameters.")
end
return basis
end


function _backward_pass_nd()
#remove similarly to 1d

function _backward_pass_nd(x,y,n_min_terms,basis_after_forward,penalty,rel_GCV)

end


function EarthSurrogate(x,y,lb,ub; penalty::Number = 2.0, n_min_terms::Int = 2, n_max_terms::Int = 10, rel_res_error::Number = 1e-2, rel_GCV::Number = 1e-2)
intercept = sum([y[i] for i =1:length(y)])/length(y)
basis_after_forward = _forward_pass_nd(x,y,n_max_terms,rel_res_error)
basis = _backward_pass_nd(x,y,n_min_terms,basis_after_forward,penalty,rel_GCV)
coeff = _coeff_nd(x,y,basis)
return EarthSurrogate(x,y,lb,ub,basis,coeff,penalty,n_min_terms,n_max_terms,rel_res_error,rel_GCV,intercept)
#basis = _backward_pass_nd(x,y,n_min_terms,basis_after_forward,penalty,rel_GCV)
coeff = _coeff_nd(x,y,basis_after_forward)
return EarthSurrogate(x,y,lb,ub,basis_after_forward,coeff,penalty,n_min_terms,n_max_terms,rel_res_error,rel_GCV,intercept)
end

function (earth::EarthSurrogate)(val)
return sum([earth.coeff[i]*earth.basis[i](val) for i = 1:length(earth.coeff)]) + earth.intercept
return sum([earth.coeff[i]*prod([basis[i][j](val[j]) for j = 1:length(val)]) for i = 1:length(earth.coeff)]) + earth.intercept
end


Expand Down

0 comments on commit ced0ea4

Please sign in to comment.