Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expunging @ngenerate and @nsplat #9098

Merged
merged 9 commits into from
Feb 7, 2015
64 changes: 37 additions & 27 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,38 +239,48 @@ broadcast!_function(f::Function) = (B, As...) -> broadcast!(f, B, As...)
broadcast_function(f::Function) = (As...) -> broadcast(f, As...)

broadcast_getindex(src::AbstractArray, I::AbstractArray...) = broadcast_getindex!(Array(eltype(src), broadcast_shape(I...)), src, I...)
@ngenerate N typeof(dest) function broadcast_getindex!(dest::AbstractArray, src::AbstractArray, I::NTuple{N, AbstractArray}...)
check_broadcast_shape(size(dest), I...) # unnecessary if this function is never called directly
checkbounds(src, I...)
@nloops N i dest d->(@nexprs N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs N k->(@inbounds J_k = @nref N I_k d->j_d_k)
@inbounds (@nref N dest i) = (@nref N src J)
stagedfunction broadcast_getindex!(dest::AbstractArray, src::AbstractArray, I::AbstractArray...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
quote
@nexprs $N d->(I_d = I[d])
check_broadcast_shape(size(dest), $(Isplat...)) # unnecessary if this function is never called directly
checkbounds(src, $(Isplat...))
@nloops $N i dest d->(@nexprs $N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs $N k->(@inbounds J_k = @nref $N I_k d->j_d_k)
@inbounds (@nref $N dest i) = (@nref $N src J)
end
dest
end
dest
end

@ngenerate N typeof(A) function broadcast_setindex!(A::AbstractArray, x, I::NTuple{N, AbstractArray}...)
checkbounds(A, I...)
shape = broadcast_shape(I...)
@nextract N shape d->(length(shape) < d ? 1 : shape[d])
if !isa(x, AbstractArray)
@nloops N i d->(1:shape_d) d->(@nexprs N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs N k->(@inbounds J_k = @nref N I_k d->j_d_k)
@inbounds (@nref N A J) = x
end
else
X = x
# To call setindex_shape_check, we need to create fake 1-d indexes of the proper size
@nexprs N d->(fakeI_d = 1:shape_d)
Base.setindex_shape_check(X, (@ntuple N fakeI)...)
k = 1
@nloops N i d->(1:shape_d) d->(@nexprs N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs N k->(@inbounds J_k = @nref N I_k d->j_d_k)
@inbounds (@nref N A J) = X[k]
k += 1
stagedfunction broadcast_setindex!(A::AbstractArray, x, I::AbstractArray...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
quote
@nexprs $N d->(I_d = I[d])
checkbounds(A, $(Isplat...))
shape = broadcast_shape($(Isplat...))
@nextract $N shape d->(length(shape) < d ? 1 : shape[d])
if !isa(x, AbstractArray)
@nloops $N i d->(1:shape_d) d->(@nexprs $N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs $N k->(@inbounds J_k = @nref $N I_k d->j_d_k)
@inbounds (@nref $N A J) = x
end
else
X = x
# To call setindex_shape_check, we need to create fake 1-d indexes of the proper size
@nexprs $N d->(fakeI_d = 1:shape_d)
Base.setindex_shape_check(X, (@ntuple $N fakeI)...)
k = 1
@nloops $N i d->(1:shape_d) d->(@nexprs $N k->(j_d_k = size(I_k, d) == 1 ? 1 : i_d)) begin
@nexprs $N k->(@inbounds J_k = @nref $N I_k d->j_d_k)
@inbounds (@nref $N A J) = X[k]
k += 1
end
end
A
end
A
end

## elementwise operators ##
Expand Down
271 changes: 1 addition & 270 deletions base/cartesian.jl
Original file line number Diff line number Diff line change
@@ -1,275 +1,6 @@
module Cartesian

export @ngenerate, @nsplat, @nloops, @nref, @ncall, @nexprs, @nextract, @nall, @ntuple, @nif, ngenerate

const CARTESIAN_DIMS = 4

### @ngenerate, for auto-generation of separate versions of functions for different dimensionalities
# Examples (deliberately trivial):
# @ngenerate N returntype myndims{T,N}(A::Array{T,N}) = N
# or alternatively
# function gen_body(N::Int)
# quote
# return $N
# end
# end
# eval(ngenerate(:N, returntypeexpr, :(myndims{T,N}(A::Array{T,N})), gen_body))
# The latter allows you to use a single gen_body function for both ngenerate and
# when your function maintains its own method cache (e.g., reduction or broadcasting).
#
# Special syntax for function prototypes:
# @ngenerate N returntype function myfunction(A::AbstractArray, I::NTuple{N, Int}...)
# for N = 3 translates to
# function myfunction(A::AbstractArray, I_1::Int, I_2::Int, I_3::Int)
# and for the generic (cached) case as
# function myfunction(A::AbstractArray, I::Int...)
# @nextract N I I
# with N = length(I). N should _not_ be listed as a parameter of the function unless
# earlier arguments use it that way.
# To avoid ambiguity, it would be preferable to have some specific syntax for this, such as
# myfunction(A::AbstractArray, I::Int...N)
# where N can be an integer or symbol. Currently T...N generates a parser error.
macro ngenerate(itersym, returntypeexpr, funcexpr)
if isa(funcexpr, Expr) && funcexpr.head == :macrocall && funcexpr.args[1] == symbol("@inline")
funcexpr = Base._inline(funcexpr.args[2])
end
isfuncexpr(funcexpr) || throw(ArgumentError("requires a function expression"))
esc(ngenerate(itersym, returntypeexpr, funcexpr.args[1], N->sreplace!(copy(funcexpr.args[2]), itersym, N)))
end

# @nsplat takes an expression like
# @nsplat N 2:3 myfunction(A, I::NTuple{N,Real}...) = getindex(A, I...)
# and generates
# myfunction(A, I_1::Real, I_2::Real) = getindex(A, I_1, I_2)
# myfunction(A, I_1::Real, I_2::Real, I_3::Real) = getindex(A, I_1, I_2, I_3)
# myfunction(A, I::Real...) = getindex(A, I...)
# An @nsplat function _cannot_ have any other Cartesian macros in it.
# If you omit the range, it uses 1:CARTESIAN_DIMS.
macro nsplat(itersym, args...)
local rng
if length(args) == 1
rng = 1:CARTESIAN_DIMS
funcexpr = args[1]
elseif length(args) == 2
rangeexpr = args[1]
funcexpr = args[2]
if !isa(rangeexpr, Expr) || rangeexpr.head != :(:) || length(rangeexpr.args) != 2
throw(ArgumentError("first argument must be a from:to expression"))
end
rng = rangeexpr.args[1]:rangeexpr.args[2]
else
throw(ArgumentError("wrong number of arguments"))
end
if isa(funcexpr, Expr) && funcexpr.head == :macrocall && funcexpr.args[1] == symbol("@inline")
funcexpr = Base._inline(funcexpr.args[2])
end
isfuncexpr(funcexpr) || throw(ArgumentError("second argument must be a function expression"))
prototype = funcexpr.args[1]
body = funcexpr.args[2]
varname, T = get_splatinfo(prototype, itersym)
isempty(varname) && throw(ArgumentError("last argument must be a splat"))
explicit = [Expr(:function, resolvesplat!(copy(prototype), varname, T, N),
resolvesplats!(copy(body), varname, N)) for N in rng]
protosplat = resolvesplat!(copy(prototype), varname, T, 0)
protosplat.args[end] = Expr(:..., protosplat.args[end])
splat = Expr(:function, protosplat, body)
esc(Expr(:block, explicit..., splat))
end

generate1(itersym, prototype, bodyfunc, N::Int, varname, T) =
Expr(:function, spliceint!(sreplace!(resolvesplat!(copy(prototype), varname, T, N), itersym, N)),
resolvesplats!(bodyfunc(N), varname, N))

function ngenerate(itersym, returntypeexpr, prototype, bodyfunc, dims=1:CARTESIAN_DIMS, makecached::Bool = true)
varname, T = get_splatinfo(prototype, itersym)
# Generate versions for specific dimensions
fdim = [generate1(itersym, prototype, bodyfunc, N, varname, T) for N in dims]
if !makecached
return Expr(:block, fdim...)
end
# Generate the generic cache-based version
if isempty(varname)
setitersym, extractvarargs = :(), N -> nothing
else
s = symbol(varname)
setitersym = hasparameter(prototype, itersym) ? (:(@assert $itersym == length($s))) : (:($itersym = length($s)))
extractvarargs = N -> Expr(:block, map(popescape, _nextract(N, s, s).args)...)
end
fsym = funcsym(prototype)
dictname = symbol(fsym,"_cache")
fargs = funcargs(prototype)
if !isempty(varname)
fargs[end] = Expr(:..., fargs[end].args[1])
end
flocal = funcrename(copy(prototype), :_F_)
F = Expr(:function, resolvesplat!(prototype, varname, T), quote
$setitersym
if !haskey($dictname, $itersym)
gen1 = Base.Cartesian.generate1($(symbol(itersym)), $(Expr(:quote, flocal)), $bodyfunc, $itersym, $varname, $T)
$(dictname)[$itersym] = eval(quote
local _F_
$gen1
_F_
end)
end
($(dictname)[$itersym]($(fargs...)))::$returntypeexpr
end)
Expr(:block, fdim..., quote
let $dictname = Dict{Int,Function}()
$F
end
end)
end

isfuncexpr(ex::Expr) =
ex.head == :function || (ex.head == :(=) && typeof(ex.args[1]) == Expr && ex.args[1].head == :call)
isfuncexpr(arg) = false

sreplace!(arg, sym, val) = arg
function sreplace!(ex::Expr, sym, val)
for i = 1:length(ex.args)
ex.args[i] = sreplace!(ex.args[i], sym, val)
end
ex
end
sreplace!(s::Symbol, sym, val) = s == sym ? val : s

# If using the syntax that will need "desplatting",
# myfunction(A::AbstractArray, I::NTuple{N, Int}...)
# return the variable name (as a string) and type
function get_splatinfo(ex::Expr, itersym::Symbol)
if ex.head == :call
a = ex.args[end]
if isa(a, Expr) && a.head == :... && length(a.args) == 1
b = a.args[1]
if isa(b, Expr) && b.head == :(::)
varname = string(b.args[1])
c = b.args[2]
if isa(c, Expr) && c.head == :curly && c.args[1] == :NTuple && c.args[2] == itersym
T = c.args[3]
return varname, T
end
end
end
end
"", Void
end

# Replace splatted with desplatted for a specific number of arguments
function resolvesplat!(prototype, varname, T::Union(Type,Symbol,Expr), N::Int)
if !isempty(varname)
prototype.args[end] = N > 0 ? Expr(:(::), symbol(varname, "_1"), T) :
Expr(:(::), symbol(varname), T)
for i = 2:N
push!(prototype.args, Expr(:(::), symbol(varname, "_", i), T))
end
end
prototype
end

# Return the generic splatting form, e.g.,
# myfunction(A::AbstractArray, I::Int...)
function resolvesplat!(prototype, varname, T::Union(Type,Symbol,Expr))
if !isempty(varname)
svarname = symbol(varname)
prototype.args[end] = Expr(:..., :($svarname::$T))
end
prototype
end

# Desplatting function calls: replace func(a, b, I...) with func(a, b, I_1, I_2, I_3)
resolvesplats!(arg, varname, N) = arg
function resolvesplats!(ex::Expr, varname, N::Int)
if ex.head == :call
for i = 2:length(ex.args)-1
resolvesplats!(ex.args[i], varname, N)
end
a = ex.args[end]
if isa(a, Expr) && a.head == :... && a.args[1] == symbol(varname)
ex.args[end] = symbol(varname, "_1")
for i = 2:N
push!(ex.args, symbol(varname, "_", i))
end
else
resolvesplats!(a, varname, N)
end
else
for i = 1:length(ex.args)
resolvesplats!(ex.args[i], varname, N)
end
end
ex
end

# Remove any function parameters that are integers
function spliceint!(ex::Expr)
if ex.head == :escape
return esc(spliceint!(ex.args[1]))
end
ex.head == :call || throw(ArgumentError("$ex must be a call"))
if isa(ex.args[1], Expr) && ex.args[1].head == :curly
args = ex.args[1].args
for i = length(args):-1:1
if isa(args[i], Int)
deleteat!(args, i)
end
end
end
ex
end

function popescape(ex::Expr)
while ex.head == :escape
ex = ex.args[1]
end
ex
end

# Extract the "function name"
function funcsym(prototype::Expr)
prototype = popescape(prototype)
prototype.head == :call || throw(ArgumentError("$prototype must be a call"))
tmp = prototype.args[1]
if isa(tmp, Expr) && tmp.head == :curly
tmp = tmp.args[1]
end
return tmp
end

function funcrename(prototype::Expr, name::Symbol)
prototype = popescape(prototype)
prototype.head == :call || throw(ArgumentError("$prototype must be a call"))
tmp = prototype.args[1]
if isa(tmp, Expr) && tmp.head == :curly
tmp.args[1] = name
else
prototype.args[1] = name
end
return prototype
end

function hasparameter(prototype::Expr, sym::Symbol)
prototype = popescape(prototype)
prototype.head == :call || throw(ArgumentError("$prototype must be a call"))
tmp = prototype.args[1]
if isa(tmp, Expr) && tmp.head == :curly
for i = 2:length(tmp.args)
if tmp.args[i] == sym
return true
end
end
end
false
end

# Extract the symbols of the function arguments
funcarg(s::Symbol) = s
funcarg(ex::Expr) = ex.args[1]
function funcargs(prototype::Expr)
prototype = popescape(prototype)
prototype.head == :call || throw(ArgumentError("$prototype must be a call"))
map(a->funcarg(a), prototype.args[2:end])
end
export @nloops, @nref, @ncall, @nexprs, @nextract, @nall, @ntuple, @nif

### Cartesian-specific macros

Expand Down
Loading