-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
Labels
second orderzygote over zygote, or otherwisezygote over zygote, or otherwise
Description
I am trying to implement WGAN-GP using Flux and Zygote. My implementation works fine on CPU but fails on GPU with error LoadError: this intrinsic must be compiled to be called.. I have read that taking nested gradients in Zygote is a mess right now and one needs to use ForwardDiff for that. I have spent hours surfing through relevant issues, but I can't figure how to do that in my case.
Here is my brief implantation, which is almost line by line identical to the pseudocode in the original paper.
using Flux
using Flux: update!, params
using Zygote
using StatsBase
"""
WGAN with gradient penalty. See algorithm 1 in https://proceedings.neurips.cc/paper/2017/file/892c3b1c6dccd52936e27cbd0ff683d6-Paper.pdf. The following code is almost line by line identical.
"""
function train_WGAN_GP(𝐺, 𝐷, 𝐗::Array{Float32, N}, latent_size, num_iters, device_fn; m=32, λ=10f0, ncritic=5, α=0.0001, β₁=0, β₂=0.9) where N
n = size(𝐗)[end] # length of dataset
𝐺, 𝐷 = device_fn(deepcopy(𝐺)), device_fn(deepcopy(𝐷))
θ, 𝑤 = params(𝐺), params(𝐷)
adamθ, adam𝑤 = ADAM(α, (β₁, β₂)), ADAM(α, (β₁, β₂))
for iter in 1:num_iters
for t in 1:ncritic
𝐱, 𝐳, 𝛜 = 𝐗[repeat([:], N-1)..., rand(1:n, m)], randn(Float32, latent_size..., m), rand(Float32, repeat([1], N-1)..., m) # Sample a minibatch of real data x, latent variables z, random numbers ϵ ∼ U[0, 1].
𝐱, 𝐳, 𝛜 = device_fn(𝐱), device_fn(𝐳), device_fn(𝛜)
𝐱̃ = 𝐺(𝐳)
𝐱̂ = 𝛜 .* 𝐱 + (1f0 .- 𝛜) .* 𝐱̃
∇𝑤L = gradient(𝑤) do
∇𝐱̂𝐷, = gradient(𝐱̂ -> sum(𝐷(𝐱̂)), 𝐱̂)
L = mean(𝐷(𝐱̃)) - mean(𝐷(𝐱)) + λ * mean((sqrt.(sum(∇𝐱̂𝐷.^2, dims=1) .+ 1f-12) .- 1f0).^2)
end
update!(adam𝑤, 𝑤, ∇𝑤L)
end
𝐳 = device_fn(randn(Float32, latent_size..., m))
∇θ𝐷 = gradient(θ) do
-mean(𝐷(𝐺(𝐳)))
end
update!(adamθ, θ, ∇θ𝐷)
end
return 𝐺, 𝐷
end
𝐗 = rand(Float32, 50, 10000) # dummy dataset
z = 16 # latent size
𝐺 = Chain(Dense(z, 32, leakyrelu), Dense(32, 50)) # Generator
𝐷 = Chain(Dense(50, 32, leakyrelu), Dense(32, 1)) # Critic
𝐺, 𝐷 = train_WGAN_GP(𝐺, 𝐷, 𝐗, (z, ), 1, cpu) # works
𝐺, 𝐷 = train_WGAN_GP(𝐺, 𝐷, 𝐗, (z, ), 1, gpu) # failsThis fails at line ∇𝐱̂𝐷, = gradient(𝐱̂ -> sum(𝐷(𝐱̂)), 𝐱̂) on GPU with error: LoadError: this intrinsic must be compiled to be called.
Can anyone help me?
Here is a code snippet that isolates the problem:
using Flux
using Statistics # [edit: does not need StatsBase]
function run_isolated_code_on(device_fn)
D = Chain(Dense(5, 3, leakyrelu), Dense(3, 1)) |> device_fn # Critic [edit: size was 50 => 32]
w = Flux.params(D) # [edit]
x = rand(Float32, 5, 3) |> device_fn # Dummy minibatch
∇wL = gradient(w) do
∇xD, = gradient(x -> sum(D(x)), x) # The problematic line
L = mean((sqrt.(sum(∇xD.^2, dims=1) .+ 1f-12) .- 1f0).^2) # gradient penalty
end
end
run_isolated_code_on(cpu) # works
run_isolated_code_on(gpu) # failsMetadata
Metadata
Assignees
Labels
second orderzygote over zygote, or otherwisezygote over zygote, or otherwise