Skip to content

Commit

Permalink
[NDTensors] [Enchancements] Reorganize NDTensors (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Mar 13, 2023
1 parent 15edb84 commit 4606546
Show file tree
Hide file tree
Showing 36 changed files with 1,951 additions and 1,866 deletions.
4 changes: 2 additions & 2 deletions ITensorGPU/src/tensor/cudense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ function Base.permute!(B::CuDenseTensor, A::CuDenseTensor)
@assert isperm(perm)
permutedims!(reshapeBdata, reshapeAdata, invperm(perm))
end
return Tensor(inds(B), Dense(vec(reshapeBdata)))
return Tensor(Dense(vec(reshapeBdata)), inds(B))
end

function Base.permute!(B::CuDense, Bis::IndexSet, A::CuDense, Ais::IndexSet)
Expand Down Expand Up @@ -567,7 +567,7 @@ function Base.permute!(B::CuDense, Bis::IndexSet, A::CuDense, Ais::IndexSet)
reshapeBdata,
Vector{Char}(ctbinds),
)
return Tensor(Bis, Dense(vec(reshapeBdata)))
return Tensor(Dense(reshapeBdata), Tuple(Bis))
end

Base.:/(A::CuDenseTensor, x::Number) = A * inv(x)
38 changes: 25 additions & 13 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,32 +42,42 @@ include("abstractarray/set_types.jl")
include("abstractarray/to_shape.jl")
include("abstractarray/similar.jl")
include("abstractarray/ndims.jl")
include("abstractarray/fill.jl")
include("array/set_types.jl")
include("tupletools.jl")
include("dims.jl")
include("tensorstorage.jl")
include("default_storage.jl")
include("tensorstorage/tensorstorage.jl")
include("tensorstorage/default_storage.jl")
include("tensorstorage/similar.jl")
include("tensor.jl")
include("tensor/tensor.jl")
include("dims.jl")
include("tensor/set_types.jl")
include("tensor/similar.jl")
include("adapt.jl")
include("generic_tensor_operations.jl")
include("contraction_logic.jl")
include("tensoralgebra/generic_tensor_operations.jl")
include("tensoralgebra/contraction_logic.jl")

#####################################
# DenseTensor and DiagTensor
#
include("dense/dense.jl")
#include("dense/adapt.jl")
include("fill.jl")
include("symmetric.jl")
include("linearalgebra.jl")
include("dense/densetensor.jl")
include("dense/tensoralgebra/contract.jl")
include("dense/linearalgebra/decompositions.jl")
include("dense/tensoralgebra/outer.jl")
include("dense/set_types.jl")
include("dense/fill.jl")
include("linearalgebra/symmetric.jl")
include("linearalgebra/linearalgebra.jl")
include("diag/diag.jl")
include("diag/set_types.jl")
include("diag/diagtensor.jl")
include("diag/similar.jl")
include("diag/tensoralgebra/contract.jl")
include("diag/tensoralgebra/outer.jl")
include("combiner/combiner.jl")
include("combiner/contract.jl")
include("truncate.jl")
include("svd.jl")
include("linearalgebra/svd.jl")

#####################################
# BlockSparseTensor
Expand All @@ -92,6 +102,8 @@ include("blocksparse/linearalgebra.jl")
# Empty
#
include("empty/empty.jl")
include("empty/EmptyTensor.jl")
include("empty/tensoralgebra/contract.jl")
include("empty/adapt.jl")

#####################################
Expand Down Expand Up @@ -211,10 +223,10 @@ end
function __init__()
@require TBLIS = "48530278-0828-4a49-9772-0f3830dfa1e9" begin
enable_tblis()
include("tblis.jl")
include("tensoralgebra/tblis.jl")
end
@require Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" begin
include("octavian.jl")
include("linearalgebra/octavian.jl")
end
end

Expand Down
15 changes: 15 additions & 0 deletions NDTensors/src/abstractarray/fill.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
function generic_randn(DataT::Type{<:AbstractArray}, dim::Integer=0)
DataT = set_eltype_if_unspecified(DataT)
data = similar(DataT, dim)
ElT = eltype(DataT)
for i in 1:length(data)
data[i] = randn(ElT)
end
return data
end

function generic_zeros(DataT::Type{<:AbstractArray}, dim::Integer=0)
DataT = set_eltype_if_unspecified(DataT)
ElT = eltype(DataT)
return fill!(similar(DataT, dim), zero(ElT))
end
29 changes: 29 additions & 0 deletions NDTensors/src/abstractarray/set_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,32 @@ like `OffsetArrays` or named indices
function set_indstype(arraytype::Type{<:AbstractArray}, dims::Tuple)
return set_ndims(arraytype, length(dims))
end

function set_eltype_if_unspecified(
arraytype::Type{<:AbstractArray{T}}, eltype::Type=default_eltype()
) where {T}
return arraytype
end

#TODO transition to set_eltype when working for wrapped types
function set_eltype_if_unspecified(
arraytype::Type{<:AbstractArray}, eltype::Type=default_eltype()
)
return similartype(arraytype, eltype)
end

function set_parameter_if_unspecified(
arraytype::Type{<:AbstractArray{ElT,N}}, eltype::Type=default_eltype(), ndims::Integer=1
) where {ElT,N}
return arraytype
end
function set_parameter_if_unspecified(
arraytype::Type{<:AbstractArray{ElT}}, eltype::Type=default_eltype(), ndims::Integer=1
) where {ElT}
return set_ndims(arraytype, ndims)
end
function set_parameter_if_unspecified(
arraytype::Type{<:AbstractArray}, eltype::Type=default_eltype(), ndims::Integer=1
)
return set_eltype(set_ndims(arraytype, ndims), eltype)
end
13 changes: 0 additions & 13 deletions NDTensors/src/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,3 @@ to_vector_type(arraytype::Type{<:AbstractVector}) = arraytype

to_vector_type(arraytype::Type{Array}) = Vector
to_vector_type(arraytype::Type{Array{T}}) where {T} = Vector{T}

function set_eltype_if_unspecified(
arraytype::Type{<:AbstractArray{T}}, eltype::Type=default_eltype()
) where {T}
return arraytype
end

#TODO transition to set_eltype when working for wrapped types
function set_eltype_if_unspecified(
arraytype::Type{<:AbstractArray}, eltype::Type=default_eltype()
)
return similartype(arraytype, eltype)
end
104 changes: 0 additions & 104 deletions NDTensors/src/combiner/combiner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,110 +70,6 @@ end
blockperm(C::CombinerTensor) = blockperm(storage(C))
blockcomb(C::CombinerTensor) = blockcomb(storage(C))

function contraction_output(
::TensorT1, ::TensorT2, indsR::Tuple
) where {TensorT1<:CombinerTensor,TensorT2<:DenseTensor}
TensorR = contraction_output_type(TensorT1, TensorT2, indsR)
return similar(TensorR, indsR)
end

function contraction_output(
T1::TensorT1, T2::TensorT2, indsR
) where {TensorT1<:DenseTensor,TensorT2<:CombinerTensor}
return contraction_output(T2, T1, indsR)
end

function contract!!(
output_tensor::Tensor,
output_tensor_labels,
combiner_tensor::CombinerTensor,
combiner_tensor_labels,
tensor::Tensor,
tensor_labels,
)
if ndims(combiner_tensor) 1
# Empty combiner, acts as multiplying by 1
output_tensor = permutedims!!(
output_tensor, tensor, getperm(output_tensor_labels, tensor_labels)
)
return output_tensor
end
if is_index_replacement(tensor, tensor_labels, combiner_tensor, combiner_tensor_labels)
ui = setdiff(combiner_tensor_labels, tensor_labels)[]
newind = inds(combiner_tensor)[findfirst(==(ui), combiner_tensor_labels)]
cpos1, cpos2 = intersect_positions(combiner_tensor_labels, tensor_labels)
output_tensor_storage = copy(storage(tensor))
output_tensor_inds = setindex(inds(tensor), newind, cpos2)
return NDTensors.tensor(output_tensor_storage, output_tensor_inds)
end
is_combining_contraction = is_combining(
tensor, tensor_labels, combiner_tensor, combiner_tensor_labels
)
if is_combining_contraction
Alabels, Blabels = tensor_labels, combiner_tensor_labels
final_labels = contract_labels(Blabels, Alabels)
final_labels_n = contract_labels(combiner_tensor_labels, tensor_labels)
output_tensor_inds = inds(output_tensor)
if final_labels != final_labels_n
perm = getperm(final_labels_n, final_labels)
output_tensor_inds = permute(inds(output_tensor), perm)
output_tensor_labels = permute(output_tensor_labels, perm)
end
cpos1, output_tensor_cpos = intersect_positions(
combiner_tensor_labels, output_tensor_labels
)
labels_comb = deleteat(combiner_tensor_labels, cpos1)
output_tensor_vl = [output_tensor_labels...]
for (ii, li) in enumerate(labels_comb)
insert!(output_tensor_vl, output_tensor_cpos + ii, li)
end
deleteat!(output_tensor_vl, output_tensor_cpos)
labels_perm = tuple(output_tensor_vl...)
perm = getperm(labels_perm, tensor_labels)
tensorp = reshape(output_tensor, permute(inds(tensor), perm))
permutedims!(tensorp, tensor, perm)
return reshape(tensorp, output_tensor_inds)
else # Uncombining
cpos1, cpos2 = intersect_positions(combiner_tensor_labels, tensor_labels)
output_tensor_storage = copy(storage(tensor))
indsC = deleteat(inds(combiner_tensor), cpos1)
output_tensor_inds = insertat(inds(tensor), indsC, cpos2)
return NDTensors.tensor(output_tensor_storage, output_tensor_inds)
end
return invalid_combiner_contraction_error(
tensor, tensor_labels, combiner_tensor, combiner_tensor_labels
)
end

function contract!!(
output_tensor::Tensor,
output_tensor_labels,
tensor::Tensor,
tensor_labels,
combiner_tensor::CombinerTensor,
combiner_tensor_labels,
)
return contract!!(
output_tensor,
output_tensor_labels,
combiner_tensor,
combiner_tensor_labels,
tensor,
tensor_labels,
)
end

function contract(
diag_tensor::DiagTensor,
diag_tensor_labels,
combiner_tensor::CombinerTensor,
combiner_tensor_labels,
)
return contract(
dense(diag_tensor), diag_tensor_labels, combiner_tensor, combiner_tensor_labels
)
end

function is_index_replacement(
tensor::Tensor, tensor_labels, combiner_tensor::CombinerTensor, combiner_tensor_labels
)
Expand Down
103 changes: 103 additions & 0 deletions NDTensors/src/combiner/contract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
function contraction_output(
::TensorT1, ::TensorT2, indsR::Tuple
) where {TensorT1<:CombinerTensor,TensorT2<:DenseTensor}
TensorR = contraction_output_type(TensorT1, TensorT2, indsR)
return similar(TensorR, indsR)
end

function contraction_output(
T1::TensorT1, T2::TensorT2, indsR
) where {TensorT1<:DenseTensor,TensorT2<:CombinerTensor}
return contraction_output(T2, T1, indsR)
end

function contract!!(
output_tensor::Tensor,
output_tensor_labels,
combiner_tensor::CombinerTensor,
combiner_tensor_labels,
tensor::Tensor,
tensor_labels,
)
if ndims(combiner_tensor) 1
# Empty combiner, acts as multiplying by 1
output_tensor = permutedims!!(
output_tensor, tensor, getperm(output_tensor_labels, tensor_labels)
)
return output_tensor
end
if is_index_replacement(tensor, tensor_labels, combiner_tensor, combiner_tensor_labels)
ui = setdiff(combiner_tensor_labels, tensor_labels)[]
newind = inds(combiner_tensor)[findfirst(==(ui), combiner_tensor_labels)]
cpos1, cpos2 = intersect_positions(combiner_tensor_labels, tensor_labels)
output_tensor_storage = copy(storage(tensor))
output_tensor_inds = setindex(inds(tensor), newind, cpos2)
return NDTensors.tensor(output_tensor_storage, output_tensor_inds)
end
is_combining_contraction = is_combining(
tensor, tensor_labels, combiner_tensor, combiner_tensor_labels
)
if is_combining_contraction
Alabels, Blabels = tensor_labels, combiner_tensor_labels
final_labels = contract_labels(Blabels, Alabels)
final_labels_n = contract_labels(combiner_tensor_labels, tensor_labels)
output_tensor_inds = inds(output_tensor)
if final_labels != final_labels_n
perm = getperm(final_labels_n, final_labels)
output_tensor_inds = permute(inds(output_tensor), perm)
output_tensor_labels = permute(output_tensor_labels, perm)
end
cpos1, output_tensor_cpos = intersect_positions(
combiner_tensor_labels, output_tensor_labels
)
labels_comb = deleteat(combiner_tensor_labels, cpos1)
output_tensor_vl = [output_tensor_labels...]
for (ii, li) in enumerate(labels_comb)
insert!(output_tensor_vl, output_tensor_cpos + ii, li)
end
deleteat!(output_tensor_vl, output_tensor_cpos)
labels_perm = tuple(output_tensor_vl...)
perm = getperm(labels_perm, tensor_labels)
tensorp = reshape(output_tensor, permute(inds(tensor), perm))
permutedims!(tensorp, tensor, perm)
return reshape(tensorp, output_tensor_inds)
else # Uncombining
cpos1, cpos2 = intersect_positions(combiner_tensor_labels, tensor_labels)
output_tensor_storage = copy(storage(tensor))
indsC = deleteat(inds(combiner_tensor), cpos1)
output_tensor_inds = insertat(inds(tensor), indsC, cpos2)
return NDTensors.tensor(output_tensor_storage, output_tensor_inds)
end
return invalid_combiner_contraction_error(
tensor, tensor_labels, combiner_tensor, combiner_tensor_labels
)
end

function contract!!(
output_tensor::Tensor,
output_tensor_labels,
tensor::Tensor,
tensor_labels,
combiner_tensor::CombinerTensor,
combiner_tensor_labels,
)
return contract!!(
output_tensor,
output_tensor_labels,
combiner_tensor,
combiner_tensor_labels,
tensor,
tensor_labels,
)
end

function contract(
diag_tensor::DiagTensor,
diag_tensor_labels,
combiner_tensor::CombinerTensor,
combiner_tensor_labels,
)
return contract(
dense(diag_tensor), diag_tensor_labels, combiner_tensor, combiner_tensor_labels
)
end

0 comments on commit 4606546

Please sign in to comment.