Skip to content

Commit

Permalink
[NDTensors] Some fixes for element type promotion and conversion (#1244)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Nov 10, 2023
1 parent cfc9dc9 commit 7eb2e30
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 26 deletions.
8 changes: 5 additions & 3 deletions NDTensors/ext/NDTensorsMetalExt/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@ function NDTensors.Unwrap.ql_positive(A::Exposed{<:MtlMatrix})
end

function LinearAlgebra.eigen(A::Exposed{<:MtlMatrix})
D, U = eigen(expose(NDTensors.cpu(A)))
return adapt(set_ndims(unwrap_type(A), ndims(D)), D), adapt(unwrap_type(A), U)
Dcpu, Ucpu = eigen(expose(NDTensors.cpu(A)))
D = adapt(set_ndims(set_eltype(unwrap_type(A), eltype(Dcpu)), ndims(Dcpu)), Dcpu)
U = adapt(unwrap_type(A), Ucpu)
return D, U
end

function LinearAlgebra.svd(A::Exposed{<:MtlMatrix}; kwargs...)
Ucpu, Scpu, Vcpu = svd(expose(NDTensors.cpu(A)); kwargs...)
U = adapt(unwrap_type(A), Ucpu)
S = adapt(set_ndims(unwrap_type(A), ndims(Scpu)), Scpu)
S = adapt(set_ndims(set_eltype(unwrap_type(A), eltype(Scpu)), ndims(Scpu)), Scpu)
V = adapt(unwrap_type(A), Vcpu)
return U, S, V
end
27 changes: 21 additions & 6 deletions NDTensors/src/Unwrap/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ using NDTensors
using LinearAlgebra

include("../../../test/device_list.jl")
@testset "Testing Unwrap" for dev in devices_list(ARGS)
elt = Float32
@testset "Testing Unwrap $dev, $elt" for dev in devices_list(ARGS),
elt in (Float32, ComplexF32)

v = dev(Vector{elt}(undef, 10))
v = dev(randn(elt, 10))
vt = transpose(v)
va = v'

Expand Down Expand Up @@ -39,7 +39,7 @@ include("../../../test/device_list.jl")
@test typeof(Et) == Exposed{m_type,LinearAlgebra.Transpose{e_type,m_type}}
@test typeof(Ea) == Exposed{m_type,LinearAlgebra.Adjoint{e_type,m_type}}

o = dev(Vector{elt})(undef, 1)
o = dev(randn(elt, 1))
expose(o)[] = 2
@test expose(o)[] == 2

Expand All @@ -58,17 +58,32 @@ include("../../../test/device_list.jl")
q, r = Unwrap.qr_positive(expose(mp))
@test q * r mp

square = dev(rand(elt, (10, 10)))
square = dev(rand(real(elt), (10, 10)))
square = (square + transpose(square)) / 2
## CUDA only supports Hermitian or Symmetric eigen decompositions
## So I symmetrize square and call symetric here
l, U = eigen(expose(Symmetric(square)))
@test eltype(l) == real(elt)
@test eltype(U) == real(elt)
@test square * U U * Diagonal(l)

square = dev(rand(elt, (10, 10)))
# Can use `hermitianpart` in Julia 1.10
square = (square + square') / 2
## CUDA only supports Hermitian or Symmetric eigen decompositions
## So I symmetrize square and call symetric here
l, U = eigen(expose(Hermitian(square)))
@test eltype(l) == real(elt)
@test eltype(U) == elt
@test square * U U * Diagonal(l)

U, S, V, = svd(expose(mp))
@test eltype(U) == elt
@test eltype(S) == real(elt)
@test eltype(V) == elt
@test U * Diagonal(S) * V' mp

cm = dev(fill!(Matrix{elt}(undef, (2, 2)), 0.0))
cm = dev(randn(elt, 2, 2))
mul!(expose(cm), expose(mp), expose(mp'), 1.0, 0.0)
@test cm mp * mp'

Expand Down
16 changes: 11 additions & 5 deletions NDTensors/src/blocksparse/blocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,26 @@ can_contract(T1::Type{<:BlockSparse}, T2::Type{<:Dense}) = can_contract(T2, T1)
function promote_rule(
::Type{<:BlockSparse{ElT1,VecT1,N}}, ::Type{<:BlockSparse{ElT2,VecT2,N}}
) where {ElT1,ElT2,VecT1,VecT2,N}
return BlockSparse{promote_type(ElT1, ElT2),promote_type(VecT1, VecT2),N}
# Promote the element types properly.
ElT = promote_type(ElT1, ElT2)
VecT = promote_type(set_eltype(VecT1, ElT), set_eltype(VecT2, ElT))
return BlockSparse{ElT,VecT,N}
end

function promote_rule(
::Type{<:BlockSparse{ElT1,VecT1,N1}}, ::Type{<:BlockSparse{ElT2,VecT2,N2}}
) where {ElT1,ElT2,VecT1,VecT2,N1,N2}
return BlockSparse{promote_type(ElT1, ElT2),promote_type(VecT1, VecT2),NR} where {NR}
# Promote the element types properly.
ElT = promote_type(ElT1, ElT2)
VecT = promote_type(set_eltype(VecT1, ElT), set_eltype(VecT2, ElT))
return BlockSparse{ElT,VecT,NR} where {NR}
end

function promote_rule(
::Type{<:BlockSparse{ElT1,Vector{ElT1},N1}}, ::Type{ElT2}
) where {ElT1,ElT2<:Number,N1}
::Type{<:BlockSparse{ElT1,VecT1,N1}}, ::Type{ElT2}
) where {ElT1,VecT1<:AbstractVector{ElT1},ElT2<:Number,N1}
ElR = promote_type(ElT1, ElT2)
VecR = Vector{ElR}
VecR = set_eltype(VecT1, ElR)
return BlockSparse{ElR,VecR,N1}
end

Expand Down
11 changes: 11 additions & 0 deletions NDTensors/test/ITensors/TestITensorDMRG/TestITensorDMRG.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,16 @@ reference_energies = Dict([
])

default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt))))

is_supported_eltype(dev, elt::Type) = true
is_supported_eltype(dev::typeof(NDTensors.mtl), elt::Type{Float64}) = false
function is_supported_eltype(dev::typeof(NDTensors.mtl), elt::Type{<:Complex})
return is_supported_eltype(dev, real(elt))
end

is_broken(dev, elt::Type, conserve_qns::Val) = false
is_broken(dev::typeof(NDTensors.cu), elt::Type, conserve_qns::Val{true}) = true

include("dmrg.jl")

end
9 changes: 5 additions & 4 deletions NDTensors/test/ITensors/TestITensorDMRG/dmrg.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function test_dmrg(elt, N::Integer, dev::Function)
sites = siteinds("S=1/2", N)
function test_dmrg(elt, N::Integer; dev::Function, conserve_qns)
sites = siteinds("S=1/2", N; conserve_qns)

os = OpSum()
for j in 1:(N - 1)
Expand All @@ -9,12 +9,13 @@ function test_dmrg(elt, N::Integer, dev::Function)
end

Random.seed!(1234)
psi0 = dev(randomMPS(Float64, sites; linkdims=4))
init = j -> isodd(j) ? "" : ""
psi0 = dev(randomMPS(elt, sites, init; linkdims=4))
H = dev(MPO(elt, os, sites))

nsweeps = 3
cutoff = [1e-3, 1e-13]
noise = [1e-12, 0]
noise = [1e-6, 0]
## running these with nsweeps = 100 and no maxdim
## all problems do not have a maxlinkdim > 32
maxdim = 32
Expand Down
18 changes: 13 additions & 5 deletions NDTensors/test/ITensors/TestITensorDMRG/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@ include("TestITensorDMRG.jl")

include("../../device_list.jl")

@testset "Testing DMRG different backends" begin
for dev in devices_list(ARGS),
N in [4, 10],
elt in (Float32, ComplexF32, Float64, ComplexF64)
@testset "Test DMRG $dev, $conserve_qns, $elt, $N" for dev in devices_list(ARGS),
conserve_qns in [false, true],
elt in (Float32, ComplexF32, Float64, ComplexF64),
N in [4, 10]

TestITensorDMRG.test_dmrg(elt, N, dev)
if !TestITensorDMRG.is_supported_eltype(dev, elt)
continue
end
if TestITensorDMRG.is_broken(dev, elt, Val(conserve_qns))
# TODO: Switch to `@test ... broken=true`, introduced
# in Julia 1.7.
@test_broken TestITensorDMRG.test_dmrg(elt, N; dev, conserve_qns)
else
TestITensorDMRG.test_dmrg(elt, N; dev, conserve_qns)
end
end
8 changes: 5 additions & 3 deletions src/mps/dmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,12 @@ function dmrg(
ortho = ha == 1 ? "left" : "right"

drho = nothing
if noise(sweeps, sw) > 0.0
if noise(sweeps, sw) > 0
@timeit_debug timer "dmrg: noiseterm" begin
# Use noise term when determining new MPS basis
drho = noise(sweeps, sw) * noiseterm(PH, phi, ortho)
# Use noise term when determining new MPS basis.
# This is used to preserve the element type of the MPS.
elt = real(scalartype(psi))
drho = elt(noise(sweeps, sw)) * noiseterm(PH, phi, ortho)
end
end

Expand Down

0 comments on commit 7eb2e30

Please sign in to comment.