Skip to content

Commit

Permalink
Improve Space a bit (#40)
Browse files Browse the repository at this point in the history
* Add tensorsize for Space and use it intead of Dims

* Add tensororder for Space and use it intead of ndims

* Add dot and double_contraction for Space
  • Loading branch information
KeitaNakamura committed Feb 17, 2021
1 parent 0bbefe3 commit 5afe3e5
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 17 deletions.
6 changes: 3 additions & 3 deletions src/AbstractTensor.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
abstract type AbstractTensor{S <: Tuple, T, N} <: AbstractArray{T, N} end

Base.size(::Type{TT}) where {S, TT <: AbstractTensor{S}} = Dims(Space(S))
Base.size(::Type{TT}) where {S, TT <: AbstractTensor{S}} = tensorsize(Space(S))
Base.size(x::AbstractTensor) = size(typeof(x))

# indices
Expand All @@ -23,11 +23,11 @@ end
# to SArray
@generated function convert_to_SArray(x::AbstractTensor)
S = Space(x)
NewS = Space(Dims(S)) # remove Symmetry
NewS = Space(tensorsize(S)) # remove Symmetry
exps = [getindex_expr(:x, x, i) for i in indices(NewS)]
quote
@_inline_meta
@inbounds SArray{Tuple{$(Dims(NewS)...)}}(tuple($(exps...)))
@inbounds SArray{Tuple{$(tensorsize(NewS)...)}}(tuple($(exps...)))
end
end

Expand Down
21 changes: 14 additions & 7 deletions src/Space.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@ _ncomponents(x::Symmetry) = ncomponents(x)
@pure Base.Dims(::Space{S}) where {S} = flatten_tuple(map(Dims, S))
@pure Base.Tuple(::Space{S}) where {S} = S

@pure Base.ndims(s::Space) = length(Dims(s))
@pure Base.length(s::Space) = length(Tuple(s))
Base.getindex(s::Space, i::Int) = Tuple(s)[i]

@pure tensorsize(s::Space) = Dims(s)
@pure tensororder(s::Space) = length(tensorsize(s))
# don't allow to use `size` and `ndims` because their names are confusing.
Base.size(s::Space) = throw(ArgumentError("use `tensorsize` to get size of a tensor instead of `size`"))
Base.ndims(s::Space) = throw(ArgumentError("use `tensororder` to get order of a tensor instead of `ndims`"))

function Base.show(io::IO, ::Space{S}) where {S}
print(io, "Space", S)
end
Expand All @@ -46,21 +51,23 @@ for op in (:dropfirst, :droplast)
end
end

# otimes/contraction
@pure otimes(x::Space, y::Space) = Space(Tuple(x)..., Tuple(y)...)
# contractions
@pure function contraction(x::Space, y::Space, ::Val{N}) where {N}
if !(0 N ndims(x) && 0 N ndims(y) && Dims(x)[end-N+1:end] === Dims(y)[1:N])
if !(0 N tensororder(x) && 0 N tensororder(y) && tensorsize(x)[end-N+1:end] === tensorsize(y)[1:N])
throw(DimensionMismatch("dimensions must match"))
end
otimes(droplast(x, Val(N)), dropfirst(y, Val(N)))
end
@pure otimes(x::Space, y::Space) = Space(Tuple(x)..., Tuple(y)...)
@pure dot(x::Space, y::Space) = contraction(x, y, Val(1))
@pure double_contraction(x::Space, y::Space) = contraction(x, y, Val(2))

# promote_space
promote_space(x::Space) = x
@generated function promote_space(x::Space{S1}, y::Space{S2}) where {S1, S2}
S = _promote_space(S1, S2, ())
quote
Dims(x) == Dims(y) || throw(DimensionMismatch("dimensions must match"))
tensorsize(x) == tensorsize(y) || throw(DimensionMismatch("dimensions must match"))
Space($S)
end
end
Expand Down Expand Up @@ -97,10 +104,10 @@ end
_typeof(x::Int) = x
_typeof(x::Symmetry) = typeof(x)
@pure function tensortype(x::Space)
Tensor{Tuple{map(_typeof, Tuple(x))...}, T, ndims(x), ncomponents(x)} where {T}
Tensor{Tuple{map(_typeof, Tuple(x))...}, T, tensororder(x), ncomponents(x)} where {T}
end

# LinearIndices/CartesianIndices
for IndicesType in (LinearIndices, CartesianIndices)
@eval (::Type{$IndicesType})(x::Space) = $IndicesType(Dims(x))
@eval (::Type{$IndicesType})(x::Space) = $IndicesType(tensorsize(x))
end
8 changes: 4 additions & 4 deletions src/Tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ end

@generated function check_tensor_parameters(::Type{S}, ::Type{T}, ::Val{N}, ::Val{L}) where {S, T, N, L}
check_size_parameters(S)
if ndims(Space(S)) != N
return :(throw(ArgumentError("Number of dimensions must be $(ndims(Space(S))) for $S size, got $N.")))
if tensororder(Space(S)) != N
return :(throw(ArgumentError("Number of dimensions must be $(tensororder(Space(S))) for $S size, got $N.")))
end
if ncomponents(Space(S)) != L
return :(throw(ArgumentError("Length of tuple data must be $(ncomponents(Space(S))) for $S size, got $L.")))
Expand All @@ -26,11 +26,11 @@ const Vec{dim, T} = Tensor{Tuple{dim}, T, 1, dim}

# constructors
@inline function Tensor{S, T}(data::Tuple{Vararg{Any, L}}) where {S, T, L}
N = ndims(Space(S))
N = tensororder(Space(S))
Tensor{S, T, N, L}(data)
end
@inline function Tensor{S}(data::Tuple{Vararg{Any, L}}) where {S, L}
N = ndims(Space(S))
N = tensororder(Space(S))
T = promote_ntuple_eltype(data)
Tensor{S, T, N, L}(data)
end
Expand Down
4 changes: 2 additions & 2 deletions src/simd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ const SIMDTypes = Union{Float16, Float32, Float64}
@generated function contraction(x::Tensor{<: Any, T, order1}, y::Tensor{<: Any, T, order2}, ::Val{N}) where {T <: SIMDTypes, N, order1, order2}
S1 = Space(x)
S2 = Space(y)
S_Inner = Space((Dims(S2)[i] for i in 1:N)...)
S_Inner = Space((tensorsize(S2)[i] for i in 1:N)...)
S1 = otimes(droplast(S1, Val(N)), S_Inner)
S2 = otimes(S_Inner, dropfirst(S2, Val(N)))
s1 = [:(Tuple(x)[$i]) for i in 1:ncomponents(S1)]
s2 = [:(Tuple(y)[$i]) for i in 1:ncomponents(S2)]
K = prod(Dims(S_Inner))
K = prod(tensorsize(S_Inner))
I = length(s1) ÷ K
J = length(s2) ÷ K
s1′ = reshape(s1, I, K)
Expand Down
6 changes: 5 additions & 1 deletion test/Space.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
# basic
@test (@inferred length(Space(3,3))) == 2
@test (@inferred length(Space(Symmetry(3,3),3))) == 2
@test (@inferred ndims(Space(Symmetry(3,3),3))) == 3
@test (@inferred Tensorial.tensorsize(Space(Symmetry(3,3),3))) == (3,3,3)
@test (@inferred Tensorial.tensororder(Space(Symmetry(3,3),3))) == 3
@test (@inferred Tuple(Space(Symmetry(3,3),3))) == (Symmetry(3,3),3)
@test Space(3,2)[1] == 3
@test Space(3,2)[2] == 2
# prohibited
@test_throws Exception size(Space(Symmetry(3,3),3))
@test_throws Exception ndims(Space(Symmetry(3,3),3))
# promotion
@test (@inferred Tensorial.promote_space(Space(3,2), Space(3,2))) == Space(3,2)
@test (@inferred Tensorial.promote_space(Space(3,3,3), Space(Symmetry(3,3),3))) == Space(3,3,3)
Expand Down

0 comments on commit 5afe3e5

Please sign in to comment.