Skip to content

Commit

Permalink
further changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Sep 12, 2017
1 parent 2822721 commit 1a92478
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 38 deletions.
21 changes: 7 additions & 14 deletions src/auxiliary/meta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,21 @@
#
# A bunch of auxiliary metaprogramming tools and generated functions

@generated function _strides(A::StridedArray{T,N}) where {T,N}
meta = Expr(:meta,:inline)
ex = Expr(:tuple,[:(stride(A,$d)) for d = 1:N]...)
Expr(:block, meta, ex)
end

@generated function _indmax(values::NTuple{N,T}) where {N,T}
@generated function _indmax(values::NTuple{N}) where {N}
meta = Expr(:meta,:inline)
Expr(:block, meta, :(dmax = 1), :(max = values[1]), [:(values[$d] > max && (dmax = $d; max = values[$d])) for d = 2:N]..., :(return dmax))
end

@generated function _permute(t::NTuple{N,T}, p) where {T,N}
@generated function _permute(t::NTuple{N}, p) where {N}
meta = Expr(:meta,:inline)
ex = Expr(:tuple,[:(t[p[$d]]) for d = 1:N]...)
Expr(:block, meta, ex)
end

@generated function _memjumps(dims::NTuple{N,Int},strides::NTuple{N,Int}) where N
@inline memjumps(dims::Tuple{}, strides::Tuple{}) = ()
@inline memjumps(dims::NTuple{N,Int}, strides::NTuple{N,Int}) where {N} = ((dims[1]-1)*strides[1], memjumps(tail(dims),tail(strides))...)

@generated function _memjumps(dims::NTuple{N,Int}, strides::NTuple{N,Int}) where N
meta = Expr(:meta,:inline)
ex = Expr(:tuple,[:((dims[$d]-1)*strides[$d]) for d = 1:N]...)
Expr(:block, meta, ex)
Expand Down Expand Up @@ -80,11 +77,7 @@ function _stridedloops(N::Int, dims::Symbol, args...)
forex = Expr(:(=), gensym(), rangeex)
ex = Expr(:for, forex, ex)
if d==1
if VERSION < v"0.7-"
ex = Expr(:macrocall, Symbol("@simd"), ex)
else
ex = Expr(:macrocall, Symbol("@simd"), LineNumberNode(@__LINE__), ex)
end
ex = :(@simd $ex)
end
end
pre = [Expr(:(=),Symbol(args[i],N),args[i+1]) for i in argiter]
Expand Down
22 changes: 15 additions & 7 deletions src/auxiliary/stridedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,37 @@
"""
numind(A)
Returns the number of indices of a tensor-like object `A`, i.e. for a multidimensional array (`<:AbstractArray`) we have `numind(A) = ndims(A)`. Also works in type domain.
Returns the number of indices of a tensor-like object `A`, i.e. for a multidimensional
array (`<:AbstractArray`) we have `numind(A) = ndims(A)`. Also works in type domain.
"""
numind(A::AbstractArray) = ndims(A)
numind(::Type{T}) where {T<:AbstractArray} = ndims(T)

"""
similar_from_indices(T, indices, A, conjA=Val{:N})
Returns an object similar to `A` which has an `eltype` given by `T` and dimensions/sizes corresponding to a selection of those of `op(A)`, where the selection is specified by `indices` (which contains integer between `1` and `numind(A)`) and `op` is `conj` if `conjA=Val{:C}` or does nothing if `conjA=Val{:N}` (default).
Returns an object similar to `A` which has an `eltype` given by `T` and dimensions/sizes
corresponding to a selection of those of `op(A)`, where the selection is specified by
`indices` (which contains integer between `1` and `numind(A)`) and `op` is `conj` if
`conjA=Val{:C}` or does nothing if `conjA=Val{:N}` (default).
"""
function similar_from_indices(::Type{T}, indices, A::StridedArray, ::Type{Val{CA}}=Val{:N}) where {T,CA}
dims = size(A)
return similar(A,T,dims[indices])
srcdims = size(A)
return similar(A, T, srcdims[indices])
end

"""
similar_from_indices(T, indices, A, B, conjA=Val{:N}, conjB={:N})
Returns an object similar to `A` which has an `eltype` given by `T` and dimensions/sizes corresponding to a selection of those of `op(A)` and `op(B)` concatenated, where the selection is specified by `indices` (which contains integers between `1` and `numind(A)+numind(B)` and `op` is `conj` if `conjA` or `conjB` equal `Val{:C}` or does nothing if `conjA` or `conjB` equal `Val{:N}` (default).
Returns an object similar to `A` which has an `eltype` given by `T` and dimensions/sizes
corresponding to a selection of those of `op(A)` and `op(B)` concatenated, where the
selection is specified by `indices` (which contains integers between `1` and
`numind(A)+numind(B)` and `op` is `conj` if `conjA` or `conjB` equal `Val{:C}`
or does nothing if `conjA` or `conjB` equal `Val{:N}` (default).
"""
function similar_from_indices(::Type{T}, indices, A::StridedArray, B::StridedArray, ::Type{Val{CA}}=Val{:N}, ::Type{Val{CB}}=Val{:N}) where {T,CA,CB}
dims = tuple(size(A)...,size(B)...)
return similar(A,T,dims[indices])
srcdims = tuple(size(A)...,size(B)...)
return similar(A, T, srcdims[indices])
end

"""
Expand Down
18 changes: 9 additions & 9 deletions src/auxiliary/strideddata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ ConjugatedStridedData{N,T} = StridedData{N,T,:C}

StridedSubArray{T,N,A<:Array,I<:Tuple{Vararg{Union{Colon,Range{Int64},Int64}}},LD} = SubArray{T,N,A,I,LD}

StridedData(a::Array{T}, strides::NTuple{N,Int} = _strides(a), ::Type{Val{C}} = Val{:N}) where {N,T,C} =
StridedData(a::Array{T}, strides::NTuple{N,Int} = strides(a), ::Type{Val{C}} = Val{:N}) where {N,T,C} =
StridedData{N,T,C}(vec(a), strides, 1)
StridedData(a::StridedSubArray{T}, strides::NTuple{N,Int} = _strides(a), ::Type{Val{C}} = Val{:N}) where {N,T,C} =
StridedData(a::StridedSubArray{T}, strides::NTuple{N,Int} = strides(a), ::Type{Val{C}} = Val{:N}) where {N,T,C} =
StridedData{N,T,C}(vec(a.parent), strides, Base.first_index(a))

Base.getindex(a::NormalStridedData,i) = a.data[i]
Base.getindex(a::ConjugatedStridedData,i) = conj(a.data[i])
Base.getindex(a::NormalStridedData, i) = a.data[i]
Base.getindex(a::ConjugatedStridedData, i) = conj(a.data[i])

Base.setindex!(a::NormalStridedData,v,i) = (@inbounds a.data[i] = v)
Base.setindex!(a::ConjugatedStridedData,v,i) = (@inbounds a.data[i] = conj(v))
Base.setindex!(a::NormalStridedData, v, i) = (@inbounds a.data[i] = v)
Base.setindex!(a::ConjugatedStridedData, v, i) = (@inbounds a.data[i] = conj(v))

# set dimensions dims[d]==1 for all d where a.strides[d] == 0.
@generated function _filterdims(dims::NTuple{N,Int}, a::StridedData{N}) where N
@generated function _filterdims(dims::NTuple{N,Int}, a::StridedData{N}) where {N}
meta = Expr(:meta,:inline)
ex = Expr(:tuple,[:(a.strides[$d]==0 ? 1 : dims[$d]) for d=1:N]...)
Expr(:block,meta,ex)
Expand All @@ -36,7 +36,7 @@ end
# initial scaling of a block specified by dims
_scale!(C::StridedData{N}, β::One, dims::NTuple{N,Int}, offset::Int=0) where {N} = C

@generated function _scale!(C::StridedData{N}, β::Zero, dims::NTuple{N,Int}, offset::Int=0) where N
@generated function _scale!(C::StridedData{N}, β::Zero, dims::NTuple{N,Int}, offset::Int=0) where {N}
meta = Expr(:meta,:inline)
quote
$meta
Expand All @@ -48,7 +48,7 @@ _scale!(C::StridedData{N}, β::One, dims::NTuple{N,Int}, offset::Int=0) where {N
end
end

@generated function _scale!(C::StridedData{N}, β::Number, dims::NTuple{N,Int}, offset::Int=0) where N
@generated function _scale!(C::StridedData{N}, β::Number, dims::NTuple{N,Int}, offset::Int=0) where {N}
meta = Expr(:meta,:inline)
quote
$meta
Expand Down
10 changes: 5 additions & 5 deletions src/implementation/stridedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function add!(α, A::StridedArray, ::Type{Val{CA}}, β, C::StridedArray, indCinA
size(A,indCinA[i]) == size(C,i) || throw(DimensionMismatch())
end

dims, stridesA, stridesC, minstrides = add_strides(size(C), _permute(_strides(A),indCinA), _strides(C))
dims, stridesA, stridesC, minstrides = add_strides(size(C), _permute(strides(A),indCinA), strides(C))
dataA = StridedData(A, stridesA, Val{CA})
offsetA = 0
dataC = StridedData(C, stridesC)
Expand Down Expand Up @@ -61,7 +61,7 @@ function trace!(α, A::StridedArray, ::Type{Val{CA}}, β, C::StridedArray, indCi
end

pA = vcat(indCinA, cindA1, cindA2)
dims, stridesA, stridesC, minstrides = trace_strides(_permute(size(A),pA), _permute(_strides(A),pA), _strides(C))
dims, stridesA, stridesC, minstrides = trace_strides(_permute(size(A),pA), _permute(strides(A),pA), strides(C))
dataA = StridedData(A, stridesA, Val{CA})
offsetA = 0
dataC = StridedData(C, stridesC)
Expand Down Expand Up @@ -218,9 +218,9 @@ function contract!(α, A::StridedArray, ::Type{Val{CA}}, B::StridedArray, ::Type
# Perform contraction
pA = vcat(oindA, cindA)
pB = vcat(oindB, cindB)
sA = _permute(_strides(A), pA)
sB = _permute(_strides(B), pB)
sC = _permute(_strides(C), invperm(indCinoAB))
sA = _permute(strides(A), pA)
sB = _permute(strides(B), pB)
sC = _permute(strides(C), invperm(indCinoAB))

dimsA = _permute(size(A), pA)
dimsB = _permute(size(B), pB)
Expand Down
4 changes: 2 additions & 2 deletions src/indexnotation/indexedobject.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ end
length(Jdst) == length(Idst) || throw(IndexError("left-hand side cannot have partial trace: $Idst"))
if length(Isrc) == length(Jdst)
indCinA = add_indices(Isrc, Idst)
:($meta;add!(src.α, src.object, Val{C}, β, dst, $indCinA))
:($meta; add!(src.α, src.object, Val{C}, β, dst, $indCinA))
else
indCinA, cindA1, cindA2 = trace_indices(Isrc, Idst)
return :($meta;trace!(src.α, src.object, Val{C}, β, dst, $indCinA, $cindA1, $cindA2))
return :($meta; trace!(src.α, src.object, Val{C}, β, dst, $indCinA, $cindA1, $cindA2))
end
end
2 changes: 1 addition & 1 deletion src/indexnotation/product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ end
meta = Expr(:meta,:inline)
quote
$meta
contract!(P.A.α*P.B.α,P.A.object,$conjA,P.B.object,$conjB,β,dst,$oindA,$cindA,$oindB,$cindB,$indCinoAB)
contract!(P.A.α*P.B.α, P.A.object, $conjA, P.B.object, $conjB, β, dst, $oindA, $cindA, $oindB, $cindB, $indCinoAB)
end
end

0 comments on commit 1a92478

Please sign in to comment.