You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Here I present the correct (but poor) implementation of BP for SVD, this implementation changes the original svd interfaces a bit, hoping someone can help improve it.
using LinearAlgebra
using Flux
using Flux.Tracker:@grad, data, track, TrackedTuple
import Flux.Tracker: _forward
import LinearAlgebra: svd
"""stablized back propagation function for svd"""functionsvd_back(U, S, V, dU, dS, dV)
NS =length(S)
S2 = S.^2
Sinv =1./ S
F = S2'.- S2
@. F = F/(F^2+1e-12)
UdU = U'*dU
VdV = V'*dV
Su = (F.*(UdU-UdU'))*Diagonal(S)
Sv =Diagonal(S) * (F.*(VdV-VdV'))
U * (Su + Sv +Diagonal(dS)) * V'+
(I - U*U') * dU*Diagonal(Sinv) * V'+
U*Diagonal(Sinv) * dV'* (I - V*V')
endsvd(a::TrackedArray) =track(svd, a)
# I doubt the macro `@grad` interface is less intuitive than `_forward`function_forward(::typeof(svd), a)
U, S, V =svd(data(a)) # making `svd` return value SVD, making Julian's life shorter.# returning a list won't work, one will get 0 gradient# [U|>param, S|>param, V|>param], -> (svd_back(U, S, V, dU, dS, dV),)
(U, S, Matrix(V)), Δ -> (svd_back(U, S, V, Δ...),)
end# This is a use case
M, N =4, 6
K =min(M, N)
A =param(randn(M, N))
res =svd(A)
# implement `Base.iterate(res::TrackedTuple) = ?` can make it prettier
U, S, V = res[1], res[2], res[3]
dU, dS, dV =randn(M, K), randn(K), randn(N, K)
Tracker.back!(res, (dU, dS, dV))
Tracker.grad(A)
Why we use Matrix(V) here?
We see this line in file src/tracker/scalar.jl is called
forgot to label this with "move to tracker". I labeled a few Tracker related issues, I'm not authorized to move them to Tracker.jl so I just closed them. Ideally, they should all be moved there and reopened (if we still care about tracker, I honestly don't)
Here I present the correct (but poor) implementation of BP for SVD, this implementation changes the original
svd
interfaces a bit, hoping someone can help improve it.Why we use
Matrix(V)
here?We see this line in file
src/tracker/scalar.jl
is calledOne should notice function
zero
can changetype
sometimes!Here, SVD returns V as Adjoint,
zero(Adjoint)
will getArray
!Gocha!
Some aspects can be improved
_forward
is nessesary, so that readable error message can be throwed.@grad
is an arguably useful interfacesvd
, I didn't see many benefits of such design.zero
andone
should never change type, here, it should be considered as a bug.The text was updated successfully, but these errors were encountered: