You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
using Revise, Zygote, Flux
struct mwestruct
A
sA
endfunctionmwestruct(K::Int)
A =rand(K,K)
sA =view(A, rand(1:K, K^2), :)
returnmwestruct(A,sA)
end
(m::mwestruct)(x) =sum(m.sA * x)
Flux.@functor mwestruct
Flux.trainable(a::mwestruct) = (a.A,)
functionzygotemwe(N=5)
f =mwestruct(N)
x =rand(N)
@info"test f: $(f(x))"
gs =gradient(()->f(x), Flux.params(f))
@info"check gradient: $(gs[f.A])"endzygotemwe()
Output:
[ Info: test f: 33.298669413682056
[ Info: check gradient: nothing
Not sure why this wouldn't work. This came up in the context of a panel of data with two-way indexing- I can provide more context on the use case if desired.
The text was updated successfully, but these errors were encountered:
Since the creation of the view happens outside of the gradient context, Zygote (or any AD, really) has no idea A and sA are related. That said, the nice thing about a view is that you can update it in-place. Just change one line:
Flux.trainable(a::mwestruct) = (a.sA,)
And A will be updated accordingly during the optimizer step.
Output:
Not sure why this wouldn't work. This came up in the context of a panel of data with two-way indexing- I can provide more context on the use case if desired.
The text was updated successfully, but these errors were encountered: