Skip to content

Commit

Permalink
Merge c9af26c into 6a7f042
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Oct 24, 2022
2 parents 6a7f042 + c9af26c commit e4f9e57
Show file tree
Hide file tree
Showing 20 changed files with 365 additions and 223 deletions.
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ EquationSearch(X::AbstractMatrix{T}, y::AbstractMatrix{T};

```@docs
Options(;)
MutationWeights(;)
```

## Printing
Expand Down Expand Up @@ -56,7 +57,6 @@ node_to_symbolic(tree::Node, options::Options;

## Pareto frontier


```@docs
calculate_pareto_frontier(X::AbstractMatrix{T}, y::AbstractVector{T},
hallOfFame::HallOfFame{T}, options::Options;
Expand Down
4 changes: 2 additions & 2 deletions src/Configure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ function move_functions_to_workers(procs, options::Options, dataset::Dataset{T})
ops = (options.loss,)
nargs = dataset.weighted ? 3 : 2
elseif function_set == :early_stop_condition
if !(typeof(options.earlyStopCondition) <: Function)
if !(typeof(options.early_stop_condition) <: Function)
continue
end
ops = (options.earlyStopCondition,)
ops = (options.early_stop_condition,)
nargs = 2
else
error("Invalid function set: $function_set")
Expand Down
2 changes: 1 addition & 1 deletion src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import .ProgramConstantsModule:
SRThreaded,
SRDistributed
import .DatasetModule: Dataset
import .OptionsStructModule: Options
import .OptionsStructModule: Options, MutationWeights, sample_mutation
import .OptionsModule: Options
import .OperatorsModule:
plus,
Expand Down
2 changes: 1 addition & 1 deletion src/LossFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ function score_func_batch(
dataset::Dataset{T}, tree::Node{T}, options::Options
)::Tuple{T,T} where {T<:Real}
# TODO: Use StatsBase.sample here.
batch_idx = randperm(dataset.n)[1:(options.batchSize)]
batch_idx = randperm(dataset.n)[1:(options.batch_size)]
batch_X = dataset.X[:, batch_idx]
batch_y = dataset.y[batch_idx]
(prediction, completion) = eval_tree_array(tree, batch_X, options.operators)
Expand Down
54 changes: 25 additions & 29 deletions src/Mutate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module MutateModule

import DynamicExpressions:
Node, copy_node, count_constants, count_depth, simplify_tree, combine_operators
import ..CoreModule: Options, Dataset, RecordType
import ..CoreModule: Options, Dataset, RecordType, sample_mutation
import ..ComplexityModule: compute_complexity
import ..LossFunctionsModule: score_func, score_func_batch
import ..CheckConstraintsModule: check_constraints
Expand Down Expand Up @@ -39,29 +39,28 @@ function next_generation(
#TODO - reconsider this
if options.batching
beforeScore, beforeLoss = score_func_batch(dataset, prev, options)
num_evals += (options.batchSize / dataset.n)
num_evals += (options.batch_size / dataset.n)
else
beforeScore = member.score
beforeLoss = member.loss
end

nfeatures = dataset.nfeatures

mutationChoice = rand()
weights = copy(options.mutation_weights)

#More constants => more likely to do constant mutation
weightAdjustmentMutateConstant = min(8, count_constants(prev)) / 8.0
cur_weights = copy(options.mutationWeights) .* 1.0
cur_weights[1] *= weightAdjustmentMutateConstant
weights.mutate_constant *= min(8, count_constants(prev)) / 8.0
n = compute_complexity(prev, options)
depth = count_depth(prev)

# If equation too big, don't add new operators
if n >= curmaxsize || depth >= options.maxdepth
cur_weights[3] = 0.0
cur_weights[4] = 0.0
weights.add_node = 0.0
weights.insert_node = 0.0
end
cur_weights /= sum(cur_weights)
cweights = cumsum(cur_weights)

mutation_choice = sample_mutation(weights)

successful_mutation = false
#TODO: Currently we dont take this \/ into account
Expand All @@ -75,22 +74,18 @@ function next_generation(
while (!successful_mutation) && attempts < max_attempts
tree = copy_node(prev)
successful_mutation = true
if mutationChoice < cweights[1]
if mutation_choice == :mutate_constant
tree = mutate_constant(tree, temperature, options)
@recorder tmp_recorder["type"] = "constant"

is_success_always_possible = true
# Mutating a constant shouldn't invalidate an already-valid function

elseif mutationChoice < cweights[2]
elseif mutation_choice == :mutate_operator
tree = mutate_operator(tree, options)

@recorder tmp_recorder["type"] = "operator"

is_success_always_possible = true
# Can always mutate to the same operator

elseif mutationChoice < cweights[3]
elseif mutation_choice == :add_node
if rand() < 0.5
tree = append_random_op(tree, options, nfeatures)
@recorder tmp_recorder["type"] = "append_op"
Expand All @@ -100,17 +95,17 @@ function next_generation(
end
is_success_always_possible = false
# Can potentially have a situation without success
elseif mutationChoice < cweights[4]
elseif mutation_choice == :insert_node
tree = insert_random_op(tree, options, nfeatures)
@recorder tmp_recorder["type"] = "insert_op"
is_success_always_possible = false
elseif mutationChoice < cweights[5]
elseif mutation_choice == :delete_node
tree = delete_random_op(tree, options, nfeatures)
@recorder tmp_recorder["type"] = "delete_op"
is_success_always_possible = true
elseif mutationChoice < cweights[6]
tree = simplify_tree(tree, options.operators) # Sometimes we simplify tree
tree = combine_operators(tree, options.operators) # See if repeated constants at outer levels
elseif mutation_choice == :simplify
tree = simplify_tree(tree, options.operators)
tree = combine_operators(tree, options.operators)
@recorder tmp_recorder["type"] = "partial_simplify"
mutation_accepted = true
return (
Expand All @@ -129,16 +124,15 @@ function next_generation(
# Simplification shouldn't hurt complexity; unless some non-symmetric constraint
# to commutative operator...

elseif mutationChoice < cweights[7]
# Sometimes we generate a new tree completely tree
elseif mutation_choice == :randomize
# We select a random size, though the generated tree
# may have fewer nodes than we request.
tree_size_to_generate = rand(1:curmaxsize)
tree = gen_random_tree_fixed_size(tree_size_to_generate, options, nfeatures, T)
@recorder tmp_recorder["type"] = "regenerate"

is_success_always_possible = true
else # no mutation applied
elseif mutation_choice == :do_nothing
@recorder begin
tmp_recorder["type"] = "identity"
tmp_recorder["result"] = "accept"
Expand All @@ -156,6 +150,8 @@ function next_generation(
mutation_accepted,
num_evals,
)
else
error("Unknown mutation choice: $mutation_choice")
end

successful_mutation =
Expand Down Expand Up @@ -186,7 +182,7 @@ function next_generation(

if options.batching
afterScore, afterLoss = score_func_batch(dataset, tree, options)
num_evals += (options.batchSize / dataset.n)
num_evals += (options.batch_size / dataset.n)
else
afterScore, afterLoss = score_func(dataset, tree, options)
num_evals += 1
Expand Down Expand Up @@ -216,7 +212,7 @@ function next_generation(
delta = afterScore - beforeScore
probChange *= exp(-delta / (temperature * options.alpha))
end
if options.useFrequency
if options.use_frequency
oldSize = compute_complexity(prev, options)
newSize = compute_complexity(tree, options)
old_frequency = if (oldSize <= options.maxsize)
Expand Down Expand Up @@ -302,11 +298,11 @@ function crossover_generation(
if options.batching
afterScore1, afterLoss1 = score_func_batch(dataset, child_tree1, options)
afterScore2, afterLoss2 = score_func_batch(dataset, child_tree2, options)
num_evals += 2 * (options.batchSize / dataset.n)
num_evals += 2 * (options.batch_size / dataset.n)
else
afterScore1, afterLoss1 = score_func(dataset, child_tree1, options)
afterScore2, afterLoss2 = score_func(dataset, child_tree2, options)
num_evals += options.batchSize / dataset.n
num_evals += options.batch_size / dataset.n
end

baby1 = PopMember(
Expand Down
4 changes: 2 additions & 2 deletions src/MutationFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function mutate_constant(
end

bottom = 1//10
maxChange = T(options.perturbationFactor) * temperature + T(1 + bottom)
maxChange = T(options.perturbation_factor) * temperature + T(1 + bottom)
factor = maxChange^rand(T)
makeConstBigger = rand() > 0.5

Expand All @@ -71,7 +71,7 @@ function mutate_constant(
node.val::T /= factor
end

if rand() > options.probNegate
if rand() > options.probability_negate_constant
node.val::T *= -1
end

Expand Down
Loading

0 comments on commit e4f9e57

Please sign in to comment.