Skip to content

Commit

Permalink
Merge aa7d08b into 9857863
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Nov 28, 2022
2 parents 9857863 + aa7d08b commit b0368c8
Show file tree
Hide file tree
Showing 18 changed files with 184 additions and 102 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ build
*.csv.out*
pysr_recorder.json
docs/src/index.md
*.code-workspace
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SymbolicRegression"
uuid = "8254be44-1295-4e6a-a16d-46603ac705cb"
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
version = "0.14.4"
version = "0.14.5"

[deps]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -15,20 +15,22 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[compat]
DynamicExpressions = "0.4"
DynamicExpressions = "0.4.2"
JSON3 = "1"
LineSearches = "7"
LossFunctions = "0.6, 0.7, 0.8"
Optim = "0.19, 1.1"
Pkg = "1"
ProgressBars = "1.4"
Reexport = "1"
SnoopPrecompile = "1"
SpecialFunctions = "0.10.1, 1, 2"
StatsBase = "0.33"
SymbolicUtils = "0.19"
Expand Down
34 changes: 18 additions & 16 deletions example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,29 @@ X = randn(Float32, 5, 100)
y = 2 * cos.(X[4, :]) + X[1, :] .^ 2 .- 2

options = SymbolicRegression.Options(;
binary_operators=(+, *, /, -), unary_operators=(cos, exp), npopulations=20
binary_operators=(+, *, /, -),
unary_operators=(cos, exp),
npopulations=20,
seed=0,
deterministic=true,
)

hall_of_fame = EquationSearch(
X, y; niterations=40, options=options, parallelism=:multithreading
)
hall_of_fame = EquationSearch(X, y; niterations=80, options=options, parallelism=:serial)

dominating = calculate_pareto_frontier(X, y, hall_of_fame, options)
# dominating = calculate_pareto_frontier(X, y, hall_of_fame, options)

trees = [member.tree for member in dominating]
# trees = [member.tree for member in dominating]

tree = trees[end]
output, did_succeed = eval_tree_array(tree, X, options)
# tree = trees[end]
# output, did_succeed = eval_tree_array(tree, X, options)

eqn = node_to_symbolic(dominating[end].tree, options)
println("Complexity\tMSE\tEquation")
# eqn = node_to_symbolic(dominating[end].tree, options)
# println("Complexity\tMSE\tEquation")

for member in dominating
complexity = compute_complexity(member.tree, options)
loss = member.loss
string = string_tree(member.tree, options)
# for member in dominating
# complexity = compute_complexity(member.tree, options)
# loss = member.loss
# string = string_tree(member.tree, options)

println("$(complexity)\t$(loss)\t$(string)")
end
# println("$(complexity)\t$(loss)\t$(string)")
# end
2 changes: 1 addition & 1 deletion src/AdaptiveParsimony.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ for an equation at size `size`.
@inline function update_frequencies!(
running_search_statistics::RunningSearchStatistics; size=nothing
)
if size <= length(running_search_statistics.frequencies)
if 0 < size <= length(running_search_statistics.frequencies)
running_search_statistics.frequencies[size] += 1
end
return nothing
Expand Down
3 changes: 2 additions & 1 deletion src/CheckConstraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ end

"""Check if user-passed constraints are violated or not"""
function check_constraints(tree::Node, options::Options, maxsize::Int)::Bool
if compute_complexity(tree, options) > maxsize
size = compute_complexity(tree, options)
if 0 > size > maxsize
return false
end
for i in 1:(options.nbin)
Expand Down
4 changes: 1 addition & 3 deletions src/Complexity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ function compute_complexity(tree::Node, options::Options)::Int
end
end

function _compute_complexity(
tree::Node, options::Options{C,complexity_type}
)::complexity_type where {C,complexity_type<:Real}
function _compute_complexity(tree::Node, options::Options{CT})::CT where {CT<:Real}
if tree.degree == 0
if tree.constant
return options.complexity_mapping.constant_complexity
Expand Down
10 changes: 1 addition & 9 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,7 @@ include("OptionsStruct.jl")
include("Operators.jl")
include("Options.jl")

import .ProgramConstantsModule:
MAX_DEGREE,
BATCH_DIM,
FEATURE_DIM,
RecordType,
SRConcurrency,
SRSerial,
SRThreaded,
SRDistributed
import .ProgramConstantsModule: MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType
import .DatasetModule: Dataset
import .OptionsStructModule: Options, MutationWeights, sample_mutation
import .OptionsModule: Options
Expand Down
4 changes: 2 additions & 2 deletions src/Mutate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,12 @@ function next_generation(
if options.use_frequency
oldSize = compute_complexity(prev, options)
newSize = compute_complexity(tree, options)
old_frequency = if (oldSize <= options.maxsize)
old_frequency = if (0 < oldSize <= options.maxsize)
running_search_statistics.normalized_frequencies[oldSize]
else
1e-6
end
new_frequency = if (newSize <= options.maxsize)
new_frequency = if (0 < newSize <= options.maxsize)
running_search_statistics.normalized_frequencies[newSize]
else
1e-6
Expand Down
9 changes: 3 additions & 6 deletions src/Options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -600,17 +600,13 @@ function Options(;
end
end

options = Options{
typeof(loss),
eltype(complexity_mapping),
tournament_selection_p,
tournament_selection_n,
}(
options = Options{eltype(complexity_mapping)}(
operators,
bin_constraints,
una_constraints,
complexity_mapping,
tournament_selection_n,
tournament_selection_p,
parsimony,
alpha,
maxsize,
Expand Down Expand Up @@ -659,6 +655,7 @@ function Options(;
skip_mutation_failures,
nested_constraints,
deterministic,
define_helper_functions,
)

return options
Expand Down
8 changes: 5 additions & 3 deletions src/OptionsStruct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,13 @@ function ComplexityMapping(;
)
end

struct Options{LossType<:Union{SupervisedLoss,Function},ComplexityType,_prob_pick_first,_ns}
struct Options{CT}
operators::AbstractOperatorEnum
bin_constraints::Vector{Tuple{Int,Int}}
una_constraints::Vector{Int}
complexity_mapping::ComplexityMapping{ComplexityType}
complexity_mapping::ComplexityMapping{CT}
tournament_selection_n::Int
tournament_selection_p::Float32
parsimony::Float32
alpha::Float32
maxsize::Int
Expand Down Expand Up @@ -140,7 +141,7 @@ struct Options{LossType<:Union{SupervisedLoss,Function},ComplexityType,_prob_pic
nuna::Int
nbin::Int
seed::Union{Int,Nothing}
loss::LossType
loss::Union{SupervisedLoss,Function}
progress::Bool
terminal_width::Union{Int,Nothing}
optimizer_algorithm::String
Expand All @@ -157,6 +158,7 @@ struct Options{LossType<:Union{SupervisedLoss,Function},ComplexityType,_prob_pic
skip_mutation_failures::Bool
nested_constraints::Union{Vector{Tuple{Int,Int,Vector{Tuple{Int,Int,Int}}}},Nothing}
deterministic::Bool
define_helper_functions::Bool
end

function Base.print(io::IO, options::Options)
Expand Down
9 changes: 6 additions & 3 deletions src/Population.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,13 @@ end
function best_of_sample(
pop::Population{T},
running_search_statistics::RunningSearchStatistics,
options::Options{A,B,p,tournament_selection_n},
)::PopMember where {T<:Real,A,B,p,tournament_selection_n}
options::Options{CT},
)::PopMember where {T<:Real,CT}
sample = sample_pop(pop, options)

p = options.tournament_selection_p
tournament_selection_n = options.tournament_selection_n

if options.use_frequency_in_tournament
# Score based on frequency of that size occuring.
# In the end, all sizes should be just as common in the population.
Expand All @@ -92,7 +95,7 @@ function best_of_sample(
scores = Vector{T}(undef, tournament_selection_n)
for (i, member) in enumerate(sample.members)
size = compute_complexity(member.tree, options)
frequency = if (size <= options.maxsize)
frequency = if (0 < size <= options.maxsize)
running_search_statistics.normalized_frequencies[size]
else
T(0)
Expand Down
6 changes: 0 additions & 6 deletions src/ProgramConstants.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,4 @@ const BATCH_DIM = 2
const FEATURE_DIM = 1
const RecordType = Dict{String,Any}

"""Enum for concurrency type (to get function specialization)"""
abstract type SRConcurrency end
struct SRSerial <: SRConcurrency end
struct SRThreaded <: SRConcurrency end
struct SRDistributed <: SRConcurrency end

end
20 changes: 10 additions & 10 deletions src/SearchUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Printf: @printf, @sprintf
using Distributed
import StatsBase: mean

import ..CoreModule: SRThreaded, SRSerial, SRDistributed, Dataset, Options
import ..CoreModule: Dataset, Options
import ..ComplexityModule: compute_complexity
import ..PopulationModule: Population, copy_population
import ..HallOfFameModule:
Expand All @@ -32,11 +32,11 @@ end

macro sr_spawner(parallel, p, expr)
quote
if $(esc(parallel)) == SRSerial
if $(esc(parallel)) == :serial
$(esc(expr))
elseif $(esc(parallel)) == SRDistributed
elseif $(esc(parallel)) == :multiprocessing
@spawnat($(esc(p)), $(esc(expr)))
elseif $(esc(parallel)) == SRThreaded
elseif $(esc(parallel)) == :multithreading
Threads.@spawn($(esc(expr)))
else
error("Invalid parallel type.")
Expand Down Expand Up @@ -197,8 +197,8 @@ function estimate_work_fraction(monitor::ResourceMonitor)::Float64
return mean(work_intervals) / (mean(work_intervals) + mean(rest_intervals))
end

function get_load_string(; head_node_occupation::Float64, ConcurrencyType=SRSerial)
ConcurrencyType == SRSerial && return ""
function get_load_string(; head_node_occupation::Float64, parallelism=:serial)
parallelism == :serial && return ""
out = @sprintf("Head worker occupation: %.1f%%", head_node_occupation * 100)

raise_usage_warning = head_node_occupation > 0.2
Expand All @@ -218,11 +218,11 @@ function update_progress_bar!(
dataset::Dataset{T},
options::Options,
head_node_occupation::Float64,
ConcurrencyType=SRSerial,
parallelism=:serial,
) where {T}
equation_strings = string_dominating_pareto_curve(hall_of_fame, dataset, options)
# TODO - include command about "q" here.
load_string = get_load_string(; head_node_occupation, ConcurrencyType)
load_string = get_load_string(; head_node_occupation, parallelism)
load_string *= @sprintf("Press 'q' and then <enter> to stop execution early.\n")
equation_strings = load_string * equation_strings
set_multiline_postfix!(progress_bar, equation_strings)
Expand All @@ -238,14 +238,14 @@ function print_search_state(
total_cycles::Int,
cycles_remaining::Vector{Int},
head_node_occupation::Float64,
ConcurrencyType=SRSerial,
parallelism=:serial,
) where {T}
nout = length(datasets)
average_speed = sum(equation_speed) / length(equation_speed)

@printf("\n")
@printf("Cycles per second: %.3e\n", round(average_speed, sigdigits=3))
load_string = get_load_string(; head_node_occupation, ConcurrencyType)
load_string = get_load_string(; head_node_occupation, parallelism)
print(load_string)
cycles_elapsed = total_cycles * nout - sum(cycles_remaining)
@printf(
Expand Down
2 changes: 1 addition & 1 deletion src/SingleIteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function s_r_cycle(
for member in pop.members
size = compute_complexity(member.tree, options)
score = member.score
if size <= options.maxsize && (
if 0 < size <= options.maxsize && (
!best_examples_seen.exists[size] ||
score < best_examples_seen.members[size].score
)
Expand Down

0 comments on commit b0368c8

Please sign in to comment.