Skip to content

Commit

Permalink
Merge 6464351 into 15f5ba3
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Nov 3, 2022
2 parents 15f5ba3 + 6464351 commit 5eee45f
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 1 deletion.
16 changes: 16 additions & 0 deletions src/Mutate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import ..MutationFunctionsModule:
insert_random_op,
delete_random_op,
crossover_trees
import ..ConstantOptimizationModule: optimize_constants
import ..RecorderModule: @recorder

# Go through one simulated options.annealing mutation cycle
Expand Down Expand Up @@ -131,6 +132,21 @@ function next_generation(
tree = gen_random_tree_fixed_size(tree_size_to_generate, options, nfeatures, T)
@recorder tmp_recorder["type"] = "regenerate"

is_success_always_possible = true
elseif mutation_choice == :optimize
cur_member = PopMember(
tree,
beforeScore,
beforeLoss;
parent=parent_ref,
deterministic=options.deterministic,
)
cur_member, new_num_evals = optimize_constants(dataset, cur_member, options)
num_evals += new_num_evals
@recorder tmp_recorder["type"] = "optimize"
mutation_accepted = true
return (cur_member, mutation_accepted, num_evals)

is_success_always_possible = true
elseif mutation_choice == :do_nothing
@recorder begin
Expand Down
5 changes: 5 additions & 0 deletions src/OptionsStruct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mutable struct MutationWeights
simplify::Float64
randomize::Float64
do_nothing::Float64
optimize::Float64
end

const mutations = [fieldnames(MutationWeights)...]
Expand All @@ -32,6 +33,9 @@ will be normalized to sum to 1.0 after initialization.
- `simplify::Float64`: How often to simplify the tree.
- `randomize::Float64`: How often to create a random tree.
- `do_nothing::Float64`: How often to do nothing.
- `optimize::Float64`: How often to optimize the constants in the tree, as a mutation.
Note that this is different from `optimizer_probability`, which is
performed at the end of an iteration for all individuals.
"""
@generated function MutationWeights(;
mutate_constant=0.048,
Expand All @@ -42,6 +46,7 @@ will be normalized to sum to 1.0 after initialization.
simplify=0.0020,
randomize=0.00023,
do_nothing=0.21,
optimize=0.0,
)
return :(MutationWeights($(mutations...)))
end
Expand Down
4 changes: 4 additions & 0 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -926,4 +926,8 @@ function _EquationSearch(
end
end

macro ignore(args...) end
# Hack to get static analysis to work from within tests:
@ignore include("../test/runtests.jl")

end #module SR
2 changes: 1 addition & 1 deletion test/test_deprecation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ options = Options(;
@test options.fraction_replaced_hof == 0.01f0
@test options.should_optimize_constants == true

options = Options(; mutationWeights=[1.0 for i in 1:8])
options = Options(; mutationWeights=[1.0 for i in 1:9])
@test options.mutation_weights.add_node == 1.0
41 changes: 41 additions & 0 deletions test/test_optimizer_mutation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using SymbolicRegression
using SymbolicRegression: SymbolicRegression
using SymbolicRegression: Dataset, RunningSearchStatistics, RecordType
using Optim: Optim
import SymbolicRegression.MutateModule: next_generation
import DynamicExpressions: get_constants
using Test

mutation_weights = MutationWeights(; optimize=Inf)
options = Options(;
binary_operators=(+, -, *),
unary_operators=(sin,),
mutation_weights=mutation_weights,
optimizer_options=Optim.Options(),
)

X = randn(5, 100)
y = sin.(X[1, :] .* 2.1 .+ 0.8) .+ X[2, :] .^ 2
dataset = Dataset(X, y)

x1 = Node(; feature=1)
x2 = Node(; feature=2)
tree = sin(x1 * 1.9 + 0.2) + x2 * x2

member = PopMember(dataset, tree, options; deterministic=false)
temperature = 1.0
maxsize = 20

new_member, _, _ = next_generation(
dataset,
member,
temperature,
maxsize,
RunningSearchStatistics(; options=options),
options;
tmp_recorder=RecordType(),
)

resultant_constants = get_constants(new_member.tree)
@test resultant_constants[1] 2.1 atol = 1e-3
@test sin(resultant_constants[2]) sin(0.8) atol = 1e-3
4 changes: 4 additions & 0 deletions test/unittest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,7 @@ end
@safetestset "Test deprecated options" begin
include("test_deprecation.jl")
end

@safetestset "Test optimization mutation" begin
include("test_optimizer_mutation.jl")
end

0 comments on commit 5eee45f

Please sign in to comment.