-
Notifications
You must be signed in to change notification settings - Fork 112
Closed
Labels
up for grabsLooking for a volunteer to implement this featureLooking for a volunteer to implement this feature
Description
y=[[0 0; 0 0], [0 1; 0 0]]
itp_m = interpolate(y, BSpline(Linear()))
Zygote.gradient((x)->sum(itp_m(x)), 1)
ERROR: DimensionMismatch("matrix A has dimensions (2,2), vector B has length 1")
Stacktrace:
[1] generic_matvecmul!(C::Vector{Matrix{Int64}}, tA::Char, A::FillArrays.Fill{Int64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, B::StaticArrays.SVector{1, Matrix{Int64}}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
@ LinearAlgebra /usr/local/stow/julia-1.7.2/share/julia/stdlib/v1.7/LinearAlgebra/src/matmul.jl:713
[2] mul!
@ /usr/local/stow/julia-1.7.2/share/julia/stdlib/v1.7/LinearAlgebra/src/matmul.jl:81 [inlined]
[3] mul!
@ /usr/local/stow/julia-1.7.2/share/julia/stdlib/v1.7/LinearAlgebra/src/matmul.jl:275 [inlined]
[4] *
@ /usr/local/stow/julia-1.7.2/share/julia/stdlib/v1.7/LinearAlgebra/src/matmul.jl:51 [inlined]
[5] interpolate_pullback
@ ~/.julia/dev/Interpolations/src/chainrules/chainrules.jl:13 [inlined]
[6] ZBack
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:204 [inlined]
[7] Pullback
@ ./REPL[33]:1 [inlined]
[8] (::typeof(∂(#23)))(Δ::Int64)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[9] (::Zygote.var"#56#57"{typeof(∂(#23))})(Δ::Int64)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41
[10] gradient(f::Function, args::Int64)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76
[11] top-level scope
@ REPL[33]:1
[12] top-level scope
@ ~/.julia/packages/CUDA/Uurn4/src/initialization.jl:52
I tried this fix
(nope, Iterators.flatten(SVector((Δy,)) .* Interpolations.gradient(itp, x...))...)
Which fixes the shape problems, but gives a gradient of 0 for this case.
Metadata
Metadata
Assignees
Labels
up for grabsLooking for a volunteer to implement this featureLooking for a volunteer to implement this feature