Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fastpath for generic mul! #51812

Closed
wants to merge 3 commits into from

Conversation

chriselrod
Copy link
Contributor

@chriselrod chriselrod commented Oct 21, 2023

The current _generic_matmul! is bad.

using LinearAlgebra, BenchmarkTools
using LinearAlgebra: @lazy_str
function _generic_matmatmuladd!(C, A, B)
    AxM = axes(A, 1)
    AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector`
    BxK = axes(B, 1)
    BxN = axes(B, 2)
    CxM = axes(C, 1)
    CxN = axes(C, 2)
    if AxM != CxM
        throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix C has axes ($CxM,$CxN)"))
    end
    if AxK != BxK
        throw(DimensionMismatch(lazy"matrix A has axes ($AxM,$AxK), matrix B has axes ($BxK,$CxN)"))
    end
    if BxN != CxN
        throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)"))
    end
    for n = BxN, k = BxK, m = AxM
        C[m,n] = muladd(A[m,k], B[k,n], C[m,n])
    end
    return C
end
function _generic_matmatmul!(C, A, B)
    _generic_matmatmuladd!(fill!(C, zero(eltype(C))), A, B)
end

for n = 1:12
   N = 1<<n
   A = rand(N,N)
   B = rand(N,N)
   C = similar(A)
   b_new = @benchmark _generic_matmatmul!($C, $A, $B)
   b_old = @benchmark LinearAlgebra._generic_matmatmul!($C, 'N', 'N', $A, $B, LinearAlgebra.MulAddMul())
   println("$N x $N, time ratio old/new: ", mean(b_old).time / mean(b_new).time)
end

Sample results:

2 x 2, time ratio old/new: 118.117764398017
4 x 4, time ratio old/new: 28.336393904569487
8 x 8, time ratio old/new: 5.783047811906496
16 x 16, time ratio old/new: 1.7274327231826396
32 x 32, time ratio old/new: 4.342510579320906
64 x 64, time ratio old/new: 5.244950423742804
128 x 128, time ratio old/new: 5.773512626617563
256 x 256, time ratio old/new: 7.888608978219142
512 x 512, time ratio old/new: 2.829412799661151
1024 x 1024, time ratio old/new: 2.4238409504206704
2048 x 2048, time ratio old/new: 1.5253985486724402
4096 x 4096, time ratio old/new: 1.3086912505915285

This is of course with Float64, which mul! normally will not hit.
The point here is to show that all the extra complexity of _generic_matmatmul! makes the code slower to run, even at fairly large sizes.
The naive tiling there does not help performance.

Worse, the current method also takes longer to compile.
Here is a benchmark of log10 runtimes across a range of size inputs.
For each size input, the benchmark covers all dual-combinations from no duals, to 1...8 partials, to 1...8 x 1...2 partials. It does this 1981 times on multiple threads (1981 == length(0.1:0.005:10.0)).
I also added a C++ implementation to the comparison for good measure.

Here is the performance on Julia master:
size_vs_log10time
Versus this PR:
size_vs_log10time_jdev
The x axis is matrix dimension, and the y axis is log10 time in nanoseconds (so 6 is 1ms).
Note that the big bump above 3x3 on Julia master is because it special cases a few matrix sizes, like 3x3.

In case the log scale obscures it, I'll emphasize: this PR improves performance several-fold in this ForwardDiff.Dual benchmark, just like the Float64 benchmark above.

On this PR, the compile times were in seconds:

julia> println(cts ./ 1e9)
[26.177206808, 171.026078739, 0.123216004]

for the custom implementation below, ExponentialUtilities.exponential!, and the ccall to the C++ version, respectively. The time cmake --build itself reported about 3seconds to actually compile the C++ shared library.

On Julia master, I got

julia> println(cts ./ 1e9)
[60.757913529, 156.906285547, 0.111382896]

I am not sure why ExponentialUtilities.exponential!'s compile time seemed to regress on this PR, but the total compile time of the combination was reduced.

The script:

using LinearAlgebra,
  Statistics, ForwardDiff, BenchmarkTools, ExponentialUtilities

const BENCH_OPNORMS = (66.0, 33.0, 22.0, 11.0, 6.0, 3.0, 2.0, 0.5, 0.03, 0.001)

"""
Generates one random matrix per opnorm.
All generated matrices are scale multiples of one another.
This is meant to exercise all code paths in the `expm` function.
"""
function randmatrices(n)
  A = randn(n, n)
  op = opnorm(A, 1)
  map(BENCH_OPNORMS) do x
    (x / op) .* A
  end
end

function expm(A::AbstractMatrix{S}) where {S}
  # omitted: matrix balancing, i.e., LAPACK.gebal!
  nA = opnorm(A, 1)
  ## For sufficiently small nA, use lower order Padé-Approximations
  if (nA <= 2.1)
    A2 = A * A
    if nA > 0.95
      U = @evalpoly(
        A2,
        S(8821612800) * I,
        S(302702400) * I,
        S(2162160) * I,
        S(3960) * I,
        S(1) * I
      )
      U = A * U
      V = @evalpoly(
        A2,
        S(17643225600) * I,
        S(2075673600) * I,
        S(30270240) * I,
        S(110880) * I,
        S(90) * I
      )
    elseif nA > 0.25
      U = @evalpoly(A2, S(8648640) * I, S(277200) * I, S(1512) * I, S(1) * I)
      U = A * U
      V =
        @evalpoly(A2, S(17297280) * I, S(1995840) * I, S(25200) * I, S(56) * I)
    elseif nA > 0.015
      U = @evalpoly(A2, S(15120) * I, S(420) * I, S(1) * I)
      U = A * U
      V = @evalpoly(A2, S(30240) * I, S(3360) * I, S(30) * I)
    else
      U = @evalpoly(A2, S(60) * I, S(1) * I)
      U = A * U
      V = @evalpoly(A2, S(120) * I, S(12) * I)
    end
    expA = (V - U) \ (V + U)
  else
    s = log2(nA / 5.4)               # power of 2 later reversed by squaring
    if s > 0
      si = ceil(Int, s)
      A = A / S(exp2(si))
    end

    A2 = A * A
    A4 = A2 * A2
    A6 = A2 * A4

    U =
      A6 * (A6 .+ S(16380) .* A4 .+ S(40840800) .* A2) .+ (
        S(33522128640) .* A6 .+ S(10559470521600) .* A4 .+
        S(1187353796428800) .* A2
      ) + S(32382376266240000) * I
    U = A * U
    V =
      A6 * (S(182) .* A6 .+ S(960960) .* A4 .+ S(1323241920) .* A2) .+ (
        S(670442572800) .* A6 .+ S(129060195264000) .* A4 +
        S(7771770303897600) .* A2
      ) + S(64764752532480000) * I
    expA = (V - U) \ (V + U)
    if s > 0            # squaring to reverse dividing by power of 2
      for _ = 1:si
        expA = expA * expA
      end
    end
  end
  expA
end
expm_oop!(B, A) = copyto!(B, expm(A))
function _matevalpoly!(B, _, D, A::AbstractMatrix{T}, t::NTuple{1}) where {T}
  M = size(A, 1)
  te = T(last(t))
  @inbounds for n = 1:M, m = 1:M
    B[m, n] = ifelse(m == n, te, zero(te))
  end
  @inbounds for n = 1:M, k = 1:M, m = 1:M
    B[m, n] = muladd(A[m, k], D[k, n], B[m, n])
  end
  return B
end
function _matevalpoly!(B, C, D, A::AbstractMatrix{T}, t::NTuple) where {T}
  M = size(A, 1)
  te = T(last(t))
  @inbounds for n = 1:M, m = 1:M
    C[m, n] = ifelse(m == n, te, zero(te))
  end
  @inbounds for n = 1:M, k = 1:M, m = 1:M
    C[m, n] = muladd(A[m, k], D[k, n], C[m, n])
  end
  _matevalpoly!(B, D, C, A, Base.front(t))
end
function matevalpoly!(B, C, D, A::AbstractMatrix{T}, t::NTuple) where {T}
  t1 = Base.front(t)
  te = T(last(t))
  tp = T(last(t1))
  @inbounds for j in axes(A, 2), i in axes(A, 1)
    D[i, j] = muladd(A[i, j], te, ifelse(i == j, tp, zero(tp)))
  end
  _matevalpoly!(B, C, D, A, Base.front(t1))
end
function matevalpoly!(B, _, _, A::AbstractMatrix{T}, t::NTuple{2}) where {T}
  t1 = Base.front(t)
  te = T(last(t))
  tp = T(last(t1))
  @inbounds for j in axes(A, 2), i in axes(A, 1)
    B[i, j] = muladd(A[i, j], te, ifelse(i == j, tp, zero(tp)))
  end
  return B
end
ceillog2(x::Float64) =
  (reinterpret(Int, x) - 1) >> Base.significand_bits(Float64) - 1022

_deval(x) = x
_deval(x::ForwardDiff.Dual) = _deval(ForwardDiff.value(x))

function opnorm1(A)
  n = _deval(zero(eltype(A)))
  @inbounds for j in axes(A, 2)
    s = _deval(zero(eltype(A)))
    @fastmath for i in axes(A, 1)
      s += abs(_deval(A[i, j]))
    end
    n = max(n, s)
  end
  return n
end

function expm!(
  Z::AbstractMatrix,
  A::AbstractMatrix,
  matmul! = mul!,
  matmuladd! = (C, A, B) -> mul!(C, A, B, 1.0, 1.0)
)
  # omitted: matrix balancing, i.e., LAPACK.gebal!
  # nA = maximum(sum(abs.(A); dims=Val(1)))    # marginally more performant than norm(A, 1)
  nA = opnorm1(A)
  N = LinearAlgebra.checksquare(A)
  # B and C are temporaries
  ## For sufficiently small nA, use lower order Padé-Approximations
  A2 = similar(A)
  if nA <= 2.1
    matmul!(A2, A, A)
    U = Z
    V = similar(A)
    if nA <= 0.015
      matevalpoly!(V, nothing, nothing, A2, (60, 1))
      matmul!(U, A, V)
      matevalpoly!(V, nothing, nothing, A2, (120, 12))
    else
      B = similar(A)
      if nA <= 0.25
        matevalpoly!(V, nothing, U, A2, (15120, 420, 1))
        matmul!(U, A, V)
        matevalpoly!(V, nothing, B, A2, (30240, 3360, 30))
      else
        C = similar(A)
        if nA <= 0.95
          matevalpoly!(V, C, U, A2, (8648640, 277200, 1512, 1))
          matmul!(U, A, V)
          matevalpoly!(V, B, C, A2, (17297280, 1995840, 25200, 56))
        else
          matevalpoly!(V, C, U, A2, (8821612800, 302702400, 2162160, 3960, 1))
          matmul!(U, A, V)
          matevalpoly!(
            V,
            B,
            C,
            A2,
            (17643225600, 2075673600, 30270240, 110880, 90)
          )
        end
      end
    end
    @inbounds for m = 1:N*N
      u = U[m]
      v = V[m]
      U[m] = v + u
      V[m] = v - u
    end
    ldiv!(lu!(V), U)
    expA = U
    # expA = (V - U) \ (V + U)
  else
    si = ceillog2(nA / 5.4)               # power of 2 later reversed by squaring
    if si > 0
      A = A / exp2(si)
    end
    matmul!(A2, A, A)
    A4 = similar(A)
    A6 = similar(A)
    matmul!(A4, A2, A2)
    matmul!(A6, A2, A4)

    U = Z
    B = zero(A)
    @inbounds for m = 1:N
      B[m, m] = 32382376266240000
    end
    @inbounds for m = 1:N*N
      a6 = A6[m]
      a4 = A4[m]
      a2 = A2[m]
      B[m] = muladd(
        33522128640,
        a6,
        muladd(10559470521600, a4, muladd(1187353796428800, a2, B[m]))
      )
      U[m] = muladd(16380, a4, muladd(40840800, a2, a6))
    end
    matmuladd!(B, A6, U)
    matmul!(U, A, B)

    V = si > 0 ? fill!(A, 0) : zero(A)
    @inbounds for m = 1:N
      V[m, m] = 64764752532480000
    end
    @inbounds for m = 1:N*N
      a6 = A6[m]
      a4 = A4[m]
      a2 = A2[m]
      B[m] = muladd(182, a6, muladd(960960, a4, 1323241920 * a2))
      V[m] = muladd(
        670442572800,
        a6,
        muladd(129060195264000, a4, muladd(7771770303897600, a2, V[m]))
      )
    end
    matmuladd!(V, A6, B)

    @inbounds for m = 1:N*N
      u = U[m]
      v = V[m]
      U[m] = v + u
      V[m] = v - u
    end
    ldiv!(lu!(V), U)
    expA = U
    # expA = (V - U) \ (V + U)

    if si > 0            # squaring to reverse dividing by power of 2
      for _ = 1:si
        matmul!(V, expA, expA)
        expA, V = V, expA
      end
      if Z !== expA
        copyto!(Z, expA)
        expA = Z
      end
    end
  end
  expA
end
naive_matmul!(C, A, B) = @inbounds for n in axes(C, 2), m in axes(C, 1)
  Cmn = zero(eltype(C))
  for k in axes(A, 2)
    Cmn = muladd(A[m, k], B[k, n], Cmn)
  end
  C[m, n] = Cmn
end
naive_matmuladd!(C, A, B) = @inbounds for n in axes(C, 2), m in axes(C, 1)
  Cmn = zero(eltype(C))
  for k in axes(A, 2)
    Cmn = muladd(A[m, k], B[k, n], Cmn)
  end
  C[m, n] += Cmn
end
expm_naivematmul!(Z, A) = expm!(Z, A, naive_matmul!, naive_matmuladd!)
d(x, n) = ForwardDiff.Dual(x, ntuple(_ -> randn(), n))
function dualify(A, n, j)
  if n > 0
    A = d.(A, n)
    if (j > 0)
      A = d.(A, j)
    end
  end
  A
end
struct ForEach{A,B,F}
  f::F
  a::A
  b::B
end
ForEach(f, b) = ForEach(f, nothing, b)
(f::ForEach)() = foreach(Base.Fix1(f.f, f.a), f.b)
(f::ForEach{Nothing})() = foreach(f.f, f.b)

function localwork(
  f::F,
  As::NTuple{<:Any,<:AbstractMatrix{T}},
  r,
  i,
  nt
) where {F,T}
  N = length(r)
  start = div(N * (i - 1), nt) + 1
  stop = div(N * i, nt)
  x::T = zero(T)
  B = similar(first(As))
  C = similar(B)
  for A in As
    for j = start:stop
      B .= r[j] .* A
      f(C, B)
      x += sum(C)
    end
  end
  return x
end
function localwork!(
  f::F,
  acc::Ptr{T},
  As::NTuple{<:Any,<:AbstractMatrix{T}},
  r,
  i,
  nt
) where {F,T}
  x = localwork(f, As, r, i, nt)
  unsafe_store!(acc, x)
end
function threadedwork!(
  f::F,
  As::NTuple{<:Any,<:AbstractMatrix{T}},
  r::AbstractArray
) where {T,F}
  nt = min(Threads.nthreads(), length(r))
  if nt <= 1
    return localwork(f, As, r, 1, 1)
  end
  x = cld(64, sizeof(T))
  acc = Matrix{T}(undef, x, nt)
  GC.@preserve acc begin
    p = pointer(acc)
    xst = sizeof(T) * x
    Threads.@threads for i = 1:nt
      localwork!(f, p + xst * (i - 1), As, r, i, nt)
    end
  end
  return sum(@view(acc[1, :]))::T
end
function isoutofplace(f, ::Val{NTuple{N,A}}) where {N,A}
  Base.promote_op(f, A) !== Union{}
end
struct ThreadedForEach{A,F}
  f::F
  a::A
end
ThreadedForEach(f, _, b) = ThreadedForEach(f, b)
function (f::ThreadedForEach{A})() where {A}
  r = 0.1:0.005:10.0
  if isoutofplace(f.f, Val(A))
    threadedwork!((x, y) -> copyto!(x, f.f(y)), f.a, r)
  else
    threadedwork!(f.f, f.a, r)
  end
end

const libMatrixExp = joinpath(@__DIR__, "buildgcc/libMatrixExp.so")
const libMatrixExpClang = joinpath(@__DIR__, "buildclang/libMatrixExp.so")
for (lib, cc) in ((:libMatrixExp, :gcc), (:libMatrixExpClang, :clang))
  j = Symbol(cc, :expm!)
  @eval $j(A::Matrix{Float64}, reps::Int) =
    @ccall $lib.food(A::Ptr{Float64}, size(A, 1)::Clong, reps::Clong)::Float64

  @eval $j(B::Matrix{Float64}, A::Matrix{Float64}) = @ccall $lib.expmf64(
    B::Ptr{Float64},
    A::Ptr{Float64},
    size(A, 1)::Clong
  )::Nothing
  for n = 1:8
    sym = Symbol(:expmf64d, n)
    @eval $j(
      B::Matrix{ForwardDiff.Dual{T,Float64,$n}},
      A::Matrix{ForwardDiff.Dual{T,Float64,$n}}
    ) where {T} = @ccall $lib.$sym(
      B::Ptr{Float64},
      A::Ptr{Float64},
      size(A, 1)::Clong
    )::Nothing
    for i = 1:2
      sym = Symbol(:expmf64d, n, :d, i)
      @eval $j(
        B::Matrix{ForwardDiff.Dual{T1,ForwardDiff.Dual{T0,Float64,$n},$i}},
        A::Matrix{ForwardDiff.Dual{T1,ForwardDiff.Dual{T0,Float64,$n},$i}}
      ) where {T0,T1} = @ccall $lib.$sym(
        B::Ptr{Float64},
        A::Ptr{Float64},
        size(A, 1)::Clong
      )::Nothing
    end
  end
end

struct BMean
  t::Float64
  m::Float64
  a::Float64
end
BMean() = BMean(0.0, 0.0, 0.0)
Base.zero(::BMean) = BMean()
BMean(b::BenchmarkTools.Trial) = BMean(BenchmarkTools.mean(b))
function BMean(b::BenchmarkTools.TrialEstimate)
  a = BenchmarkTools.allocs(b)
  BMean(BenchmarkTools.time(b), BenchmarkTools.memory(b), a)
end
Base.:(+)(x::BMean, y::BMean) = BMean(x.t+y.t,x.m+y.m,x.a+y.a)
Base.:(/)(x::BMean, y::Number) = BMean(x.t/y,x.m/y,x.a/y)
get_time(b::BMean) = b.t
function Base.show(io::IO, ::MIME"text/plain", b::BMean)
  (; t, m, a) = b
  println(
    io,
    "  ",
    BenchmarkTools.prettytime(t),
    " (",
    a,
    " allocation",
    (a == 1 ? "" : "s"),
    ": ",
    BenchmarkTools.prettymemory(m),
    ")"
  )
end

# macros are too awkward to work with, so we use a function
# mean times are much better for benchmarking than minimum
# whenever you have a function that allocates
function bmean(f)
  b = @benchmark $f()
  BMean(b)
end
function exputils!(B, A)
  exponential!(copyto!(B, A))
  return B
end

function run_benchmarks(funs, sizes = 2:8, D0 = 0:8, D1 = 0:2)
  num_funs = length(funs)
  compile_times = zeros(Int, num_funs)
  brs = Array{BMean}(undef, num_funs, length(sizes), length(D0), length(D1))
  counter = 0
  max_count = length(sizes) * (length(D0) * length(D1) - (length(D1) - 1))
  for (i, dim) in enumerate(sizes)
    for (j, d1) in enumerate(D1)
      for (k, d0) in enumerate(D0)
        if d0 == 0 && d1 != 0
          fill!(@view(brs[:, i, k, j]), BMean())
          continue
        end
        As = map(x -> dualify(x, d0, d1), randmatrices(dim))
        Bs = [similar(first(As)) for _ in eachindex(funs)]
        for A in As
          for l in eachindex(funs)
            fun! = funs[l]
            B = Bs[l]
            Base.cumulative_compile_timing(true)
            tstart = Base.cumulative_compile_time_ns()
            fun!(B, A)
            Base.cumulative_compile_timing(false)
            compile_times[l] += Base.cumulative_compile_time_ns()[1] - tstart[1]
          end
          for l = 2:num_funs
            if reinterpret(Float64, Bs[1])  reinterpret(Float64, Bs[l])
              throw("Funs 1 and $l disagree with dim=$dim, d0=$d0, d1=$d1.")
            end
          end
        end
        FE = Threads.nthreads() > 1 ? ThreadedForEach : ForEach
        for (l, fun) in enumerate(funs)
          brs[l, i, k, j] = bmean(FE(fun, As))
        end
        if (counter += 1) != max_count
          println(round(100counter / max_count; digits = 2), "% complete")
        end
      end
    end
  end
  return brs, compile_times
end

#=
# compare all
funs = [expm_oop!, expm!, expm_naivematmul!, exputils!, gccexpm!, clangexpm!];
fun_names = ["Out Of Place", "In Place", "In Place+Naive matmul!", "ExponentialUtilities.exponential!", "GCC", "Clang"]

# comparison used for this PR's reports
funs = [expm!, exputils!, clangexpm!];
fun_names = ["In Place", "ExponentialUtilities.exponential!", "Clang"]
brs, cts = run_benchmarks(funs);
println(cts ./ 1e9)

using CairoMakie, Statistics
t_vs_sz = mean(brs, dims = (3,4));
logtime(x) = log10(get_time(x))
f, ax, l1 = lines(2:8, logtime.(t_vs_sz[1,:,1,1]), label = fun_names[1]);
for l = 2:size(t_vs_sz,1)
  lines!(2:8, logtime.(t_vs_sz[l,:,1,1]), label=fun_names[l]);
end
axislegend(position=:rb); f
save("size_vs_log10time.png", f);
=#

Anyway, it is annoying that "write your own triple loop instead of calling mul!" is an important performance hack whenever dealing with generic code.

@chriselrod chriselrod added the performance Must go faster label Oct 21, 2023
@chriselrod
Copy link
Contributor Author

Also, this implementation does not require one-based-indexing, as long as the axes match.

@oscardssmith
Copy link
Member

If this implementation is faster, shouldn't we be able to remove the current fallback implementation?

@chriselrod
Copy link
Contributor Author

chriselrod commented Oct 21, 2023

If this implementation is faster, shouldn't we be able to remove the current fallback implementation?

The current fallback implementation takes tA, tB, and arbitrary MulAddMuls as arguments.
Arbitrary MulAddMuls is easy enough to support, but tA and tB are annoying.

I agree, ideally, we'd remove that generic implementation, and only go the unwrap_char route for matrices that actually do ultimately dispatch to BLAS/LAPACK.

I'd rather leave the reworking of the internal dispatch hierarchy to someone else, if anyone would like to take it on. ;)
This is roughly the minimal change that supports a few use cases of interest, while code cleanups, architecture, and technical debt are problems for the maintainers that I will thus leave to them.

@dkarrasch
Copy link
Member

If this is supposed to handle (for now) only the case tA == tB == 'N'), then perhaps we could replace the code in that branch in the most generic matmatmul function? As for the other cases, isn't it "just" about potentially reordering some loops and spelling out getindex for transpose and adjoint? Like, given the code from the 'N' path, replace A[i,j] by transpose(A[j,i]) etc.? I remember @andreasnoack was actually complaining once that the "generic" matmatmul is actually not so generic, because it's not just a triple loop, but something more "specific". So I think this effort here is welcome (not even taking into account the performance benefit).

@chriselrod
Copy link
Contributor Author

chriselrod commented Oct 25, 2023

I think we should only have a BLAS-like API for actually forwarding to BLAS.

Having a Julian API that gets processed into a BLAS-like one (meaning using N vs C vs T instead of types) so that we can then manually apply the transforms the wrappers were applying -- i.e. swapping dims for our size comparisons and selectively applying adjoint or transpose to our getindex calls -- sounds like the wrong approach.

Additionally, it makes the function 4x bigger. Why compile all code paths when we could compile the only one (in most circumstances) that we need?

I think the current _generic_matmul! function should never be called, and that dispatch leads either to the Julian API or BLAS up front. Only when we take the BLAS route should we end up looking at chars like N vs C, and we should rarely need to process them ourselves aside from forwarding to BLAS.

We may want to dispatch based on adjoint vs transpose to reorder loops -- the order I'm using here is optimal only N,N -- but I don't think we should get too carried away with half-hearted optimization. Else we might end up with a situation like _generic_matmul!: lots of code, slow to compile, slow to run.

@dkarrasch
Copy link
Member

What you're describing is the status of Julia prior to v1.10. What you see on master, i.e., the character processing and the BLAS-like interface is the new status of Julia v1.10+. That change was actually "celebrated", because it allowed the reduction of the number of mul! methods dramatically, and hence reduce load times significantly, for SparseArrays.jl and for a couple of GPU-related packages that have linear algebra included.

And the reason for that reduction is very simple: if you wanted to distinguish between the different storage types (dense, sparse, GPU arrays), you needed to dispatch over type parameters (à la A::Adjoint{<:Any,<:AbstractSparseMatrixCSC}). With the BLAS-like behaviour, we split outer wrappers from the storage types, and can dispatch on the storage type directly (yes, by overloading LinearAlgebra._generic_matmatmul!). Things like "BLAS up front" is easier said than done: which BLAS?! OpenBLAS, SparseArray"BLAS", metalBLAS, oneBLAS, CUDABLAS, CUDASparse"BLAS"...? I was hoping that constant propagation would actually make the character processing "almost type-dispatch-y", and hence not all branches need to be compiled for specific situations. But I'm unfortunately not competent enough to check if that currently works, or what would be needed to get it to work.

@chriselrod
Copy link
Contributor Author

These dispatches are internal and thus do not need to add to mul!.
For example, in this PR I did not change or add any mul! signature.

The terrible function _generic_matmul! is old, and using N, etc there predates 1.10.

@chriselrod
Copy link
Contributor Author

Anyway, assuming Cthulhu isn't lying

using ForwardDiff, LinearAlgebra
d(x, n) = ForwardDiff.Dual(x, ntuple(_ -> randn(), n))
function dualify(A, n, j)
  if n > 0
    A = d.(A, n)
    if (j > 0)
      A = d.(A, j)
    end
  end
  A
end

A = dualify.(randn(5,5), 7, 2);
B = dualify.(randn(5,5), 7, 2);
C = similar(A);

@time @eval mul!(C, A, B);
@time @eval mul!(C, A', B);
@time @eval mul!(C, A, B');
@time @eval mul!(C, A', B');
@time @eval mul!(C, transpose(A), B);
@time @eval mul!(C, A, transpose(B));
@time @eval mul!(C, transpose(A), transpose(B));

On Julia master, I get

julia> @time @eval mul!(C, A, B);
  1.121381 seconds (2.98 M allocations: 202.741 MiB, 0.97% gc time, 99.36% compilation time)

julia> @time @eval mul!(C, A', B);
  0.006050 seconds (4.22 k allocations: 294.500 KiB, 96.86% compilation time)

julia> @time @eval mul!(C, A, B');
  0.003332 seconds (2.67 k allocations: 204.047 KiB, 95.07% compilation time)

julia> @time @eval mul!(C, A', B');
  0.003324 seconds (2.72 k allocations: 206.984 KiB, 95.31% compilation time)

julia> @time @eval mul!(C, transpose(A), B);
  0.006030 seconds (4.20 k allocations: 293.125 KiB, 97.28% compilation time)

julia> @time @eval mul!(C, A, transpose(B));
  0.003468 seconds (2.67 k allocations: 203.719 KiB, 95.28% compilation time)

julia> @time @eval mul!(C, transpose(A), transpose(B));
  0.003527 seconds (2.72 k allocations: 207.172 KiB, 94.68% compilation time)

while on this PR, I get

julia> @time @eval mul!(C, A, B);
  0.114683 seconds (137.27 k allocations: 9.591 MiB, 93.84% compilation time)

julia> @time @eval mul!(C, A', B);
  0.077514 seconds (37.96 k allocations: 2.566 MiB, 99.74% compilation time)

julia> @time @eval mul!(C, A, B');
  0.048722 seconds (29.82 k allocations: 2.013 MiB, 99.61% compilation time)

julia> @time @eval mul!(C, A', B');
  0.063893 seconds (29.93 k allocations: 2.019 MiB, 99.72% compilation time)

julia> @time @eval mul!(C, transpose(A), B);
  0.076696 seconds (35.31 k allocations: 2.390 MiB, 99.74% compilation time)

julia> @time @eval mul!(C, A, transpose(B));
  0.048606 seconds (29.81 k allocations: 2.013 MiB, 99.63% compilation time)

julia> @time @eval mul!(C, transpose(A), transpose(B));
  0.073229 seconds (29.92 k allocations: 2.018 MiB, 11.88% gc time, 99.75% compilation time)

The thing to point out here is that on this PR, we have close to the same compile time each time, as we're compiling the entire mul! function.

However, on master, we only need to compile _generic_matmul! once, as all calls are there, and constant prop is NOT leading to new specializations.
(I also checked using Cthulhu.jl, where the optimzed typed code showed it was branching on the char's values, but I don't know when it does vs doesn't lie, so I'm not sure how much I can trust it.)

The single compile on Julia master was, however, significantly slower than the sum of all 7 separate compile times on this PR.
Real code is likely to benefit even further, from probably only needing one or maybe two of these specializations.

Here is another benchmark

using ForwardDiff, LinearAlgebra
d(x, n) = ForwardDiff.Dual(x, ntuple(_ -> randn(), n))
function dualify(A, n, j)
  if n > 0
    A = d.(A, n)
    if (j > 0)
      A = d.(A, j)
    end
  end
  A
end

@time for n = 0:8, j = (n!=0):4
  A = dualify.(randn(5,5), n, j);
  B = dualify.(randn(5,5), n, j);
  C = similar(A);
  mul!(C, A, B);
  mul!(C, A', B);
  mul!(C, A, B');
  mul!(C, A', B');
  mul!(C, transpose(A), B);
  mul!(C, A, transpose(B));
  mul!(C, transpose(A), transpose(B));
end

On Julia master, I get

31.304752 seconds (56.77 M allocations: 3.749 GiB, 1.15% gc time, 99.93% compilation time)

versus on this PR:

18.514421 seconds (12.76 M allocations: 860.920 MiB, 0.49% gc time, 99.88% compilation time)           

If we comment out all the mul! except mul!(C, A, B), on Julia master I get

30.290741 seconds (56.12 M allocations: 3.703 GiB, 1.14% gc time, 99.95% compilation time)

versus this PR

5.230921 seconds (6.99 M allocations: 468.793 MiB, 0.99% gc time, 99.76% compilation time)

So I suggest:

  1. Get rid of the buffers. These add tons of LOC while worsening compile times and runtimes; they're just bad.
  2. Separate small specializations like we have here, so that you don't have to pay for the compile times of what you don't use.

This PR does both by avoiding _generic_matmul! while also not adding or changing any mul! signatures.
I was aware of the problems those caused, which is why I didn't touch them and instead branched based on checks inside the method.
Alternatively, you could handle it with dispatches in the internal processing chain.

If you have strong opinions about the implementation, I suggest we close this PR and you implement the fixes.
I just don't want the combination of bad compile times and bad runtimes that we have now. That these are associated with a ton of extra complexity is painful; we have a ton of extra lines of code and complexity in the implementation for the sake of being bad.
Getting rid of it would be better, but in the short term it is easier to just bypass it by adding more code and complexity on top.

Comment on lines 268 to 270
if (!((eltype(C) <: BlasFloat && eltype(C) === eltype(A) === eltype(B)))) ||
(!(C isa StridedVecOrMat && A isa StridedVecOrMat && B isa StridedVecOrMat)) &&
α == true && (β == true || β == false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, then this condition applies, for instance, to CUDA, CUSparse and sparse matrices (and/or their transposes and adjoints), right? That would massively slow down SparseArrays.jl and even break a few GPUArrays-related packages, which specifically overload LinearAlgebra.generic_matmatmul! (with the character arguments).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make it Array-only.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, I might have misread the logic. Suppose I want to multiply non-transposed dense GPUArrays with BlasFloat elements, then this returns ... true? And directs away from generic_matmatmul!? Too many conditions for my little brain.

Copy link
Contributor Author

@chriselrod chriselrod Nov 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check is supposed to be for BLAS compatibility of types.
BLAS requires both of
1.The eltypes match and they're BlasFloat
2. They're all strided

The idea is, if either of these are the case, call the new overload.
Otherwise, call the BLAS dispatcher.

But because packages like CUDA are relying on extending the non-Julian and non-exported generic_matmatmul!, it's suddenly a breaking change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they do because that significantly reduced load times of those packages, both SparseArrays and all of the GPUArrays-related packages. Otherwise people had to overload tons of mul! methods, for all combinations of plain/transpose/adjoint factors, possibly dispatching on type parameters. JuliaGPU/CUDA.jl#1904

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, ideally there would have been another way to avoid ``mul!`.

I went with requiring Matrix for the first two arguments and Array for the third.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, ideally there would have been another way to avoid mul!.

Which one? Since this is all internal, it can be changed to whatever yields the correct dispatch to the underlying *BLAS methods and doesn't increase load times. Method insertion at package loading was the big big issue.

@chriselrod
Copy link
Contributor Author

Closing in favor of #52038

@chriselrod chriselrod closed this Nov 5, 2023
oscardssmith pushed a commit that referenced this pull request Nov 14, 2023
This is another attempt at improving the compile time issue with generic
matmatmul, hopefully improving runtime performance also.

@chriselrod @jishnub

There seems to be a little typo/oversight somewhere, but it shows how it
could work. Locally, this reduces benchmark times from
#51812 (comment) by
more than 50%.

---------

Co-authored-by: Chris Elrod <elrodc@gmail.com>
@chriselrod chriselrod deleted the mulfastpath branch May 1, 2024 18:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
domain:linear algebra Linear algebra performance Must go faster
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants