Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimized hcat of onehot vectors and matrices #1595

Merged
merged 9 commits into from
May 13, 2021
Merged

Conversation

racinmat
Copy link
Contributor

@racinmat racinmat commented May 11, 2021

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable
  • API changes require approval from a committer (different from the author, if applicable)

Fixes #1594 , adds tests for it and also optimized reduce(hcat, xs).
No new features are added, only performance optimization.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! I left some suggestions that make sure we hit reshaped arrays as well.

src/onehot.jl Outdated Show resolved Hide resolved
src/onehot.jl Outdated Show resolved Hide resolved
src/onehot.jl Outdated Show resolved Hide resolved
hitting reshaped arrays

Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
@johnnychen94
Copy link
Contributor

johnnychen94 commented May 11, 2021

I believe the root issue is that we don't have similar defined for OneHotArray; once it's defined, it's very likely that many of these overly-verbose optimizations will be unnecessary.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go back to the previously suggested changes (I have commented them again with some updates). If the constructor is causing issues, then why not just change the constructor to:

OneHotArray(indices::I, L::Integer) where {T, N, I<: AbstractArray{T, N}} = OneHotArray{T, L, N, I}(indices)

Then much more than hcat can see the performance benefit.

src/onehot.jl Outdated Show resolved Hide resolved
src/onehot.jl Outdated Show resolved Hide resolved
@darsnack
Copy link
Member

darsnack commented May 12, 2021

I believe the root is is that we don't have similar defined for OneHotArray; once it's defined, it's very likely that many of these overly-verbose optimizations will be unnecessary.

Good point! Though this means we would need to pair similar with setindex! which would require checks to stop someone from breaking the one-hot behavior. I think directly operating on the index arrays will still see a performance benefit.

We might want to define similar anyways to guarantee more type stability in fallback cases (e.g. the reduce(hcat, ...) one).

@DhairyaLGandhi
Copy link
Member

Ideally we would fix it generally rather than for special cases. That's the flux way.

@darsnack
Copy link
Member

darsnack commented May 13, 2021

Here is an implementation of similar:

function Base.similar(x::OneHotArray, ::Type{Bool}, dims::Dims)
  indices = similar(_indices(x), Base.tail(dims))

  return OneHotArray(indices, first(dims))
end

This would need to be paired with Base.setindex! to address the reduce(hcat, ...) type instability (too late for me to figure out the index combos tonight 😅). Though I am guessing that since setindex! will need to perform checks to ensure the one hot property is maintained, this will not be as fast as an optimized custom reduce(::typeof(hcat), ...). I think we should define similar no matter what though.

racinmat and others added 2 commits May 13, 2021 10:08
making hcat more generic for onehotlike

Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
@racinmat
Copy link
Contributor Author

racinmat commented May 13, 2021

Ok, the performance is now good even for reshaped arrays. Should I now include the Base.similar in this PR, or not, because of pairing with Base.setindex!?
Is it ready to merge?
I agree that generic fix would be the best solution, but I think this is good as an intermediate thing before the proper implementation.

@darsnack
Copy link
Member

I think the best move is to remove the reduce(hcat, ...) paths from this PR. Leave #1596 open, and we address it later with similar and setindex!.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some readability changes in addition to the request to remove the reduce(hcat, ...) paths for now.

src/onehot.jl Outdated Show resolved Hide resolved
src/onehot.jl Outdated Show resolved Hide resolved
racinmat and others added 3 commits May 13, 2021 17:52
more readable code

Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
test/onehot.jl Show resolved Hide resolved
test/onehot.jl Show resolved Hide resolved
@racinmat
Copy link
Contributor Author

Done, reduce and tests for it are removed.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

@DhairyaLGandhi
Copy link
Member

bors r+

@bors
Copy link
Contributor

bors bot commented May 13, 2021

Build succeeded:

@bors bors bot merged commit 4bd20c7 into FluxML:master May 13, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Too slow hcat of OneHotMatrix.
4 participants