diff --git a/src/NDTensors.jl b/src/NDTensors.jl index afb06df..cbc3201 100644 --- a/src/NDTensors.jl +++ b/src/NDTensors.jl @@ -173,6 +173,9 @@ function __init__() enable_tblis() include("tblis.jl") end + @require Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" begin + include("octavian.jl") + end end end # module NDTensors diff --git a/src/dense.jl b/src/dense.jl index 7c758fd..6e09ffe 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -1,6 +1,7 @@ # # Dense storage # +using LinearAlgebra: BlasFloat struct Dense{ElT, VecT<:AbstractVector} <: TensorStorage{ElT} data::VecT @@ -441,39 +442,61 @@ function outer!(R::DenseTensor{ElR}, return R end -# BLAS matmul +export backend_auto, backend_blas, backend_generic + +@eval struct GemmBackend{T} + (f::Type{<:GemmBackend})() = $(Expr(:new, :f)) +end +GemmBackend(s) = GemmBackend{Symbol(s)}() +macro GemmBackend_str(s) + :(GemmBackend{$(Expr(:quote, Symbol(s)))}) +end + +const gemm_backend = Ref(:Auto) +function backend_auto() + gemm_backend[] = :Auto +end +function backend_blas() + gemm_backend[] = :BLAS +end +function backend_generic() + gemm_backend[] = :Generic +end + +@inline function auto_select_backend(::Type{<:StridedVecOrMat{<:BlasFloat}}, ::Type{<:StridedVecOrMat{<:BlasFloat}}, ::Type{<:StridedVecOrMat{<:BlasFloat}}) + GemmBackend(:BLAS) +end + +@inline function auto_select_backend(::Type{<:AbstractVecOrMat}, ::Type{<:AbstractVecOrMat}, ::Type{<:AbstractVecOrMat}) + GemmBackend(:Generic) +end + function _gemm!(tA, tB, alpha, - A::AbstractVecOrMat{<:LinearAlgebra.BlasFloat}, - B::AbstractVecOrMat{<:LinearAlgebra.BlasFloat}, - beta, C::AbstractVecOrMat{<:LinearAlgebra.BlasFloat}) + A::TA, + B::TB, + beta, C::TC) where {TA<:AbstractVecOrMat, TB<:AbstractVecOrMat, TC<:AbstractVecOrMat} + if gemm_backend[] == :Auto + _gemm!(auto_select_backend(TA, TB, TC), tA, tB, alpha, A, B, beta, C) + else + _gemm!(GemmBackend(gemm_backend[]), tA, tB, alpha, A, B, beta, C) + end +end + +# BLAS matmul +function _gemm!(::GemmBackend{:BLAS}, tA, tB, alpha, + A::AbstractVecOrMat, + B::AbstractVecOrMat, + beta, C::AbstractVecOrMat) #@timeit_debug timer "BLAS.gemm!" begin BLAS.gemm!(tA, tB, alpha, A, B, beta, C) #end # @timeit end # generic matmul -function _gemm!(tA, tB, alpha::AT, +function _gemm!(::GemmBackend{:Generic}, tA, tB, alpha::AT, A::AbstractVecOrMat, B::AbstractVecOrMat, beta::BT, C::AbstractVecOrMat) where {AT, BT} - if tA == 'T' - A = transpose(A) - end - if tB == 'T' - B = transpose(B) - end - if beta == zero(BT) - if alpha == one(AT) - C .= A * B - else - C .= alpha .* (A * B) - end - else - if alpha == one(AT) - C .= (A * B) .+ beta .* C - else - C .= alpha .* (A * B) .+ beta .* C - end - end + mul!(C, tA == 'T' ? transpose(A) : A, tB == 'T' ? transpose(B) : B, alpha, beta) return C end diff --git a/src/octavian.jl b/src/octavian.jl new file mode 100644 index 0000000..4ac73c2 --- /dev/null +++ b/src/octavian.jl @@ -0,0 +1,14 @@ +using .Octavian + +export backend_octavian + +function backend_octavian() + gemm_backend[] = :Octavian +end + +function _gemm!(::GemmBackend{:Octavian}, tA, tB, alpha, + A::AbstractVecOrMat, + B::AbstractVecOrMat, + beta, C::AbstractVecOrMat) + Octavian.matmul!(C, tA == 'T' ? transpose(A) : A, tB == 'T' ? transpose(B) : B, alpha, beta) +end \ No newline at end of file diff --git a/test/dense.jl b/test/dense.jl index 5541bb8..1951e81 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -1,6 +1,12 @@ using NDTensors, Test +@static if VERSION >= v"1.5" + using Pkg + Pkg.add("Octavian") + using Octavian +end + @testset "Dense Tensors" begin @testset "DenseTensor basic functionality" begin @@ -210,6 +216,39 @@ end end end +@testset "change backends" begin + a, b, c = [randn(5,5) for i=1:3] + backend_auto() + @test NDTensors.gemm_backend[] == :Auto + @test NDTensors.auto_select_backend(typeof.((a, b, c))...) == NDTensors.GemmBackend(:BLAS) + res1 = NDTensors._gemm!('N', 'N', 2.0, a, b, 0.2, copy(c)) + backend_blas() + @test NDTensors.gemm_backend[] == :BLAS + res2 = NDTensors._gemm!('N', 'N', 2.0, a, b, 0.2, copy(c)) + backend_generic() + @test NDTensors.gemm_backend[] == :Generic + res3 = NDTensors._gemm!('N', 'N', 2.0, a, b, 0.2, copy(c)) + @test res1 == res2 + @test res1 ≈ res3 + backend_auto() +end + +@static if VERSION >= v"1.5" + @testset "change backends" begin + a, b, c = [randn(5,5) for i=1:3] + backend_auto() + @test NDTensors.gemm_backend[] == :Auto + @test NDTensors.auto_select_backend(typeof.((a, b, c))...) == NDTensors.GemmBackend(:BLAS) + res1 = NDTensors._gemm!('N', 'N', 2.0, a, b, 0.2, copy(c)) + backend_octavian() + @test NDTensors.gemm_backend[] == :Octavian + res4 = NDTensors._gemm!('N', 'N', 2.0, a, b, 0.2, copy(c)) + @test res1 ≈ res4 + backend_auto() + end +end + end + nothing