Skip to content

Commit

Permalink
[NDTensors] CuArray linear algebra in NDTensors (#1146)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Jul 20, 2023
1 parent 708220f commit 70eb783
Show file tree
Hide file tree
Showing 14 changed files with 126 additions and 88 deletions.
1 change: 0 additions & 1 deletion ITensorGPU/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,3 @@ Strided = "1.1.2, 2"
TimerOutputs = "0.5.13"
cuTENSOR = "1.1.0"
julia = "1.6"

1 change: 1 addition & 0 deletions NDTensors/ext/NDTensorCUDA/NDTensorCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ end
include("imports.jl")
include("set_types.jl")
include("adapt.jl")
include("linearalgebra.jl")
end
13 changes: 13 additions & 0 deletions NDTensors/ext/NDTensorCUDA/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function NDTensors.svd_catch_error(A::CuMatrix; alg="JacobiAlgorithm")
if alg == "JacobiAlgorithm"
alg = CUDA.CUSOLVER.JacobiAlgorithm()
else
alg = CUDA.CUSOLVER.QRAlgorithm()
end
USV = try
svd(A; alg=alg)
catch
return nothing
end
return USV
end
10 changes: 5 additions & 5 deletions NDTensors/src/dense/tensoralgebra/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,23 +145,23 @@ function _contract_scalar_noperm!(
fill!(Rᵈ, 0)
else
# Rᵈ .= α .* T₂ᵈ
BLAS.axpby!(α, Tᵈ, β, Rᵈ)
LinearAlgebra.axpby!(α, Tᵈ, β, Rᵈ)
end
elseif isone(β)
if iszero(α)
# No-op
# Rᵈ .= Rᵈ
else
# Rᵈ .= α .* Tᵈ .+ Rᵈ
BLAS.axpy!(α, Tᵈ, Rᵈ)
LinearAlgebra.axpy!(α, Tᵈ, Rᵈ)
end
else
if iszero(α)
# Rᵈ .= β .* Rᵈ
BLAS.scal!(length(Rᵈ), β, Rᵈ, 1)
LinearAlgebra.scal!(length(Rᵈ), β, Rᵈ, 1)
else
# Rᵈ .= α .* Tᵈ .+ β .* Rᵈ
BLAS.axpby!(α, Tᵈ, β, Rᵈ)
LinearAlgebra.axpby!(α, Tᵈ, β, Rᵈ)
end
end
return R
Expand All @@ -187,7 +187,7 @@ function _contract_scalar_perm!(
else
if iszero(α)
# Rᵃ .= β .* Rᵃ
BLAS.scal!(length(Rᵃ), β, Rᵃ, 1)
LinearAlgebra.scal!(length(Rᵃ), β, Rᵃ, 1)
else
Rᵃ .= α .* permutedims(Tᵃ, perm) .+ β .* Rᵃ
end
Expand Down
24 changes: 17 additions & 7 deletions NDTensors/src/linearalgebra/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ function lapack_svd_error_message(alg)
" To get an `svd` of a matrix `A`, an eigendecomposition of\n" *
" ``A^{\\dagger} A`` is used to compute `U` and then a `qr` of\n" *
" ``A^{\\dagger} U`` is used to compute `V`. This is performed\n" *
" recursively to compute small singular values.\n\n" *
" recursively to compute small singular values.\n" *
" - `\"QRAlgorithm\"` is a CUDA.jl implemented SVD algorithm using QR.\n" *
" - `\"JacobiAlgorithm\"` is a CUDA.jl implemented SVD algorithm.\n\n" *
"Returning `nothing`. For an output `F = svd(A, ...)` you can check if\n" *
"`isnothing(F)` in your code and try a different algorithm.\n\n" *
"To suppress this message in the future, you can wrap the `svd` call in the\n" *
Expand Down Expand Up @@ -146,6 +148,8 @@ function LinearAlgebra.svd(T::DenseTensor{ElT,2,IndsT}; kwargs...) where {ElT,In
end
elseif alg == "recursive"
MUSV = svd_recursive(matrix(T))
elseif alg == "QRAlgorithm" || alg == "JacobiAlgorithm"
MUSV = svd_catch_error(matrix(T); alg=alg)
else
error(
"svd algorithm $alg is not currently supported. Please see the documentation for currently supported algorithms.",
Expand Down Expand Up @@ -270,12 +274,13 @@ function random_unitary(::Type{ElT}, n::Int, m::Int) where {ElT<:Number}
return random_unitary(Random.default_rng(), ElT, n, m)
end

function random_unitary(rng::AbstractRNG, ::Type{ElT}, n::Int, m::Int) where {ElT<:Number}
function random_unitary(rng::AbstractRNG, DataT::Type{<:AbstractArray}, n::Int, m::Int)
ElT = eltype(DataT)
if n < m
return Matrix(random_unitary(rng, ElT, m, n)')
return DataT(random_unitary(rng, ElT, m, n)')
end
F = qr(randn(rng, ElT, n, m))
Q = Matrix(F.Q)
Q = DataT(F.Q)
# The upper triangle of F.factors
# are the elements of R.
# Multiply cols of Q by the signs
Expand All @@ -287,6 +292,10 @@ function random_unitary(rng::AbstractRNG, ::Type{ElT}, n::Int, m::Int) where {El
return Q
end

function random_unitary(rng::AbstractRNG, ::Type{ElT}, n::Int, m::Int) where {ElT<:Number}
return random_unitary(rng, set_ndims(default_datatype(ElT), 2), n, m)
end

random_unitary(n::Int, m::Int) = random_unitary(ComplexF64, n, m)

"""
Expand Down Expand Up @@ -390,7 +399,8 @@ function qx(qx::Function, T::DenseTensor{<:Any,2}; kwargs...)
IndsT = indstype(T) #get the index type
Qinds = IndsT((ind(T, 1), q))
Xinds = IndsT((q, ind(T, 2)))
Q = tensor(Dense(vec(Matrix(QM))), Qinds) #Q was strided
QM = convert(typeof(XM), QM)
Q = tensor(Dense(vec(QM)), Qinds) #Q was strided
X = tensor(Dense(vec(XM)), Xinds)
return Q, X
end
Expand All @@ -409,7 +419,7 @@ matrix is unique. Returns a tuple (Q,R).
"""
function qr_positive(M::AbstractMatrix)
sparseQ, R = qr(M)
Q = convert(Matrix, sparseQ)
Q = convert(typeof(R), sparseQ)
nc = size(Q, 2)
for c in 1:nc
if R[c, c] != 0.0 #sign(0.0)==0.0 so we don't want to zero out a column of Q.
Expand All @@ -433,7 +443,7 @@ matrix is unique. Returns a tuple (Q,L).
"""
function ql_positive(M::AbstractMatrix)
sparseQ, L = ql(M)
Q = convert(Matrix, sparseQ)
Q = convert(typeof(L), sparseQ)
nr, nc = size(L)
dc = nc > nr ? nc - nr : 0 #diag is shifted over by dc if nc>nr
for c in 1:(nc - dc)
Expand Down
4 changes: 2 additions & 2 deletions NDTensors/src/linearalgebra/svd.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

function checkSVDDone(S::Vector, thresh::Float64)
function checkSVDDone(S::AbstractArray, thresh::Float64)
N = length(S)
(N <= 1 || thresh < 0.0) && return (true, 1)
S1t = S[1] * thresh
Expand All @@ -25,7 +25,7 @@ function svd_recursive(M::AbstractMatrix; thresh::Float64=1E-3, north_pass::Int=

#rho = BLAS.gemm('N','T',-1.0,M,M) #negative to sort eigenvalues greatest to smallest
rho = -M * M' #negative to sort eigenvalues in decreasing order
D, U = eigen(Hermitian(rho), 1:size(rho, 1))
D, U = eigen(Hermitian(rho))

Nd = length(D)

Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/truncate.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export truncate!

function truncate!(P::Vector{ElT}; kwargs...)::Tuple{ElT,ElT} where {ElT}
function truncate!(P::AbstractVector{ElT}; kwargs...)::Tuple{ElT,ElT} where {ElT}
cutoff::Union{Nothing,ElT} = get(kwargs, :cutoff, zero(ElT))
if isnothing(cutoff)
cutoff = typemin(ElT)
Expand Down
6 changes: 3 additions & 3 deletions NDTensors/test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ end
@test eltype(Asim) == elt
@test length(Asim) == 10

B = Tensor(undef, (3, 4))
B = dev(Tensor(undef, (3, 4)))
randn!(B)

C = A + B
Expand Down Expand Up @@ -112,8 +112,8 @@ end
@test dim(I) == 1000
@test Array(I) == I_arr

J = Tensor((2, 2))
K = Tensor((2, 2))
J = dev(Tensor((2, 2)))
K = dev(Tensor((2, 2)))
@test Array(J * K) Array(J) * Array(K)
end

Expand Down
7 changes: 7 additions & 0 deletions NDTensors/test/device_list.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
if "cuda" in ARGS || "all" in ARGS
using CUDA
end
if "metal" in ARGS || "all" in ARGS
using Metal
end

function devices_list(test_args)
devs = Vector{Function}(undef, 0)
if isempty(test_args) || "base" in test_args
Expand Down
7 changes: 1 addition & 6 deletions NDTensors/test/diag.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
using NDTensors
using Test
if "cuda" in ARGS || "all" in ARGS
using CUDA
end
if "metal" in ARGS || "all" in ARGS
using Metal
end

@testset "DiagTensor basic functionality" begin
include("device_list.jl")
devs = devices_list(copy(ARGS))
Expand Down
6 changes: 0 additions & 6 deletions NDTensors/test/emptystorage.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
using NDTensors
using Test
if "cuda" in ARGS || "all" in ARGS
using CUDA
end
if "metal" in ARGS || "all" in ARGS
using Metal
end

@testset "EmptyStorage test" begin
include("device_list.jl")
Expand Down
105 changes: 57 additions & 48 deletions NDTensors/test/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,57 +26,66 @@ end
@test norm(U2 * U2' - Diagonal(fill(1.0, m))) < 1E-14
end

@testset "Dense $qx decomposition, elt=$elt, positve=$positive, singular=$singular" for qx in
[
qr, ql
],
elt in [Float64, ComplexF64, Float32, ComplexF32],
positive in [false, true],
singular in [false, true]
include("device_list.jl")
devs = devices_list(copy(ARGS))
@testset "QX testing" begin
@testset "Dense $qx decomposition, elt=$elt, positve=$positive, singular=$singular, device=$dev" for qx in
[
qr, ql
],
elt in [Float64, ComplexF64, Float32, ComplexF32],
positive in [false, true],
singular in [false, true],
dev in devs

eps = Base.eps(real(elt)) * 100 #this is set rather tight, so if you increase/change m,n you may have open up the tolerance on eps.
n, m = 4, 8
Id = Diagonal(fill(1.0, min(n, m)))
#
# Wide matrix (more columns than rows)
#
A = randomTensor(elt, (n, m))
# We want to test 0.0 on the diagonal. We need make all roaw equal to gaurantee this with numerical roundoff.
if singular
for i in 2:n
A[i, :] = A[1, :]
eps = Base.eps(real(elt)) * 100 #this is set rather tight, so if you increase/change m,n you may have open up the tolerance on eps.
n, m = 4, 8
Id = Diagonal(fill(1.0, min(n, m)))
#
# Wide matrix (more columns than rows)
#
A = dev(randomTensor(elt, (n, m)))
# We want to test 0.0 on the diagonal. We need to make all rows equal to gaurantee this with numerical roundoff.
if singular
for i in 2:n
A[i, :] = A[1, :]
end
end
end
Q, X = qx(A; positive=positive) #X is R or L.
@test A Q * X atol = eps
@test array(Q)' * array(Q) Id atol = eps
@test array(Q) * array(Q)' Id atol = eps
if positive
nr, nc = size(X)
dr = qx == ql ? Base.max(0, nc - nr) : 0
diagX = diag(X[:, (1 + dr):end]) #location of diag(L) is shifted dr columns over the right.
@test all(real(diagX) .>= 0.0)
@test all(imag(diagX) .== 0.0)
end
#
# Tall matrix (more rows than cols)
#
A = randomTensor(elt, (m, n)) #Tall array
# We want to test 0.0 on the diagonal. We need make all rows equal to gaurantee this with numerical roundoff.
if singular
for i in 2:m
A[i, :] = A[1, :]
if qx == ql && dev != NDTensors.cpu
@test_broken qx(A; positive=positive)
continue
end
Q, X = qx(A; positive=positive) #X is R or L.
@test A Q * X atol = eps
@test array(Q)' * array(Q) Id atol = eps
@test array(Q) * array(Q)' Id atol = eps
if positive
nr, nc = size(X)
dr = qx == ql ? Base.max(0, nc - nr) : 0
diagX = diag(X[:, (1 + dr):end]) #location of diag(L) is shifted dr columns over the right.
@test all(real(diagX) .>= 0.0)
@test all(imag(diagX) .== 0.0)
end
#
# Tall matrix (more rows than cols)
#
A = dev(randomTensor(elt, (m, n))) #Tall array
# We want to test 0.0 on the diagonal. We need make all rows equal to gaurantee this with numerical roundoff.
if singular
for i in 2:m
A[i, :] = A[1, :]
end
end
Q, X = qx(A; positive=positive)
@test A Q * X atol = eps
@test array(Q)' * array(Q) Id atol = eps
if positive
nr, nc = size(X)
dr = qx == ql ? Base.max(0, nc - nr) : 0
diagX = diag(X[:, (1 + dr):end]) #location of diag(L) is shifted dr columns over the right.
@test all(real(diagX) .>= 0.0)
@test all(imag(diagX) .== 0.0)
end
end
Q, X = qx(A; positive=positive)
@test A Q * X atol = eps
@test array(Q)' * array(Q) Id atol = eps
if positive
nr, nc = size(X)
dr = qx == ql ? Base.max(0, nc - nr) : 0
diagX = diag(X[:, (1 + dr):end]) #location of diag(L) is shifted dr columns over the right.
@test all(real(diagX) .>= 0.0)
@test all(imag(diagX) .== 0.0)
end
end

Expand Down
22 changes: 16 additions & 6 deletions src/itensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -564,15 +564,25 @@ B = onehot(i=>1,j=>3)
# B[i=>1,j=>3] == 1, all other element zero
```
"""
function onehot(datatype::Type{<:AbstractArray}, ivs::Pair{<:Index}...)
A = ITensor(eltype(datatype), ind.(ivs)...)
A[val.(ivs)...] = one(eltype(datatype))
return Adapt.adapt(datatype, A)
end

function onehot(eltype::Type{<:Number}, ivs::Pair{<:Index}...)
A = ITensor(eltype, ind.(ivs)...)
A[val.(ivs)...] = one(eltype)
return A
return onehot(NDTensors.default_datatype(eltype), ivs...)
end
function onehot(eltype::Type{<:Number}, ivs::Vector{<:Pair{<:Index}})
return onehot(NDTensors.default_datatype(eltype), ivs...)
end
function setelt(eltype::Type{<:Number}, ivs::Pair{<:Index}...)
return onehot(NDTensors.default_datatype(eltype), ivs...)
end
onehot(eltype::Type{<:Number}, ivs::Vector{<:Pair{<:Index}}) = onehot(eltype, ivs...)
setelt(eltype::Type{<:Number}, ivs::Pair{<:Index}...) = onehot(eltype, ivs...)

onehot(ivs::Pair{<:Index}...) = onehot(Float64, ivs...)
function onehot(ivs::Pair{<:Index}...)
return onehot(NDTensors.default_datatype(NDTensors.default_eltype()), ivs...)
end
onehot(ivs::Vector{<:Pair{<:Index}}) = onehot(ivs...)
setelt(ivs::Pair{<:Index}...) = onehot(ivs...)

Expand Down
6 changes: 3 additions & 3 deletions src/tensor_operations/matrix_decomposition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ function svd(A::ITensor, Linds...; kwargs...)
Ris_original = Ris
if isempty(Lis_original)
α = trivial_index(Ris)
vLα = onehot(eltype(A), α => 1)
vLα = onehot(datatype(A), α => 1)
A *= vLα
Lis = [α]
end
if isempty(Ris_original)
α = trivial_index(Lis)
vRα = onehot(eltype(A), α => 1)
vRα = onehot(datatype(A), α => 1)
A *= vRα
Ris = [α]
end
Expand Down Expand Up @@ -370,7 +370,7 @@ end
#
function add_trivial_index(A::ITensor, Ainds)
α = trivial_index(Ainds) #If Ainds[1] has no QNs makes Index(1), otherwise Index(QN()=>1)
= onehot(eltype(A), α => 1)
= onehot(datatype(A), α => 1)
A *=
return A, vα, [α]
end
Expand Down

0 comments on commit 70eb783

Please sign in to comment.