Skip to content

Taking nested gradient for implementing Wasserstein GAN with gradient penalty (WGAN-GP) on GPU #1262

@bhatiaabhinav

Description

@bhatiaabhinav

I am trying to implement WGAN-GP using Flux and Zygote. My implementation works fine on CPU but fails on GPU with error LoadError: this intrinsic must be compiled to be called.. I have read that taking nested gradients in Zygote is a mess right now and one needs to use ForwardDiff for that. I have spent hours surfing through relevant issues, but I can't figure how to do that in my case.

Here is my brief implantation, which is almost line by line identical to the pseudocode in the original paper.

using Flux
using Flux: update!, params
using Zygote
using StatsBase

"""
WGAN with gradient penalty. See algorithm 1 in https://proceedings.neurips.cc/paper/2017/file/892c3b1c6dccd52936e27cbd0ff683d6-Paper.pdf. The following code is almost line by line identical.
"""
function train_WGAN_GP(𝐺, 𝐷, 𝐗::Array{Float32, N}, latent_size, num_iters, device_fn; m=32, λ=10f0, ncritic=5, α=0.0001, β₁=0, β₂=0.9) where N
    n = size(𝐗)[end]    # length of dataset
    𝐺, 𝐷 = device_fn(deepcopy(𝐺)), device_fn(deepcopy(𝐷))
    θ, 𝑤 = params(𝐺), params(𝐷)
    adamθ, adam𝑤 = ADAM(α, (β₁, β₂)), ADAM(α, (β₁, β₂)) 

    for iter in 1:num_iters
        for t in 1:ncritic
            𝐱, 𝐳, 𝛜 = 𝐗[repeat([:], N-1)..., rand(1:n, m)], randn(Float32, latent_size..., m), rand(Float32, repeat([1], N-1)..., m) # Sample a minibatch of real data x, latent variables z, random numbers ϵ ∼ U[0, 1].
            𝐱, 𝐳, 𝛜 = device_fn(𝐱), device_fn(𝐳), device_fn(𝛜)
            𝐱̃ = 𝐺(𝐳)
            𝐱̂ = 𝛜 .* 𝐱 + (1f0 .- 𝛜) .* 𝐱̃
            ∇𝑤L = gradient(𝑤) do
                ∇𝐱̂𝐷, = gradient(𝐱̂ ->  sum(𝐷(𝐱̂)), 𝐱̂)
                L = mean(𝐷(𝐱̃)) - mean(𝐷(𝐱)) + λ * mean((sqrt.(sum(∇𝐱̂𝐷.^2, dims=1) .+ 1f-12) .- 1f0).^2)
            end
            update!(adam𝑤, 𝑤, ∇𝑤L)
        end

        𝐳 = device_fn(randn(Float32, latent_size..., m))
        ∇θ𝐷 = gradient(θ) do
            -mean(𝐷(𝐺(𝐳)))
        end
        update!(adamθ, θ, ∇θ𝐷)
    end

    return 𝐺, 𝐷
end

𝐗 = rand(Float32, 50, 10000)  # dummy dataset
z = 16                        # latent size
𝐺 = Chain(Dense(z, 32, leakyrelu), Dense(32, 50))   # Generator
𝐷 = Chain(Dense(50, 32, leakyrelu), Dense(32, 1))   # Critic

𝐺, 𝐷 = train_WGAN_GP(𝐺, 𝐷, 𝐗, (z, ), 1, cpu) # works
𝐺, 𝐷 = train_WGAN_GP(𝐺, 𝐷, 𝐗, (z, ), 1, gpu) # fails

This fails at line ∇𝐱̂𝐷, = gradient(𝐱̂ -> sum(𝐷(𝐱̂)), 𝐱̂) on GPU with error: LoadError: this intrinsic must be compiled to be called.

Can anyone help me?

Here is a code snippet that isolates the problem:

using Flux
using Statistics  # [edit: does not need StatsBase]

function run_isolated_code_on(device_fn)
    D = Chain(Dense(5, 3, leakyrelu), Dense(3, 1)) |> device_fn  # Critic [edit: size was 50 => 32]
    w = Flux.params(D)  # [edit]
    x = rand(Float32, 5, 3) |> device_fn                   # Dummy minibatch
    ∇wL = gradient(w) do
        ∇xD, = gradient(x ->  sum(D(x)), x)                # The problematic line
        L = mean((sqrt.(sum(∇xD.^2, dims=1) .+ 1f-12) .- 1f0).^2)   # gradient penalty
    end
end

run_isolated_code_on(cpu)  # works
run_isolated_code_on(gpu)  # fails

Metadata

Metadata

Assignees

No one assigned

    Labels

    second orderzygote over zygote, or otherwise

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions