Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FFJORD for batched data #342

Closed
SebastianCallh opened this issue Jul 9, 2020 · 8 comments · Fixed by #415
Closed

FFJORD for batched data #342

SebastianCallh opened this issue Jul 9, 2020 · 8 comments · Fixed by #415

Comments

@SebastianCallh
Copy link
Contributor

I looked into the FFJORD implementation and tests and noticed that it does not seem to work with batched training. I started looking into how to do this, but since it is not done in this implementation perhaps the problem is harder than I realize. Could someone shed some light on the situation? @d-netto perhaps?

@ChrisRackauckas
Copy link
Member

Oh yes, that is worth looking into. We hadn't tried it yet, so it is just an omission.

@d-netto
Copy link
Contributor

d-netto commented Jul 9, 2020

I just discussed this with @ChrisRackauckas and it would only be a matter of making sure that passing the data as a matrix instead of an array of vectors doesn't break it.

@SebastianCallh
Copy link
Contributor Author

I agree it seems like a fairly simple change. Though I think we would have to use DistributionsAD to differentiate through the logpdf function. I also wonder if it makes sense/is possible to use ForwardDiff.jacobian for computing the Jacobian? I see that it is currently done with Zygote.pullback and some indexing, but at the moment I do not see how this generalizes to multiple dimensions.

@ChrisRackauckas
Copy link
Member

pullback generalizes to multiple dimensions, but the indexing scheme might need some work. The algorithm itself should be fine, so it's just a matter of working out the implementation detail there.

@SebastianCallh
Copy link
Contributor Author

Yeah it was the indexing scheme I was referring to. Cool, then I'll look into it

@nirmal-suthar
Copy link
Contributor

Hello I have tried tweaking the FFJORD for supporting Batched data. Forward pass is running well but Zygote is giving some error related to size (which I am not able to figure out). I will post the MWE and exact error in some time.

struct FFJORD{M,P,RE,Distribution,Bool,T,A,K} <: CNFLayer
    model::M
    p::P
    re::RE
    basedist::Distribution
    monte_carlo::Bool
    tspan::T
    args::A
    kwargs::K

    function FFJORD(model,tspan,args...;p = nothing,basedist=nothing,monte_carlo=false,kwargs...)
        _p,re = Flux.destructure(model)
        if p === nothing
            p = _p
        end
        if basedist === nothing
            size_input = size(model[1].W)[2]
            basedist = MvNormal(zeros(size_input), I + zeros(size_input,size_input))
        end
        new{typeof(model),typeof(p),typeof(re),typeof(basedist),
            typeof(monte_carlo),typeof(tspan),typeof(args),typeof(kwargs)}(
            model,p,re,basedist,monte_carlo,tspan,args,kwargs)
    end
end

function jacobian_fn(f, x::AbstractArray)
   y::AbstractArray, back = Zygote.pullback(f, x)
   ȳ(i) = [i == j for j = 1:size(y,1)]
   _J = [back(ȳ(i))[1] for i = 1:size(y,1)]
   cat(map(x->reshape(x,1,size(x)...), _J)...; dims=1)
 end

function ffjord(du,u,p,t,re,monte_carlo,e)
    z = @view u[1:end-1,:]
    m = re(p)
    if monte_carlo
        _, back = Zygote.pullback(m,z)
        eJ = back(e)[1]
        trace_jac = (eJ.*e)[1,:]
    else
        J = jacobian_fn(m, z)
        trace_jac = sum(reshape(J,:,size(J,3))[diagind(size(J,1),size(J,2)),:], dims=1)
    end
    du[1:end-1,:] = m(z)
    du[end,:] = -trace_jac
end

function (n::FFJORD)(x,p=n.p,monte_carlo=n.monte_carlo)
    e = monte_carlo ? randn(Float32,size(x)) : nothing
    ffjord_ = (du,u,p,t)->ffjord(du,u,p,t,n.re,monte_carlo,e)
    prob = ODEProblem{true}(ffjord_,vcat(x,fill(0f0,1,size(x,2))),n.tspan,p)
    pred = solve(prob,n.args...;n.kwargs...)[:,:,end]
    pz = n.basedist
    z = pred[1:end-1,:]
    delta_logp = pred[end,:]
    #maybe use DistributionAD
    logpz = [logpdf(pz, z[:,i]) for i in 1:size(z,2)]
    # logpz = log.(pdf(pz, z))
    logpx = logpz .- delta_logp
    return logpx
end

@SebastianCallh
Copy link
Contributor Author

Glad there is some action on this one! The implementation I'm working on also fails with some internal Zygote errors and I've yet to figure out why.

@avik-pal
Copy link
Member

avik-pal commented Jul 29, 2020

The code snippet posted by @nirmal-suthar works if the sensealg is changed (definitely not a good long term soln though):

function (n::FFJORD)(x,p=n.p,monte_carlo=n.monte_carlo)
    e = monte_carlo ? repeat(randn(Float32,(size(x,1),1)),size(x,2)) : nothing
    initial_cond = cat(x,zeros(Float32,1,size(x, 2)),dims=1)
    ffjord_ = (du,u,p,t)->ffjord(du,u,p,t,n.re,monte_carlo,e)
    prob = ODEProblem{true}(ffjord_,initial_cond,n.tspan,p)
    sense = InterpolatingAdjoint(autojacvec = false)
    pred = solve(prob,n.args...;sensealg=sense,n.kwargs...)[:,:,end]
    pz = n.basedist
    z = pred[1:end-1,:]
    delta_logp = pred[end,:]
    logpz = [logpdf(pz,z[:,i]) for i in 1:size(x)[end]]
    return logpz .- delta_logp
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants