New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LinearAlgebra: improve type-inference in Symmetric/Hermitian matmul #54303
Conversation
Should we even backport this to v1.10? |
So, the main idea is not to confuse the compiler unnecessarily with uppercase/lowercase H or S when, for the result type, that distinction is irrelevant anyway, right? Because that distinction is just the value of a field, and not encoded into the type? |
Yes, that's the idea. Backporting to v1.10 might require some manual intervention, but should be a good idea. |
The last few commits improve type-stability and ensure constant propagation in various checks in the matmul functions. Fixes #53951 after the recent set of commits. After this, julia> using LinearAlgebra
julia> using BenchmarkTools
julia> A = Hermitian([1.0 2.0; 2.0 3.0])
2×2 Hermitian{Float64, Matrix{Float64}}:
1.0 2.0
2.0 3.0
julia> B = [4.0 5.0; 6.0 7.0]
2×2 Matrix{Float64}:
4.0 5.0
6.0 7.0
julia> Y = similar(B)
2×2 Matrix{Float64}:
6.0e-323 6.4e-323
NaN 0.0
julia> @btime mul!($Y, $A, $B)
127.843 ns (0 allocations: 0 bytes)
2×2 Matrix{Float64}:
16.0 19.0
26.0 31.0
julia> Badj = B'
2×2 adjoint(::Matrix{Float64}) with eltype Float64:
4.0 6.0
5.0 7.0
julia> @btime mul!($Y, $A, $Badj)
44.311 ns (0 allocations: 0 bytes)
2×2 Matrix{Float64}:
14.0 20.0
23.0 33.0 |
ae7fc6c
to
541a76d
Compare
# We convert the chars to uppercase to potentially unwrap a WrapperChar, | ||
# and extract the char corresponding to the wrapper type | ||
tA_uc, tB_uc = uppercase(tA), uppercase(tB) | ||
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is true, this should be a giant contribution to the reduction of compile times, right? If we land in this branch, then we don't need to compile symm and hemm, or in the other case syrk/herk/gemm_wrapper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is some compile-time improvement indeed, although it's not dramatic.
Each execution is in a separate session in the following:
julia> A = rand(2,2); B = rand(2,2); C = zeros(2,2);
julia> @time mul!(C, A, B);
0.847057 seconds (3.39 M allocations: 171.963 MiB, 24.97% gc time, 100.00% compilation time) # nightly
0.757433 seconds (3.94 M allocations: 202.922 MiB, 4.65% gc time, 100.00% compilation time) # This PR
julia> A = rand(2,2); B = Symmetric(rand(2,2)); C = zeros(2,2);
julia> @time mul!(C, A, B);
1.098831 seconds (3.68 M allocations: 189.159 MiB, 24.52% gc time, 99.99% compilation time) # nightly
0.687847 seconds (4.72 M allocations: 238.864 MiB, 7.04% gc time, 99.99% compilation time) # This PR
Descending into generic_matmatmul!
using Cthulhu does seem to indicate that unused branches are eliminated, and e.g. in the first case, only gemm_wrapper!
is being compiled, and in the second, only BLAS.symm!
is compiled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code_typed
for the first case (gemm
) is identical between this PR and nightly:
julia> A = rand(2,2); B = rand(2,2); C = zeros(2,2);
julia> @code_typed mul!(C, A, B)
CodeInfo(
1 ─ %1 = invoke LinearAlgebra.gemm_wrapper!(C::Matrix{Float64}, 'N'::Char, 'N'::Char, A::Matrix{Float64}, B::Matrix{Float64}, $(QuoteNode(LinearAlgebra.MulAddMul{true, true, Bool, Bool}(true, false)))::LinearAlgebra.MulAddMul{true, true, Bool, Bool})::Matrix{Float64}
└── return %1
) => Matrix{Float64}
I'm not certain why there's a compile-time improvement here. (perhaps noise?) In this case, the all
is already being folded (despite the loop over the characters). I suspect the loop is being unrolled entirely, as the characters are all Char
s that are fully known at compile time.
The second case (symm
) is where the major improvement comes in:
julia> A = rand(2,2); B = Symmetric(rand(2,2)); C = zeros(2,2);
julia> @code_typed mul!(C, A, B)
CodeInfo(
1 ── %1 = Base.getfield(B, :uplo)::Char
│ %2 = Base.bitcast(Base.UInt32, %1)::UInt32
│ %3 = Base.bitcast(Base.UInt32, 'U')::UInt32
│ %4 = (%2 === %3)::Bool
│ %5 = Base.getfield(B, :data)::Matrix{Float64}
└─── goto #3 if not %4
2 ── goto #4
3 ── goto #4
4 ┄─ %9 = φ (#2 => 'S', #3 => 's')::Char
│ %10 = Base.bitcast(Base.UInt32, %9)::UInt32
│ %11 = Base.bitcast(Base.UInt32, 'S')::UInt32
│ %12 = (%10 === %11)::Bool
└─── goto #5
5 ── goto #7 if not %12
6 ── goto #8
7 ── nothing::Nothing
8 ┄─ %17 = φ (#6 => 'U', #7 => 'L')::Char
│ %18 = invoke LinearAlgebra.BLAS.symm!('R'::Char, %17::Char, 1.0::Float64, %5::Matrix{Float64}, A::Matrix{Float64}, 0.0::Float64, C::Matrix{Float64})::Matrix{Float64}
└─── goto #9
9 ── goto #10
10 ─ goto #11
11 ─ return %18
) => Matrix{Float64}
The BLAS.symm!
branch that is being followed is "inlined" now. This is the case where the loop is not unrolled ordinarily, but using the all(map(..))
combination permits constant propagation.
I love it. I always dreamt of the day when that character stuff be inferred, or constant-propagated far enough. Is this ready to go now? I think we should first merge this, and then "stabilize MulAddMul strategically" PR, to give this one a chance for backport to v1.10, though I'm not sure if this is a bit too ambitious. |
With this, `isuppercase`/`islowercase` are evaluated at compile-time for `Char` arguments: ```julia julia> @code_typed (() -> isuppercase('A'))() CodeInfo( 1 ─ return true ) => Bool julia> @code_typed (() -> islowercase('A'))() CodeInfo( 1 ─ return false ) => Bool ``` This would be useful in #54303, where the case of the character indicates which triangular half of a matrix is filled, and may be constant-propagated downstream. --------- Co-authored-by: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com>
Yes, this is ready from my side. |
Matrix multiplication for wrapper types such as
Hermitian
currently uses the unwrapping mechanism that assigns a character based on the type of the wrapper. However, this isn't always unique, as forHermitian
/Symmetric
types, this also looks as theuplo
field, which isn't usually known at compile time.julia/stdlib/LinearAlgebra/src/LinearAlgebra.jl
Lines 523 to 525 in 6023ad6
An example of a badly inferred function call because of this:
The output type is inferred as a large union, which complicates further type-inference downstream. Often, the impact of the runtime dispatch is minimal due to function barriers. However, we may avoid the runtime dispatch altogether.
This PR separates the
uplo
character from that for the type, storing them both in a newly defined struct. Using this approach, the type information may be constant-propagated even if theuplo
isn't, and the return type may be concretely inferred. After this,This change should be compatible with existing codes, as the new struct subtypes an
AbstractChar
, and it may be converted and compared to aChar
like before.Fixes #53951.