From 17c4db9bc104f929d92ffb1a72b6d9faae97d89a Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Mon, 12 Apr 2021 14:47:26 -0400 Subject: [PATCH 1/3] Unroll based off of sizeof(eltype(C)) --- src/exp.jl | 104 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 64 insertions(+), 40 deletions(-) diff --git a/src/exp.jl b/src/exp.jl index 0f573c1..0097c60 100644 --- a/src/exp.jl +++ b/src/exp.jl @@ -120,7 +120,17 @@ function naivemul!(C, A, B) mstep = step(Maxis) # I don't want to deal with axes having non-unit step if nstep == mstep == 1 - naivemul!(C, A, B, Maxis, Naxis) + if sizeof(eltype(C)) > 256 + naivemul!(C, A, B, Maxis, Naxis, Val(1), Val(1)) + elseif sizeof(eltype(C)) > 128 + naivemul!(C, A, B, Maxis, Naxis, Val(2), Val(1)) + elseif sizeof(eltype(C)) > 96 + naivemul!(C, A, B, Maxis, Naxis, Val(4), Val(1)) + elseif sizeof(eltype(C)) > 64 + naivemul!(C, A, B, Maxis, Naxis, Val(4), Val(2)) + else + naivemul!(C, A, B, Maxis, Naxis, Val(4), Val(3)) + end else mul!(C,A,B) end @@ -128,51 +138,65 @@ end _const(A) = A _const(A::Array) = Base.Experimental.Const(A) # Separated to make it easier to test. -function naivemul!(C::AbstractMatrix{T}, A, B, Maxis, Naxis) where {T} - N = last(Naxis) - M = last(Maxis) - Kaxis = axes(B,1) - Base.Experimental.@aliasscope begin - n = first(Naxis)-1 - @inbounds begin - while n < N - 1 - m = first(Maxis)-1 - while m < M - 3 - Base.Cartesian.@nexprs 2 j -> Base.Cartesian.@nexprs 4 i -> Cmn_i_j = zero(T) - for k ∈ Kaxis - Base.Cartesian.@nexprs 2 j -> Base.Cartesian.@nexprs 4 i -> Cmn_i_j = muladd(_const(A)[m+i,k],_const(B)[k,n+j],Cmn_i_j) - end - Base.Cartesian.@nexprs 2 j -> Base.Cartesian.@nexprs 4 i -> C[m+i,n+j] = Cmn_i_j - m += 4 - end - for mm ∈ 1+m:M - Base.Cartesian.@nexprs 2 j -> Cmn_j = zero(T) - for k ∈ Kaxis - Base.Cartesian.@nexprs 2 j -> Cmn_j = muladd(_const(A)[mm,k],_const(B)[k,n+j],Cmn_j) - end - Base.Cartesian.@nexprs 2 j -> C[mm,n+j] = Cmn_j - end - n += 2 +@generated function naivemul!(C::AbstractMatrix{T}, A, B, Maxis, Naxis, ::Val{MU}, ::Val{NU}) where {T,MU,NU} + nrem_body = quote + m = first(Maxis)-1 + while m < M - $(MU-1) + Base.Cartesian.@nexprs $MU i -> Cmn_i = zero(T) + for k ∈ Kaxis + Base.Cartesian.@nexprs $MU i -> Cmn_i = muladd(_const(A)[m+i,k],_const(B)[k,nn],Cmn_i) end - m = first(Maxis)-1 - while m < M - 3 - Base.Cartesian.@nexprs 4 i -> Cmn_i = zero(T) - for k ∈ Kaxis - Base.Cartesian.@nexprs 4 i -> Cmn_i = muladd(_const(A)[m+i,k],_const(B)[k,N],Cmn_i) - end - Base.Cartesian.@nexprs 4 i -> C[m+i,N] = Cmn_i - m += 4 + Base.Cartesian.@nexprs $MU i -> C[m+i,nn] = Cmn_i + m += $MU + end + for mm ∈ 1+m:M + Cmn = zero(T) + for k ∈ Kaxis + Cmn = muladd(_const(A)[mm,k], _const(B)[k,nn], Cmn) end - for mm ∈ 1+m:M - Cmn = zero(T) - for k ∈ Kaxis - Cmn = muladd(_const(A)[mm,k], _const(B)[k,N], Cmn) + C[mm,nn] = Cmn + end + end + nrem_quote = if NU > 2 + :(for nn ∈ 1+n:N; $nrem_body; end) + else + :(let nn = N; $nrem_body; end) + end + quote + N = last(Naxis) + M = last(Maxis) + Kaxis = axes(B,1) + Base.Experimental.@aliasscope begin + n = first(Naxis)-1 + @inbounds begin + while n < N - $(NU-1) + m = first(Maxis)-1 + while m < M - $(MU-1) + Base.Cartesian.@nexprs $NU j -> Base.Cartesian.@nexprs $MU i -> Cmn_i_j = zero(T) + for k ∈ Kaxis + Base.Cartesian.@nexprs $MU i -> Ak_i = _const(A)[m+i,k] + Base.Cartesian.@nexprs $NU j -> begin + Bk_j = _const(B)[k,n+j] + Base.Cartesian.@nexprs $MU i -> Cmn_i_j = muladd(Ak_i, Bk_j, Cmn_i_j) + end + end + Base.Cartesian.@nexprs $NU j -> Base.Cartesian.@nexprs $MU i -> C[m+i,n+j] = Cmn_i_j + m += $MU + end + for mm ∈ 1+m:M + Base.Cartesian.@nexprs $NU j -> Cmn_j = zero(T) + for k ∈ Kaxis + Base.Cartesian.@nexprs $NU j -> Cmn_j = muladd(_const(A)[mm,k],_const(B)[k,n+j],Cmn_j) + end + Base.Cartesian.@nexprs $NU j -> C[mm,n+j] = Cmn_j + end + n += $NU end - C[mm,N] = Cmn + $(NU > 1 ? nrem_quote : nothing) end end + C end - C end """ From 7831b5d256da5d7046f7fa127f20a3e5c9b9723e Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Mon, 12 Apr 2021 14:55:36 -0400 Subject: [PATCH 2/3] update naivemul tests --- test/runtests.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index c8cc7e5..af0a9b3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -82,12 +82,17 @@ end A = rand(n,n); B = rand(n,n); C = similar(A); - @test ExponentialUtilities.naivemul!(C, A, B, axes(C)...) ≈ A*B + AB = A*B; + @test ExponentialUtilities.naivemul!(C, A, B, axes(C,1), axes(C,2), Val(2), Val(1)) ≈ AB + @test ExponentialUtilities.naivemul!(C, A, B, axes(C,1), axes(C,2), Val(4), Val(2)) ≈ AB + @test ExponentialUtilities.naivemul!(C, A, B, axes(C,1), axes(C,2), Val(4), Val(3)) ≈ AB if n ≤ 16 Am = MMatrix{n,n}(A) Bm = MMatrix{n,n}(B) Cm = MMatrix{n,n}(A) - @test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm)...) ≈ A*B + @test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm,1), axes(cm,2), Val(2), Val(2)) ≈ AB + @test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm,1), axes(cm,2), Val(4), Val(2)) ≈ AB + @test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm,1), axes(cm,2), Val(4), Val(3)) ≈ AB end end A = @SMatrix rand(7,7); From 46add0790e64f0b284cdcc7be5d842e0887d6164 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Mon, 12 Apr 2021 15:09:53 -0400 Subject: [PATCH 3/3] Fix capitalization --- test/runtests.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index af0a9b3..e537691 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -90,9 +90,9 @@ end Am = MMatrix{n,n}(A) Bm = MMatrix{n,n}(B) Cm = MMatrix{n,n}(A) - @test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm,1), axes(cm,2), Val(2), Val(2)) ≈ AB - @test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm,1), axes(cm,2), Val(4), Val(2)) ≈ AB - @test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm,1), axes(cm,2), Val(4), Val(3)) ≈ AB + @test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm,1), axes(Cm,2), Val(2), Val(2)) ≈ AB + @test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm,1), axes(Cm,2), Val(4), Val(2)) ≈ AB + @test ExponentialUtilities.naivemul!(Cm, Am, Bm, axes(Cm,1), axes(Cm,2), Val(4), Val(3)) ≈ AB end end A = @SMatrix rand(7,7);