Skip to content

Commit

Permalink
[SYSTEMDS-418] Performance improvements lineage reuse probing/spilling
Browse files Browse the repository at this point in the history
This patch makes some minor performance improvements to the lineage
reuse probing and cache put operations. Specifically, we now avoid
unnecessary lineage hashing and comparisons by using lists instead of
hash maps, move the time computations into the reuse path (to not affect
the code path without lineage reuse), avoid unnecessary branching, and
materialize the score of cache entries to avoid repeated computation
for the log N comparisons per add/remove/constaints operation.

For 100K iterations and ~40 ops per iteration, lineage tracing w/ reuse
improved from 41.9s to 38.8s (pure lineage tracing: 27.9s).
  • Loading branch information
mboehm7 committed Jul 2, 2020
1 parent 0e369bd commit 1abe9cb
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ private void executeSingleInstruction( Instruction currInst, ExecutionContext ec
tmp.processInstruction(ec);

// cache result
LineageCache.putValue(tmp, ec, System.nanoTime()-et0);
LineageCache.putValue(tmp, ec, et0);

// post-process instruction (debug)
tmp.postprocessInstruction( ec );
Expand Down
48 changes: 28 additions & 20 deletions src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

package org.apache.sysds.runtime.lineage;

import org.apache.commons.lang3.tuple.MutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.FileFormat;
Expand All @@ -45,6 +47,8 @@
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.MetaDataFormat;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -81,39 +85,40 @@ public static boolean reuse(Instruction inst, ExecutionContext ec) {
if (LineageCacheConfig.isReusable(inst, ec)) {
ComputationCPInstruction cinst = (ComputationCPInstruction) inst;
LineageItem instLI = cinst.getLineageItem(ec).getValue();
HashMap<LineageItem, LineageCacheEntry> liMap = new HashMap<>();
List<MutablePair<LineageItem, LineageCacheEntry>> liList = null;
if (inst instanceof MultiReturnBuiltinCPInstruction) {
liList = new ArrayList<>();
MultiReturnBuiltinCPInstruction mrInst = (MultiReturnBuiltinCPInstruction)inst;
for (int i=0; i<mrInst.getNumOutputs(); i++) {
String opcode = instLI.getOpcode() + String.valueOf(i);
liMap.put(new LineageItem(opcode, instLI.getInputs()), null);
liList.add(MutablePair.of(new LineageItem(opcode, instLI.getInputs()), null));
}
}
else
liMap.put(instLI, null);
liList = Arrays.asList(MutablePair.of(instLI, null));

//atomic try reuse full/partial and set placeholder, without
//obtaining value to avoid blocking in critical section
LineageCacheEntry e = null;
boolean reuseAll = true;
synchronized( _cache ) {
//try to reuse full or partial intermediates
for (LineageItem item : liMap.keySet()) {
for (MutablePair<LineageItem,LineageCacheEntry> item : liList) {
if (LineageCacheConfig.getCacheType().isFullReuse())
e = LineageCache.probe(item) ? getIntern(item) : null;
e = LineageCache.probe(item.getKey()) ? getIntern(item.getKey()) : null;
//TODO need to also move execution of compensation plan out of here
//(create lazily evaluated entry)
if (e == null && LineageCacheConfig.getCacheType().isPartialReuse())
if( LineageRewriteReuse.executeRewrites(inst, ec) )
e = getIntern(item);
e = getIntern(item.getKey());
//TODO: MultiReturnBuiltin and partial rewrites
reuseAll &= (e != null);
liMap.put(item, e);
item.setValue(e);

//create a placeholder if no reuse to avoid redundancy
//(e.g., concurrent threads that try to start the computation)
if(e == null && isMarkedForCaching(inst, ec)) {
putIntern(item, cinst.output.getDataType(), null, null, 0);
putIntern(item.getKey(), cinst.output.getDataType(), null, null, 0);
//FIXME: different o/p datatypes for MultiReturnBuiltins.
}
}
Expand All @@ -122,7 +127,7 @@ public static boolean reuse(Instruction inst, ExecutionContext ec) {

if(reuse) { //reuse
//put reuse value into symbol table (w/ blocking on placeholders)
for (Map.Entry<LineageItem, LineageCacheEntry> entry : liMap.entrySet()) {
for (MutablePair<LineageItem, LineageCacheEntry> entry : liList) {
e = entry.getValue();
String outName = null;
if (inst instanceof MultiReturnBuiltinCPInstruction)
Expand Down Expand Up @@ -243,39 +248,42 @@ public static void putMatrix(Instruction inst, ExecutionContext ec, long compute
}
}

public static void putValue(Instruction inst, ExecutionContext ec, long computetime) {
public static void putValue(Instruction inst, ExecutionContext ec, long starttime) {
if (ReuseCacheType.isNone())
return;
long computetime = System.nanoTime() - starttime;
if (LineageCacheConfig.isReusable(inst, ec) ) {
//if (!isMarkedForCaching(inst, ec)) return;
HashMap<LineageItem, Data> liDataMap = new HashMap<>();
List<Pair<LineageItem, Data>> liData = null;
LineageItem instLI = ((LineageTraceable) inst).getLineageItem(ec).getValue();
if (inst instanceof MultiReturnBuiltinCPInstruction) {
liData = new ArrayList<>();
MultiReturnBuiltinCPInstruction mrInst = (MultiReturnBuiltinCPInstruction)inst;
for (int i=0; i<mrInst.getNumOutputs(); i++) {
String opcode = instLI.getOpcode() + String.valueOf(i);
LineageItem li = new LineageItem(opcode, instLI.getInputs());
Data value = ec.getVariable(mrInst.getOutput(i));
liDataMap.put(li, value);
liData.add(Pair.of(li, value));
}
}
else
liDataMap.put(instLI, ec.getVariable(((ComputationCPInstruction) inst).output));
liData = Arrays.asList(Pair.of(instLI, ec.getVariable(((ComputationCPInstruction) inst).output)));
synchronized( _cache ) {
for (Map.Entry<LineageItem, Data> entry : liDataMap.entrySet()) {
for (Pair<LineageItem, Data> entry : liData) {
LineageItem item = entry.getKey();
Data data = entry.getValue();
LineageCacheEntry centry = _cache.get(item);
if (data instanceof MatrixObject)
_cache.get(item).setValue(((MatrixObject)data).acquireReadAndRelease(), computetime);
centry.setValue(((MatrixObject)data).acquireReadAndRelease(), computetime);
else if (data instanceof ScalarObject)
_cache.get(item).setValue((ScalarObject)data, computetime);
centry.setValue((ScalarObject)data, computetime);
else
throw new DMLRuntimeException("Lineage Cache: unsupported data: "+data.getDataType());

//maintain order for eviction
LineageCacheEviction.addEntry(_cache.get(item));
LineageCacheEviction.addEntry(centry);

long size = _cache.get(item).getSize();
long size = centry.getSize();
if (!LineageCacheEviction.isBelowThreshold(size))
LineageCacheEviction.makeSpace(_cache, size);
LineageCacheEviction.updateSize(size, true);
Expand All @@ -284,8 +292,8 @@ else if (data instanceof ScalarObject)
}
}

public static void putValue(List<DataIdentifier> outputs, LineageItem[] liInputs,
String name, ExecutionContext ec, long computetime)
public static void putValue(List<DataIdentifier> outputs,
LineageItem[] liInputs, String name, ExecutionContext ec, long computetime)
{
if (!LineageCacheConfig.isMultiLevelReuse())
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ private enum CachedItemTail {

private static LineageCachePolicy _cachepolicy = null;
// Weights for scoring components (computeTime/size, LRU timestamp)
private static double[] WEIGHTS = {0, 1};
protected static double[] WEIGHTS = {0, 1};

protected enum LineageCacheStatus {
EMPTY, //Placeholder with no data. Cannot be evicted.
Expand All @@ -121,14 +121,9 @@ public enum LineageCachePolicy {
}

protected static Comparator<LineageCacheEntry> LineageCacheComparator = (e1, e2) -> {
// Gather the weights for scoring components
double w1 = LineageCacheConfig.WEIGHTS[0];
double w2 = LineageCacheConfig.WEIGHTS[1];
// Generate scores
double score1 = w1*(((double)e1._computeTime)/e1.getSize()) + w2*e1.getTimestamp();
double score2 = w1*((double)e2._computeTime)/e2.getSize() + w2*e2.getTimestamp();
// Generate order. If scores are same, order by LineageItem ID.
return score1 == score2 ? Long.compare(e1._key.getId(), e2._key.getId()) : score1 < score2 ? -1 : 1;
return e1.score == e2.score ?
Long.compare(e1._key.getId(), e2._key.getId()) :
e1.score < e2.score ? -1 : 1;
};

//----------------------------------------------------------------//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class LineageCacheEntry {
protected LineageCacheStatus _status;
protected LineageCacheEntry _nextEntry;
protected LineageItem _origItem;
protected double score;

public LineageCacheEntry(LineageItem key, DataType dt, MatrixBlock Mval, ScalarObject Sval, long computetime) {
_key = key;
Expand Down Expand Up @@ -123,9 +124,18 @@ protected synchronized void setNullValues() {

protected synchronized void setTimestamp() {
_timestamp = System.currentTimeMillis();
recomputeScore();
}

protected synchronized long getTimestamp() {
return _timestamp;
}

private void recomputeScore() {
// Gather the weights for scoring components
double w1 = LineageCacheConfig.WEIGHTS[0];
double w2 = LineageCacheConfig.WEIGHTS[1];
// Generate scores
score = w1*(((double)_computeTime)/getSize()) + w2*getTimestamp();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

import org.apache.sysds.api.DMLScript;
Expand All @@ -36,7 +37,7 @@ public class LineageCacheEviction
{
private static long _cachesize = 0;
private static long CACHE_LIMIT; //limit in bytes
protected static final HashSet<LineageItem> _removelist = new HashSet<>();
protected static final Set<LineageItem> _removelist = new HashSet<>();
private static final Map<LineageItem, SpilledItem> _spillList = new HashMap<>();
private static String _outdir = null;
private static TreeSet<LineageCacheEntry> weightedQueue = new TreeSet<>(LineageCacheConfig.LineageCacheComparator);
Expand Down Expand Up @@ -202,8 +203,6 @@ protected static void makeSpace(Map<LineageItem, LineageCacheEntry> cache, long
//TODO: Graceful handling of status.
}

double exectime = ((double) e._computeTime) / 1000000; // in milliseconds

if (!e.isMatrixValue()) {
// No spilling for scalar entries. Just delete those.
// Note: scalar entries with higher computation time are pinned.
Expand All @@ -213,6 +212,7 @@ protected static void makeSpace(Map<LineageItem, LineageCacheEntry> cache, long

// Estimate time to write to FS + read from FS.
double spilltime = getDiskSpillEstimate(e) * 1000; // in milliseconds
double exectime = ((double) e._computeTime) / 1000000; // in milliseconds

if (LineageCache.DEBUG) {
if (exectime > LineageCacheConfig.MIN_SPILL_TIME_ESTIMATE) {
Expand All @@ -227,19 +227,13 @@ protected static void makeSpace(Map<LineageItem, LineageCacheEntry> cache, long
if (spilltime < LineageCacheConfig.MIN_SPILL_TIME_ESTIMATE) {
// Can't trust the estimate if less than 100ms.
// Spill if it takes longer to recompute.
if (exectime >= LineageCacheConfig.MIN_SPILL_TIME_ESTIMATE)
//spillToLocalFS(e);
removeOrSpillEntry(cache, e, true); //spill
else
removeOrSpillEntry(cache, e, false); //delete
removeOrSpillEntry(cache, e, //spill or delete
exectime >= LineageCacheConfig.MIN_SPILL_TIME_ESTIMATE);
}
else {
// Spill if it takes longer to recompute than spilling.
if (exectime > spilltime)
//spillToLocalFS(e);
removeOrSpillEntry(cache, e, true); //spill
else
removeOrSpillEntry(cache, e, false); //delete
removeOrSpillEntry(cache, e, //spill or delete
exectime > spilltime);
}
}
}
Expand Down

0 comments on commit 1abe9cb

Please sign in to comment.