Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 64 additions & 40 deletions src/exp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,59 +120,83 @@ function naivemul!(C, A, B)
mstep = step(Maxis)
# I don't want to deal with axes having non-unit step
if nstep == mstep == 1
naivemul!(C, A, B, Maxis, Naxis)
if sizeof(eltype(C)) > 256
naivemul!(C, A, B, Maxis, Naxis, Val(1), Val(1))
elseif sizeof(eltype(C)) > 128
naivemul!(C, A, B, Maxis, Naxis, Val(2), Val(1))
elseif sizeof(eltype(C)) > 96
naivemul!(C, A, B, Maxis, Naxis, Val(4), Val(1))
elseif sizeof(eltype(C)) > 64
naivemul!(C, A, B, Maxis, Naxis, Val(4), Val(2))
else
naivemul!(C, A, B, Maxis, Naxis, Val(4), Val(3))
end
else
mul!(C,A,B)
end
end
_const(A) = A
_const(A::Array) = Base.Experimental.Const(A)
# Separated to make it easier to test.
function naivemul!(C::AbstractMatrix{T}, A, B, Maxis, Naxis) where {T}
N = last(Naxis)
M = last(Maxis)
Kaxis = axes(B,1)
Base.Experimental.@aliasscope begin
n = first(Naxis)-1
@inbounds begin
while n < N - 1
m = first(Maxis)-1
while m < M - 3
Base.Cartesian.@nexprs 2 j -> Base.Cartesian.@nexprs 4 i -> Cmn_i_j = zero(T)
for k ∈ Kaxis
Base.Cartesian.@nexprs 2 j -> Base.Cartesian.@nexprs 4 i -> Cmn_i_j = muladd(_const(A)[m+i,k],_const(B)[k,n+j],Cmn_i_j)
end
Base.Cartesian.@nexprs 2 j -> Base.Cartesian.@nexprs 4 i -> C[m+i,n+j] = Cmn_i_j
m += 4
end
for mm ∈ 1+m:M
Base.Cartesian.@nexprs 2 j -> Cmn_j = zero(T)
for k ∈ Kaxis
Base.Cartesian.@nexprs 2 j -> Cmn_j = muladd(_const(A)[mm,k],_const(B)[k,n+j],Cmn_j)
end
Base.Cartesian.@nexprs 2 j -> C[mm,n+j] = Cmn_j
end
n += 2
@generated function naivemul!(C::AbstractMatrix{T}, A, B, Maxis, Naxis, ::Val{MU}, ::Val{NU}) where {T,MU,NU}
nrem_body = quote
m = first(Maxis)-1
while m < M - $(MU-1)
Base.Cartesian.@nexprs $MU i -> Cmn_i = zero(T)
for k ∈ Kaxis
Base.Cartesian.@nexprs $MU i -> Cmn_i = muladd(_const(A)[m+i,k],_const(B)[k,nn],Cmn_i)
end
m = first(Maxis)-1
while m < M - 3
Base.Cartesian.@nexprs 4 i -> Cmn_i = zero(T)
for k ∈ Kaxis
Base.Cartesian.@nexprs 4 i -> Cmn_i = muladd(_const(A)[m+i,k],_const(B)[k,N],Cmn_i)
end
Base.Cartesian.@nexprs 4 i -> C[m+i,N] = Cmn_i
m += 4
Base.Cartesian.@nexprs $MU i -> C[m+i,nn] = Cmn_i
m += $MU
end
for mm ∈ 1+m:M
Cmn = zero(T)
for k ∈ Kaxis
Cmn = muladd(_const(A)[mm,k], _const(B)[k,nn], Cmn)
end
for mm ∈ 1+m:M
Cmn = zero(T)
for k ∈ Kaxis
Cmn = muladd(_const(A)[mm,k], _const(B)[k,N], Cmn)
C[mm,nn] = Cmn
end
end
nrem_quote = if NU > 2
:(for nn ∈ 1+n:N; $nrem_body; end)
else
:(let nn = N; $nrem_body; end)
end
quote
N = last(Naxis)
M = last(Maxis)
Kaxis = axes(B,1)
Base.Experimental.@aliasscope begin
n = first(Naxis)-1
@inbounds begin
while n < N - $(NU-1)
m = first(Maxis)-1
while m < M - $(MU-1)
Base.Cartesian.@nexprs $NU j -> Base.Cartesian.@nexprs $MU i -> Cmn_i_j = zero(T)
for k ∈ Kaxis
Base.Cartesian.@nexprs $MU i -> Ak_i = _const(A)[m+i,k]
Base.Cartesian.@nexprs $NU j -> begin
Bk_j = _const(B)[k,n+j]
Base.Cartesian.@nexprs $MU i -> Cmn_i_j = muladd(Ak_i, Bk_j, Cmn_i_j)
end
end
Base.Cartesian.@nexprs $NU j -> Base.Cartesian.@nexprs $MU i -> C[m+i,n+j] = Cmn_i_j
m += $MU
end
for mm ∈ 1+m:M
Base.Cartesian.@nexprs $NU j -> Cmn_j = zero(T)
for k ∈ Kaxis
Base.Cartesian.@nexprs $NU j -> Cmn_j = muladd(_const(A)[mm,k],_const(B)[k,n+j],Cmn_j)
end
Base.Cartesian.@nexprs $NU j -> C[mm,n+j] = Cmn_j
end
n += $NU
end
C[mm,N] = Cmn
$(NU > 1 ? nrem_quote : nothing)
end
end
C
end
C
end

"""
Expand Down
9 changes: 7 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,17 @@ end
A = rand(n,n);
B = rand(n,n);
C = similar(A);
@test ExponentialUtilities.naivemul!(C, A, B, axes(C)...) ≈ A*B
AB = A*B;
@test ExponentialUtilities.naivemul!(C, A, B, axes(C,1), axes(C,2), Val(2), Val(1)) ≈ AB
@test ExponentialUtilities.naivemul!(C, A, B, axes(C,1), axes(C,2), Val(4), Val(2)) ≈ AB
@test ExponentialUtilities.naivemul!(C, A, B, axes(C,1), axes(C,2), Val(4), Val(3)) ≈ AB
if n ≤ 16
Am = MMatrix{n,n}(A)
Bm = MMatrix{n,n}(B)
Cm = MMatrix{n,n}(A)
@test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm)...) ≈ A*B
@test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm,1), axes(Cm,2), Val(2), Val(2)) ≈ AB
@test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm,1), axes(Cm,2), Val(4), Val(2)) ≈ AB
@test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm,1), axes(Cm,2), Val(4), Val(3)) ≈ AB
end
end
A = @SMatrix rand(7,7);
Expand Down