Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split codebase in two: DynamicExpressions.jl and SymbolicRegression.jl #147

Merged
merged 53 commits into from Oct 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
9a17001
Remove deprecated functions
MilesCranmer Oct 21, 2022
0382d43
Complete split to DynamicExpressions.jl
MilesCranmer Oct 21, 2022
03499c4
Include other missing functions
MilesCranmer Oct 21, 2022
ff17ca2
Fix formatting issues
MilesCranmer Oct 21, 2022
731aa10
Fix implementation of `node_to_symbolic`
MilesCranmer Oct 21, 2022
4e6a356
Remove test modifications
MilesCranmer Oct 21, 2022
0a21f5d
Extend `print_tree` and `string_tree` to varMap
MilesCranmer Oct 21, 2022
098e0bc
Fix test for `eval_diff_tree_array`
MilesCranmer Oct 21, 2022
8301f36
Remove unused libraries
MilesCranmer Oct 21, 2022
a95596d
Fix removed imports in test
MilesCranmer Oct 21, 2022
c6598df
Allow user to turn off user operators/helper functions
MilesCranmer Oct 22, 2022
71d9009
Bump DynamicExpressions version with new options
MilesCranmer Oct 22, 2022
c3d58df
Clean up tournament sampling
MilesCranmer Oct 22, 2022
df936de
Add note to options
MilesCranmer Oct 22, 2022
c1a1114
Use separate function for weighted loss
MilesCranmer Oct 22, 2022
221c379
Bump backend version with type stability
MilesCranmer Oct 22, 2022
317a9a9
Fix potential for bug in adaptive parsimony
MilesCranmer Oct 22, 2022
20c788a
Clean up type instability
MilesCranmer Oct 22, 2022
83c0a44
Fix other type instability
MilesCranmer Oct 22, 2022
ffdf103
Faster sampling of population
MilesCranmer Oct 22, 2022
49c9d7f
Speed up other parts of evaluation
MilesCranmer Oct 22, 2022
fa0b656
Remove unused parameters
MilesCranmer Oct 22, 2022
46bd4eb
Fix parameter name
MilesCranmer Oct 22, 2022
f27a3f4
Use `parallelism` parameter instead of `multithreading`
MilesCranmer Oct 22, 2022
db2dd23
Update tests to use `parallelism` argument
MilesCranmer Oct 22, 2022
bd4374d
Clean up formatting
MilesCranmer Oct 23, 2022
c3b0987
Update example with new interface
MilesCranmer Oct 23, 2022
df7e9fd
Set up multithreading in CI test
MilesCranmer Oct 23, 2022
397d8a2
Update AdaptiveParsimony.jl
MilesCranmer Oct 23, 2022
67abc49
Clean up `move_functions_to_workers`
MilesCranmer Oct 23, 2022
2ebe97e
Clean up `LossFunctions.jl`
MilesCranmer Oct 23, 2022
621b5ca
Fix formatting issue
MilesCranmer Oct 23, 2022
2ceb369
Fix loss tests
MilesCranmer Oct 23, 2022
497d469
Add error for end of if statement
MilesCranmer Oct 23, 2022
d2a2f11
Allow `binary_operators` to be vector instead of tuple
MilesCranmer Oct 23, 2022
cc153dd
Add docs for helper function argument
MilesCranmer Oct 23, 2022
644dea1
Clean up unit tests
MilesCranmer Oct 23, 2022
b7f0527
Fix error in `move_functions_to_workers`
MilesCranmer Oct 23, 2022
97632f9
Bump version with codebase splitting
MilesCranmer Oct 23, 2022
5531e10
Lower number of processes to pass windows tests
MilesCranmer Oct 23, 2022
c842879
Always check bounds in testing
MilesCranmer Oct 23, 2022
27bbf40
Remove unneccessary inbounds
MilesCranmer Oct 23, 2022
eee62ef
Clean up full pipeline test
MilesCranmer Oct 23, 2022
d90134f
Greater guarantee that we can find multi-output equations
MilesCranmer Oct 23, 2022
6afc628
Clean up state saving
MilesCranmer Oct 23, 2022
a9761f6
Clean up pop member copies
MilesCranmer Oct 23, 2022
29d14bd
Clean up order of options
MilesCranmer Oct 23, 2022
4520189
Bump DynamicExpressions.jl with equality operator
MilesCranmer Oct 23, 2022
b754ed1
Refactor migration into module
MilesCranmer Oct 23, 2022
ed9079d
Add compat for StatsBase
MilesCranmer Oct 23, 2022
e55a6af
Allow errors on windows
MilesCranmer Oct 23, 2022
4f3fa1d
Bump julia versions
MilesCranmer Oct 23, 2022
c48d033
Update actions
MilesCranmer Oct 23, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 19 additions & 7 deletions .github/workflows/CI.yml
Expand Up @@ -27,22 +27,22 @@ jobs:
fail-fast: false
matrix:
julia-version:
- '1.6.0'
- '1.7.1'
- '1.8.0'
- '1.6.7'
- '1.7.3'
- '1.8.2'
os:
- ubuntu-latest
- windows-latest
- macOS-latest

steps:
- uses: actions/checkout@v1.0.0
- uses: actions/checkout@v3
- name: "Set up Julia"
uses: julia-actions/setup-julia@v1.6.0
uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.julia-version }}
- name: Cache dependencies
uses: actions/cache@v1 # Thanks FromFile.jl
uses: actions/cache@v3
env:
cache-name: cache-artifacts
with:
Expand All @@ -57,9 +57,21 @@ jobs:
- name: "Run tests"
run: |
julia --color=yes --project=. -e 'import Pkg; Pkg.add("Coverage")'
julia --color=yes --inline=yes --depwarn=yes --code-coverage=user --project=. -e 'import Pkg; Pkg.test(coverage=true)'
julia --color=yes --threads=auto --check-bounds=yes --depwarn=yes --code-coverage=user --project=. -e 'import Pkg; Pkg.test(coverage=true)'
julia --color=yes --project=. coverage.jl
shell: bash
if: ${{ matrix.os != 'windows-latest' }}
- name: "Run tests, skipping errors."
run: |
julia --color=yes --project=. -e 'import Pkg; Pkg.add("Coverage")'
{
julia --color=yes --threads=auto --check-bounds=yes --depwarn=yes --code-coverage=user --project=. -e 'import Pkg; Pkg.test(coverage=true)'
} || {
echo "Tests failed, but continuing anyway."
}
julia --color=yes --project=. coverage.jl
shell: bash
if: ${{ matrix.os == 'windows-latest' }}
- name: Coveralls
uses: coverallsapp/github-action@master
with:
Expand Down
15 changes: 8 additions & 7 deletions Project.toml
@@ -1,42 +1,43 @@
name = "SymbolicRegression"
uuid = "8254be44-1295-4e6a-a16d-46603ac705cb"
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
version = "0.12.6"
version = "0.13.0"

[deps]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
DynamicExpressions = "0.3.0"
JSON3 = "1"
LineSearches = "7"
LossFunctions = "0.6, 0.7, 0.8"
Optim = "0.19, 1.1"
Pkg = "1"
PreallocationTools = "< 0.4.2"
Reexport = "1"
SpecialFunctions = "0.10.1, 1, 2"
StatsBase = "0.33"
SymbolicUtils = "0.19"
Zygote = "0.6"
julia = "1.6"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "SafeTestsets", "ForwardDiff"]
test = ["Test", "SafeTestsets", "ForwardDiff", "LinearAlgebra", "Zygote"]
7 changes: 5 additions & 2 deletions README.md
Expand Up @@ -33,7 +33,7 @@ a 2D array (shape [features, rows]) and attempts
to model a 1D array (shape [rows])
using analytic functional forms.

Run distributed on four processes with:
Run with:
```julia
using SymbolicRegression

Expand All @@ -46,7 +46,10 @@ options = SymbolicRegression.Options(
npopulations=20
)

hall_of_fame = EquationSearch(X, y, niterations=40, options=options, numprocs=4)
hall_of_fame = EquationSearch(
X, y, niterations=40, options=options,
parallelism=:multithreading
)
```
You can view the resultant equations in the dominating Pareto front (best expression
seen at each complexity) with:
Expand Down
4 changes: 3 additions & 1 deletion example.jl
Expand Up @@ -7,7 +7,9 @@ options = SymbolicRegression.Options(;
binary_operators=(+, *, /, -), unary_operators=(cos, exp), npopulations=20
)

hall_of_fame = EquationSearch(X, y; niterations=40, options=options, numprocs=4)
hall_of_fame = EquationSearch(
X, y; niterations=40, options=options, parallelism=:multithreading
)

dominating = calculate_pareto_frontier(X, y, hall_of_fame, options)

Expand Down
5 changes: 3 additions & 2 deletions src/CheckConstraints.jl
@@ -1,8 +1,9 @@
module CheckConstraintsModule

import DynamicExpressions: Node
import ..UtilsModule: vals
import ..CoreModule: Node, Options
import ..EquationUtilsModule: compute_complexity
import ..CoreModule: Options
import ..ComplexityModule: compute_complexity

# Check if any binary operator are overly complex
function flag_bin_operator_complexity(
Expand Down
44 changes: 44 additions & 0 deletions src/Complexity.jl
@@ -0,0 +1,44 @@
module ComplexityModule

import DynamicExpressions: Node, count_nodes
import ..CoreModule: Options

"""
Compute the complexity of a tree.

By default, this is the number of nodes in a tree.
However, it could use the custom settings in options.complexity_mapping
if these are defined.
"""
function compute_complexity(tree::Node, options::Options)::Int
if options.complexity_mapping.use
return round(Int, _compute_complexity(tree, options))
else
return count_nodes(tree)
end
end

function _compute_complexity(
tree::Node, options::Options{C,complexity_type}
)::complexity_type where {C,complexity_type<:Real}
if tree.degree == 0
if tree.constant
return options.complexity_mapping.constant_complexity
else
return options.complexity_mapping.variable_complexity
end
elseif tree.degree == 1
return (
options.complexity_mapping.unaop_complexities[tree.op] +
_compute_complexity(tree.l, options)
)
else # tree.degree == 2
return (
options.complexity_mapping.binop_complexities[tree.op] +
_compute_complexity(tree.l, options) +
_compute_complexity(tree.r, options)
)
end
end

end
66 changes: 41 additions & 25 deletions src/Configure.jl
@@ -1,15 +1,17 @@
const TEST_TYPE = Float32

function assert_operators_defined_over_reals(T, options::Options)
test_input = map(x -> convert(T, x), LinRange(-100, 100, 99))
cur_op = nothing
try
for left in test_input
for right in test_input
for binop in options.binops
for binop in options.operators.binops
cur_op = binop
test_output = binop.(left, right)
end
end
for unaop in options.unaops
for unaop in options.operators.unaops
cur_op = unaop
test_output = unaop.(left)
end
Expand All @@ -25,7 +27,7 @@ end

# Check for errors before they happen
function test_option_configuration(T, options::Options)
for op in (options.binops..., options.unaops...)
for op in (options.operators.binops..., options.operators.unaops...)
if is_anonymous_function(op)
throw(
AssertionError(
Expand All @@ -37,14 +39,13 @@ function test_option_configuration(T, options::Options)

assert_operators_defined_over_reals(T, options)

for binop in options.binops
if binop in options.unaops
throw(
AssertionError(
"Your configuration is invalid - one operator ($binop) appears in both the binary operators and unary operators.",
),
)
end
operator_intersection = intersect(options.operators.binops, options.operators.unaops)
if length(operator_intersection) > 0
throw(
AssertionError(
"Your configuration is invalid - $(operator_intersection) appear in both the binary operators and unary operators.",
),
)
end
end

Expand Down Expand Up @@ -83,37 +84,52 @@ end

""" Move custom operators and loss functions to workers, if undefined """
function move_functions_to_workers(procs, options::Options, dataset::Dataset{T}) where {T}
for function_set in 1:6
if function_set == 1
ops = options.unaops
enable_autodiff =
:diff_binops in fieldnames(typeof(options.operators)) &&
:diff_unaops in fieldnames(typeof(options.operators)) &&
(
options.operators.diff_binops !== nothing ||
options.operators.diff_unaops !== nothing
)

# All the types of functions we need to move to workers:
function_sets = (
:unaops, :binops, :diff_unaops, :diff_binops, :loss, :early_stop_condition
)

for function_set in function_sets
if function_set == :unaops
ops = options.operators.unaops
nargs = 1
elseif function_set == 2
ops = options.binops
elseif function_set == :binops
ops = options.operators.binops
nargs = 2
elseif function_set == 3
if !options.enable_autodiff
elseif function_set == :diff_unaops
if !enable_autodiff
continue
end
ops = options.diff_unaops
ops = options.operators.diff_unaops
nargs = 1
elseif function_set == 4
if !options.enable_autodiff
elseif function_set == :diff_binops
if !enable_autodiff
continue
end
ops = options.diff_binops
ops = options.operators.diff_binops
nargs = 2
elseif function_set == 5
elseif function_set == :loss
if typeof(options.loss) <: SupervisedLoss
continue
end
ops = (options.loss,)
nargs = dataset.weighted ? 3 : 2
elseif function_set == 6
elseif function_set == :early_stop_condition
if !(typeof(options.earlyStopCondition) <: Function)
continue
end
ops = (options.earlyStopCondition,)
nargs = 2
else
error("Invalid function set: $function_set")
end
for op in ops
try
Expand Down Expand Up @@ -211,7 +227,7 @@ function test_module_on_workers(procs, options::Options)
for proc in procs
push!(
futures,
@spawnat proc SymbolicRegression.gen_random_tree(3, options, 5, CONST_TYPE)
@spawnat proc SymbolicRegression.gen_random_tree(3, options, 5, TEST_TYPE)
)
end
for future in futures
Expand Down
4 changes: 2 additions & 2 deletions src/ConstantOptimization.jl
Expand Up @@ -2,9 +2,9 @@ module ConstantOptimizationModule

using LineSearches: LineSearches
using Optim: Optim
import ..CoreModule: Node, Options, Dataset
import DynamicExpressions: Node, get_constants, set_constants, count_constants
import ..CoreModule: Options, Dataset
import ..UtilsModule: get_birth_order
import ..EquationUtilsModule: get_constants, set_constants, count_constants
import ..LossFunctionsModule: score_func, eval_loss
import ..PopMemberModule: PopMember

Expand Down
3 changes: 0 additions & 3 deletions src/Core.jl
Expand Up @@ -4,12 +4,10 @@ include("Utils.jl")
include("ProgramConstants.jl")
include("Dataset.jl")
include("OptionsStruct.jl")
include("Equation.jl")
include("Operators.jl")
include("Options.jl")

import .ProgramConstantsModule:
CONST_TYPE,
MAX_DEGREE,
BATCH_DIM,
FEATURE_DIM,
Expand All @@ -20,7 +18,6 @@ import .ProgramConstantsModule:
SRDistributed
import .DatasetModule: Dataset
import .OptionsStructModule: Options
import .EquationModule: Node, copy_node, set_node!, string_tree, print_tree
import .OptionsModule: Options
import .OperatorsModule:
plus,
Expand Down