From 41a6e2923d5f7a26cdd1f19fd2d17ad88d907e87 Mon Sep 17 00:00:00 2001 From: pdeffebach <23196228+pdeffebach@users.noreply.github.com> Date: Thu, 6 Jan 2022 14:47:00 -0500 Subject: [PATCH] Composition of nested functions (#319) * transform nested functions * add checking method * test with col * fix macro hygiene * add tests for correctness * simplify logic of make_composed * Apply suggestions from code review Co-authored-by: Milan Bouchet-Valat Co-authored-by: Milan Bouchet-Valat --- src/parsing.jl | 47 ++++++++++++++++++++++++++++++++++++ test/function_compilation.jl | 40 ++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/src/parsing.jl b/src/parsing.jl index 768b7889..2a5adbd0 100644 --- a/src/parsing.jl +++ b/src/parsing.jl @@ -67,6 +67,48 @@ function composed_or_symbol(x::Expr) all(composed_or_symbol, x.args[2:end]) end +is_call(x) = false +is_call(x::Expr) = x.head === :call + +is_nested_fun(x) = false +function is_nested_fun(x::Expr) + x.head === :call && + length(x.args) == 2 && + is_call(x.args[2]) && + # AsTable(:x) or `$(:x)` + return get_column_expr(x.args[2]) === nothing +end + +is_nested_fun_recursive(x, nested_once) = false +function is_nested_fun_recursive(x::Expr, nested_once) + if is_nested_fun(x) + return is_nested_fun_recursive(x.args[2], true) + elseif is_simple_non_broadcast_call(x) + return nested_once + else + return false + end +end +make_composed(x) = x +function make_composed(x::Expr) + funs = Any[] + x_orig = x + nested_once = false + while true + if is_nested_fun(x) + push!(funs, x.args[1]) + x = x.args[2] + nested_once = true + elseif is_simple_non_broadcast_call(x) && nested_once + push!(funs, x.args[1]) + # ∘(f, g, h)(:x, :y, :z) + return Expr(:call, Expr(:call, ∘, funs...), x.args[2:end]...) + else + throw(ArgumentError("Not eligible for function composition")) + end + end +end + is_simple_non_broadcast_call(x) = false function is_simple_non_broadcast_call(expr::Expr) expr.head == :call && @@ -214,6 +256,11 @@ function get_source_fun(function_expr; exprflags = deepcopy(DEFAULT_FLAGS)) source = args_to_selectors(function_expr.args[2].args) fun_t = function_expr.args[1] fun = :(DataFrames.ByRow($fun_t)) + elseif is_nested_fun_recursive(function_expr, false) + composed_expr = make_composed(function_expr) + # Repeat clean up from simple non-broadcast above + source = args_to_selectors(composed_expr.args[2:end]) + fun = composed_expr.args[1] else membernames = Dict{Any, Symbol}() diff --git a/test/function_compilation.jl b/test/function_compilation.jl index b6d60143..471f07f9 100644 --- a/test/function_compilation.jl +++ b/test/function_compilation.jl @@ -218,11 +218,51 @@ end slowtime = @timed select(df_wide, AsTable(:) => ByRow(t -> (sum ∘ skipmissing)(t)) => :y) (slowtime[2] > fasttime[2]) || @warn("Slow compilation") + + @test @select(df, :y = f(g(:a, :b))).y == [3] + + fasttime = @timed @select(df, :y = f(g(:a, :b))) + slowtime = @timed select(df, [:a, :b] => ((a, b) -> f(g(a, b))) => :y ) + (slowtime[2] > fasttime[2]) || @warn("Slow compilation") + + fasttime = @timed @rselect df_wide :y = sum(skipmissing(AsTable(:))) + slowtime = @timed select(df_wide, AsTable(:) => ByRow(t -> sum(skipmissing(t))) => :y) + + (slowtime[2] > fasttime[2]) || @warn("Slow compilation") end end + # Tests for correctness + t = @select df :y = sum(skipmissing(AsTable([:a, :b]))) + @test t == DataFrame(y = [3]) + + t = @select df :y = f(g(:a, :b)) + @test t == DataFrame(y = [3]) + + a_str = "a" + b_str = "b" + + t = @select df :y = f(g($a_str, $(b_str))) + @test t == DataFrame(y = [3]) + + t = @select df :y = f(g(:a, $b_str)) + @test t == DataFrame(y = [3]) + + t = @select df :y = sum(skipmissing(AsTable(Cols(:)))) + @test t == DataFrame(y = [3]) + + t = @select df :y = sum(skipmissing(AsTable(Cols(:)))) + @test t == DataFrame(y = [3]) + t = DataFramesMeta.@col :y = (identity ∘ first)(:x) @test t == ([:x] => (identity ∘ first => :y)) + + t = DataFramesMeta.@col :y = identity(first(:x)) + @test t == ([:x] => (identity ∘ first => :y)) + + t = DataFramesMeta.@col :y = identity(first(AsTable(Cols(:)))) + + @test t == (AsTable(Cols(:)) => ((identity ∘ first) => :y)) end end # module \ No newline at end of file