-
-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Poor performance relative to PyTorch #886
Comments
As with debugging the NaN issue, it would be helpful to strip away as much of this code as possible (e.g. removing the dense layer, loss function) in ways that still show the perf bug. If we can narrow it down to e.g. just one function it'll probably be very easy to fix. I would definitely like to track this down, although the numerical issue should probably be the priority. |
@MikeInnes re stripping away as much as possible. Yep. These performance and numerical concerns showed up while doing an assignment, so I just wanted to include the circumstances that showed the issues as they came up in the real world. Can look into where the performance difference shows up later. One thing I'll look at is whether |
With #1031, this gives a warning "Debug: Chain(...) has output of eltype Float32 but receives gradient of eltype Float64". Changing |
This is can be reproduced on CPU but not on GPU. using Flux
using Statistics: mean
using MLDataUtils
# dummy data
x = rand(Float32, 113, 100000) |> gpu
y = sum(x.^2, dims = 1) |> gpu
dataset = batchview((x, y), size = 256)
model = Chain(Dense(113, 1000, relu), Dense(1000, 1)) |> gpu
criterion(logits, y) = mean(Flux.logitbinarycrossentropy.(logits, y))
optimizer = Flux.ADAMW(1e-4, (0.9, 0.999), 1e-5)
for (x,y) in dataset
θ = params(model)
loss,back = Flux.Zygote.pullback(θ) do
criterion(model(x), y)
end
println(loss)
grads = back(1f0)
Flux.Optimise.update!(optimizer, θ, grads)
end |
Effect of FluxML/Zygote.jl#1044 on this, on CPU:
All after warming up, quite noisy times. |
In addition to the numerical stability differences between Tracker and Zygote described in #876, Zygote is performing considerably worse than the equivalent pytorch code for that example.
Here is the PyTorch code:
Here is the Flux code:
NaN
s, where the PyTorch code does not. Possibly the same issue as Model optimization fails (NaNs) with Zygote.pullback but works with Tracker.forward #876The text was updated successfully, but these errors were encountered: