Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.ReorgOp;
Expand Down Expand Up @@ -84,7 +85,7 @@ public boolean isApplicable(FunctionCallGraph fgraph) {
*/
@Override
public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
rewriteStatementBlocks(prog.getStatementBlocks());
rewriteStatementBlocks(prog, prog.getStatementBlocks());
return false;
}

Expand All @@ -93,13 +94,14 @@ public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, Functio
* by setting the federated output value of each hop in the statement blocks.
* The method calls the contained statement blocks recursively.
*
* @param prog dml program
* @param sbs list of statement blocks
* @return list of statement blocks with the federated output value updated for each hop
*/
public ArrayList<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs) {
public ArrayList<StatementBlock> rewriteStatementBlocks(DMLProgram prog, List<StatementBlock> sbs) {
ArrayList<StatementBlock> rewrittenStmBlocks = new ArrayList<>();
for ( StatementBlock stmBlock : sbs )
rewrittenStmBlocks.addAll(rewriteStatementBlock(stmBlock));
rewrittenStmBlocks.addAll(rewriteStatementBlock(prog, stmBlock));
return rewrittenStmBlocks;
}

Expand All @@ -108,66 +110,80 @@ public ArrayList<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs
* by setting the federated output value of each hop in the statement blocks.
* The method calls the contained statement blocks recursively.
*
* @param prog dml program
* @param sb statement block
* @return list of statement blocks with the federated output value updated for each hop
*/
public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb) {
public ArrayList<StatementBlock> rewriteStatementBlock(DMLProgram prog, StatementBlock sb) {
if ( sb instanceof WhileStatementBlock)
return rewriteWhileStatementBlock((WhileStatementBlock) sb);
return rewriteWhileStatementBlock(prog, (WhileStatementBlock) sb);
else if ( sb instanceof IfStatementBlock)
return rewriteIfStatementBlock((IfStatementBlock) sb);
return rewriteIfStatementBlock(prog, (IfStatementBlock) sb);
else if ( sb instanceof ForStatementBlock){
// This also includes ParForStatementBlocks
return rewriteForStatementBlock((ForStatementBlock) sb);
return rewriteForStatementBlock(prog, (ForStatementBlock) sb);
}
else if ( sb instanceof FunctionStatementBlock)
return rewriteFunctionStatementBlock((FunctionStatementBlock) sb);
return rewriteFunctionStatementBlock(prog, (FunctionStatementBlock) sb);
else {
// StatementBlock type (no subclass)
selectFederatedExecutionPlan(sb.getHops());
return rewriteDefaultStatementBlock(prog, sb);
}
return new ArrayList<>(Collections.singletonList(sb));
}

private ArrayList<StatementBlock> rewriteWhileStatementBlock(WhileStatementBlock whileSB){
private ArrayList<StatementBlock> rewriteWhileStatementBlock(DMLProgram prog, WhileStatementBlock whileSB){
Hop whilePredicateHop = whileSB.getPredicateHops();
selectFederatedExecutionPlan(whilePredicateHop);
for ( Statement stm : whileSB.getStatements() ){
WhileStatement whileStm = (WhileStatement) stm;
whileStm.setBody(rewriteStatementBlocks(whileStm.getBody()));
whileStm.setBody(rewriteStatementBlocks(prog, whileStm.getBody()));
}
return new ArrayList<>(Collections.singletonList(whileSB));
}

private ArrayList<StatementBlock> rewriteIfStatementBlock(IfStatementBlock ifSB){
private ArrayList<StatementBlock> rewriteIfStatementBlock(DMLProgram prog, IfStatementBlock ifSB){
selectFederatedExecutionPlan(ifSB.getPredicateHops());
for ( Statement statement : ifSB.getStatements() ){
IfStatement ifStatement = (IfStatement) statement;
ifStatement.setIfBody(rewriteStatementBlocks(ifStatement.getIfBody()));
ifStatement.setElseBody(rewriteStatementBlocks(ifStatement.getElseBody()));
ifStatement.setIfBody(rewriteStatementBlocks(prog, ifStatement.getIfBody()));
ifStatement.setElseBody(rewriteStatementBlocks(prog, ifStatement.getElseBody()));
}
return new ArrayList<>(Collections.singletonList(ifSB));
}

private ArrayList<StatementBlock> rewriteForStatementBlock(ForStatementBlock forSB){
private ArrayList<StatementBlock> rewriteForStatementBlock(DMLProgram prog, ForStatementBlock forSB){
selectFederatedExecutionPlan(forSB.getFromHops());
selectFederatedExecutionPlan(forSB.getToHops());
selectFederatedExecutionPlan(forSB.getIncrementHops());
for ( Statement statement : forSB.getStatements() ){
ForStatement forStatement = ((ForStatement)statement);
forStatement.setBody(rewriteStatementBlocks(forStatement.getBody()));
forStatement.setBody(rewriteStatementBlocks(prog, forStatement.getBody()));
}
return new ArrayList<>(Collections.singletonList(forSB));
}

private ArrayList<StatementBlock> rewriteFunctionStatementBlock(FunctionStatementBlock funcSB){
private ArrayList<StatementBlock> rewriteFunctionStatementBlock(DMLProgram prog, FunctionStatementBlock funcSB){
for ( Statement statement : funcSB.getStatements() ){
FunctionStatement funcStm = (FunctionStatement) statement;
funcStm.setBody(rewriteStatementBlocks(funcStm.getBody()));
funcStm.setBody(rewriteStatementBlocks(prog, funcStm.getBody()));
}
return new ArrayList<>(Collections.singletonList(funcSB));
}

private ArrayList<StatementBlock> rewriteDefaultStatementBlock(DMLProgram prog, StatementBlock sb){
if ( sb.getHops() != null && !sb.getHops().isEmpty() ){
for ( Hop sbHop : sb.getHops() ){
if ( sbHop instanceof FunctionOp ){
String funcName = ((FunctionOp) sbHop).getFunctionName();
FunctionStatementBlock sbFuncBlock = prog.getBuiltinFunctionDictionary().getFunction(funcName);
rewriteStatementBlock(prog, sbFuncBlock);
}
else selectFederatedExecutionPlan(sbHop);
}
}
return new ArrayList<>(Collections.singletonList(sb));
}

/**
* Sets FederatedOutput field of all hops in DAG starting from given root.
* The FederatedOutput chosen for root is the minimum cost HopRel found in memo table for the given root.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@ public class RewriteFederatedExecution extends HopRewriteRule {

@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
if ( roots == null )
return null;
for ( Hop root : roots )
visitHop(root);
if ( roots != null )
for ( Hop root : roots )
rewriteHopDAG(root, state);
return roots;
}

@Override public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
return null;
if ( root != null )
visitHop(root);
return root;
}

private void visitHop(Hop hop){
Expand Down
Loading