Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/main/java/org/apache/sysds/conf/ConfigurationManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ public static boolean isBroadcastEnabled() {
return (getDMLConfig().getBooleanValue(DMLConfig.ASYNC_SPARK_BROADCAST)
|| OptimizerUtils.ASYNC_BROADCAST_SPARK);
}
public static boolean isCheckpointEnabled() {
return (getDMLConfig().getBooleanValue(DMLConfig.ASYNC_SPARK_CHECKPOINT)
|| OptimizerUtils.ASYNC_CHECKPOINT_SPARK);
}

public static ILinearize.DagLinearization getLinearizationOrder() {
if (OptimizerUtils.MAX_PARALLELIZE_ORDER)
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/org/apache/sysds/conf/DMLConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ public class DMLConfig
/** Asynchronous triggering of Spark OPs and operator placement **/
public static final String ASYNC_SPARK_PREFETCH = "sysds.async.prefetch"; // boolean: enable asynchronous prefetching spark intermediates
public static final String ASYNC_SPARK_BROADCAST = "sysds.async.broadcast"; // boolean: enable asynchronous broadcasting CP intermediates
public static final String ASYNC_SPARK_CHECKPOINT = "sysds.async.checkpoint"; // boolean: enable asynchronous persisting of Spark intermediates
//internal config
public static final String DEFAULT_SHARED_DIR_PERMISSION = "777"; //for local fs and DFS

Expand Down Expand Up @@ -202,6 +203,7 @@ public class DMLConfig
_defaultVals.put(PRIVACY_CONSTRAINT_MOCK, null);
_defaultVals.put(ASYNC_SPARK_PREFETCH, "false" );
_defaultVals.put(ASYNC_SPARK_BROADCAST, "false" );
_defaultVals.put(ASYNC_SPARK_CHECKPOINT, "false" );
}

public DMLConfig() {
Expand Down Expand Up @@ -454,7 +456,8 @@ public String getConfigInfo() {
PRINT_GPU_MEMORY_INFO, AVAILABLE_GPUS, SYNCHRONIZE_GPU, EAGER_CUDA_FREE, FLOATING_POINT_PRECISION,
GPU_EVICTION_POLICY, LOCAL_SPARK_NUM_THREADS, EVICTION_SHADOW_BUFFERSIZE, GPU_MEMORY_ALLOCATOR,
GPU_MEMORY_UTILIZATION_FACTOR, USE_SSL_FEDERATED_COMMUNICATION, DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT,
FEDERATED_TIMEOUT, FEDERATED_MONITOR_FREQUENCY, ASYNC_SPARK_PREFETCH, ASYNC_SPARK_BROADCAST
FEDERATED_TIMEOUT, FEDERATED_MONITOR_FREQUENCY, ASYNC_SPARK_PREFETCH, ASYNC_SPARK_BROADCAST,
ASYNC_SPARK_CHECKPOINT
};

StringBuilder sb = new StringBuilder();
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/hops/OptimizerUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ public enum MemoryManager {
*/
public static boolean ASYNC_PREFETCH_SPARK = false;
public static boolean ASYNC_BROADCAST_SPARK = false;
public static boolean ASYNC_CHECKPOINT_SPARK = false;

/**
* Heuristic-based instruction ordering to maximize inter-operator parallelism.
Expand Down
17 changes: 13 additions & 4 deletions src/main/java/org/apache/sysds/lops/Checkpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.instructions.InstructionUtils;

import java.util.Arrays;

/**
* Lop for checkpoint operations. For example, on Spark, the semantic of a checkpoint
Expand All @@ -38,13 +39,15 @@
*/
public class Checkpoint extends Lop
{
public static final String OPCODE = "chkpoint";

public static final String DEFAULT_CP_OPCODE = "chkpoint";
public static final String ASYNC_CP_OPCODE = "chkpoint_e";

public static final StorageLevel DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK();
public static final StorageLevel SER_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK_SER();
public static final boolean CHECKPOINT_SPARSE_CSR = true;

private StorageLevel _storageLevel;
private boolean _async = false;


/**
Expand All @@ -55,16 +58,22 @@ public class Checkpoint extends Lop
* @param dt data type
* @param vt value type
* @param level storage level
* @param isAsync true if eager and asynchronous checkpoint
*/
public Checkpoint(Lop input, DataType dt, ValueType vt, String level) {
public Checkpoint(Lop input, DataType dt, ValueType vt, String level, boolean isAsync) {
super(Lop.Type.Checkpoint, dt, vt);
addInput(input);
input.addOutput(this);

_storageLevel = StorageLevel.fromString(level);
_async = isAsync;
lps.setProperties(inputs, ExecType.SPARK);
}

public Checkpoint(Lop input, DataType dt, ValueType vt, String level) {
this(input, dt, vt, level, false);
}

public StorageLevel getStorageLevel()
{
return _storageLevel;
Expand All @@ -89,7 +98,7 @@ public String getInstructions(String input1, String output) {

return InstructionUtils.concatOperands(
getExecType().name(),
OPCODE,
_async ? ASYNC_CP_OPCODE : DEFAULT_CP_OPCODE,
getInputs().get(0).prepInputOperand(input1),
prepOutputOperand(output),
getStorageLevelString(_storageLevel));
Expand Down
86 changes: 0 additions & 86 deletions src/main/java/org/apache/sysds/lops/compile/Dag.java
Original file line number Diff line number Diff line change
Expand Up @@ -237,56 +237,6 @@ private void prefetchFederated(List<Lop> lops){
}
}

private static List<Lop> addPrefetchLop(List<Lop> nodes) {
List<Lop> nodesWithPrefetch = new ArrayList<>();

//Find the Spark nodes with all CP outputs
for (Lop l : nodes) {
nodesWithPrefetch.add(l);
if (isPrefetchNeeded(l)) {
//TODO: No prefetch if the parent is placed right after the spark OP
//or push the parent further to increase parallelism
List<Lop> oldOuts = new ArrayList<>(l.getOutputs());
//Construct a Prefetch lop that takes this Spark node as a input
UnaryCP prefetch = new UnaryCP(l, OpOp1.PREFETCH, l.getDataType(), l.getValueType(), ExecType.CP);
for (Lop outCP : oldOuts) {
//Rewire l -> outCP to l -> Prefetch -> outCP
prefetch.addOutput(outCP);
outCP.replaceInput(l, prefetch);
l.removeOutput(outCP);
//FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
}
//Place it immediately after the Spark lop in the node list
nodesWithPrefetch.add(prefetch);
}
}
return nodesWithPrefetch;
}

private static List<Lop> addBroadcastLop(List<Lop> nodes) {
List<Lop> nodesWithBroadcast = new ArrayList<>();

for (Lop l : nodes) {
nodesWithBroadcast.add(l);
if (isBroadcastNeeded(l)) {
List<Lop> oldOuts = new ArrayList<>(l.getOutputs());
//Construct a Broadcast lop that takes this Spark node as an input
UnaryCP bc = new UnaryCP(l, OpOp1.BROADCAST, l.getDataType(), l.getValueType(), ExecType.CP);
//FIXME: Wire Broadcast only with the necessary outputs
for (Lop outCP : oldOuts) {
//Rewire l -> outCP to l -> Broadcast -> outCP
bc.addOutput(outCP);
outCP.replaceInput(l, bc);
l.removeOutput(outCP);
//FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
}
//Place it immediately after the Spark lop in the node list
nodesWithBroadcast.add(bc);
}
}
return nodesWithBroadcast;
}

private ArrayList<Instruction> doPlainInstructionGen(StatementBlock sb, List<Lop> nodes)
{
//prepare basic instruction sets
Expand Down Expand Up @@ -319,42 +269,6 @@ private static boolean isTransientWriteRead(Data dnode) {
&& dnode.getOutputParameters().getLabel().equals(input.getOutputParameters().getLabel());
}

private static boolean isPrefetchNeeded(Lop lop) {
// Run Prefetch for a Spark instruction if the instruction is a Transformation
// and the output is consumed by only CP instructions.
boolean transformOP = lop.getExecType() == ExecType.SPARK && lop.getAggType() != SparkAggType.SINGLE_BLOCK
// Always Action operations
&& !(lop.getDataType() == DataType.SCALAR)
&& !(lop instanceof MapMultChain) && !(lop instanceof PickByCount)
&& !(lop instanceof MMZip) && !(lop instanceof CentralMoment)
&& !(lop instanceof CoVariance)
// Not qualified for prefetching
&& !(lop instanceof Checkpoint) && !(lop instanceof ReBlock)
&& !(lop instanceof CSVReBlock)
// Cannot filter Transformation cases from Actions (FIXME)
&& !(lop instanceof MMTSJ) && !(lop instanceof UAggOuterChain)
&& !(lop instanceof ParameterizedBuiltin) && !(lop instanceof SpoofFused);

//FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
boolean hasParameterizedOut = lop.getOutputs().stream()
.anyMatch(out -> ((out instanceof ParameterizedBuiltin)
|| (out instanceof GroupedAggregate)
|| (out instanceof GroupedAggregateM)));
//TODO: support non-matrix outputs
return transformOP && !hasParameterizedOut
&& lop.isAllOutputsCP() && lop.getDataType() == DataType.MATRIX;
}

private static boolean isBroadcastNeeded(Lop lop) {
// Asynchronously broadcast a matrix if that is produced by a CP instruction,
// and at least one Spark parent needs to broadcast this intermediate (eg. mapmm)
boolean isBc = lop.getOutputs().stream()
.anyMatch(out -> (out.getBroadcastInput() == lop));
//TODO: Early broadcast objects that are bigger than a single block
//return isCP && isBc && lop.getDataTypes() == DataType.Matrix;
return isBc && lop.getDataType() == DataType.MATRIX;
}

private static List<Instruction> deleteUpdatedTransientReadVariables(StatementBlock sb, List<Lop> nodeV) {
List<Instruction> insts = new ArrayList<>();
if ( sb == null ) //return modifiable list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,31 @@ private static List<Lop> addBroadcastLop(List<Lop> nodes) {
return nodesWithBroadcast;
}

private static List<Lop> addAsyncEagerCheckpointLop(List<Lop> nodes) {
List<Lop> nodesWithCheckpoint = new ArrayList<>();
// Find the Spark action nodes
for (Lop l : nodes) {
if (isCheckpointNeeded(l)) {
List<Lop> oldInputs = new ArrayList<>(l.getInputs());
// Place a Checkpoint node just below this node (Spark action)
for (Lop in : oldInputs) {
if (in.getExecType() != ExecType.SPARK)
continue;
// Rewire in -> l to in -> Checkpoint -> l
//UnaryCP checkpoint = new UnaryCP(in, OpOp1.TRIGREMOTE, in.getDataType(), in.getValueType(), ExecType.CP);
Lop checkpoint = new Checkpoint(in, in.getDataType(), in.getValueType(),
Checkpoint.getDefaultStorageLevelString(), true);
checkpoint.addOutput(l);
l.replaceInput(in, checkpoint);
in.removeOutput(l);
nodesWithCheckpoint.add(checkpoint);
}
}
nodesWithCheckpoint.add(l);
}
return nodesWithCheckpoint;
}

private static boolean isPrefetchNeeded(Lop lop) {
// Run Prefetch for a Spark instruction if the instruction is a Transformation
// and the output is consumed by only CP instructions.
Expand Down Expand Up @@ -354,4 +379,28 @@ private static boolean isBroadcastNeeded(Lop lop) {
//return isCP && isBc && lop.getDataTypes() == DataType.Matrix;
return isBc && lop.getDataType() == DataType.MATRIX;
}

private static boolean isCheckpointNeeded(Lop lop) {
// Place checkpoint_e just before a Spark action (FIXME)
boolean actionOP = lop.getExecType() == ExecType.SPARK
&& ((lop.getAggType() == SparkAggType.SINGLE_BLOCK)
// Always Action operations
|| (lop.getDataType() == DataType.SCALAR)
|| (lop instanceof MapMultChain) || (lop instanceof PickByCount)
|| (lop instanceof MMZip) || (lop instanceof CentralMoment)
|| (lop instanceof CoVariance) || (lop instanceof MMTSJ))
// Not qualified for Checkpoint
&& !(lop instanceof Checkpoint) && !(lop instanceof ReBlock)
&& !(lop instanceof CSVReBlock)
// Cannot filter Transformation cases from Actions (FIXME)
&& !(lop instanceof UAggOuterChain)
&& !(lop instanceof ParameterizedBuiltin) && !(lop instanceof SpoofFused);

//FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
boolean hasParameterizedOut = lop.getOutputs().stream()
.anyMatch(out -> ((out instanceof ParameterizedBuiltin)
|| (out instanceof GroupedAggregate)
|| (out instanceof GroupedAggregateM)));
return actionOP && !hasParameterizedOut;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
import org.apache.sysds.runtime.instructions.cp.SqlCPInstruction;
import org.apache.sysds.runtime.instructions.cp.StringInitCPInstruction;
import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.TriggerRemoteOpsCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UaggOuterChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
Expand Down Expand Up @@ -482,9 +481,6 @@ public static CPInstruction parseSingleInstruction ( CPType cptype, String str )
case Broadcast:
return BroadcastCPInstruction.parseInstruction(str);

case TrigRemote:
return TriggerRemoteOpsCPInstruction.parseInstruction(str);

default:
throw new DMLRuntimeException("Invalid CP Instruction Type: " + cptype );
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put("libsvmrblk", SPType.LIBSVMReblock);

// Spark-specific instructions
String2SPInstructionType.put( Checkpoint.OPCODE, SPType.Checkpoint);
String2SPInstructionType.put( Checkpoint.DEFAULT_CP_OPCODE, SPType.Checkpoint);
String2SPInstructionType.put( Checkpoint.ASYNC_CP_OPCODE, SPType.Checkpoint);
String2SPInstructionType.put( Compression.OPCODE, SPType.Compression);
String2SPInstructionType.put( DeCompression.OPCODE, SPType.DeCompression);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,18 @@
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.utils.stats.SparkStatistics;

public class TriggerRemoteOperationsTask implements Runnable {
public class TriggerCheckpointTask implements Runnable {
MatrixObject _remoteOperationsRoot;

public TriggerRemoteOperationsTask(MatrixObject mo) {
public TriggerCheckpointTask(MatrixObject mo) {
_remoteOperationsRoot = mo;
}

@Override
public void run() {
boolean triggered = false;
synchronized (_remoteOperationsRoot) {
// FIXME: Handle double execution
if (_remoteOperationsRoot.isPendingRDDOps()) {
JavaPairRDD<?, ?> rdd = _remoteOperationsRoot.getRDDHandle().getRDD();
rdd.persist(Checkpoint.DEFAULT_STORAGE_LEVEL).count();
Expand All @@ -45,6 +46,6 @@ public void run() {
}

if (DMLScript.STATISTICS && triggered)
SparkStatistics.incAsyncTriggerRemoteCount(1);
SparkStatistics.incAsyncTriggerCheckpointCount(1);
}
}

This file was deleted.

Loading