Skip to content

Commit

Permalink
Fix blas tests (#627)
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed May 11, 2024
1 parent d3bf4c1 commit 70ac953
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
48 changes: 44 additions & 4 deletions src/blas/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,15 @@ end
if VERSION v"1.10-"
# multiplication
LinearAlgebra.generic_trimatmul!(
c::ROCVector{T}, uploc, isunitc, tfun::Function,
A::ROCMatrix{T}, b::AbstractVector{T},
c::StridedROCVector{T}, uploc, isunitc, tfun::Function,
A::StridedROCMatrix{T}, b::StridedROCVector{T},
) where T <: ROCBLASFloat = trmv!(
uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C',
isunitc, A, c === b ? c : copyto!(c, b))
# division
LinearAlgebra.generic_trimatdiv!(
C::ROCVector{T}, uploc, isunitc, tfun::Function,
A::ROCMatrix{T}, B::AbstractVector{T},
C::StridedROCVector{T}, uploc, isunitc, tfun::Function,
A::StridedROCMatrix{T}, B::StridedROCVector{T},
) where T <: ROCBLASFloat = trsv!(
uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C',
isunitc, A, C === B ? C : copyto!(C, B))
Expand Down Expand Up @@ -410,3 +410,43 @@ else
end
end
end

# Matrix inversion.

for (t, uploc, isunitc) in (
(:LowerTriangular, 'U', 'N'),
(:UnitLowerTriangular, 'U', 'U'),
(:UpperTriangular, 'L', 'N'),
(:UnitUpperTriangular, 'L', 'U'),
)
@eval function LinearAlgebra.inv(x::$t{T, <: ROCMatrix{T}}) where T <: ROCBLASFloat
out = ROCArray{T}(I(size(x, 1)))
$t(LinearAlgebra.ldiv!(x, out))
end
end

# Diagonal matrix.

Base.Array(D::Diagonal{T, <: ROCArray{T}}) where T = Diagonal(Array(D.diag))

ROCArray(D::Diagonal{T, <: Vector{T}}) where T = Diagonal(ROCArray(D.diag))

function LinearAlgebra.inv(D::Diagonal{T, <: ROCArray{T}}) where T
Di = map(inv, D.diag)
any(isinf, Di) && error("Singular Exception $Di")
Diagonal(Di)
end

function Base.:/(A::ROCArray, D::Diagonal)
B = similar(A, typeof(oneunit(eltype(A)) / oneunit(eltype(D))))
_rdiv!(B, A, D)
end

function _rdiv!(B::ROCArray, A::ROCArray, D::Diagonal)
m, n = size(A, 1), size(A, 2)
(k = length(D.diag)) != n && throw(DimensionMismatch(
"left hand side has $n columns but D is $k by $k"))

B .= A * inv(D)
return B
end
16 changes: 8 additions & 8 deletions test/rocarray/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ m = 20
n = 35
k = 13

handle = rocBLAS.handle()

@testset "Build Information" begin
ver = rocBLAS.version()
@test ver isa VersionNumber
Expand Down Expand Up @@ -103,13 +101,13 @@ end

A = rand(T, m, m)
x = rand(T, m)
@testset "Triangular mul/lmul!" for TR in (
@testset "Triangular mul/lmul!" for TR in (
UpperTriangular, LowerTriangular,
), f in (
identity, adjoint, transpose,
)
@test testf((a, b) -> f(TR(A)) * x, A, x)
@test testf((a, b) -> lmul!(f(TR(A)), b), A, copy(x))
@test testf((a, b) -> f(TR(a)) * b, A, x)
@test testf((a, b) -> lmul!(f(TR(a)), b), A, copy(x))
end

A, x = rand(T, m, m), rand(T, m)
Expand All @@ -118,15 +116,17 @@ end
), f in (
identity, adjoint, transpose,
)
@test testf((a, b) -> f(TR(A)) \ x, A, x)
@test testf((a, b) -> ldiv!(f(TR(A)), b), A, copy(x))
@test testf((a, b) -> f(TR(a)) \ b, A, x)
@test testf((a, b) -> ldiv!(f(TR(a)), b), A, copy(x))
end

x = rand(T, m, m)
@testset "inv($TR)" for TR in (
UpperTriangular, LowerTriangular,
UnitUpperTriangular, UnitLowerTriangular,
Diagonal,
)
@test testf(x -> inv(TR(x)), rand(T, m, m))
@test testf(a -> inv(TR(a)), x)
end
end
end
Expand Down

0 comments on commit 70ac953

Please sign in to comment.