Skip to content
This repository has been archived by the owner on Nov 8, 2022. It is now read-only.

Commit

Permalink
DomainIR keep reduction opr info using fake lambda
Browse files Browse the repository at this point in the history
Create a fake lambda for reduction functions so that the backend can
replace with high performance implementation (e.g. std::min, OpenMP reductions
in CGen). The required information is lost with capturing and inlining the
operation lambda.
  • Loading branch information
Ehsan Totoni committed Dec 7, 2016
1 parent cbbd611 commit 38b121a
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions src/domain-ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1025,8 +1025,8 @@ mk_range(state, start, step, final) = mk_range(simplify(state, start), simplify(
function from_lambda(state, env, expr, closure = nothing)
local env_ = nextEnv(env)
linfo, body = lambdaToLambdaVarInfo(expr)
cfg = CompilerTools.CFGs.from_lambda(body)
body = getBody(CompilerTools.CFGs.createFunctionBody(cfg), getReturnType(linfo))
cfg = CompilerTools.CFGs.from_lambda(body)
body = getBody(CompilerTools.CFGs.createFunctionBody(cfg), getReturnType(linfo))
@dprintln(2,"from_lambda typeof(body) = ", typeof(body))
@dprintln(3,"expr = ", expr)

Expand Down Expand Up @@ -1459,7 +1459,7 @@ function translate_call_copy!(state, env, args)
if isArrayType(argtyp1) && isArrayType(argtyp2)
eltyp1 = eltype(argtyp1)
eltyp2 = eltype(argtyp2)
if eltyp1 == eltyp2
if eltyp1 == eltyp2
if idx_to == nothing
expr = mk_mmap!(args, DomainLambda(Type[eltyp1,eltyp2], Type[eltyp1], params->Any[Expr(:tuple, params[2])], state.linfo))
else # range copy
Expand Down Expand Up @@ -1760,7 +1760,7 @@ function translate_call_getsetindex(state, env, typ, fun::Symbol, args::Array{An
lhsname = CompilerTools.LambdaHandling.getVarDef(lhs, linfo).name
params = CompilerTools.LambdaHandling.getInputParameters(linfo)
CompilerTools.LambdaHandling.setInputParameters(Symbol[lhsname, params[2]], linfo)
body.args = [ mk_expr(Type{etyp}, :(=), lookupLHSVarByName(params[1], linfo), etyp);
body.args = [ mk_expr(Type{etyp}, :(=), lookupLHSVarByName(params[1], linfo), etyp);
body.args...]
f = DomainLambda(linfo, body)
else # set to scalar value
Expand All @@ -1770,15 +1770,15 @@ function translate_call_getsetindex(state, env, typ, fun::Symbol, args::Array{An
lhs = addFreshLocalVariable(string("ignored"), etyp, 0, linfo)
lhsname = CompilerTools.LambdaHandling.getVarDef(lhs, linfo).name
params = CompilerTools.LambdaHandling.getInputParametersAsLHSVar(linfo)
if isa(var, RHSVar)
if isa(var, RHSVar)
var = makeCaptured(state, var)
var = toLHSVar(var)
rhs = addToEscapingVariable(var, linfo, state.linfo)
else
rhs = var
end
CompilerTools.LambdaHandling.setInputParameters(Symbol[lhsname], linfo)
body.args = [ mk_expr(Type{etyp}, :(=), params[1], etyp);
body.args = [ mk_expr(Type{etyp}, :(=), params[1], etyp);
mk_expr(vtyp, :(=), params[2], rhs);
body.args...]
f = DomainLambda(linfo, body)
Expand Down Expand Up @@ -2404,7 +2404,8 @@ function translate_call_reduceop(state, env, typ, fun::Symbol, args::Array{Any,1
setInputParameters(params, linfo)
params = [ toRHSVar(x, outtyp, linfo) for x in params ]
setReturnType(outtyp, linfo)
(inner_body, inner_linfo) = get_lambda_for_arg(state, env, opr, [etyp, etyp])
#(inner_body, inner_linfo) = get_lambda_for_arg(state, env, opr, [etyp, etyp])
(inner_body, inner_linfo) = make_fake_reduce_lambda(state, env, opr, etyp)
inner_dl = DomainLambda(inner_linfo, inner_body)
# inner_dl = DomainLambda(Type[etyp, etyp], Type[etyp], params->Any[Expr(:tuple, box_ty(etyp, Expr(:call, opr, params...)))], LambdaVarInfo())
inner_expr = mk_mmap!(params, inner_dl)
Expand All @@ -2415,7 +2416,8 @@ function translate_call_reduceop(state, env, typ, fun::Symbol, args::Array{Any,1
neutral = neutralelt
outtyp = etyp
opr = GlobalRef(Base, fun)
(inner_body, inner_linfo) = get_lambda_for_arg(state, env, opr, [etyp, etyp])
#(inner_body, inner_linfo) = get_lambda_for_arg(state, env, opr, [etyp, etyp])
(inner_body, inner_linfo) = make_fake_reduce_lambda(state, env, opr, etyp)
f = DomainLambda(inner_linfo, inner_body)
end
# turn reduce(z, getindex(a, ...), f) into reduce(z, select(a, ranges(...)), f)
Expand All @@ -2425,6 +2427,22 @@ function translate_call_reduceop(state, env, typ, fun::Symbol, args::Array{Any,1
return expr
end

"""
Create a fake lambda for reduction functions so that the backend can replace
with high performance implementation (e.g. std::min, OpenMP reductions in CGen).
The required information is lost with capturing and inlining the operation lambda.
"""
function make_fake_reduce_lambda(state, env, opr::GlobalRef, etyp::Type)
linfo = LambdaVarInfo()
setInputParameters([:x,:y], linfo)
setReturnType(etyp, linfo)
addLocalVariable(:x,etyp,0,linfo)
addLocalVariable(:y,etyp,0,linfo)
body = mk_expr(etyp,:body, mk_expr(etyp, :tuple,
mk_expr(etyp,:call, opr, toLHSVar(:x,linfo), toLHSVar(:y,linfo))))
return (body, linfo)
end

function translate_call_fill!(state, env, typ, args::Array{Any,1})
args = normalize_args(state, env, args)
@assert length(args)==2 "fill! should have 2 arguments"
Expand All @@ -2446,7 +2464,7 @@ function translate_call_fill!(state, env, typ, args::Array{Any,1})
rhs = ival
end
CompilerTools.LambdaHandling.setInputParameters(Symbol[lhsname], linfo)
body.args = [ mk_expr(Type{etyp}, :(=), params[1], etyp);
body.args = [ mk_expr(Type{etyp}, :(=), params[1], etyp);
mk_expr(ityp, :(=), params[2], rhs);
body.args...]
domF = DomainLambda(linfo, body)
Expand Down

0 comments on commit 38b121a

Please sign in to comment.