Skip to content

Derivative in loss function error #1464

@stelmo

Description

@stelmo

Hi, I am trying to implement a PINN as described here using Flux. Essentially, I am trying to train a neural network that includes the time derivative of it in the loss function (time is one of its inputs). Below is a very minimal example:

using Flux

m = Chain(Dense(3, 10, relu), Dense(10, 10, relu), Dense(10, 1)) # [u0, k, t] -> u(t)
ps = Flux.params(m)

function loss(x, y)
    fitloss = Flux.Losses.mse(m(x), y) # typical loss function

    derivativeloss = abs2(gradient(a -> m(a)[1], x)[1][3]) #  problem source (3rd input is time)
    
    return fitloss + derivativeloss
end

xt = rand(3)
yt = rand(1)

gs = gradient(ps) do
    loss(xt, yt)
end # this generates a foreigncall exception

This issue seems to be pervasive, see here and #1338 and #1257 and here (the last one is me on the discourse channel). I have tried all the suggestions in the aforementioned links, but nothing seems to work. Do you have a work around or is this some built in limitation of Flux/Zygote?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions