Skip to content
This repository has been archived by the owner on Nov 8, 2022. It is now read-only.

Commit

Permalink
Pattern match on broadcast! calls with first argument as Base.identit…
Browse files Browse the repository at this point in the history
…y. If found and the arguments to broadcast have correlations that are identical then convert to a DomainNode for mmap! equivalent to copy! and then process that new domain node normally.
  • Loading branch information
DrTodd13 committed May 27, 2017
1 parent d273447 commit f8298c1
Showing 1 changed file with 30 additions and 4 deletions.
34 changes: 30 additions & 4 deletions src/parallel-ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ Process a :lambda Expr.
function from_lambda(lambda :: Expr, depth, state)
# :lambda expression
assert(lambda.head == :lambda)
@dprintln(4,"from_lambda starting")
@dprintln(3,"from_lambda starting")

# Save the current LambdaVarInfo away so we can restore it later.
save_LambdaVarInfo = state.LambdaVarInfo
Expand Down Expand Up @@ -2654,13 +2654,39 @@ end
"""
Process a call AST node. Note that it takes an Expr as input because it can be either :call or :invoke.
"""
function from_call(ast::Expr, depth, state)
function from_call(head, ast::Expr, depth, state)
fun = getCallFunction(ast)
args = getCallArguments(ast)
@dprintln(2,"from_call fun = ", fun, " typeof fun = ", typeof(fun))
if length(args) > 0
@dprintln(2,"first arg = ",args[1], " type = ", typeof(args[1]))
end

if isBaseFunc(fun, :broadcast!) && args[1] == GlobalRef(Base, :identity)
@dprintln(3,"Detected call to broadcast! with identity argument.")
if length(args) == 3
arr1 = args[2]
arr2 = args[3]
arrtyp1 = CompilerTools.LambdaHandling.getType(arr1, state.LambdaVarInfo)
arrtyp2 = CompilerTools.LambdaHandling.getType(arr2, state.LambdaVarInfo)
eltyp1 = eltype(arrtyp1)
eltyp2 = eltype(arrtyp2)
@dprintln(3,"arr1 = ", arr1, " arr2 = ", arr2, " hasCorrelation1 = ", haskey(state.array_length_correlation, arr1), " hasCorrelation2 = ", haskey(state.array_length_correlation, arr2))
if haskey(state.array_length_correlation, arr1) &&
haskey(state.array_length_correlation, arr2) &&
state.array_length_correlation[arr1] == state.array_length_correlation[arr2]
@dprintln(3,"Arrays are equivalent in length. Switch to copy! here.")
new_domain_expr = ParallelAccelerator.DomainIR.mk_mmap!(args[2:3], ParallelAccelerator.DomainIR.DomainLambda(Type[eltyp1,eltyp2], Type[eltyp1], params->Any[Expr(:tuple, params[2])], state.LambdaVarInfo))
@dprintln(3,"New mmap! = ", new_domain_expr)

head = :parfor
domain_oprs = [DomainOperation(:mmap!, args)]
args = mk_parfor_args_from_mmap!(new_domain_expr.args[1], new_domain_expr.args[2], false, domain_oprs, state)
return head, args
end
end
end

# We don't need to translate Function Symbols but potentially other call targets we do.
if typeof(fun) != Symbol
fun = from_expr(fun, depth, state, false)
Expand All @@ -2671,7 +2697,7 @@ function from_call(ast::Expr, depth, state)
# Recursively process the arguments to the call.
args = from_exprs(args, depth+1, state)

return ast.head == :invoke ? [ast.args[1]; fun; args] : [fun; args]
return head, (ast.head == :invoke ? [ast.args[1]; fun; args] : [fun; args])
end

"""
Expand Down Expand Up @@ -3696,7 +3722,7 @@ function from_expr(ast ::Expr, depth, state :: expr_state, top_level)
elseif head == :return
args = from_exprs(args, depth, state)
elseif head == :invoke || head == :call || head == :call1
args = from_call(ast, depth, state)
head, args = from_call(head, ast, depth, state)
# TODO: catch domain IR result here
elseif head == :foreigncall
args = from_foreigncall(ast, depth, state)
Expand Down

0 comments on commit f8298c1

Please sign in to comment.