Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FunctionTerm is dead, long live FunctionTerm #183

Merged
merged 55 commits into from
Jan 24, 2023
Merged

Conversation

kleinschmidt
Copy link
Member

@kleinschmidt kleinschmidt commented Jul 11, 2020

This is a pretty substantial change in how non-special function calls are
represented. Instead of generating an anonymous function, the new
FunctionTerm simple wraps the called function, it's arguments (wrapped as
Terms), and the original expression. When evaluated with modelcols, it uses
Base.Broadcast.broadcasted to lazily fuse nested function calls and then at
the top level calls materialize to run. I hope that this will provide
performance that's comparable to the anonymous function, while being both more
run-time friendly and much simpler.

As a side effect, all of the macro-time parsing has been removed, and all the
special syntax now applies at run time (via method overloading of +, &, ~,
and * (I know, I'm sorry, this is just for continuity's sake right now).

Also, shamelessly stealing from inspired by @oxinabox suggestions for how to
handle nested special and non-special functions (#117), I've added some
additional special syntax: protect and unprotect. protect says to treat
every call below it as non-special, while unprotect (in an otherwise protected
context) says to treat everything below it as potentially special (e.g., as if
it occurred at the top level of a formula). This isn't perfect at the moment
but it's a long way towards being able to do something like 1 - unprotect(poly(x, 3))
and having it do something sensible (actually that might work; what doesn't work
is (1 - unprotect(a*b)) since that will generated a + b + a&b which don't
get fused into a single matrix which messes with the broadcasting).

I also added on a whim an @support_unprotect op macro which will designate a function as
special in unprotected contexts. This simply adds a method like

apply_schema(t::FunctionTerm{typeof($op)}, sch::Schema, Mod::Type) =
    apply_schema($op(t.args...), sch, Mod)

So, if you have something like FunctionTerm{typeof(+)} in a non-protected
context, then it will be converted into a call to +(args[1], args[2], ...)
when you do apply_schema. This is taking advantage of the fact that all the
special handling of special syntax is done via method overloading of the
corresponding functions, so it should work fine at runtime. At the moment it's
basically only useful for internal purposes but a two-argument form (function
and context type) might be useful for packages.

Tests do pass with this change, but it probably needs more tests to specifically
cover the protect/unprotect stuff. Docs aren't updated because that seems
premature at this point before folks have had a chance to weigh in.

@kleinschmidt
Copy link
Member Author

kleinschmidt commented Jul 11, 2020

The major problem with this proposal, now that I've messed with it a bit, is that these kinds of methods really hit the compiler hard:

# + concatenates terms
Base.:+(terms::AbstractTerm...) = (unique(reduce(+, terms))..., )
Base.:+(a::AbstractTerm) = a
Base.:+(a::AbstractTerm, b::AbstractTerm) = (a,b)

# associative rule for +
Base.:+(as::TupleTerm, b::AbstractTerm) = (as..., b)
Base.:+(a::AbstractTerm, bs::TupleTerm) = (a, bs...)
Base.:+(as::TupleTerm, bs::TupleTerm) = (as..., bs...)

(TupleTerm is just NTuple{AbstractTerm, N} where {N}).

The reason is that a new method has to be compiled for every combination of number and types of terms being added together. This is related to #165 : all the type parameters in the term types create a lot of compiler overhead. I can think of a few ways around this.

  1. we could go back to doing these transformations at parse time, which would require that FunctionTerm continue to hold onto two copies of the args (one parsed and the other not). It would also require some duplication of functionality if we want to continue to support run-time construction as a first class interface, since we still need all these methods for the same syntax to apply as inside the macro.

  2. we could use some other non-parametrized version of the +-ed terms, like a Vector, or a custom wrapper. This seems potentially the least disruptive (and could even convert back to a tuple if needed for performance at some point).

  3. we could represnt ALL calls, even to special syntax like +, &, and *, as FunctionTerms, and only parse them at apply_schema time, or do some fanciness during construction. I think this is the most radical/involved and I'm not sure it's really worth it.

Edit: I did try wrapping all these in a @nospecialize block but it didn't seem to have any effect, but then again I don't really know what I'm doing so ..

@kleinschmidt
Copy link
Member Author

Now I'm really confused. I've actually gone through the exercise of replacing the tuple-based representation for +ed terms with a Vector-based storage (in this branch), and STILL get much longer first-call timings for every new number of terms, which suggests that some new methods are getting compiled, even though all the types are the same (no longer depend on the number of terms).

@nalimilan
Copy link
Member

Base.:+(terms::AbstractTerm...) is still going to be specialized on the number of arguments, right?

@kleinschmidt
Copy link
Member Author

kleinschmidt commented Jul 14, 2020

Yeah, I'm afraid so. Makes me wonder how stuff like this is handled in Base...

edit, it's easy enough to find out :)

https://github.com/JuliaLang/julia/blob/24f033c9517cd186448acc3bededa1a64eaad09f/base/operators.jl#L521-L543

So I think defining that method is actually not necessary, except that we also want to call unique on the output...

@oxinabox
Copy link
Contributor

oxinabox commented Jul 25, 2020

Specialisation around splatting is weird in general, and its not documented what it will do.
IIRC foo(x::VarArg) and foo(x...) specialize differently.

But in anycase if you don't want specialiation just used the @nospecialise macro.
Which is either done for a secition of code via:

@nospecialize

foo(x)=1
bar(x)=2

@specialize

or for an argument:

foobar(x, @nospecialize(y)) = x + y

if @nospecialize isn't working, then that would be a bug in Base, and you should open an issue.
Rather than redo the whole design to workaround the fact that you can't turn off specialization.
I have had @nospecialize work before.

@kleinschmidt
Copy link
Member Author

I'm more and more concerned that just putting a few @nospecialize around the base methods we're overloading isn't going to be enough, since every time you hit a constructor for any term that can have children (formula, interaction, matrix, function) you're going to trigger compilation for every number and type of children, and same thing for every other method involving those types. At that point you're going to have to put @nospecialize on just about everything that can take an AbstractTerm which seems like a sign that something has gone wrong.

At this point, my plan is to do a bit of benchmarking to see whether formula creation gets appreciably WORSE with this PR, and if not, we can forget about trying to fix the performance in another PR and focus on whether these changes do justice to @oxinabox vision enough to merge.

@kleinschmidt
Copy link
Member Author

Okay I've done a bit of timing (using the script from
https://gist.github.com/kleinschmidt/f51d305d56a590030c4f8688cbf18929). On
master, these are the timings (first and second run times, since I'm mostly
interested in compilation time itself here):

y ~ 1 + x
  0.000033 seconds (24 allocations: 1.453 KiB)
  0.000032 seconds (24 allocations: 1.453 KiB)
y ~ a + b
  0.008989 seconds (3.57 k allocations: 222.652 KiB)
  0.000014 seconds (11 allocations: 720 bytes)
y ~ a + b + c
  0.009176 seconds (3.58 k allocations: 223.059 KiB)
  0.000017 seconds (12 allocations: 752 bytes)
y ~ a + b + c + d
  0.011082 seconds (3.58 k allocations: 222.840 KiB)
  0.000017 seconds (13 allocations: 784 bytes)
y ~ a + b + c + d + e
  0.009841 seconds (3.58 k allocations: 223.012 KiB)
  0.000023 seconds (15 allocations: 896 bytes)
y ~ a + b + c + d + e + f
  0.010660 seconds (3.58 k allocations: 223.105 KiB)
  0.000029 seconds (16 allocations: 928 bytes)
y ~ log(a)
  0.065474 seconds (96.09 k allocations: 5.480 MiB)
  0.029364 seconds (9.05 k allocations: 520.255 KiB)
y ~ log(a) + log(b)
  0.941035 seconds (1.19 M allocations: 59.954 MiB, 10.87% gc time)
  0.746309 seconds (1.02 M allocations: 51.098 MiB, 12.56% gc time)
y ~ log(a) + log(b) + log(c)
  0.793438 seconds (1.43 M allocations: 72.068 MiB, 2.55% gc time)
  0.647390 seconds (901.92 k allocations: 45.123 MiB, 1.24% gc time)
y ~ log(a) + log(b) + log(c) + log(d)
  0.777564 seconds (1.01 M allocations: 50.078 MiB, 1.10% gc time)
  0.809190 seconds (1.01 M allocations: 50.081 MiB, 1.07% gc time)
y ~ log(a) + log(b) + log(c) + log(d) + log(e)
  0.914824 seconds (1.11 M allocations: 55.022 MiB, 1.15% gc time)
  0.918323 seconds (1.11 M allocations: 55.028 MiB, 1.02% gc time)
y ~ log(a) + log(b) + log(c) + log(d) + log(e) + log(f)
  1.036401 seconds (1.21 M allocations: 59.980 MiB, 1.71% gc time)
  1.017207 seconds (1.21 M allocations: 59.978 MiB, 0.86% gc time)
y ~ exp(a)
  0.027674 seconds (9.39 k allocations: 547.880 KiB)
  0.030770 seconds (9.05 k allocations: 520.177 KiB)
y ~ exp(a) + exp(b)
  0.778712 seconds (1.20 M allocations: 59.997 MiB, 2.29% gc time)
  0.684971 seconds (1.02 M allocations: 51.120 MiB, 1.34% gc time)
y ~ exp(a) + exp(b) + exp(c)
  0.742065 seconds (1.24 M allocations: 62.374 MiB, 1.24% gc time)
  0.655097 seconds (901.92 k allocations: 45.129 MiB, 1.56% gc time)
y ~ exp(a) + exp(b) + exp(c) + exp(d)
  0.773946 seconds (1.01 M allocations: 50.081 MiB, 1.32% gc time)
  0.778134 seconds (1.01 M allocations: 50.084 MiB, 2.41% gc time)
y ~ exp(a) + exp(b) + exp(c) + exp(d) + exp(e)
  0.891833 seconds (1.11 M allocations: 55.043 MiB, 1.04% gc time)
  0.873386 seconds (1.11 M allocations: 55.022 MiB, 1.05% gc time)
y ~ exp(a) + exp(b) + exp(c) + exp(d) + exp(e) + exp(f)
  0.982716 seconds (1.21 M allocations: 59.970 MiB, 1.11% gc time)
  1.017105 seconds (1.21 M allocations: 59.977 MiB, 1.86% gc time)

And on this PR (commit eaec084):

y ~ 1 + x
  0.000100 seconds (25 allocations: 1.406 KiB)
  0.000093 seconds (25 allocations: 1.406 KiB)
y ~ a + b
  0.007148 seconds (2.17 k allocations: 135.030 KiB)
  0.000016 seconds (11 allocations: 768 bytes)
y ~ a + b + c
  0.042739 seconds (53.78 k allocations: 3.048 MiB)
  0.000026 seconds (23 allocations: 1.484 KiB)
y ~ a + b + c + d
  0.036280 seconds (53.91 k allocations: 3.052 MiB)
  0.000020 seconds (25 allocations: 1.547 KiB)
y ~ a + b + c + d + e
  0.047632 seconds (54.05 k allocations: 3.059 MiB)
  0.000022 seconds (28 allocations: 1.688 KiB)
y ~ a + b + c + d + e + f
  0.039089 seconds (54.18 k allocations: 3.064 MiB)
  0.000020 seconds (30 allocations: 1.750 KiB)
y ~ log(a)
  0.027760 seconds (44.74 k allocations: 2.699 MiB)
  0.000014 seconds (11 allocations: 480 bytes)
y ~ log(a) + log(b)
  0.277548 seconds (332.92 k allocations: 17.035 MiB)
  0.000028 seconds (27 allocations: 1.703 KiB)
y ~ log(a) + log(b) + log(c)
  0.080824 seconds (121.33 k allocations: 6.716 MiB)
  0.000031 seconds (42 allocations: 2.906 KiB)
y ~ log(a) + log(b) + log(c) + log(d)
  0.113681 seconds (114.74 k allocations: 6.305 MiB, 10.12% gc time)
  0.000034 seconds (49 allocations: 3.297 KiB)
y ~ log(a) + log(b) + log(c) + log(d) + log(e)
  0.075690 seconds (118.20 k allocations: 6.460 MiB)
  0.000034 seconds (57 allocations: 3.828 KiB)
y ~ log(a) + log(b) + log(c) + log(d) + log(e) + log(f)
  0.075736 seconds (121.95 k allocations: 6.627 MiB)
  0.000037 seconds (64 allocations: 4.219 KiB)
y ~ exp(a)
  0.013628 seconds (6.69 k allocations: 411.996 KiB)
  0.000013 seconds (11 allocations: 480 bytes)
y ~ exp(a) + exp(b)
  0.251576 seconds (332.96 k allocations: 17.036 MiB)
  0.000035 seconds (27 allocations: 1.703 KiB)
y ~ exp(a) + exp(b) + exp(c)
  0.083678 seconds (121.34 k allocations: 6.716 MiB)
  0.000037 seconds (42 allocations: 2.906 KiB)
y ~ exp(a) + exp(b) + exp(c) + exp(d)
  0.103966 seconds (114.75 k allocations: 6.304 MiB, 9.53% gc time)
  0.000035 seconds (49 allocations: 3.297 KiB)
y ~ exp(a) + exp(b) + exp(c) + exp(d) + exp(e)
  0.087459 seconds (118.21 k allocations: 6.461 MiB)
  0.000039 seconds (57 allocations: 3.828 KiB)
y ~ exp(a) + exp(b) + exp(c) + exp(d) + exp(e) + exp(f)
  0.080457 seconds (121.96 k allocations: 6.629 MiB)
  0.000038 seconds (64 allocations: 4.219 KiB)

Bottom line is, this PR is MUCH faster for anything involving a custom function,
both first and second run (noticeably so, going from ~1s every time you create a
formula with a function call in it to ~100ms first run and <1ms after that).
It's slower on first run (~50ms vs. ~10ms) but comparable after that.

There was, I remember, a good reason why it made sense to move the parsing rules
out of the macro and into run-time but at the moment I can't recall exactly
why. Only that it had something to do with the protect/unprotect stuff. I
don't think it was just that I had started to dislike keeping both the
"parsed" and "non-parsed" copies of the arguments in the function term, but that
is a definite bonus in my view.

docs/src/api.md Outdated Show resolved Hide resolved
src/schema.jl Outdated Show resolved Hide resolved
Project.toml Outdated
@@ -15,6 +16,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
CategoricalArrays = "0.8"
Compat = "2.2, 3"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line isn't in the diff anymore. Is this intended?

src/terms.jl Outdated
Comment on lines 384 to 390
Base.:&(::ConstantTerm, b::AbstractTerm) = b
Base.:&(a::AbstractTerm, ::ConstantTerm) = a

# Avoid method ambiguities
Base.:&(::ConstantTerm, b::InteractionTerm) = b
Base.:&(a::InteractionTerm, ::ConstantTerm) = a
Base.:&(a::ConstantTerm, ::ConstantTerm) = a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. It would sound safer to check that the value is 1 then.

Project.toml Outdated
@@ -1,8 +1,9 @@
name = "StatsModels"
uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
version = "0.6.33"
version = "0.7.0-DEV"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually we just set the version that will be released (I'm not even sure Pkg accepts letters, which may explain the docs build failure).

Suggested change
version = "0.7.0-DEV"
version = "0.7.0"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we should even do this:

Suggested change
version = "0.7.0-DEV"
version = "1.0.0"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ehhh I'm on the fence about that...there are a few other potentially breaking changes that I think we should at least CONSIDER before releasing 1.0 (that don't have PRs/WIP already or are pretty complex, like like using Tables.Columns, removing all teh type parameters from the AbstractTerms, handling the missing masks, etc.), and I'd like to get this released while I still have the spoons to work on it, rather than waiting on all the other breaking changes we might want to consider before 1.0.

src/terms.jl Outdated Show resolved Hide resolved
src/terms.jl Outdated
Comment on lines 384 to 390
Base.:&(::ConstantTerm, b::AbstractTerm) = b
Base.:&(a::AbstractTerm, ::ConstantTerm) = a

# Avoid method ambiguities
Base.:&(::ConstantTerm, b::InteractionTerm) = b
Base.:&(a::InteractionTerm, ::ConstantTerm) = a
Base.:&(a::ConstantTerm, ::ConstantTerm) = a
Copy link
Member

@nalimilan nalimilan Jan 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Referring to #183 (comment):
Yes better throw errors in absurd cases rather than risk hiding bugs. So with the new state of the PR an error is thrown in both cases above, right?

kleinschmidt and others added 3 commits January 24, 2023 10:42
Co-authored-by: Milan Bouchet-Valat <nalimilan@club.fr>
@kleinschmidt kleinschmidt mentioned this pull request Jan 24, 2023
5 tasks
@kleinschmidt kleinschmidt merged commit c4b68cf into master Jan 24, 2023
@kleinschmidt kleinschmidt deleted the dfk/syntax branch January 24, 2023 16:28
kleinschmidt added a commit that referenced this pull request Mar 13, 2023
* functionterm news

* run-time

* Update NEWS.md

Co-authored-by: Phillip Alday <palday@users.noreply.github.com>

* Update NEWS.md

Co-authored-by: Phillip Alday <palday@users.noreply.github.com>

---------

Co-authored-by: Phillip Alday <palday@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants