Skip to content

Commit

Permalink
Merge df91694 into 93a6764
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Feb 13, 2023
2 parents 93a6764 + df91694 commit 16124a6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
8 changes: 3 additions & 5 deletions src/CheckConstraints.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module CheckConstraintsModule

import DynamicExpressions: Node
import DynamicExpressions: Node, count_depth
import ..UtilsModule: vals
import ..CoreModule: Options
import ..ComplexityModule: compute_complexity
Expand Down Expand Up @@ -140,10 +140,8 @@ end

"""Check if user-passed constraints are violated or not"""
function check_constraints(tree::Node, options::Options, maxsize::Int)::Bool
size = compute_complexity(tree, options)
if size > maxsize
return false
end
compute_complexity(tree, options) > maxsize && return false
count_depth(tree) > options.maxdepth && return false
for i in 1:(options.nbin)
if options.bin_constraints[i] == (-1, -1)
continue
Expand Down
5 changes: 2 additions & 3 deletions src/Mutate.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module MutateModule

import DynamicExpressions:
Node, copy_node, count_constants, count_depth, simplify_tree, combine_operators
Node, copy_node, count_constants, simplify_tree, combine_operators
import ..CoreModule: Options, Dataset, RecordType, sample_mutation
import ..ComplexityModule: compute_complexity
import ..LossFunctionsModule: score_func, score_func_batch
Expand Down Expand Up @@ -53,10 +53,9 @@ function next_generation(
#More constants => more likely to do constant mutation
weights.mutate_constant *= min(8, count_constants(prev)) / 8.0
n = compute_complexity(prev, options)
depth = count_depth(prev)

# If equation too big, don't add new operators
if n >= curmaxsize || depth >= options.maxdepth
if n >= curmaxsize
weights.add_node = 0.0
weights.insert_node = 0.0
end
Expand Down
22 changes: 22 additions & 0 deletions test/test_constraints.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import DynamicExpressions: count_depth
using SymbolicRegression
using SymbolicRegression: check_constraints
using Test
Expand Down Expand Up @@ -32,3 +33,24 @@ options = Options(; binary_operators=(+, *), maxsize=5, complexity_of_operators=
@test check_constraints(tree, options) == false
options = Options(; binary_operators=(+, *), maxsize=5, complexity_of_operators=[(*) => 0])
@test check_constraints(violating_tree, options) == true

# Test for depth constraints:
options = Options(; binary_operators=(+, *), unary_operators=(cos,), maxsize=100, maxdepth=3)
@extend_operators options
x1, x2, x3 = [Node(; feature=i) for i in 1:3]

tree = (x1 + x2) + (x3 + x1)
@test count_depth(tree) == 3
@test check_constraints(tree, options) == true

tree = (x1 + x2) + (x3 + x1) * x1
@test count_depth(tree) == 4
@test check_constraints(tree, options) == false

tree = cos(cos(x1))
@test count_depth(tree) == 3
@test check_constraints(tree, options) == true

tree = cos(cos(cos(x1)))
@test count_depth(tree) == 4
@test check_constraints(tree, options) == false

0 comments on commit 16124a6

Please sign in to comment.