Skip to content

Commit

Permalink
remove custom _strides and _indmax
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Sep 12, 2017
1 parent 0a19cf7 commit e34eb54
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 41 deletions.
38 changes: 20 additions & 18 deletions src/auxiliary/meta.jl
Expand Up @@ -2,27 +2,29 @@
#
# A bunch of auxiliary metaprogramming tools and generated functions

@generated function _strides{T,N}(A::StridedArray{T,N})
meta = Expr(:meta,:inline)
ex = Expr(:tuple,[:(stride(A,$d)) for d = 1:N]...)
Expr(:block, meta, ex)
end
import Base.tail

@generated function _indmax{N,T}(values::NTuple{N,T})
meta = Expr(:meta,:inline)
Expr(:block, meta, :(dmax = 1), :(max = values[1]), [:(values[$d] > max && (dmax = $d; max = values[$d])) for d = 2:N]..., :(return dmax))
end
# OK up to N=15
_permute(t::NTuple{N,T} where {N,T}, p) = __permute((), t, p)
@inline __permute(tdst::NTuple{N,T}, tsrc::NTuple{N,T}, p) where {N,T} = tdst
@inline __permute(tdst::NTuple{N1,T}, tsrc::NTuple{N2,T}, p) where {N1,N2,T} = __permute(tuple(tdst..., tsrc[p[N1+1]]), tsrc, p)

@generated function _permute{T,N}(t::NTuple{N,T}, p)
meta = Expr(:meta,:inline)
ex = Expr(:tuple,[:(t[p[$d]]) for d = 1:N]...)
Expr(:block, meta, ex)
end
# OK up to N=14
_memjumps(dims::Tuple{}, strides::Tuple{}) = ()
_memjumps(dims::NTuple{N,Int}, strides::NTuple{N,Int}) where {N} = tuple((dims[1]-1)*strides[1], _memjumps(tail(dims), tail(strides))...)

@generated function _memjumps{N}(dims::NTuple{N,Int},strides::NTuple{N,Int})
meta = Expr(:meta,:inline)
ex = Expr(:tuple,[:((dims[$d]-1)*strides[$d]) for d = 1:N]...)
Expr(:block, meta, ex)
# inferrable and fast up to N = 14, slow afterwards
function _invperm(p::NTuple{N,T}) where {N,T<:Integer}
ip = ntuple(n->T(n),Val{N})
__swapsort(ip, p, 1)
end
@inline __swapsort(ip::Tuple{}, p::Tuple{}, k) = ()
@inline function __swapsort(ip::NTuple{N,Integer}, p::NTuple{N,Integer}, k) where {N}
while p[1] != k
p = tuple(tail(p)..., p[1])
ip = tuple(tail(ip)..., ip[1])
end
tuple(ip[1], __swapsort(tail(ip), tail(p), k+1)...)
end

# Based on Tim Holy's Cartesian
Expand Down
35 changes: 24 additions & 11 deletions src/auxiliary/stridedarray.jl
Expand Up @@ -3,32 +3,45 @@
# Simple auxiliary methods to interface with StridedArray from Julia Base.


"""`numind(A)`
"""
numind(A)
Returns the number of indices of a tensor-like object `A`, i.e. for a multidimensional array (`<:AbstractArray`) we have `numind(A) = ndims(A)`. Also works in type domain.
"""
numind(A::AbstractArray) = ndims(A)
numind{T<:AbstractArray}(::Type{T}) = ndims(T)

"""`similar_from_indices(T, indices, A, conjA=Val{:N})`
"""
similar_from_indices(T, indices, A, conjA=Val{:N})
Returns an object similar to `A` which has an `eltype` given by `T` and dimensions/sizes corresponding to a selection of those of `op(A)`, where the selection is specified by `indices` (which contains integer between `1` and `numind(A)`) and `op` is `conj` if `conjA=Val{:C}` or does nothing if `conjA=Val{:N}` (default).
Returns an object similar to `A` which has an `eltype` given by `T` and
dimensions/sizes corresponding to a selection of those of `op(A)`,
where the selection is specified by `indices` (which contains integer between
`1` and `numind(A)`) and `op` is `conj` if `conjA=Val{:C}` or does nothing if
`conjA=Val{:N}` (default).
"""
function similar_from_indices{T,CA}(::Type{T}, indices, A::StridedArray, ::Type{Val{CA}}=Val{:N})
function similar_from_indices(::Type{T}, indices, A::StridedArray, ::Type{Val{CA}}=Val{:N}) where {T,CA}
dims = size(A)
return similar(A,T,dims[indices])
return similar(A, T, dims[indices])
end

"""`similar_from_indices(T, indices, A, B, conjA=Val{:N}, conjB={:N})`
Returns an object similar to `A` which has an `eltype` given by `T` and dimensions/sizes corresponding to a selection of those of `op(A)` and `op(B)` concatenated, where the selection is specified by `indices` (which contains integers between `1` and `numind(A)+numind(B)` and `op` is `conj` if `conjA` or `conjB` equal `Val{:C}` or does nothing if `conjA` or `conjB` equal `Val{:N}` (default).
"""
function similar_from_indices{T,CA,CB}(::Type{T}, indices, A::StridedArray, B::StridedArray, ::Type{Val{CA}}=Val{:N}, ::Type{Val{CB}}=Val{:N})
similar_from_indices(T, indices, A, B, conjA=Val{:N}, conjB={:N})
Returns an object similar to `A` which has an `eltype` given by `T` and
dimensions/sizes corresponding to a selection of those of `op(A)` and `op(B)`
concatenated, where the selection is specified by `indices`
(which contains integers between `1` and `numind(A)+numind(B)` and `op` is
`conj` if `conjA` or `conjB` equal `Val{:C}` or does nothing if `conjA` or
`conjB` equal `Val{:N}` (default).
"""
function similar_from_indices(::Type{T}, indices, A::StridedArray, B::StridedArray, ::Type{Val{CA}}=Val{:N}, ::Type{Val{CB}}=Val{:N}) where {T,CA,CB}
dims = tuple(size(A)...,size(B)...)
return similar(A,T,dims[indices])
return similar(A, T, dims[indices])
end

"""`scalar(C)`
"""
scalar(C)
Returns the single element of a tensor-like object with zero dimensions, i.e. if `numind(C)==0`.
"""
Expand Down
4 changes: 2 additions & 2 deletions src/auxiliary/strideddata.jl
Expand Up @@ -14,9 +14,9 @@ end

@compat const StridedSubArray{T,N,A<:Array,I<:Tuple{Vararg{Union{Colon,Range{Int64},Int64}}},LD} = SubArray{T,N,A,I,LD}

StridedData{N,T,C}(a::Array{T}, strides::NTuple{N,Int} = _strides(a), ::Type{Val{C}} = Val{:N}) =
StridedData{N,T,C}(a::Array{T}, strides::NTuple{N,Int} = strides(a), ::Type{Val{C}} = Val{:N}) =
StridedData{N,T,C}(vec(a), strides, 1)
StridedData{N,T,C}(a::StridedSubArray{T}, strides::NTuple{N,Int} = _strides(a), ::Type{Val{C}} = Val{:N}) =
StridedData{N,T,C}(a::StridedSubArray{T}, strides::NTuple{N,Int} = strides(a), ::Type{Val{C}} = Val{:N}) =
StridedData{N,T,C}(vec(a.parent), strides, Base.first_index(a))

Base.getindex(a::NormalStridedData,i) = a.data[i]
Expand Down
10 changes: 5 additions & 5 deletions src/implementation/recursive.jl
Expand Up @@ -12,7 +12,7 @@ const BASELENGTH=2048
if 2*prod(dims) <= BASELENGTH
add_micro!(α, A, β, C, dims, offsetA, offsetC)
else
dmax = _indmax(_memjumps(dims, minstrides))
dmax = indmax(_memjumps(dims, minstrides))
@dividebody $N dmax dims offsetA A offsetC C begin
add_rec!(α, A, β, C, dims, offsetA, offsetC, minstrides)
end begin
Expand All @@ -28,7 +28,7 @@ end
if prod(dims) + prod(_filterdims(dims,C)) <= BASELENGTH
trace_micro!(α, A, β, C, dims, offsetA, offsetC)
else
dmax = _indmax(_memjumps(dims, minstrides))
dmax = indmax(_memjumps(dims, minstrides))
@dividebody $N dmax dims offsetA A offsetC C begin
trace_rec!(α, A, β, C, dims, offsetA, offsetC, minstrides)
end begin
Expand Down Expand Up @@ -58,11 +58,11 @@ end
contract_micro!(α, A, B, β, C, dims, offsetA, offsetB, offsetC)
else
if clength > oAlength && clength > oBlength
dmax = _indmax(_memjumps(cdims, minstrides))
dmax = indmax(_memjumps(cdims, minstrides))
elseif oAlength > oBlength
dmax = _indmax(_memjumps(odimsA, minstrides))
dmax = indmax(_memjumps(odimsA, minstrides))
else
dmax = _indmax(_memjumps(odimsB, minstrides))
dmax = indmax(_memjumps(odimsB, minstrides))
end
@dividebody $N dmax dims offsetA A offsetB B offsetC C begin
contract_rec!(α, A, B, β, C, dims, offsetA, offsetB, offsetC, minstrides)
Expand Down
10 changes: 5 additions & 5 deletions src/implementation/stridedarray.jl
Expand Up @@ -13,7 +13,7 @@ function add!{CA}(α, A::StridedArray, ::Type{Val{CA}}, β, C::StridedArray, ind
size(A,indCinA[i]) == size(C,i) || throw(DimensionMismatch())
end

dims, stridesA, stridesC, minstrides = add_strides(size(C), _permute(_strides(A),indCinA), _strides(C))
dims, stridesA, stridesC, minstrides = add_strides(size(C), _permute(strides(A),indCinA), strides(C))
dataA = StridedData(A, stridesA, Val{CA})
offsetA = 0
dataC = StridedData(C, stridesC)
Expand Down Expand Up @@ -51,7 +51,7 @@ function trace!{CA}(α, A::StridedArray, ::Type{Val{CA}}, β, C::StridedArray, i
end

pA = vcat(indCinA, cindA1, cindA2)
dims, stridesA, stridesC, minstrides = trace_strides(_permute(size(A),pA), _permute(_strides(A),pA), _strides(C))
dims, stridesA, stridesC, minstrides = trace_strides(_permute(size(A),pA), _permute(strides(A),pA), strides(C))
dataA = StridedData(A, stridesA, Val{CA})
offsetA = 0
dataC = StridedData(C, stridesC)
Expand Down Expand Up @@ -198,9 +198,9 @@ function contract!{CA,CB}(α, A::StridedArray, ::Type{Val{CA}}, B::StridedArray,
# Perform contraction
pA = vcat(oindA, cindA)
pB = vcat(oindB, cindB)
sA = _permute(_strides(A), pA)
sB = _permute(_strides(B), pB)
sC = _permute(_strides(C), invperm(indCinoAB))
sA = _permute(strides(A), pA)
sB = _permute(strides(B), pB)
sC = _permute(strides(C), invperm(indCinoAB))

dimsA = _permute(size(A), pA)
dimsB = _permute(size(B), pB)
Expand Down

0 comments on commit e34eb54

Please sign in to comment.