-
-
Notifications
You must be signed in to change notification settings - Fork 617
Open
Description
Hi, I am trying to implement a PINN as described here using Flux. Essentially, I am trying to train a neural network that includes the time derivative of it in the loss function (time is one of its inputs). Below is a very minimal example:
using Flux
m = Chain(Dense(3, 10, relu), Dense(10, 10, relu), Dense(10, 1)) # [u0, k, t] -> u(t)
ps = Flux.params(m)
function loss(x, y)
fitloss = Flux.Losses.mse(m(x), y) # typical loss function
derivativeloss = abs2(gradient(a -> m(a)[1], x)[1][3]) # problem source (3rd input is time)
return fitloss + derivativeloss
end
xt = rand(3)
yt = rand(1)
gs = gradient(ps) do
loss(xt, yt)
end # this generates a foreigncall exception
This issue seems to be pervasive, see here and #1338 and #1257 and here (the last one is me on the discourse channel). I have tried all the suggestions in the aforementioned links, but nothing seems to work. Do you have a work around or is this some built in limitation of Flux/Zygote?
Metadata
Metadata
Assignees
Labels
No labels