diff --git a/src/block_sizes.jl b/src/block_sizes.jl index a4a9bfc..4d7756f 100644 --- a/src/block_sizes.jl +++ b/src/block_sizes.jl @@ -1,5 +1,5 @@ -matmul_params(::Val{T}) where {T <: Base.HWReal} = LoopVectorization.matmul_params() +matmul_params(::Val{T}) where {T} = LoopVectorization.matmul_params() function block_sizes(::Val{T}, _α, _β, R₁, R₂) where {T} W = pick_vector_width(T) diff --git a/src/complex_matmul.jl b/src/complex_matmul.jl index 42c7432..6ca0f7f 100644 --- a/src/complex_matmul.jl +++ b/src/complex_matmul.jl @@ -1,129 +1,133 @@ real_rep(a::AbstractArray{Complex{T}, N}) where {T, N} = reinterpret(reshape, T, a) #PtrArray(Ptr{T}(pointer(a)), (StaticInt(2), size(a)...)) -@inline function _matmul!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::AbstractVecOrMat{Complex{V}}, - α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V} - C, A, B = map(real_rep, (_C, _A, _B)) - - η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1)) - θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1)) - (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) - ηθ = η*θ - - @tturbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), 2) - Cmn_re = zero(T) - Cmn_im = zero(T) - for k ∈ indices((A, B), (3, 2)) - Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n] - Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n] +for AT in [:AbstractVector, :AbstractMatrix] # to avoid ambiguity error + @eval begin + function _matmul!(_C::$AT{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::$AT{Complex{V}}, + α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V} + C, A, B = map(real_rep, (_C, _A, _B)) + + η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1)) + θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1)) + (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) + ηθ = η*θ + + @tturbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), 2) + Cmn_re = zero(T) + Cmn_im = zero(T) + for k ∈ indices((A, B), (3, 2)) + Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n] + Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n] + end + C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) + C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) + end + _C end - C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) - C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) - end - _C -end - -@inline function _matmul!(_C::AbstractVecOrMat{Complex{T}}, A::AbstractMatrix{U}, _B::AbstractVecOrMat{Complex{V}}, - α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V} - C, B = map(real_rep, (_C, _B)) - - θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1)) - (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) - - @tturbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), (2, 1)) - Cmn_re = zero(T) - Cmn_im = zero(T) - for k ∈ indices((A, B), (2, 2)) - Cmn_re += A[m, k] * B[1, k, n] - Cmn_im += θ * A[m, k] * B[2, k, n] + + @inline function _matmul!(_C::$AT{Complex{T}}, A::AbstractMatrix{U}, _B::$AT{Complex{V}}, + α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V} + C, B = map(real_rep, (_C, _B)) + + θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1)) + (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) + + @tturbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), (2, 1)) + Cmn_re = zero(T) + Cmn_im = zero(T) + for k ∈ indices((A, B), (2, 2)) + Cmn_re += A[m, k] * B[1, k, n] + Cmn_im += θ * A[m, k] * B[2, k, n] + end + C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) + C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) + end + _C end - C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) - C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) - end - _C -end - -@inline function _matmul!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::AbstractVecOrMat{V}, - α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V} - C, A = map(real_rep, (_C, _A)) - - η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1)) - (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) - - @tturbo for n ∈ indices((C, B), (3, 2)), m ∈ indices((C, A), 2) - Cmn_re = zero(T) - Cmn_im = zero(T) - for k ∈ indices((A, B), (3, 1)) - Cmn_re += A[1, m, k] * B[k, n] - Cmn_im += η * A[2, m, k] * B[k, n] + + @inline function _matmul!(_C::$AT{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::$AT{V}, + α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V} + C, A = map(real_rep, (_C, _A)) + + η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1)) + (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) + + @tturbo for n ∈ indices((C, B), (3, 2)), m ∈ indices((C, A), 2) + Cmn_re = zero(T) + Cmn_im = zero(T) + for k ∈ indices((A, B), (3, 1)) + Cmn_re += A[1, m, k] * B[k, n] + Cmn_im += η * A[2, m, k] * B[k, n] + end + C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) + C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) + end + _C end - C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) - C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) - end - _C -end -@inline function _matmul_serial!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::AbstractVecOrMat{Complex{V}}, - α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V} - C, A, B = map(real_rep, (_C, _A, _B)) + @inline function _matmul_serial!(_C::$AT{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::$AT{Complex{V}}, + α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V} + C, A, B = map(real_rep, (_C, _A, _B)) - η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1)) - θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1)) - (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) - ηθ = η*θ - @turbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), 2) - Cmn_re = zero(T) - Cmn_im = zero(T) - for k ∈ indices((A, B), (3, 2)) - Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n] - Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n] + η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1)) + θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1)) + (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) + ηθ = η*θ + @turbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), 2) + Cmn_re = zero(T) + Cmn_im = zero(T) + for k ∈ indices((A, B), (3, 2)) + Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n] + Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n] + end + C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) + C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) + end + _C end - C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) - C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) - end - _C -end - -@inline function _matmul_serial!(_C::AbstractVecOrMat{Complex{T}}, A::AbstractMatrix{U}, _B::AbstractVecOrMat{Complex{V}}, - α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V} - C, B = map(real_rep, (_C, _B)) - - θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1)) - (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) - - @turbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), (2, 1)) - Cmn_re = zero(T) - Cmn_im = zero(T) - for k ∈ indices((A, B), (2, 2)) - Cmn_re += A[m, k] * B[1, k, n] - Cmn_im += θ * A[m, k] * B[2, k, n] + + @inline function _matmul_serial!(_C::$AT{Complex{T}}, A::AbstractMatrix{U}, _B::$AT{Complex{V}}, + α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V} + C, B = map(real_rep, (_C, _B)) + + θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1)) + (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) + + @turbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), (2, 1)) + Cmn_re = zero(T) + Cmn_im = zero(T) + for k ∈ indices((A, B), (2, 2)) + Cmn_re += A[m, k] * B[1, k, n] + Cmn_im += θ * A[m, k] * B[2, k, n] + end + C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) + C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) + end + _C end - C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) - C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) - end - _C -end - -@inline function _matmul_serial!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::AbstractVecOrMat{V}, - α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V} - C, A = map(real_rep, (_C, _A)) - - η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1)) - (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) - - @turbo for n ∈ indices((C, B), (3, 2)), m ∈ indices((C, A), 2) - Cmn_re = zero(T) - Cmn_im = zero(T) - for k ∈ indices((A, B), (3, 1)) - Cmn_re += A[1, m, k] * B[k, n] - Cmn_im += η * A[2, m, k] * B[k, n] + + @inline function _matmul_serial!(_C::$AT{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::$AT{V}, + α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V} + C, A = map(real_rep, (_C, _A)) + + η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1)) + (+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -)) + + @turbo for n ∈ indices((C, B), (3, 2)), m ∈ indices((C, A), 2) + Cmn_re = zero(T) + Cmn_im = zero(T) + for k ∈ indices((A, B), (3, 1)) + Cmn_re += A[1, m, k] * B[k, n] + Cmn_im += η * A[2, m, k] * B[k, n] + end + C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) + C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) + end + _C end - C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n]) - C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n]) end - _C -end +end \ No newline at end of file diff --git a/src/macrokernels.jl b/src/macrokernels.jl index aa3b48a..1b7568f 100644 --- a/src/macrokernels.jl +++ b/src/macrokernels.jl @@ -1,7 +1,7 @@ -@inline incrementp(A::AbstractStridedPointer{<:Base.HWReal,3}, a::Ptr) = VectorizationBase.increment_ptr(A, a, (Zero(), Zero(), One())) -@inline increment2(B::AbstractStridedPointer{<:Base.HWReal,2}, b::Ptr, ::StaticInt{nᵣ}) where {nᵣ} = VectorizationBase.increment_ptr(B, b, (Zero(), StaticInt{nᵣ}())) -@inline increment1(C::AbstractStridedPointer{<:Base.HWReal,2}, c::Ptr, ::StaticInt{mᵣW}) where {mᵣW} = VectorizationBase.increment_ptr(C, c, (StaticInt{mᵣW}(), Zero())) +@inline incrementp(A::AbstractStridedPointer{T,3} where T, a::Ptr) = VectorizationBase.increment_ptr(A, a, (Zero(), Zero(), One())) +@inline increment2(B::AbstractStridedPointer{T,2} where T, b::Ptr, ::StaticInt{nᵣ}) where {nᵣ} = VectorizationBase.increment_ptr(B, b, (Zero(), StaticInt{nᵣ}())) +@inline increment1(C::AbstractStridedPointer{T,2} where T, c::Ptr, ::StaticInt{mᵣW}) where {mᵣW} = VectorizationBase.increment_ptr(C, c, (StaticInt{mᵣW}(), Zero())) macro kernel(pack::Bool, ex::Expr) ex.head === :for || throw(ArgumentError("Must be a matmul for loop.")) mincrements = Expr[:(c = increment1(C, c, mᵣW)), :(ãₚ = incrementp(Ãₚ, ãₚ)), :(m = vsub_nsw(m, mᵣW))] diff --git a/src/matmul.jl b/src/matmul.jl index 7d4d746..f413dcc 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -165,7 +165,7 @@ Otherwise, based on the array's size, whether they are transposed, and whether t """ @inline function _matmul_serial!( C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN -) where {T<:Real} +) where {T} M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN if M * N == 0 return @@ -263,7 +263,7 @@ end end # passing MKN directly would let osmeone skip the size check. -@inline function _matmul!(C::AbstractMatrix{T}, A, B, α, β, nthread, MKN) where {T<:Real} +@inline function _matmul!(C::AbstractMatrix{T}, A, B, α, β, nthread, MKN) where {T} M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN if M * N == 0 return @@ -504,7 +504,7 @@ function sync_mul!( nothing end -function _matmul!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α, β, MKN, contig_axis) where {T<:Real} +function _matmul!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α, β, MKN, contig_axis) where {T} @tturbo for m ∈ indices((A,y),1) yₘ = zero(T) for n ∈ indices((A,x),(2,1)) @@ -514,7 +514,7 @@ function _matmul!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α end return y end -function _matmul_serial!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α, β, MKN) where {T<:Real} +function _matmul_serial!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α, β, MKN) where {T} @turbo for m ∈ indices((A,y),1) yₘ = zero(T) for n ∈ indices((A,x),(2,1))