Skip to content

Commit

Permalink
Merge e45d1ba into c744cb3
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Sep 13, 2022
2 parents c744cb3 + e45d1ba commit e5c4a23
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 21 deletions.
1 change: 1 addition & 0 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ HallOfFame(options::Options, ::Type{T}) where {T<:Real}
## Dataset

```@docs
Dataset{T<:Real}
Dataset(X::AbstractMatrix{T},
y::AbstractVector{T};
weights::Union{AbstractVector{T}, Nothing}=nothing,
Expand Down
15 changes: 15 additions & 0 deletions src/Dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@ module DatasetModule

import ..ProgramConstantsModule: BATCH_DIM, FEATURE_DIM

"""
Dataset{T<:Real}
# Fields
- `X::AbstractMatrix{T}`: The input features, with shape `(nfeatures, n)`.
- `y::AbstractVector{T}`: The desired output values, with shape `(n,)`.
- `n::Int`: The number of samples.
- `nfeatures::Int`: The number of features.
- `weighted::Bool`: Whether the dataset is non-uniformly weighted.
- `weights::Union{AbstractVector{T},Nothing}`: If the dataset is weighted,
these specify the per-sample weight (with shape `(n,)`).
- `varMap::Array{String,1}`: The names of the features,
with shape `(nfeatures,)`.
"""
struct Dataset{T<:Real}
X::AbstractMatrix{T}
y::AbstractVector{T}
Expand Down
16 changes: 16 additions & 0 deletions src/EquationUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,22 @@ function count_constants(tree::Node)::Int
end
end

"""
is_constant(tree::Node)::Bool
Check if an expression is a constant numerical value, or
whether it depends on input features.
"""
function is_constant(tree::Node)::Bool
if tree.degree == 0
return tree.constant
elseif tree.degree == 1
return is_constant(tree.l)
else
return is_constant(tree.l) && is_constant(tree.r)
end
end

"""
Compute the complexity of a tree.
Expand Down
88 changes: 67 additions & 21 deletions src/EvaluateEquation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module EvaluateEquationModule

import ..CoreModule: Node, Options
import ..UtilsModule: @return_on_false, is_bad_array, debug
import ..EquationUtilsModule: is_constant

macro return_on_check(val, T, n)
# This will generate the following code:
Expand Down Expand Up @@ -78,38 +79,40 @@ end
function _eval_tree_array(
tree::Node{T}, cX::AbstractMatrix{T}, options::Options
)::Tuple{AbstractVector{T},Bool} where {T<:Real}
# First, we see if there are only constants in the tree - meaning
# we can just return the constant result.
if tree.degree == 0
deg0_eval(tree, cX, options)
return deg0_eval(tree, cX, options)
elseif is_constant(tree)
# Speed hack for constant trees.
result, flag = _eval_constant_tree(tree, options)
!flag && return Array{T,1}(undef, size(cX, 2)), false
return fill(result, size(cX, 2)), true
elseif tree.degree == 1
# TODO: We could all do Val(tree.l.degree) here, instead of having
# different kernels for const vs data.

# We fuse (and compile) the following:
# - op(op2(x, y)), where x, y, z are constants or variables.
# - op(op2(x)), where x is a constant or variable.
# - op(x), for any x.
if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
deg1_l2_ll0_lr0_eval(tree, cX, Val(tree.op), Val(tree.l.op), options)
# op(op2(x, y)), where x, y, z are constants or variables.
return deg1_l2_ll0_lr0_eval(tree, cX, Val(tree.op), Val(tree.l.op), options)
elseif tree.l.degree == 1 && tree.l.l.degree == 0
deg1_l1_ll0_eval(tree, cX, Val(tree.op), Val(tree.l.op), options)
# op(op2(x)), where x is a constant or variable.
return deg1_l1_ll0_eval(tree, cX, Val(tree.op), Val(tree.l.op), options)
else
deg1_eval(tree, cX, Val(tree.op), options)
# op(x), for any x.
return deg1_eval(tree, cX, Val(tree.op), options)
end
else
# We fuse (and compile) the following:
# - op(x, y), where x, y are constants or variables.
# - op(x, y), where x is a constant or variable but y is not.
# - op(x, y), where y is a constant or variable but x is not.
# - op(x, y), for any x or y
elseif tree.degree == 2
# TODO - add op(op2(x, y), z) and op(x, op2(y, z))
if tree.l.degree == 0 && tree.r.degree == 0
deg2_l0_r0_eval(tree, cX, Val(tree.op), options)
# op(x, y), where x, y are constants or variables.
return deg2_l0_r0_eval(tree, cX, Val(tree.op), options)
elseif tree.l.degree == 0
deg2_l0_eval(tree, cX, Val(tree.op), options)
# op(x, y), where x is a constant or variable but y is not.
return deg2_l0_eval(tree, cX, Val(tree.op), options)
elseif tree.r.degree == 0
deg2_r0_eval(tree, cX, Val(tree.op), options)
# op(x, y), where y is a constant or variable but x is not.
return deg2_r0_eval(tree, cX, Val(tree.op), options)
else
deg2_eval(tree, cX, Val(tree.op), options)
# op(x, y), for any x or y
return deg2_eval(tree, cX, Val(tree.op), options)
end
end
end
Expand Down Expand Up @@ -332,6 +335,49 @@ function deg2_r0_eval(
return (cumulator, true)
end

"""
_eval_constant_tree(tree::Node{T}, options::Options)::Tuple{T,Bool} where {T<:Real}
Evaluate a tree which is assumed to not contain any variable nodes. This
gives better performance, as we do not need to perform computation
over an entire array when the values are all the same.
"""
function _eval_constant_tree(tree::Node{T}, options::Options)::Tuple{T,Bool} where {T<:Real}
if tree.degree == 0
return deg0_eval_constant(tree)
elseif tree.degree == 1
return deg1_eval_constant(tree, Val(tree.op), options)
else
return deg2_eval_constant(tree, Val(tree.op), options)
end
end

@inline function deg0_eval_constant(tree::Node{T})::Tuple{T,Bool} where {T<:Real}
return tree.val, true
end

function deg1_eval_constant(
tree::Node{T}, ::Val{op_idx}, options::Options
)::Tuple{T,Bool} where {T<:Real,op_idx}
op = options.unaops[op_idx]
(cumulator, complete) = _eval_constant_tree(tree.l, options)
!complete && return zero(T), false
output = op(cumulator)::T
return output, isfinite(output)
end

function deg2_eval_constant(
tree::Node{T}, ::Val{op_idx}, options::Options
)::Tuple{T,Bool} where {T<:Real,op_idx}
op = options.binops[op_idx]
(cumulator, complete) = _eval_constant_tree(tree.l, options)
!complete && return zero(T), false
(cumulator2, complete2) = _eval_constant_tree(tree.r, options)
!complete2 && return zero(T), false
output = op(cumulator, cumulator2)::T
return output, isfinite(output)
end

# Evaluate an equation over an array of datapoints
# This one is just for reference. The fused one should be faster.
function differentiable_eval_tree_array(
Expand Down

0 comments on commit e5c4a23

Please sign in to comment.