Skip to content

Commit

Permalink
update dense/linalg for libblastrampoline
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Mar 16, 2022
1 parent 46cc2d4 commit ae53a89
Showing 1 changed file with 66 additions and 94 deletions.
160 changes: 66 additions & 94 deletions src/dense/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@ import LinearAlgebra:
PosDefException,
chkstride1,
checksquare
import LinearAlgebra.BLAS: @blasfunc, libblas, BlasReal, BlasComplex
import LinearAlgebra.LAPACK: liblapack, chklapackerror
import LinearAlgebra.BLAS: @blasfunc, BlasReal, BlasComplex
import LinearAlgebra.LAPACK: chklapackerror
@static if VERSION >= v"1.7"
const liblapack = LinearAlgebra.BLAS.libblastrampoline
else
const liblapack = LinearAlgebra.LAPACK.liblapack
end

@static if isdefined(Base, :require_one_based_indexing)
import Base: require_one_based_indexing
Expand Down Expand Up @@ -117,8 +122,8 @@ geneigh!(A::StridedMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat} =
# Singular value decomposition of a Bidiagonal matrix
function bidiagsvd!(
B::Bidiagonal{T},
U::StridedMatrix{T} = one(B),
VT::StridedMatrix{T} = one(B)
U::AbstractMatrix{T} = one(B),
VT::AbstractMatrix{T} = one(B)
) where {T<:BlasReal}
s, Vt, U, = LAPACK.bdsqr!(B.uplo, B.dv, B.ev, VT, U, similar(U, (size(B, 1), 0)))
return U, s, Vt
Expand All @@ -144,20 +149,20 @@ function reverserows!(V::AbstractVecOrMat)
end

# Schur factorization of a Hessenberg matrix
hschur!(H::StridedMatrix{T}, Z::StridedMatrix{T} = one(H)) where {T<:BlasFloat} =
hschur!(H::AbstractMatrix{T}, Z::AbstractMatrix{T} = one(H)) where {T<:BlasFloat} =
hseqr!(H, Z)

schur2eigvals(T::StridedMatrix{<:BlasFloat}) = schur2eigvals(T, 1:size(T, 1))
schur2eigvals(T::AbstractMatrix{<:BlasFloat}) = schur2eigvals(T, 1:size(T, 1))

function schur2eigvals(T::StridedMatrix{<:BlasComplex}, which::AbstractVector{Int})
function schur2eigvals(T::AbstractMatrix{<:BlasComplex}, which::AbstractVector{Int})
n = checksquare(T)
which2 = unique(which)
length(which2) == length(which) ||
throw(ArgumentError("which should contain unique values"))
return [T[i, i] for i in which2]
end

function schur2eigvals(T::StridedMatrix{<:BlasReal}, which::AbstractVector{Int})
function schur2eigvals(T::AbstractMatrix{<:BlasReal}, which::AbstractVector{Int})
n = checksquare(T)
which2 = unique(which)
length(which2) == length(which) ||
Expand Down Expand Up @@ -188,15 +193,15 @@ function _normalizevecs!(V)
end
return V
end
function schur2eigvecs(T::StridedMatrix{<:BlasComplex})
function schur2eigvecs(T::AbstractMatrix{<:BlasComplex})
n = checksquare(T)
VR = similar(T, n, n)
VL = similar(T, n, 0)
select = Vector{BlasInt}(undef, 0)
trevc!('R', 'A', select, T, VL, VR)
return _normalizevecs!(VR)
end
function schur2eigvecs(T::StridedMatrix{<:BlasComplex}, which::AbstractVector{Int})
function schur2eigvecs(T::AbstractMatrix{<:BlasComplex}, which::AbstractVector{Int})
n = checksquare(T)
which2 = unique(which)
length(which2) == length(which) ||
Expand Down Expand Up @@ -238,7 +243,7 @@ function schur2eigvecs(T::StridedMatrix{<:BlasReal})
end
return _normalizevecs!(VR)
end
function schur2eigvecs(T::StridedMatrix{<:BlasReal}, which::AbstractVector{Int})
function schur2eigvecs(T::AbstractMatrix{<:BlasReal}, which::AbstractVector{Int})
n = checksquare(T)
which2 = unique(which)
length(which2) == length(which) ||
Expand Down Expand Up @@ -287,8 +292,8 @@ function schur2eigvecs(T::StridedMatrix{<:BlasReal}, which::AbstractVector{Int})
end

function permuteeig!(
D::StridedVector{S},
V::StridedMatrix{S},
D::AbstractVector{S},
V::AbstractMatrix{S},
perm::AbstractVector{Int}
) where {S}
n = checksquare(V)
Expand All @@ -310,7 +315,7 @@ function permuteeig!(
p[i] = i

D[i], D[inext] = D[inext], D[i]
for j in 1:n
@simd for j in 1:n
V[j, i], V[j, inext] = V[j, inext], V[j, i]
end
i = inext
Expand All @@ -319,46 +324,42 @@ function permuteeig!(
return D, V
end

permuteschur!(T::StridedMatrix{<:BlasFloat}, p::AbstractVector{Int}) =
permuteschur!(T::AbstractMatrix{<:BlasFloat}, p::AbstractVector{Int}) =
permuteschur!(T, one(T), p)
function permuteschur!(
T::StridedMatrix{S},
Q::StridedMatrix{S},
perm::AbstractVector{Int}
T::AbstractMatrix{S},
Q::AbstractMatrix{S},
order::AbstractVector{Int}
) where {S<:BlasComplex}
n = checksquare(T)
p = collect(perm) # makes copy cause will be overwritten
isperm(p) && length(p) == n ||
throw(ArgumentError("not a valid permutation of length $n"))
@inbounds for i in 1:n
p = collect(order) # makes copy cause will be overwritten
@inbounds for i in 1:length(p)
ifirst::BlasInt = p[i]
ilast::BlasInt = i
T, Q = LAPACK.trexc!(ifirst, ilast, T, Q)
for k in (i+1):n
for k in (i+1):length(p)
if p[k] < p[i]
p[k] += 1
end
end
end
return T, Q
return T, Q, schur2eigvals(T)
end

function permuteschur!(
T::StridedMatrix{S},
Q::StridedMatrix{S},
perm::AbstractVector{Int}
T::AbstractMatrix{S},
Q::AbstractMatrix{S},
order::AbstractVector{Int}
) where {S<:BlasReal}
n = checksquare(T)
p = collect(perm) # makes copy cause will be overwritten
isperm(p) && length(p) == n ||
throw(ArgumentError("not a valid permutation of length $n"))
p = collect(order) # makes copy cause will be overwritten
i = 1
@inbounds while i <= n
@inbounds while i <= length(p)
ifirst::BlasInt = p[i]
ilast::BlasInt = i
if ifirst == n || iszero(T[ifirst+1, ifirst])
T, Q = LAPACK.trexc!(ifirst, ilast, T, Q)
@inbounds for k in (i+1):n
@inbounds for k in (i+1):length(p)
if p[k] < p[i]
p[k] += 1
end
Expand All @@ -368,15 +369,24 @@ function permuteschur!(
p[i+1] == ifirst + 1 ||
error("cannot split 2x2 blocks when permuting schur decomposition")
T, Q = LAPACK.trexc!(ifirst, ilast, T, Q)
@inbounds for k in (i+2):n
@inbounds for k in (i+2):length(p)
if p[k] < p[i]
p[k] += 2
end
end
i += 2
end
end
return T, Q
return T, Q, schur2eigvals(T)
end

function partitionschur!(
T::AbstractMatrix{S},
Q::AbstractMatrix{S},
select::AbstractVector{Bool}
) where {S<:BlasFloat}
T, Q, vals, = trsen!('N', 'V', convert(Vector{BlasInt}, select), T, Q)
return T, Q, vals
end

# redefine LAPACK interface to tridiagonal eigenvalue problem
Expand Down Expand Up @@ -484,7 +494,7 @@ end

# redefine LAPACK interface to schur
for (hseqr, trevc, trsen, elty) in
((:dhseqr_, :dtrevc_, :dtrsen_, :Float64), (:shseqr_, :strevc_, :stgsen_, :Float32))
((:dhseqr_, :dtrevc_, :dtrsen_, :Float64), (:shseqr_, :strevc_, :strsen_, :Float32))
@eval begin
function hseqr!(H::StridedMatrix{$elty}, Z::StridedMatrix{$elty} = one(H))
require_one_based_indexing(H, Z)
Expand Down Expand Up @@ -621,15 +631,12 @@ for (hseqr, trevc, trsen, elty) in

return VL, VR, m
end
function trsen!(
job::Char,
compq::Char,
select::AbstractMatrix{BlasInt},
T::AbstractMatrix{$elty},
Q::AbstractMatrix{$elty}
)
function trsen!(job::AbstractChar, compq::AbstractChar, select::AbstractVector{BlasInt},
T::AbstractMatrix{$elty}, Q::AbstractMatrix{$elty})
chkstride1(T, Q, select)
n = checksquare(T)
checksquare(Q) == n || throw(DimensionMismatch())
length(select) == n || throw(DimensionMismatch())
ldt = max(1, stride(T, 2))
ldq = max(1, stride(Q, 2))
wr = similar(T, $elty, n)
Expand All @@ -643,62 +650,27 @@ for (hseqr, trevc, trsen, elty) in
select = convert(Array{BlasInt}, select)
s = Ref{$elty}(zero($elty))
sep = Ref{$elty}(zero($elty))
for i in 1:2 # first call returns lwork as work[1] and liwork as iwork[1]
ccall(
(@blasfunc($trsen), liblapack),
Cvoid,
(
Ref{UInt8},
Ref{UInt8},
Ptr{BlasInt},
Ref{BlasInt},
Ptr{$elty},
Ref{BlasInt},
Ptr{$elty},
Ref{BlasInt},
Ptr{$elty},
Ptr{$elty},
Ref{BlasInt},
Ref{$elty},
Ref{$elty},
Ptr{$elty},
Ref{BlasInt},
Ptr{BlasInt},
Ref{BlasInt},
Ptr{BlasInt},
Clong,
Clong
),
job,
compq,
select,
n,
T,
ldt,
Q,
ldq,
wr,
wi,
m,
s,
sep,
work,
lwork,
iwork,
liwork,
info,
1,
1
)
for i = 1:2 # first call returns lwork as work[1] and liwork as iwork[1]
ccall((@blasfunc($trsen), liblapack), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ptr{BlasInt}, Ref{BlasInt},
Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ref{$elty}, Ref{$elty},
Ptr{$elty}, Ref{BlasInt}, Ptr{BlasInt}, Ref{BlasInt},
Ptr{BlasInt}, Clong, Clong),
job, compq, select, n,
T, ldt, Q, ldq,
wr, wi, m, s, sep,
work, lwork, iwork, liwork,
info, 1, 1)
chklapackerror(info[])
if i == 1 # only estimated optimal lwork, liwork
lwork = BlasInt(real(work[1]))
lwork = BlasInt(real(work[1]))
resize!(work, lwork)
liwork = BlasInt(real(iwork[1]))
resize!(iwork, liwork)
end
end
return T, Q, complex.(wr, wi), s[], sep[]
T, Q, complex.(wr, wi), s[], sep[]
end
end
end
Expand Down Expand Up @@ -847,9 +819,9 @@ for (hseqr, trevc, trsen, elty, relty) in (
function trsen!(
job::Char,
compq::Char,
select::StridedVector{BlasInt},
T::StridedMatrix{$elty},
Q::StridedMatrix{$elty}
select::AbstractVector{BlasInt},
T::AbstractMatrix{$elty},
Q::AbstractMatrix{$elty}
)
chkstride1(select, T, Q)
n = checksquare(T)
Expand Down

0 comments on commit ae53a89

Please sign in to comment.