Skip to content

Commit

Permalink
More code review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Andy Ferris committed Mar 24, 2017
1 parent a0d6658 commit 5ebaf5c
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 141 deletions.
1 change: 1 addition & 0 deletions src/convert.jl
Expand Up @@ -6,6 +6,7 @@
# this covers most conversions and "statically-sized reshapes"
@inline convert(::Type{SA}, sa::StaticArray) where {SA<:StaticArray} = SA(Tuple(sa))
@inline convert(::Type{SA}, sa::SA) where {SA<:StaticArray} = sa
@inline convert(::Type{SA}, x::Tuple) where {SA<:StaticArray} = SA(x) # convert -> constructor. Hopefully no loops...

# A general way of going back to a tuple
@inline function convert(::Type{Tuple}, a::StaticArray)
Expand Down
51 changes: 4 additions & 47 deletions src/linalg.jl
Expand Up @@ -54,34 +54,9 @@ end


# Transpose, conjugate, etc
# TODO different methods for v0.5, v0.6 (due to `RowVector`)
@inline conj(a::StaticArray) = map(conj, a)

#=
@generated function transpose(v::StaticVector)
n = length(v)
newtype = similar_type(v, Size(1,n))
exprs = [:(v[$j]) for j = 1:n]
return quote
$(Expr(:meta, :inline))
@inbounds return $(Expr(:call, newtype, Expr(:tuple, exprs...)))
end
end
@generated function ctranspose(v::StaticVector)
n = length(v)
newtype = similar_type(v, Size(1,n))
exprs = [:(conj(v[$j])) for j = 1:n]
return quote
$(Expr(:meta, :inline))
@inbounds return $(Expr(:call, newtype, Expr(:tuple, exprs...)))
end
end
=#

@inline transpose(m::StaticMatrix) = _transpose(Size(m), m)
# note: transpose of StaticVector is a RowVector, handled by Base

@generated function _transpose(::Size{S}, m::StaticMatrix) where {S}
Snew = (S[2], S[1])
Expand Down Expand Up @@ -214,34 +189,16 @@ _cross(::Size{S}, a::StaticVector, b::StaticVector) where {S} = error("Cross pro
@inbounds return a[1]*b[2] - a[2]*b[1]
end
@inline function _cross(::Size{(3,)}, a::StaticVector, b::StaticVector)
T = typeof(a[2]*b[3]-a[3]*b[2])
@inbounds return similar_type(a, typeof(a[2]*b[3]-a[3]*b[2]))((a[2]*b[3]-a[3]*b[2], a[3]*b[1]-a[1]*b[3], a[1]*b[2]-a[2]*b[1]))
end
@inline function _cross(::Size{(2,)}, a::StaticVector{<:Unsigned}, b::StaticVector{<:Unsigned})
@inbounds return Signed(a[1]*b[2]) - Signed(a[2]*b[1])
end
@inline function _cross(::Size{(3,)}, a::StaticVector{<:Unsigned}, b::StaticVector{<:Unsigned})
T = typeof(a[2]*b[3]-a[3]*b[2])
@inbounds return similar_type(a, typeof(Signed(a[2]*b[3])-Signed(a[3]*b[2])))(((Signed(a[2]*b[3])-Signed(a[3]*b[2]), Signed(a[3]*b[1])-Signed(a[1]*b[3]), Signed(a[1]*b[2])-Signed(a[2]*b[1]))))
end

@inline dot(a::StaticVector, b::StaticVector) = _dot(same_size(a, b), a, b)
@generated function _dot(::Size{S}, a::StaticVector, b::StaticVector) where {S}
if S[1] == 0
return :(zero(promote_op(*, eltype(a), eltype(b))))
end

expr = :(conj(a[1]) * b[1])
for j = 2:S[1]
expr = :($expr + conj(a[$j]) * b[$j])
end

return quote
@_inline_meta
@inbounds return $expr
end
end

@inline dot(a::StaticVector, b::StaticVector) = _vecdot(same_size(a, b), a, b)
@inline vecdot(a::StaticArray, b::StaticArray) = _vecdot(same_size(a, b), a, b)
@generated function _vecdot(::Size{S}, a::StaticArray, b::StaticArray) where {S}
if prod(S) == 0
Expand Down Expand Up @@ -270,9 +227,9 @@ end
return zero(real(eltype(a)))
end

expr = :(real(conj(a[1]) * a[1]))
expr = :(abs2(a[1]))
for j = 2:prod(S)
expr = :($expr + real(conj(a[$j]) * a[$j]))
expr = :($expr + abs2(a[$j]))
end

return quote
Expand Down
91 changes: 43 additions & 48 deletions src/matrix_multiply.jl
Expand Up @@ -63,6 +63,8 @@ promote_matprod{T1,T2}(::Type{T1}, ::Type{T2}) = typeof(zero(T1)*zero(T2) + zero
end
end

# TODO: I removed StaticMatrix * AbstractVector. Reinstate?

# outer product
@generated function _A_mul_B(::Size{sa}, ::Size{sb}, a::StaticVector{Ta}, b::RowVector{Tb, <:StaticVector}) where {sa, sb, Ta, Tb}
newsize = (sa[1], sb[2])
Expand All @@ -75,58 +77,51 @@ end
end
end

# TODO: I removed StaticMatrix * AbstractVector. Reinstate?

@generated function _A_mul_B(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatrix{Ta}, b::StaticMatrix{Tb}) where {sa, sb, Ta, Tb}
can_mutate = a.mutable && b.mutable # TODO this probably isn't safe. Maybe a trait??
can_blas = Ta == Tb && Ta <: BlasFloat
# Heuristic choice for amount of codegen
if sa[1]*sa[2]*sb[2] <= 8*8*8
return quote
@_inline_meta
return A_mul_B_unrolled(Sa, Sb, a, b)
end
elseif sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14
return quote
@_inline_meta
return A_mul_B_unrolled_chunks(Sa, Sb, a, b)
end
else
return quote
@_inline_meta
return A_mul_B_loop(Sa, Sb, a, b)
end
end
end

if can_mutate
S = Size(sa[1], sb[2])
@generated function _A_mul_B(Sa::Size{sa}, Sb::Size{sb}, a::Union{SizedMatrix{T}, MMatrix{T}, MArray{T}}, b::Union{SizedMatrix{T}, MMatrix{T}, MArray{T}}) where {sa, sb, T <: BlasFloat}
S = Size(sa[1], sb[2])

# Heuristic choice between BLAS and explicit unrolling (or chunk-based unrolling)
if can_blas && sa[1]*sa[2]*sb[2] >= 14*14*14
return quote
@_inline_meta
T = promote_matprod(Ta, Tb)
C = similar(a, T, $S)
A_mul_B_blas!($S, C, Sa, Sb, a, b)
return C
end
elseif sa[1]*sa[2]*sb[2] < 8*8*8
return quote
@_inline_meta
return A_mul_B_unrolled(Sa, Sb, a, b)
end
elseif sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14
return quote
@_inline_meta
T = promote_matprod(Ta, Tb)
return similar_type(a, T, $S)(A_mul_B_unrolled_chunks(Sa, Sb, a, b))
end
else
return quote
@_inline_meta
return A_mul_B_loop(Sa, Sb, a, b)
end
# Heuristic choice between BLAS and explicit unrolling (or chunk-based unrolling)
if sa[1]*sa[2]*sb[2] >= 14*14*14
return quote
@_inline_meta
C = similar(a, T, $S)
A_mul_B_blas!($S, C, Sa, Sb, a, b)
return C
end
else # both are isbits type...
# Heuristic choice for amount of codegen
if sa[1]*sa[2]*sb[2] <= 8*8*8
return quote
@_inline_meta
return A_mul_B_unrolled(Sa, Sb, a, b)
end
elseif sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14
return quote
@_inline_meta
return A_mul_B_unrolled_chunks(Sa, Sb, a, b)
end
else
return quote
@_inline_meta
return A_mul_B_loop(Sa, Sb, a, b)
end
elseif sa[1]*sa[2]*sb[2] < 8*8*8
return quote
@_inline_meta
return A_mul_B_unrolled(Sa, Sb, a, b)
end
elseif sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14
return quote
@_inline_meta
return similar_type(a, T, $S)(A_mul_B_unrolled_chunks(Sa, Sb, a, b))
end
else
return quote
@_inline_meta
return A_mul_B_loop(Sa, Sb, a, b)
end
end
end
Expand Down
46 changes: 0 additions & 46 deletions src/util.jl
Expand Up @@ -27,49 +27,3 @@ end
$t
end
end


# TODO: the below seems to be type piracy...
#=
# some convenience functions for non-static arrays, generators, etc...
@inline convert{T}(::Type{Tuple}, a::AbstractArray{T}) = (a...)::Tuple{Vararg{T}}
@inline function convert{N,T}(::Type{NTuple{N,Any}}, a::AbstractArray{T})
@boundscheck if length(a) != N
error("Array of length $(length(a)) cannot be converted to a $N-tuple")
end
@inbounds return ntuple(i -> a[i], Val{N})
end
@inline function convert{N,T1,T2}(::Type{NTuple{N,T1}}, a::AbstractArray{T2})
@boundscheck if length(a) != N
error("Array of length $(length(a)) cannot be converted to a $N-tuple")
end
@inbounds return ntuple(i -> convert(T1,a[i]), Val{N})
end
if VERSION < v"0.5+"
# TODO try and make this generate fast code
@inline convert(::Type{Tuple}, g::Base.Generator) = (g...)
@inline function convert{N}(::Type{NTuple{N,Any}}, g::Base.Generator)
@boundscheck if length(g.iter) != N
error("Array of length $(length(a)) cannot be converted to a $N-tuple")
end
@inbounds return ntuple(i -> g.f(g.iter[i]), Val{N})
end
end
=#
#=
@generated function convert{N}(::Type{NTuple{N,Any}}, g::Base.Generator)
exprs = [:(g.f(g.iter[$j])) for j=1:N]
return quote
@boundscheck if length(g.iter) != N
error("Array of length $(length(a)) cannot be converted to a $N-tuple")
end
@inbounds return $(Expr(:tuple, exprs...))
end
end=#
6 changes: 6 additions & 0 deletions test/custom_types.jl
Expand Up @@ -4,4 +4,10 @@
data::NTuple{N, T}
end)
@test (MyType(3, 4) isa MyType{2, Int})

# Issue 110
@eval (struct Polly{N,T}
data::SVector{N,T}
end)
@test (Polly{2,Float64}((1.0, 0.0)) isa Polly)
end
1 change: 1 addition & 0 deletions test/linalg.jl
Expand Up @@ -17,6 +17,7 @@
v3 = [2,4,6,8]
v4 = [4,3,2,1]

# We broke "inferrable" sizes of AbstractVectors for vector+vector, matrix*vector, etc...
@test_broken @inferred(v1 + v4) === @SVector [6, 7, 8, 9]
@test_broken @inferred(v3 + v2) === @SVector [6, 7, 8, 9]
@test_broken @inferred(v1 - v4) === @SVector [-2, 1, 4, 7]
Expand Down

0 comments on commit 5ebaf5c

Please sign in to comment.