Skip to content

Commit

Permalink
[SYSTEMDS-3018] Conflict handling federated plan enumeration
Browse files Browse the repository at this point in the history
Federated plan enumeration build a global data flow graph and computes
optimal plans per interesting property (fed-out, local-out). In trees,
we could purely compose optimal plans from optimal plans of inputs, but
in DAGs optimal input plans of n-ary operations might not agree on the
decisions of common subexpressions.

We mitigate this issue (fed-out vs local-out) decisions by keeping the
data federated, but additionally spawning an asynchronous prefetch
operation to also bring the data into local memory if at least one
subplan prefers local intermediates.

Closes #1476.

Co-authored-by: arnabp <arnab.phani@tugraz.at>
  • Loading branch information
2 people authored and mboehm7 committed Dec 27, 2021
1 parent bd68831 commit ccd6a36
Show file tree
Hide file tree
Showing 19 changed files with 445 additions and 71 deletions.
2 changes: 1 addition & 1 deletion scripts/builtin/normalize.dml
Expand Up @@ -39,6 +39,6 @@ m_normalize = function(Matrix[Double] X)
# compute feature ranges for transformations
cmin = colMins(X);
cmax = colMaxs(X);
# normalize features to range [0,1]
# normalize features to range [0,1]
Y = normalizeApply(X, cmin, cmax);
}
32 changes: 30 additions & 2 deletions src/main/java/org/apache/sysds/hops/Hop.java
Expand Up @@ -93,6 +93,14 @@ public abstract class Hop implements ParseInfo {
*/
protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
protected FederatedCost _federatedCost = new FederatedCost();

/**
* Field defining if prefetch should be activated for operation.
* When prefetch is activated, the output will be transferred from
* remote federated sites to local before one of the subsequent
* local operations.
*/
protected boolean activatePrefetch;

// Estimated size for the output produced from this Hop in bytes
protected double _outputMemEstimate = OptimizerUtils.INVALID_SIZE;
Expand Down Expand Up @@ -187,6 +195,21 @@ public void setExecType(ExecType execType){
public void setFederatedOutput(FederatedOutput federatedOutput){
_federatedOutput = federatedOutput;
}

/**
* Activate prefetch of HOP.
*/
public void activatePrefetch(){
activatePrefetch = true;
}

/**
* Checks if prefetch is activated for this hop.
* @return true if prefetch is activated
*/
public boolean prefetchActivated(){
return activatePrefetch;
}

public void resetExecType()
{
Expand Down Expand Up @@ -352,6 +375,8 @@ public void constructAndSetLopsDataFlowProperties() {
//propagate federated output configuration to lops
if( isFederated() )
getLops().setFederatedOutput(_federatedOutput);
if ( prefetchActivated() )
getLops().activatePrefetch();

//Step 1: construct reblock lop if required (output of hop)
constructAndSetReblockLopIfRequired();
Expand Down Expand Up @@ -869,8 +894,11 @@ protected ExecType findExecTypeByMemEstimate() {
* This method only has an effect if FEDERATED_COMPILATION is activated.
* Federated compilation is activated in OptimizerUtils.
*/
protected void updateETFed(){
if ( someInputFederated() || isFederatedDataOp() )
protected void updateETFed() {
boolean localOut = hasLocalOutput();
boolean fedIn = getInput().stream().anyMatch(
in -> in.hasFederatedOutput() && !(in.prefetchActivated() && localOut));
if( isFederatedDataOp() || fedIn )
_etype = ExecType.FED;
}

Expand Down
Expand Up @@ -20,6 +20,7 @@
package org.apache.sysds.hops.cost;

import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.ipa.MemoTable;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
Expand All @@ -33,8 +34,6 @@
import org.apache.sysds.parser.WhileStatementBlock;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
* Cost estimator for federated executions with methods and constants for going through DML programs to estimate costs.
Expand Down Expand Up @@ -200,10 +199,9 @@ public FederatedCost costEstimate(Hop root){
* @param hopRelMemo memo table of HopRels for calculating input costs
* @return cost estimation of Hop DAG starting from given root HopRel
*/
public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>> hopRelMemo){
public FederatedCost costEstimate(HopRel root, MemoTable hopRelMemo){
// Check if root is in memo table.
if ( hopRelMemo.containsKey(root.hopRef.getHopID())
&& hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut == root.fedOut) ){
if ( hopRelMemo.containsHopRel(root) ){
return root.getCostObject();
}
else {
Expand Down
27 changes: 7 additions & 20 deletions src/main/java/org/apache/sysds/hops/cost/HopRel.java
Expand Up @@ -21,15 +21,14 @@

import org.apache.sysds.api.DMLException;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.ipa.MemoTable;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

Expand All @@ -52,7 +51,7 @@ public class HopRel {
* @param fedOut FederatedOutput value assigned to this HopRel
* @param hopRelMemo memo table storing other HopRels including the inputs of associatedHop
*/
public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, Map<Long, List<HopRel>> hopRelMemo){
public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, MemoTable hopRelMemo){
hopRef = associatedHop;
this.fedOut = fedOut;
setInputDependency(hopRelMemo);
Expand Down Expand Up @@ -108,27 +107,15 @@ public Hop getHopRef(){
* @param hopRelMemo memo table storing HopRels
* @return FOUT HopRel found in hopRelMemo
*/
private HopRel getFOUTHopRel(Hop hop, Map<Long, List<HopRel>> hopRelMemo){
return hopRelMemo.get(hop.getHopID()).stream().filter(in->in.fedOut==FederatedOutput.FOUT).findFirst().orElse(null);
}

/**
* Get the HopRel with minimum cost for given hop
* @param hopRelMemo memo table storing HopRels
* @param input hop for which minimum cost HopRel is found
* @return HopRel with minimum cost for given hop
*/
private HopRel getMinOfInput(Map<Long, List<HopRel>> hopRelMemo, Hop input){
return hopRelMemo.get(input.getHopID()).stream()
.min(Comparator.comparingDouble(a -> a.cost.getTotal()))
.orElseThrow(() -> new DMLException("No element in Memo Table found for input"));
private HopRel getFOUTHopRel(Hop hop, MemoTable hopRelMemo){
return hopRelMemo.getFederatedOutputAlternativeOrNull(hop);
}

/**
* Set valid and optimal input dependency for this HopRel as a field.
* @param hopRelMemo memo table storing input HopRels
*/
private void setInputDependency(Map<Long, List<HopRel>> hopRelMemo){
private void setInputDependency(MemoTable hopRelMemo){
if (hopRef.getInput() != null && hopRef.getInput().size() > 0) {
if ( fedOut == FederatedOutput.FOUT && !hopRef.isFederatedDataOp() ) {
int lowestFOUTIndex = 0;
Expand All @@ -152,7 +139,7 @@ else if(foutHopRel != null) {
for(int i = 0; i < hopRef.getInput().size(); i++) {
if(i != lowestFOUTIndex) {
Hop input = hopRef.getInput(i);
inputHopRels[i] = getMinOfInput(hopRelMemo, input);
inputHopRels[i] = hopRelMemo.getMinCostAlternative(input);
}
else {
inputHopRels[i] = lowestFOUTHopRel;
Expand All @@ -162,7 +149,7 @@ else if(foutHopRel != null) {
} else {
inputDependency.addAll(
hopRef.getInput().stream()
.map(input -> getMinOfInput(hopRelMemo, input))
.map(hopRelMemo::getMinCostAlternative)
.collect(Collectors.toList()));
}
}
Expand Down
Expand Up @@ -19,7 +19,6 @@

package org.apache.sysds.hops.ipa;

import org.apache.sysds.api.DMLException;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
Expand All @@ -45,10 +44,9 @@

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* This rewrite generates a federated execution plan by estimating and setting costs and the FederatedOutput values of
Expand All @@ -57,7 +55,8 @@
*/
public class IPAPassRewriteFederatedPlan extends IPAPass {

private final static Map<Long, List<HopRel>> hopRelMemo = new HashMap<>();
private final static MemoTable hopRelMemo = new MemoTable();
private final static Set<Long> hopRelUpdatedFinal = new HashSet<>();

/**
* Indicates if an IPA pass is applicable for the current configuration.
Expand All @@ -66,7 +65,8 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
* @param fgraph function call graph
* @return true if federated compilation is activated.
*/
@Override public boolean isApplicable(FunctionCallGraph fgraph) {
@Override
public boolean isApplicable(FunctionCallGraph fgraph) {
return OptimizerUtils.FEDERATED_COMPILATION;
}

Expand All @@ -79,7 +79,8 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
* @param fcallSizes function call size infos
* @return false since the function call graph never has to be rebuilt
*/
@Override public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph,
@Override
public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph,
FunctionCallSizeInfo fcallSizes) {
rewriteStatementBlocks(prog, prog.getStatementBlocks());
return false;
Expand Down Expand Up @@ -189,9 +190,7 @@ private ArrayList<StatementBlock> rewriteDefaultStatementBlock(DMLProgram prog,
* @param root hop for which FederatedOutput needs to be set
*/
private void setFinalFedout(Hop root) {
HopRel optimalRootHopRel = hopRelMemo.get(root.getHopID()).stream()
.min(Comparator.comparingDouble(HopRel::getCost))
.orElseThrow(() -> new DMLException("Hop root " + root + " has no feasible federated output alternatives"));
HopRel optimalRootHopRel = hopRelMemo.getMinCostAlternative(root);
setFinalFedout(root, optimalRootHopRel);
}

Expand All @@ -202,8 +201,21 @@ private void setFinalFedout(Hop root) {
* @param rootHopRel from which FederatedOutput value and cost is retrieved
*/
private void setFinalFedout(Hop root, HopRel rootHopRel) {
updateFederatedOutput(root, rootHopRel);
visitInputDependency(rootHopRel);
if ( hopRelUpdatedFinal.contains(root.getHopID()) ){
if((rootHopRel.hasLocalOutput() ^ root.hasLocalOutput()) && hopRelMemo.hasFederatedOutputAlternative(root)){
// Update with FOUT alternative without visiting inputs
updateFederatedOutput(root, hopRelMemo.getFederatedOutputAlternative(root));
root.activatePrefetch();
}
else {
// Update without visiting inputs
updateFederatedOutput(root, rootHopRel);
}
}
else {
updateFederatedOutput(root, rootHopRel);
visitInputDependency(rootHopRel);
}
}

/**
Expand All @@ -226,6 +238,7 @@ private void visitInputDependency(HopRel rootHopRel) {
private void updateFederatedOutput(Hop root, HopRel updateHopRel) {
root.setFederatedOutput(updateHopRel.getFederatedOutput());
root.setFederatedCost(updateHopRel.getCostObject());
hopRelUpdatedFinal.add(root.getHopID());
}

/**
Expand Down Expand Up @@ -257,7 +270,7 @@ private void selectFederatedExecutionPlan(Hop root) {
*/
private void visitFedPlanHop(Hop currentHop) {
// If the currentHop is in the hopRelMemo table, it means that it has been visited
if(hopRelMemo.containsKey(currentHop.getHopID()))
if(hopRelMemo.containsHop(currentHop))
return;
// If the currentHop has input, then the input should be visited depth-first
if(currentHop.getInput() != null && currentHop.getInput().size() > 0) {
Expand All @@ -273,7 +286,7 @@ private void visitFedPlanHop(Hop currentHop) {
}
if(hopRels.isEmpty())
hopRels.add(new HopRel(currentHop, FEDInstruction.FederatedOutput.NONE, hopRelMemo));
hopRelMemo.put(currentHop.getHopID(), hopRels);
hopRelMemo.put(currentHop, hopRels);
}

/**
Expand Down Expand Up @@ -319,8 +332,8 @@ private boolean isFOUTSupported(Hop associatedHop) {
if(associatedHop instanceof AggUnaryOp && associatedHop.isScalar())
return false;
// It can only be FOUT if at least one of the inputs are FOUT, except if it is a federated DataOp
if(associatedHop.getInput().stream().noneMatch(input -> hopRelMemo.get(input.getHopID()).stream()
.anyMatch(HopRel::hasFederatedOutput)) && !associatedHop.isFederatedDataOp())
if(associatedHop.getInput().stream().noneMatch(hopRelMemo::hasFederatedOutputAlternative)
&& !associatedHop.isFederatedDataOp())
return false;
return true;
}
Expand Down

0 comments on commit ccd6a36

Please sign in to comment.