Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chain(Parallel(...), ...) #2100

Closed
cstjean opened this issue Nov 7, 2022 · 4 comments
Closed

Chain(Parallel(...), ...) #2100

cstjean opened this issue Nov 7, 2022 · 4 comments

Comments

@cstjean
Copy link

cstjean commented Nov 7, 2022

Describe the potential feature

When combining multiple inputs (or an embedding + inputs), one can combine them at the end, with Parallel(+, Chain(...), Chain(...)), but one can't combine them earlier: Chain(Parallel(...), Dense(...)). That code runs, but when called, there is no (::Chain)(args...) method. Only the unary one.

In other words, if Chain is function composition, it should support

julia> (sin  cos  +)(1, 2)
-0.8360218615377305

Motivation

Helps with multi-modal inputs / embeddings.

Possible Implementation

_applychain(layers::Tuple, args...) = _apply_chain(layers[2:end], layers[1](args...))

? #1809 has some background on why the generated function was needed for _applychain. If the above one-liner isn't zygote-friendly (... I really don't know), then I could modify the generated function.

@DhairyaLGandhi
Copy link
Member

The reason to remove the recursive definition was to help compile times with "larger" chains and support named layers, but yes a generated function could be avoidable. Best is to try and remove the unary method on chain and allow vararg inputs.

@cstjean
Copy link
Author

cstjean commented Nov 7, 2022

I started a PR. The new behaviour is somewhat incompatible with the support for Chain(), since Chain()(x1, x2) is nonsensical. We can special-case the unary case to return its only argument. It's a bit ugly though. Is there a good reason for supporting Chain()? FWIW, composition in base julia doesn't support no argument, perhaps for this exact reason.

julia> ()
ERROR: MethodError: no method matching ()

I don't mind either way, just let me know.

cstjean added a commit to cstjean/Flux.jl that referenced this issue Nov 7, 2022
Closes FluxML#2100

As mentionned in FluxML#2100 (comment),
this will break any code using `Chain()` as the identity function.
@ToucheSir
Copy link
Member

Is there a good reason for supporting Chain()?

Because Chain acts like an indexable collection, we didn't want to make Chain(...)[i:i] throw. It would be too unergonomic and deviates from what other collection types do. As a bonus, people don't have to scatter innumerable isempty checks before constructing Chains with splatted varargs. Whereas \circ is usually invoked with a known, non-zero number of functions, the Chain constructor is often nested deep within a stack of model building functions (see e.g. Metalhead) and a bit more flexibility is helpful there.

@cstjean
Copy link
Author

cstjean commented Nov 9, 2022

Following the discussion in #2101, closing. Chain(Parallel(...)) is already supported by passing a tuple.

@cstjean cstjean closed this as completed Nov 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants