diff --git a/src/parallel-ir.jl b/src/parallel-ir.jl index 7adf740..fd8dfeb 100644 --- a/src/parallel-ir.jl +++ b/src/parallel-ir.jl @@ -3245,41 +3245,63 @@ function rm_allocs_cb(ast::Expr, state::rm_allocs_state, top_level_number, is_to state.removed_arrs[arr] = shape return ast elseif head==:call - if args[1]==TopNode(:arraysize) && in(args[2], keys(state.removed_arrs)) - shape = state.removed_arrs[args[2]] - return shape[args[3]] - elseif args[1]==GlobalRef(Base,:arraylen) && in(args[2], keys(state.removed_arrs)) - shape = state.removed_arrs[args[2]] - dim = length(shape) - dprintln(3, "arraylen found") - if dim==1 - ast = shape[1] - else - mul = foldl((a,b)->"$a*$b", "", shape) - ast = eval(parse(mul)) - end - return ast - elseif args[1]==TopNode(:unsafe_arrayref) && in(args[2], keys(state.removed_arrs)) - return 0 + if length(args)>=2 + return rm_allocs_cb_call(state, args[1], args[2], args[3:end]) end + # remove extra arrays from parfor data structures elseif head==:parfor - parfor = ast.args[1] - if in(parfor.first_input, keys(state.removed_arrs)) - #TODO parfor.first_input = NoArrayInput + rm_allocs_cb_parfor(state, args[1]) + end + return CompilerTools.AstWalker.ASTWALK_RECURSE +end + +function rm_allocs_cb_call(state::rm_allocs_state, func::TopNode, arr::SymAllGen, rest_args::Array{Any,1}) + if func.name==:arraysize && in(arr, keys(state.removed_arrs)) + shape = state.removed_arrs[arr] + return shape[rest_args[1]] + elseif func.name==:unsafe_arrayref && in(arr, keys(state.removed_arrs)) + return 0 + end + return CompilerTools.AstWalker.ASTWALK_RECURSE +end + +function rm_allocs_cb_call(state::rm_allocs_state, func::GlobalRef, arr::SymAllGen, rest_args::Array{Any,1}) + if func==GlobalRef(Base,:arraylen) && in(arr, keys(state.removed_arrs)) + shape = state.removed_arrs[arr] + dim = length(shape) + dprintln(3, "arraylen found") + if dim==1 + ast = shape[1] + else + mul = foldl((a,b)->"$a*$b", "", shape) + ast = eval(parse(mul)) end - for arr in keys(parfor.rws.readSet.arrays) - if in(arr, keys(state.removed_arrs)) - delete!(parfor.rws.readSet.arrays, arr) - end + return ast + end + return CompilerTools.AstWalker.ASTWALK_RECURSE +end + +function rm_allocs_cb_call(state::rm_allocs_state, func::ANY, arr::ANY, rest_args::Array{Any,1}) + return CompilerTools.AstWalker.ASTWALK_RECURSE +end + + + +function rm_allocs_cb_parfor(state::rm_allocs_state, parfor::PIRParForAst) + if in(parfor.first_input, keys(state.removed_arrs)) + #TODO parfor.first_input = NoArrayInput + end + for arr in keys(parfor.rws.readSet.arrays) + if in(arr, keys(state.removed_arrs)) + delete!(parfor.rws.readSet.arrays, arr) end - for arr in keys(parfor.rws.writeSet.arrays) - if in(arr, keys(state.removed_arrs)) - delete!(parfor.rws.writeSet.arrays, arr) - end + end + for arr in keys(parfor.rws.writeSet.arrays) + if in(arr, keys(state.removed_arrs)) + delete!(parfor.rws.writeSet.arrays, arr) end end - return CompilerTools.AstWalker.ASTWALK_RECURSE end function updateLambdaType(arr::Symbol, dim::Int, lambdaInfo) diff --git a/test/rand.jl b/test/rand.jl index 8a3b041..d90c054 100644 --- a/test/rand.jl +++ b/test/rand.jl @@ -53,6 +53,7 @@ end end +using Base.Test println("Testing rand()...") @test all(RandTest.test1() .<= ones(2,3).*2.0) && all(RandTest.test1() .>= zeros(2,3)) @test size(RandTest.test2())==(2,3)