In [10]:
using Enzyme

function mymul!(R, A, B)
    @assert axes(A,2) == axes(B,1)
    @inbounds @simd for i in eachindex(R)
        R[i] = 0
    end
    @inbounds for j in axes(B, 2), i in axes(A, 1)
        @inbounds @simd for k in axes(A,2)
            R[i,j] += A[i,k] * B[k,j]
        end
    end
    nothing
end


A = rand(5, 3)
B = rand(3, 7)

R = zeros(size(A,1), size(B,2))
∂z_∂R = rand(size(R)...)  # Some gradient/tangent passed to us

∂z_∂A = zero(A)
∂z_∂B = zero(B)

Enzyme.autodiff(mymul!, Const, Duplicated(R, ∂z_∂R), Duplicated(A, ∂z_∂A), Duplicated(B, ∂z_∂B))


()

In [11]:
R ≈ A * B            &&
∂z_∂A ≈ ∂z_∂R * B'   &&  # equivalent to Zygote.pullback(*, A, B)[2](∂z_∂R)[1]
∂z_∂B ≈ A' * ∂z_∂R

∂z_∂R

5×7 Matrix{Float64}:
 0.123779   0.329524  0.460547  0.367101  0.275423  0.514569  0.731231
 0.95949    0.57559   0.209324  0.519637  0.276184  0.743882  0.626408
 0.12998    0.747305  0.400288  0.50996   0.25594   0.562621  0.103458
 0.0886988  0.525744  0.315356  0.1598    0.73415   0.175436  0.4385
 0.963093   0.476419  0.271699  0.25318   0.577619  0.144234  0.639175

In [12]:
∂z_∂A

5×3 Matrix{Float64}:
 1.64251  1.17322  1.33067
 2.49763  1.96165  1.82145
 1.65402  1.1845   0.961165
 1.51159  1.07727  0.930032
 2.3609   1.63768  1.40276

In [13]:
∂z_∂B

3×7 Matrix{Float64}:
 0.773711  1.27368   0.89502   0.976215  0.884031  1.27082   1.27129
 0.801969  1.4395    0.905708  0.938414  1.07071   1.09065   1.09024
 0.541953  0.800244  0.424186  0.433145  0.808946  0.556889  0.714365

In [14]:
R

5×7 Matrix{Float64}:
 1.24336   0.719504  0.97784   0.741751  0.851008  0.58853   0.581777
 0.967855  0.518017  0.71092   0.536223  0.662391  0.700023  0.693226
 1.20892   0.638486  0.822161  0.523279  0.886068  0.853798  0.534743
 1.31441   0.634505  0.841057  0.556395  0.943817  1.29484   1.04883
 0.253377  0.114004  0.129103  0.041746  0.206314  0.268417  0.0918177

In [15]:
∂z_∂R

5×7 Matrix{Float64}:
 0.123779   0.329524  0.460547  0.367101  0.275423  0.514569  0.731231
 0.95949    0.57559   0.209324  0.519637  0.276184  0.743882  0.626408
 0.12998    0.747305  0.400288  0.50996   0.25594   0.562621  0.103458
 0.0886988  0.525744  0.315356  0.1598    0.73415   0.175436  0.4385
 0.963093   0.476419  0.271699  0.25318   0.577619  0.144234  0.639175

In [18]:
using Zygote
Zygote.pullback(*, A, B)[2](∂z_∂R)[2]

3×7 Matrix{Float64}:
 0.773711  1.27368   0.89502   0.976215  0.884031  1.27082   1.27129
 0.801969  1.4395    0.905708  0.938414  1.07071   1.09065   1.09024
 0.541953  0.800244  0.424186  0.433145  0.808946  0.556889  0.714365