Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very slow loss aggregation #172

Closed
MilesCranmer opened this issue Aug 24, 2023 · 5 comments · Fixed by #173
Closed

Very slow loss aggregation #172

MilesCranmer opened this issue Aug 24, 2023 · 5 comments · Fixed by #173

Comments

@MilesCranmer
Copy link
Contributor

The mean and sum implementations in this package are extremely slow as they rely on generators. Could this be changed to a faster implementation? For example, here is how the current implementation would compute sum-squared-error:

square(x) = x ^ 2

# example of generator-based approach:
function sse(outputs, targets)
    sum(square(ŷ - y) for (ŷ, y) in zip(outputs, targets))
end

(i.e., like this code)

which gives us the following time:

julia> @btime sse(outputs, targets) setup=(outputs=randn(100_000); targets=randn(100_000))
  92.833 μs (0 allocations: 0 bytes)

but if we change this to an approach using sum(<function>, <indices>), it's much faster:

function sse2(outputs, targets)
    sum(i -> square(outputs[i] - targets[i]), eachindex(outputs, targets))
end
julia> @btime sse2(outputs, targets) setup=(outputs=randn(100_000); targets=randn(100_000))
  26.708 μs (0 allocations: 0 bytes)

which is a 3.5x speedup.

Could this be implemented as the default loss calculation? I thought this was the method that used to be used. Perhaps it got changed in the recent refactoring?

@juliohm
Copy link
Member

juliohm commented Aug 24, 2023 via email

@MilesCranmer
Copy link
Contributor Author

I don't remember using this feature anywhere else in downstream packages.

Sorry I don't think I was clear. This is the main sum function in LossFunctions.jl which is the main interface to the package. It seems to be using a generator (e.g., sum(f(x) for x in X)) for summing, which is very slow:

function sum(loss::SupervisedLoss, outputs, targets)
sum(loss(ŷ, y) for (ŷ, y) in zip(outputs, targets))
end

If you can submit a PR dropping support for generators, we can review and merge it.

Sure!

@MilesCranmer
Copy link
Contributor Author

Implemented in #173

@juliohm
Copy link
Member

juliohm commented Aug 24, 2023 via email

@MilesCranmer
Copy link
Contributor Author

Got it, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants