Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AD does not follow through views in structs #598

Closed
clintonTE opened this issue Apr 13, 2020 · 2 comments
Closed

AD does not follow through views in structs #598

clintonTE opened this issue Apr 13, 2020 · 2 comments

Comments

@clintonTE
Copy link

using Revise, Zygote, Flux

struct mwestruct
  A
  sA
end

function mwestruct(K::Int)
  A = rand(K,K)
  sA = view(A, rand(1:K, K^2), :)

  return mwestruct(A,sA)
end

(m::mwestruct)(x) = sum(m.sA * x)

Flux.@functor mwestruct
Flux.trainable(a::mwestruct) = (a.A,)

function zygotemwe(N=5)

  f = mwestruct(N)
  x = rand(N)
  @info "test f: $(f(x))"

  gs = gradient(()->f(x), Flux.params(f))
  @info "check gradient: $(gs[f.A])"
end

zygotemwe()

Output:

[ Info: test f: 33.298669413682056
[ Info: check gradient: nothing

Not sure why this wouldn't work. This came up in the context of a panel of data with two-way indexing- I can provide more context on the use case if desired.

@clintonTE
Copy link
Author

Workaround: create a no-op function and give it the adjoint of view. Not ideal due to the performance hit, but it works.

using Revise, Zygote, Flux

zygote_pls_notice_me(A, sA) = sA
Zygote.@adjoint zygote_pls_notice_me(A, sA) = sA, Zygote.∇getindex(A, parentindices(sA))
Zygote.refresh()

struct mwestruct
  A
  sA
end

function mwestruct(K::Int)
  A = rand(K,K)
  sA = view(A, rand(1:K, K^2), :)

  return mwestruct(A,sA)
end

(m::mwestruct)(x) = sum(zygote_pls_notice_me(m.A, m.sA) * x)

Flux.@functor mwestruct
Flux.trainable(a::mwestruct) = (a.A,)

function zygotemwe(N=3)

  f = mwestruct(N)
  x = rand(N)
  @info "test f: $(f(x))"

  gs = gradient(()->f(x), Flux.params(f))
  @info "check gradient: $(gs[f.A])"
end

zygotemwe()

Output

[ Info: test f: 4.73753755344574
[ Info: check gradient: [2.2529505116724446 3.3550898998799275 4.398882538836311; 0.9011802046689779 1.342035959951971 1.7595530155345243; 0.9011802046689779 1.342035959951971 1.7595530155345243]

@ToucheSir
Copy link
Member

Since the creation of the view happens outside of the gradient context, Zygote (or any AD, really) has no idea A and sA are related. That said, the nice thing about a view is that you can update it in-place. Just change one line:

Flux.trainable(a::mwestruct) = (a.sA,)

And A will be updated accordingly during the optimizer step.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants