Skip to content

Commit

Permalink
allow more types for _broadcast + fix round
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonDanisch committed Apr 6, 2017
1 parent 6b6f519 commit 2433157
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
_broadcast(f, broadcast_sizes(a...), a...)
end

@propagate_inbounds function broadcast{T}(f::Function, a::Type{T}, x::StaticArray)
_broadcast(f, (Size(), Size(x)), T, x)
end

@inline broadcast_sizes(a...) = _broadcast_sizes((), a...)
@inline _broadcast_sizes(t::Tuple) = t
@inline _broadcast_sizes(t::Tuple, a::StaticArray, as...) = _broadcast_sizes((t..., Size(a)), as...)
Expand All @@ -23,7 +27,7 @@ function broadcasted_index(oldsize, newindex)
return sub2ind(oldsize, index...)
end

@generated function _broadcast(f, s::Tuple{Vararg{Size}}, a::Union{Number, StaticArray}...)
@generated function _broadcast(f, s::Tuple{Vararg{Size}}, a...)
first_staticarray = 0
for i = 1:length(a)
if a[i] <: StaticArray
Expand Down Expand Up @@ -57,7 +61,7 @@ end
current_ind = ones(Int, length(newsize))

while more
exprs_vals = [(a[i] <: Number ? :(a[$i]) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))])) for i = 1:length(sizes)]
exprs_vals = [(!(a[i] <: AbstractArray) ? :(a[$i]) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))])) for i = 1:length(sizes)]
exprs[current_ind...] = :(f($(exprs_vals...)))

# increment current_ind (maybe use CartesianRange?)
Expand All @@ -77,7 +81,7 @@ end
end
end

eltype_exprs = [:(eltype($t)) for t a]
eltype_exprs = [t <: AbstractArray ? :($(eltype(t))) : :($t) for t a]
newtype_expr = :(Core.Inference.return_type(f, Tuple{$(eltype_exprs...)}))

return quote
Expand Down

0 comments on commit 2433157

Please sign in to comment.