Skip to content

Commit

Permalink
Merge 318d8fa into 05c90f5
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienPascal authored Jan 5, 2019
2 parents 05c90f5 + 318d8fa commit 84cd335
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 191 deletions.
2 changes: 1 addition & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
julia 0.6
BlackBoxOptim
Optim
Optim 0.14.1 0.15.0
JLD2
Plots
CSV
Expand Down
2 changes: 1 addition & 1 deletion src/SMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ module SMM
export set_priors!, set_empirical_moments!, set_bbSetup!, generate_bbSearchRange
export create_lower_bound, create_upper_bound
export set_global_optimizer!
export create_grid, cartesian_grid, create_grid_stochastic, generate_std
export create_grid, create_grid_stochastic, generate_std

# Functions and types in save_load.jl
#------------------------------------
Expand Down
81 changes: 0 additions & 81 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,87 +217,6 @@ function create_upper_bound(sMMProblem::SMMProblem)
[sMMProblem.priors[k][3] for k in keys(sMMProblem.priors)]
end


"""
cartesian_grid(a::Array{Float64,1}, b::Array{Float64,1}, nums::Int64)
Function to create a regular cartesian grid. a is a vector of lower
bounds, b a vector of upper bounds and nums is the number of points
along each dimension. This function works for up to 15 dimensions.
Returns a NTuple.
"""
function cartesian_grid(a::Array{Float64,1}, b::Array{Float64,1}, nums::Int64)

#Safety checks
#-------------
if nums < 2
Base.error("The input nums should be >= 2. nums = $(nums).")
end

if length(a) != length(b)
Base.error("length(a) != length(b)")
end

nodes = [collect(linspace(a[i], b[i], nums)) for i in 1:length(a)]

# There is probably a better way of doing this:
#----------------------------------------------
if length(a) == 1
points = collect(Iterators.product(nodes[1]))
elseif length(a) == 2
points = collect(Iterators.product(nodes[1], nodes[2]))
elseif length(a) == 3
points = collect(Iterators.product(nodes[1], nodes[2], nodes[3]))
elseif length(a) == 4
points = collect(Iterators.product(nodes[1], nodes[2], nodes[3], nodes[4]))
elseif length(a) == 5
points = collect(Iterators.product(nodes[1], nodes[2], nodes[3], nodes[4],
nodes[5]))
elseif length(a) == 6
points = collect(Iterators.product(nodes[1], nodes[2], nodes[3], nodes[4],
nodes[5], nodes[6]))
elseif length(a) == 7
points = collect(Iterators.product(nodes[1], nodes[2], nodes[3], nodes[4],
nodes[5], nodes[6], nodes[7]))
elseif length(a) == 8
points = collect(Iterators.product(nodes[1], nodes[2], nodes[3], nodes[4],
nodes[5], nodes[6], nodes[7], nodes[8]))
elseif length(a) == 9
points = collect(Iterators.product(nodes[1], nodes[2], nodes[3], nodes[4],
nodes[5], nodes[6], nodes[7], nodes[8], nodes[9]))
elseif length(a) == 10
points = collect(Iterators.product(nodes[1], nodes[2], nodes[3], nodes[4],
nodes[5], nodes[6], nodes[7], nodes[8], nodes[9], nodes[10]))
elseif length(a) == 11
points = collect(Iterators.product(nodes[1], nodes[2], nodes[3], nodes[4],
nodes[5], nodes[6], nodes[7], nodes[8], nodes[9], nodes[10],
nodes[11]))
elseif length(a) == 12
points = collect(Iterators.product(nodes[1], nodes[2], nodes[3], nodes[4],
nodes[5], nodes[6], nodes[7], nodes[8], nodes[9], nodes[10],
nodes[11], nodes[12]))
elseif length(a) == 13
points = collect(Iterators.product(nodes[1], nodes[2], nodes[3], nodes[4],
nodes[5], nodes[6], nodes[7], nodes[8], nodes[9], nodes[10],
nodes[11], nodes[12], nodes[13]))
elseif length(a) == 14
points = collect(Iterators.product(nodes[1], nodes[2], nodes[3], nodes[4],
nodes[5], nodes[6], nodes[7], nodes[8], nodes[9], nodes[10],
nodes[11], nodes[12], nodes[13], nodes[14]))
elseif length(a) == 15
points = collect(Iterators.product(nodes[1], nodes[2], nodes[3], nodes[4],
nodes[5], nodes[6], nodes[7], nodes[8], nodes[9], nodes[10],
nodes[11], nodes[12], nodes[13], nodes[14], nodes[15]))
else

Base.error("cartesian_grid can take up to 15 dimensions")

end

return points

end

"""
create_grid(a::Array{Float64,1}, b::Array{Float64,1}, nums::Int64)
Expand Down
2 changes: 1 addition & 1 deletion src/save_load.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ file should have the following columns "name", "value", "upper_bound", "lower_bo
# store the number of rows in the dataframe:
#------------------------------------------
number_rows = size(dataFrame, 1)
info(string(number_rows, " prior(s) values found"))
info(string(number_rows, " prior value(s) found"))

# Append the dictionary
#----------------------
Expand Down
96 changes: 5 additions & 91 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct SMMOptions
saveSteps::Int64 #maximum number of steps
saveName::String #name under which the optimization should be saved
showDistance::Bool #show the distance, everytime the objective function is calculated?
minBox::Bool #When looking for a local maximum, use Fminbox ?
minBox::Bool #When looking for a local maximum, use Fminbox ?
end

function SMMOptions( ;globalOptimizer::Symbol=:dxnes,
Expand Down Expand Up @@ -215,58 +215,13 @@ function is_optim_optimizer(s::Symbol)

end


"""
convert_to_optim_algo(s::Symbol)
function to convert local optimizer (of type Symbol) to an Optim algo.
"""
function convert_to_optim_algo(s::Symbol)

if s == :NelderMead

output = NelderMead()

elseif s == :SimulatedAnnealing

output = SimulatedAnnealing()

elseif s == :ParticleSwarm

output = ParticleSwarm()

elseif s == :BFGS

output = BFGS()

elseif s == :LBFGS

output = LBFGS()

elseif s == :ConjugateGradient

output = ConjugateGradient()

elseif s == :GradientDescent

output = GradientDescent()

elseif s == :MomentumGradientDescent

output = MomentumGradientDescent()

elseif s == :AcceleratedGradientDescent

output = AcceleratedGradientDescent()

else

Base.error("$(s) is not in the list of algorithm supported by Optim.")

end

return output

eval(Meta.parse("$(s)()"))
end

"""
Expand All @@ -277,49 +232,8 @@ by Optim.
"""
function convert_to_fminbox(s::Symbol)


if s == :NelderMead

output = Fminbox{NelderMead}()

elseif s == :SimulatedAnnealing

output = Fminbox{SimulatedAnnealing}()

elseif s == :ParticleSwarm

output = Fminbox{ParticleSwarm}()

elseif s == :BFGS

output = Fminbox{BFGS}()

elseif s == :LBFGS

output = Fminbox{LBFGS}()

elseif s == :ConjugateGradient

output = Fminbox{ConjugateGradient}()

elseif s == :GradientDescent

output = Fminbox{GradientDescent}()

elseif s == :MomentumGradientDescent

output = Fminbox{MomentumGradientDescent}()

elseif s == :AcceleratedGradientDescent

output = Fminbox{AcceleratedGradientDescent}()

else

Base.error("$(s) is not in the list of algorithm supported by Optim.")

end

return output
# Old API (before v0.15.0)
# To be changed when switching to Julia v0.7
eval(Meta.parse("Fminbox{$(s)}()"))

end
Loading

0 comments on commit 84cd335

Please sign in to comment.