-
-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FFJORD for batched data #342
Comments
Oh yes, that is worth looking into. We hadn't tried it yet, so it is just an omission. |
I just discussed this with @ChrisRackauckas and it would only be a matter of making sure that passing the data as a matrix instead of an array of vectors doesn't break it. |
I agree it seems like a fairly simple change. Though I think we would have to use DistributionsAD to differentiate through the |
|
Yeah it was the indexing scheme I was referring to. Cool, then I'll look into it |
Hello I have tried tweaking the FFJORD for supporting Batched data. Forward pass is running well but Zygote is giving some error related to size (which I am not able to figure out). I will post the MWE and exact error in some time. struct FFJORD{M,P,RE,Distribution,Bool,T,A,K} <: CNFLayer
model::M
p::P
re::RE
basedist::Distribution
monte_carlo::Bool
tspan::T
args::A
kwargs::K
function FFJORD(model,tspan,args...;p = nothing,basedist=nothing,monte_carlo=false,kwargs...)
_p,re = Flux.destructure(model)
if p === nothing
p = _p
end
if basedist === nothing
size_input = size(model[1].W)[2]
basedist = MvNormal(zeros(size_input), I + zeros(size_input,size_input))
end
new{typeof(model),typeof(p),typeof(re),typeof(basedist),
typeof(monte_carlo),typeof(tspan),typeof(args),typeof(kwargs)}(
model,p,re,basedist,monte_carlo,tspan,args,kwargs)
end
end
function jacobian_fn(f, x::AbstractArray)
y::AbstractArray, back = Zygote.pullback(f, x)
ȳ(i) = [i == j for j = 1:size(y,1)]
_J = [back(ȳ(i))[1] for i = 1:size(y,1)]
cat(map(x->reshape(x,1,size(x)...), _J)...; dims=1)
end
function ffjord(du,u,p,t,re,monte_carlo,e)
z = @view u[1:end-1,:]
m = re(p)
if monte_carlo
_, back = Zygote.pullback(m,z)
eJ = back(e)[1]
trace_jac = (eJ.*e)[1,:]
else
J = jacobian_fn(m, z)
trace_jac = sum(reshape(J,:,size(J,3))[diagind(size(J,1),size(J,2)),:], dims=1)
end
du[1:end-1,:] = m(z)
du[end,:] = -trace_jac
end
function (n::FFJORD)(x,p=n.p,monte_carlo=n.monte_carlo)
e = monte_carlo ? randn(Float32,size(x)) : nothing
ffjord_ = (du,u,p,t)->ffjord(du,u,p,t,n.re,monte_carlo,e)
prob = ODEProblem{true}(ffjord_,vcat(x,fill(0f0,1,size(x,2))),n.tspan,p)
pred = solve(prob,n.args...;n.kwargs...)[:,:,end]
pz = n.basedist
z = pred[1:end-1,:]
delta_logp = pred[end,:]
#maybe use DistributionAD
logpz = [logpdf(pz, z[:,i]) for i in 1:size(z,2)]
# logpz = log.(pdf(pz, z))
logpx = logpz .- delta_logp
return logpx
end
|
Glad there is some action on this one! The implementation I'm working on also fails with some internal Zygote errors and I've yet to figure out why. |
The code snippet posted by @nirmal-suthar works if the sensealg is changed (definitely not a good long term soln though): function (n::FFJORD)(x,p=n.p,monte_carlo=n.monte_carlo)
e = monte_carlo ? repeat(randn(Float32,(size(x,1),1)),size(x,2)) : nothing
initial_cond = cat(x,zeros(Float32,1,size(x, 2)),dims=1)
ffjord_ = (du,u,p,t)->ffjord(du,u,p,t,n.re,monte_carlo,e)
prob = ODEProblem{true}(ffjord_,initial_cond,n.tspan,p)
sense = InterpolatingAdjoint(autojacvec = false)
pred = solve(prob,n.args...;sensealg=sense,n.kwargs...)[:,:,end]
pz = n.basedist
z = pred[1:end-1,:]
delta_logp = pred[end,:]
logpz = [logpdf(pz,z[:,i]) for i in 1:size(x)[end]]
return logpz .- delta_logp
end |
I looked into the FFJORD implementation and tests and noticed that it does not seem to work with batched training. I started looking into how to do this, but since it is not done in this implementation perhaps the problem is harder than I realize. Could someone shed some light on the situation? @d-netto perhaps?
The text was updated successfully, but these errors were encountered: