Skip to content

Commit

Permalink
Composition of nested functions (#319)
Browse files Browse the repository at this point in the history
* transform nested functions

* add checking method

* test with col

* fix macro hygiene

* add tests for correctness

* simplify logic of make_composed

* Apply suggestions from code review

Co-authored-by: Milan Bouchet-Valat <nalimilan@club.fr>

Co-authored-by: Milan Bouchet-Valat <nalimilan@club.fr>
  • Loading branch information
pdeffebach and nalimilan committed Jan 6, 2022
1 parent 3ecdb78 commit 41a6e29
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
47 changes: 47 additions & 0 deletions src/parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,48 @@ function composed_or_symbol(x::Expr)
all(composed_or_symbol, x.args[2:end])
end

is_call(x) = false
is_call(x::Expr) = x.head === :call

is_nested_fun(x) = false
function is_nested_fun(x::Expr)
x.head === :call &&
length(x.args) == 2 &&
is_call(x.args[2]) &&
# AsTable(:x) or `$(:x)`
return get_column_expr(x.args[2]) === nothing
end

is_nested_fun_recursive(x, nested_once) = false
function is_nested_fun_recursive(x::Expr, nested_once)
if is_nested_fun(x)
return is_nested_fun_recursive(x.args[2], true)
elseif is_simple_non_broadcast_call(x)
return nested_once
else
return false
end
end
make_composed(x) = x
function make_composed(x::Expr)
funs = Any[]
x_orig = x
nested_once = false
while true
if is_nested_fun(x)
push!(funs, x.args[1])
x = x.args[2]
nested_once = true
elseif is_simple_non_broadcast_call(x) && nested_once
push!(funs, x.args[1])
# ∘(f, g, h)(:x, :y, :z)
return Expr(:call, Expr(:call, , funs...), x.args[2:end]...)
else
throw(ArgumentError("Not eligible for function composition"))
end
end
end

is_simple_non_broadcast_call(x) = false
function is_simple_non_broadcast_call(expr::Expr)
expr.head == :call &&
Expand Down Expand Up @@ -214,6 +256,11 @@ function get_source_fun(function_expr; exprflags = deepcopy(DEFAULT_FLAGS))
source = args_to_selectors(function_expr.args[2].args)
fun_t = function_expr.args[1]
fun = :(DataFrames.ByRow($fun_t))
elseif is_nested_fun_recursive(function_expr, false)
composed_expr = make_composed(function_expr)
# Repeat clean up from simple non-broadcast above
source = args_to_selectors(composed_expr.args[2:end])
fun = composed_expr.args[1]
else
membernames = Dict{Any, Symbol}()

Expand Down
40 changes: 40 additions & 0 deletions test/function_compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,51 @@ end
slowtime = @timed select(df_wide, AsTable(:) => ByRow(t -> (sum skipmissing)(t)) => :y)

(slowtime[2] > fasttime[2]) || @warn("Slow compilation")

@test @select(df, :y = f(g(:a, :b))).y == [3]

fasttime = @timed @select(df, :y = f(g(:a, :b)))
slowtime = @timed select(df, [:a, :b] => ((a, b) -> f(g(a, b))) => :y )
(slowtime[2] > fasttime[2]) || @warn("Slow compilation")

fasttime = @timed @rselect df_wide :y = sum(skipmissing(AsTable(:)))
slowtime = @timed select(df_wide, AsTable(:) => ByRow(t -> sum(skipmissing(t))) => :y)

(slowtime[2] > fasttime[2]) || @warn("Slow compilation")
end
end

# Tests for correctness
t = @select df :y = sum(skipmissing(AsTable([:a, :b])))
@test t == DataFrame(y = [3])

t = @select df :y = f(g(:a, :b))
@test t == DataFrame(y = [3])

a_str = "a"
b_str = "b"

t = @select df :y = f(g($a_str, $(b_str)))
@test t == DataFrame(y = [3])

t = @select df :y = f(g(:a, $b_str))
@test t == DataFrame(y = [3])

t = @select df :y = sum(skipmissing(AsTable(Cols(:))))
@test t == DataFrame(y = [3])

t = @select df :y = sum(skipmissing(AsTable(Cols(:))))
@test t == DataFrame(y = [3])

t = DataFramesMeta.@col :y = (identity first)(:x)
@test t == ([:x] => (identity first => :y))

t = DataFramesMeta.@col :y = identity(first(:x))
@test t == ([:x] => (identity first => :y))

t = DataFramesMeta.@col :y = identity(first(AsTable(Cols(:))))

@test t == (AsTable(Cols(:)) => ((identity first) => :y))
end

end # module

0 comments on commit 41a6e29

Please sign in to comment.