Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/block_sizes.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

matmul_params(::Val{T}) where {T <: Base.HWReal} = LoopVectorization.matmul_params()
matmul_params(::Val{T}) where {T} = LoopVectorization.matmul_params()

function block_sizes(::Val{T}, _α, _β, R₁, R₂) where {T}
W = pick_vector_width(T)
Expand Down
230 changes: 117 additions & 113 deletions src/complex_matmul.jl
Original file line number Diff line number Diff line change
@@ -1,129 +1,133 @@
real_rep(a::AbstractArray{Complex{T}, N}) where {T, N} = reinterpret(reshape, T, a)
#PtrArray(Ptr{T}(pointer(a)), (StaticInt(2), size(a)...))

@inline function _matmul!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::AbstractVecOrMat{Complex{V}},
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
C, A, B = map(real_rep, (_C, _A, _B))

η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
ηθ = η*θ

@tturbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), 2)
Cmn_re = zero(T)
Cmn_im = zero(T)
for k ∈ indices((A, B), (3, 2))
Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n]
Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n]
for AT in [:AbstractVector, :AbstractMatrix] # to avoid ambiguity error
@eval begin
function _matmul!(_C::$AT{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::$AT{Complex{V}},
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
C, A, B = map(real_rep, (_C, _A, _B))

η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
ηθ = η*θ

@tturbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), 2)
Cmn_re = zero(T)
Cmn_im = zero(T)
for k ∈ indices((A, B), (3, 2))
Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n]
Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n]
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
end
_C
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
end
_C
end

@inline function _matmul!(_C::AbstractVecOrMat{Complex{T}}, A::AbstractMatrix{U}, _B::AbstractVecOrMat{Complex{V}},
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
C, B = map(real_rep, (_C, _B))

θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))

@tturbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), (2, 1))
Cmn_re = zero(T)
Cmn_im = zero(T)
for k ∈ indices((A, B), (2, 2))
Cmn_re += A[m, k] * B[1, k, n]
Cmn_im += θ * A[m, k] * B[2, k, n]

@inline function _matmul!(_C::$AT{Complex{T}}, A::AbstractMatrix{U}, _B::$AT{Complex{V}},
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
C, B = map(real_rep, (_C, _B))
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))

@tturbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), (2, 1))
Cmn_re = zero(T)
Cmn_im = zero(T)
for k ∈ indices((A, B), (2, 2))
Cmn_re += A[m, k] * B[1, k, n]
Cmn_im += θ * A[m, k] * B[2, k, n]
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
end
_C
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
end
_C
end

@inline function _matmul!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::AbstractVecOrMat{V},
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
C, A = map(real_rep, (_C, _A))

η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))

@tturbo for n ∈ indices((C, B), (3, 2)), m ∈ indices((C, A), 2)
Cmn_re = zero(T)
Cmn_im = zero(T)
for k ∈ indices((A, B), (3, 1))
Cmn_re += A[1, m, k] * B[k, n]
Cmn_im += η * A[2, m, k] * B[k, n]

@inline function _matmul!(_C::$AT{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::$AT{V},
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
C, A = map(real_rep, (_C, _A))

η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))

@tturbo for n ∈ indices((C, B), (3, 2)), m ∈ indices((C, A), 2)
Cmn_re = zero(T)
Cmn_im = zero(T)
for k ∈ indices((A, B), (3, 1))
Cmn_re += A[1, m, k] * B[k, n]
Cmn_im += η * A[2, m, k] * B[k, n]
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
end
_C
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
end
_C
end





@inline function _matmul_serial!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::AbstractVecOrMat{Complex{V}},
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
C, A, B = map(real_rep, (_C, _A, _B))
@inline function _matmul_serial!(_C::$AT{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::$AT{Complex{V}},
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
C, A, B = map(real_rep, (_C, _A, _B))

η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
ηθ = η*θ
@turbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), 2)
Cmn_re = zero(T)
Cmn_im = zero(T)
for k ∈ indices((A, B), (3, 2))
Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n]
Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n]
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
ηθ = η*θ
@turbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), 2)
Cmn_re = zero(T)
Cmn_im = zero(T)
for k ∈ indices((A, B), (3, 2))
Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n]
Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n]
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
end
_C
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
end
_C
end

@inline function _matmul_serial!(_C::AbstractVecOrMat{Complex{T}}, A::AbstractMatrix{U}, _B::AbstractVecOrMat{Complex{V}},
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
C, B = map(real_rep, (_C, _B))

θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))

@turbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), (2, 1))
Cmn_re = zero(T)
Cmn_im = zero(T)
for k ∈ indices((A, B), (2, 2))
Cmn_re += A[m, k] * B[1, k, n]
Cmn_im += θ * A[m, k] * B[2, k, n]

@inline function _matmul_serial!(_C::$AT{Complex{T}}, A::AbstractMatrix{U}, _B::$AT{Complex{V}},
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
C, B = map(real_rep, (_C, _B))

θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))

@turbo for n ∈ indices((C, B), 3), m ∈ indices((C, A), (2, 1))
Cmn_re = zero(T)
Cmn_im = zero(T)
for k ∈ indices((A, B), (2, 2))
Cmn_re += A[m, k] * B[1, k, n]
Cmn_im += θ * A[m, k] * B[2, k, n]
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
end
_C
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
end
_C
end

@inline function _matmul_serial!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::AbstractVecOrMat{V},
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
C, A = map(real_rep, (_C, _A))

η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))

@turbo for n ∈ indices((C, B), (3, 2)), m ∈ indices((C, A), 2)
Cmn_re = zero(T)
Cmn_im = zero(T)
for k ∈ indices((A, B), (3, 1))
Cmn_re += A[1, m, k] * B[k, n]
Cmn_im += η * A[2, m, k] * B[k, n]

@inline function _matmul_serial!(_C::$AT{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::$AT{V},
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
C, A = map(real_rep, (_C, _A))

η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))

@turbo for n ∈ indices((C, B), (3, 2)), m ∈ indices((C, A), 2)
Cmn_re = zero(T)
Cmn_im = zero(T)
for k ∈ indices((A, B), (3, 1))
Cmn_re += A[1, m, k] * B[k, n]
Cmn_im += η * A[2, m, k] * B[k, n]
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
end
_C
end
C[1,m,n] = (real(α) * Cmn_re -ᶻ imag(α) * Cmn_im) + (real(β) * C[1,m,n] -ᶻ imag(β) * C[2,m,n])
C[2,m,n] = (imag(α) * Cmn_re +ᶻ real(α) * Cmn_im) + (imag(β) * C[1,m,n] +ᶻ real(β) * C[2,m,n])
end
_C
end
end
6 changes: 3 additions & 3 deletions src/macrokernels.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

@inline incrementp(A::AbstractStridedPointer{<:Base.HWReal,3}, a::Ptr) = VectorizationBase.increment_ptr(A, a, (Zero(), Zero(), One()))
@inline increment2(B::AbstractStridedPointer{<:Base.HWReal,2}, b::Ptr, ::StaticInt{nᵣ}) where {nᵣ} = VectorizationBase.increment_ptr(B, b, (Zero(), StaticInt{nᵣ}()))
@inline increment1(C::AbstractStridedPointer{<:Base.HWReal,2}, c::Ptr, ::StaticInt{mᵣW}) where {mᵣW} = VectorizationBase.increment_ptr(C, c, (StaticInt{mᵣW}(), Zero()))
@inline incrementp(A::AbstractStridedPointer{T,3} where T, a::Ptr) = VectorizationBase.increment_ptr(A, a, (Zero(), Zero(), One()))
@inline increment2(B::AbstractStridedPointer{T,2} where T, b::Ptr, ::StaticInt{nᵣ}) where {nᵣ} = VectorizationBase.increment_ptr(B, b, (Zero(), StaticInt{nᵣ}()))
@inline increment1(C::AbstractStridedPointer{T,2} where T, c::Ptr, ::StaticInt{mᵣW}) where {mᵣW} = VectorizationBase.increment_ptr(C, c, (StaticInt{mᵣW}(), Zero()))
macro kernel(pack::Bool, ex::Expr)
ex.head === :for || throw(ArgumentError("Must be a matmul for loop."))
mincrements = Expr[:(c = increment1(C, c, mᵣW)), :(ãₚ = incrementp(Ãₚ, ãₚ)), :(m = vsub_nsw(m, mᵣW))]
Expand Down
8 changes: 4 additions & 4 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ Otherwise, based on the array's size, whether they are transposed, and whether t
"""
@inline function _matmul_serial!(
C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN
) where {T<:Real}
) where {T}
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
if M * N == 0
return
Expand Down Expand Up @@ -263,7 +263,7 @@ end
end

# passing MKN directly would let osmeone skip the size check.
@inline function _matmul!(C::AbstractMatrix{T}, A, B, α, β, nthread, MKN) where {T<:Real}
@inline function _matmul!(C::AbstractMatrix{T}, A, B, α, β, nthread, MKN) where {T}
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
if M * N == 0
return
Expand Down Expand Up @@ -504,7 +504,7 @@ function sync_mul!(
nothing
end

function _matmul!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α, β, MKN, contig_axis) where {T<:Real}
function _matmul!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α, β, MKN, contig_axis) where {T}
@tturbo for m ∈ indices((A,y),1)
yₘ = zero(T)
for n ∈ indices((A,x),(2,1))
Expand All @@ -514,7 +514,7 @@ function _matmul!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α
end
return y
end
function _matmul_serial!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α, β, MKN) where {T<:Real}
function _matmul_serial!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α, β, MKN) where {T}
@turbo for m ∈ indices((A,y),1)
yₘ = zero(T)
for n ∈ indices((A,x),(2,1))
Expand Down