Skip to content

Commit

Permalink
[SYSTEMDS-3621] Adaptive delayed lineage caching
Browse files Browse the repository at this point in the history
This patch introduces a new feature to control the entry to the
local lineage cache. If enabled, we delay the caching of the
large matrix blocks, which were never cached/evicted before.
To avoid unnecessarily restricting entries, we only delay the
matrices that are larger than 5% of the available memory.
This way we guarantee not delay for larger caches.

Closes #1917
  • Loading branch information
phaniarnab committed Sep 21, 2023
1 parent f7e98e3 commit 0d78859
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 51 deletions.
48 changes: 31 additions & 17 deletions src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,21 @@ public static boolean reuse(Instruction inst, ExecutionContext ec) {
MatrixBlock mb = e.getMBValue(); //wait if another thread is executing the same inst.
if (mb == null && e.getCacheStatus() == LineageCacheStatus.NOTCACHED)
return false; //the executing thread removed this entry from cache
else
ec.setMatrixOutput(outName, mb);
if (e.getCacheStatus() == LineageCacheStatus.TOCACHE) { //not cached yet
ec.replaceLineageItem(outName, e._key); //reuse the lineage trace
return false;
}
ec.setMatrixOutput(outName, mb);
}
else if (e.isScalarValue()) {
ScalarObject so = e.getSOValue(); //wait if another thread is executing the same inst.
if (so == null && e.getCacheStatus() == LineageCacheStatus.NOTCACHED)
return false; //the executing thread removed this entry from cache
else
ec.setScalarOutput(outName, so);
if (e.getCacheStatus() == LineageCacheStatus.TOCACHE) { //not cached yet
ec.replaceLineageItem(outName, e._key); //reuse the lineage trace
return false;
}
ec.setScalarOutput(outName, so);
}
else if (e.isRDDPersist()) {
RDDObject rdd = e.getRDDObject();
Expand Down Expand Up @@ -254,6 +260,8 @@ public static boolean reuse(List<String> outNames, List<DataIdentifier> outParam
MatrixBlock mb = e.getMBValue();
if (mb == null && e.getCacheStatus() == LineageCacheStatus.NOTCACHED)
return false; //the executing thread removed this entry from cache
if (e.getCacheStatus() == LineageCacheStatus.TOCACHE) //not cached yet
return false;
MetaDataFormat md = new MetaDataFormat(
e.getMBValue().getDataCharacteristics(),FileFormat.BINARY);
md.getDataCharacteristics().setBlocksize(ConfigurationManager.getBlocksize());
Expand Down Expand Up @@ -286,6 +294,8 @@ else if (e.isScalarValue()) {
boundValue = e.getSOValue();
if (boundValue == null && e.getCacheStatus() == LineageCacheStatus.NOTCACHED)
return false; //the executing thread removed this entry from cache
if (e.getCacheStatus() == LineageCacheStatus.TOCACHE) //not cached yet
return false;
}
//TODO: support reusing RDD output of functions

Expand Down Expand Up @@ -514,12 +524,7 @@ public static byte[] reuseSerialization(LineageItem objLI) {

public static boolean probe(LineageItem key) {
//TODO problematic as after probe the matrix might be kicked out of cache
boolean p = _cache.containsKey(key); // in cache or in disk
if (!p && DMLScript.STATISTICS && LineageCacheEviction._removelist.containsKey(key))
// The sought entry was in cache but removed later
LineageCacheStatistics.incrementDelHits();

return p;
return _cache.containsKey(key);
}

private static boolean probeRDDDistributed(LineageItem key) {
Expand Down Expand Up @@ -715,6 +720,16 @@ private static void putValueCPU(Instruction inst, List<Pair<LineageItem, Data>>
continue;
}

//delay caching of large matrix blocks if the feature is enabled
if (centry.getCacheStatus() == LineageCacheStatus.EMPTY && LineageCacheConfig.isDelayedCaching()) {
if (data instanceof MatrixObject //no delayed caching for scalars
&& !LineageCacheEviction._removelist.containsKey(centry._key) //evicted before
&& size > 0.05 * LineageCacheEviction.getAvailableSpace()) { //size adaptive
centry.setCacheStatus(LineageCacheStatus.TOCACHE);
continue;
}
}

//make space for the data
if (!LineageCacheEviction.isBelowThreshold(size))
LineageCacheEviction.makeSpace(_cache, size);
Expand All @@ -725,6 +740,7 @@ private static void putValueCPU(Instruction inst, List<Pair<LineageItem, Data>>
centry.setValue(mb, computetime);
else if (data instanceof ScalarObject)
centry.setValue((ScalarObject)data, computetime);
centry.setCacheStatus(LineageCacheStatus.CACHED);

if (DMLScript.STATISTICS && LineageCacheEviction._removelist.containsKey(centry._key)) {
// Add to missed compute time
Expand Down Expand Up @@ -785,8 +801,7 @@ private static void putValueRDD(Instruction inst, LineageItem instLI, ExecutionC
}
boolean opToPersist = LineageCacheConfig.isReusableRDDType(inst);
// Return if the intermediate is not to be persisted in the executors
// and the local only RDD caching is disabled
if (!opToPersist && !LineageCacheConfig.ENABLE_LOCAL_ONLY_RDD_CACHING) {
if (!opToPersist) {
removePlaceholder(instLI);
return;
}
Expand Down Expand Up @@ -1064,8 +1079,6 @@ private static void putIntern(LineageItem key, DataType dt, MatrixBlock Mval, Sc
LineageCacheEviction.addEntry(newItem);

_cache.put(key, newItem);
if (DMLScript.STATISTICS)
LineageCacheStatistics.incrementMemWrites();
}

private static LineageCacheEntry getIntern(LineageItem key) {
Expand Down Expand Up @@ -1105,8 +1118,8 @@ private static void mvIntern(LineageItem item, LineageItem probeItem, long compu
LineageCacheEntry e = _cache.get(item);
boolean exists = !e.isNullVal();
e.copyValueFrom(oe, computetime);
if (e.isNullVal())
throw new DMLRuntimeException("Lineage Cache: Original item is empty: "+oe._key);
//if (e.isNullVal())
// throw new DMLRuntimeException("Lineage Cache: Original item is empty: "+oe._key);

e._origItem = probeItem;
// Add itself as original item to navigate the list.
Expand Down Expand Up @@ -1205,7 +1218,8 @@ private static void persistRDDIntern(LineageCacheEntry centry, long estimatedSiz
rdd = rdd.persist(StorageLevel.MEMORY_AND_DISK());
//cut-off RDD lineage & broadcasts to prevent errors on
// task closure serialization with destroyed broadcasts
rdd.checkpoint();
//rdd.checkpoint();
rdd.rdd().localCheckpoint();
rddObj.setRDD(rdd);
rddObj.setCheckpointRDD(true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.BinaryMatrixMatrixCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryScalarScalarCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
Expand All @@ -53,24 +54,24 @@ public class LineageCacheConfig
"uamean", "max", "min", "ifelse", "-", "sqrt", "<", ">", "uak+", "<=",
"^", "uamax", "uark+", "uacmean", "eigen","ctable", "ctableexpand", "replace",
"^2", "*2", "uack+", "tak+*", "uacsqk+", "uark+", "n+", "uarimax", "qsort",
"qpick", "transformapply", "uarmax", "n+", "-*", "castdtm", "lowertri",
"qpick", "transformapply", "uarmax", "n+", "-*", "castdtm", "lowertri", "1-*",
"prefetch", "mapmm", "contains", "mmchain", "mapmmchain", "+*", "==", "rmempty"
//TODO: Reuse everything.
};

// Relatively expensive instructions. Most include shuffles.
private static final String[] PERSIST_OPCODES1 = new String[] {
"cpmm", "rmm", "pmm", "rev", "rshape", "rsort", "-", "*", "+",
"cpmm", "rmm", "pmm", "zipmm", "rev", "rshape", "rsort", "-", "*", "+",
"/", "%%", "%/%", "1-*", "^", "^2", "*2", "==", "!=", "<", ">",
"<=", ">=", "&&", "||", "xor", "max", "min", "rmempty", "rappend",
"gappend", "galignedappend", "rbind", "cbind", "nmin", "nmax",
"n+", "ctable", "ucumack+", "ucumac*", "ucumacmin", "ucumacmax",
"qsort", "qpick", "replace"
"qsort", "qpick"
};

// Relatively inexpensive instructions.
private static final String[] PERSIST_OPCODES2 = new String[] {
"mapmm", "isna", "leftIndex", "rightIndex"
"mapmm", "isna", "leftIndex"
};

private static String[] REUSE_OPCODES = new String[] {};
Expand Down Expand Up @@ -105,6 +106,7 @@ public static boolean isNone() {
private static boolean _compilerAssistedRW = false;
private static boolean _onlyEstimate = false;
private static boolean _reuseLineageTraces = true;
private static boolean DELAYED_CACHING = false;

//-------------DISK SPILLING RELATED CONFIGURATIONS--------------//

Expand Down Expand Up @@ -148,6 +150,7 @@ private enum CachedItemTail {
protected enum LineageCacheStatus {
EMPTY, //Placeholder with no data. Cannot be evicted.
NOTCACHED, //Placeholder removed from the cache
TOCACHE, //To be cached in memory if reoccur
CACHED, //General cached data. Can be evicted.
SPILLED, //Data is in disk. Empty value. Cannot be evicted.
RELOADED, //Reloaded from disk. Can be evicted.
Expand Down Expand Up @@ -240,7 +243,8 @@ public static boolean isReusable (Instruction inst, ExecutionContext ec) {
|| inst instanceof ComputationFEDInstruction
|| inst instanceof GPUInstruction
|| inst instanceof ComputationSPInstruction)
&& !(inst instanceof ListIndexingCPInstruction);
&& !(inst instanceof ListIndexingCPInstruction)
&& !(inst instanceof BinaryScalarScalarCPInstruction);
boolean rightCPOp = (ArrayUtils.contains(REUSE_OPCODES, inst.getOpcode())
|| (inst.getOpcode().equals("append") && isVectorAppend(inst, ec))
|| (inst.getOpcode().startsWith("spoof"))
Expand Down Expand Up @@ -378,6 +382,10 @@ public static boolean isLineageTraceReuse() {
return _reuseLineageTraces;
}

public static boolean isDelayedCaching() {
return DELAYED_CACHING;
}

public static void setCachePolicy(LineageCachePolicy policy) {
// TODO: Automatic tuning of weights.
switch(policy) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Map;

import jcuda.Pointer;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
Expand Down Expand Up @@ -208,6 +209,8 @@ public synchronized void setValue(MatrixBlock val, long computetime) {
_status = isNullVal() ? LineageCacheStatus.EMPTY : LineageCacheStatus.CACHED;
//resume all threads waiting for val
notifyAll();
if (DMLScript.STATISTICS && val != null)
LineageCacheStatistics.incrementMemWrites();
}

public synchronized void setValue(MatrixBlock val) {
Expand All @@ -221,6 +224,8 @@ public synchronized void setValue(ScalarObject val, long computetime) {
_status = isNullVal() ? LineageCacheStatus.EMPTY : LineageCacheStatus.CACHED;
//resume all threads waiting for val
notifyAll();
if (DMLScript.STATISTICS && val != null)
LineageCacheStatistics.incrementMemWrites();
}

public synchronized void setGPUValue(Pointer ptr, long size, MetaData md, long computetime) {
Expand All @@ -229,6 +234,8 @@ public synchronized void setGPUValue(Pointer ptr, long size, MetaData md, long c
_status = isNullVal() ? LineageCacheStatus.EMPTY : LineageCacheStatus.GPUCACHED;
//resume all threads waiting for val
notifyAll();
if (DMLScript.STATISTICS && ptr != null)
LineageCacheStatistics.incrementMemWrites();
}

public synchronized void setRDDValue(RDDObject rdd, long computetime) {
Expand All @@ -238,13 +245,17 @@ public synchronized void setRDDValue(RDDObject rdd, long computetime) {
_status = isNullVal() ? LineageCacheStatus.EMPTY : LineageCacheStatus.TOPERSISTRDD;
//resume all threads waiting for val
notifyAll();
if (DMLScript.STATISTICS && rdd != null)
LineageCacheStatistics.incrementMemWrites();
}

public synchronized void setRDDValue(RDDObject rdd) {
_rddObject = rdd;
_status = isNullVal() ? LineageCacheStatus.EMPTY : LineageCacheStatus.TOPERSISTRDD;
//resume all threads waiting for val
notifyAll();
if (DMLScript.STATISTICS && rdd != null)
LineageCacheStatistics.incrementMemWrites();
}

public synchronized void setValue(byte[] serialBytes, long computetime) {
Expand All @@ -253,6 +264,8 @@ public synchronized void setValue(byte[] serialBytes, long computetime) {
_status = isNullVal() ? LineageCacheStatus.EMPTY : LineageCacheStatus.CACHED;
// resume all threads waiting for val
notifyAll();
if (DMLScript.STATISTICS && serialBytes != null)
LineageCacheStatistics.incrementMemWrites();
}

public synchronized void copyValueFrom(LineageCacheEntry src, long computetime) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ protected static void setCacheLimit(double fraction) {
public static long getCacheLimit() {
return CACHE_LIMIT;
}

public static long getAvailableSpace() {
return CACHE_LIMIT - _cachesize;
}

protected static void updateSize(long space, boolean addspace) {
if (addspace)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,6 @@ public static String displaySparkPersist() {

public static boolean ifSparkStats() {
return (_numHitsSparkActions.longValue() + _numHitsRdd.longValue()
+ _numHitsRddPersist.longValue() + _numRddUnpersist.longValue()) != 0;
+ _numHitsRddPersist.longValue() + _numRddPersist.longValue()) != 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public void testlmdsRDD() {

@Test
public void testL2svm() {
runTest(TEST_NAME+"3", ExecMode.HYBRID, ReuseCacheType.REUSE_FULL, 3);
runTest(TEST_NAME+"3", ExecMode.SPARK, ReuseCacheType.REUSE_FULL, 3);
}

@Test
Expand All @@ -91,12 +91,10 @@ public void testEnsemble() {
//public void testHyperband() {
// runTest(TEST_NAME+"6", ExecMode.HYBRID, ReuseCacheType.REUSE_FULL, 6);
//}

@Test
public void testBroadcastBug() {
runTest(TEST_NAME+"7", ExecMode.HYBRID, ReuseCacheType.REUSE_FULL, 7);
}

@Test
public void testTopKClean() {
// Multiple cleaning pipelines with real dataset (Nashville accident)
Expand Down
18 changes: 0 additions & 18 deletions src/test/scripts/functions/async/LineageReuseSpark4.dml
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,6 @@ while (lamda < lim)
i = i + 1;
}

/*[A, b] = SimlinRegDS(X, y);
A_diag = A + diag(matrix(lamda, rows=N, cols=1));
beta = solve(A_diag, b);
R[,1] = beta;
lamda = lamda + stp;

# Reuse function call
[A, b] = SimlinRegDS(X, y);
A_diag = A + diag(matrix(lamda, rows=N, cols=1));
beta = solve(A_diag, b);
R[,2] = beta;
lamda = lamda + stp;

[A, b] = SimlinRegDS(X, y);
A_diag = A + diag(matrix(lamda, rows=N, cols=1));
beta = solve(A_diag, b);
R[,3] = beta;*/

R = sum(R);
write(R, $1, format="text");

12 changes: 6 additions & 6 deletions src/test/scripts/functions/async/LineageReuseSpark6.dml
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ return (Double accuracy) {
M = 10000;
N = 200;
sp = 1.0; #1.0
no_bracket = 2; #5
no_bracket = 1; #5

X = rand(rows=M, cols=N, sparsity=sp, seed=42);
y = rand(rows=M, cols=1, min=0, max=2, seed=42);
y = ceil(y);

no_lamda = 25; #starting combintaions = 25 * 4 = 100 HPs
no_lamda = 3; #starting combintaions = 25 * 4 = 100 HPs
stp = (0.1 - 0.0001)/no_lamda;
HPlamdas = seq(0.0001, 0.1, stp);
maxIter = 10; #starting interation count = 100 * 10 = 1k
Expand All @@ -55,10 +55,10 @@ for (r in 1:no_bracket) {
{
#print("lamda = "+as.scalar(HPlamdas[i,1])+", maxIterations = "+maxIter);
#Run L2svm with intercept true
beta = l2svm(X=X, Y=y, intercept=TRUE, epsilon=1e-12,
/*beta = l2svm(X=X, Y=y, intercept=TRUE, epsilon=1e-12,
reg = as.scalar(HPlamdas[i,1]), maxIterations=maxIter, verbose=FALSE);
svmModels[i,1] = l2norm(X, y, beta); #1st column
svmModels[i,2:nrow(beta)+1] = t(beta);
svmModels[i,2:nrow(beta)+1] = t(beta);*/

#Run L2svm with intercept false
beta = l2svm(X=X, Y=y, intercept=FALSE, epsilon=1e-12,
Expand All @@ -67,7 +67,7 @@ for (r in 1:no_bracket) {
svmModels[i,2:nrow(beta)+1] = t(beta);

#Run multilogreg with intercept true
beta = multiLogReg(X=X, Y=y, icpt=2, tol=1e-6, reg=as.scalar(HPlamdas[i,1]),
/*beta = multiLogReg(X=X, Y=y, icpt=2, tol=1e-6, reg=as.scalar(HPlamdas[i,1]),
maxi=maxIter, maxii=20, verbose=FALSE);
[prob_mlr, Y_mlr, acc] = multiLogRegPredict(X=X, B=beta, Y=y, verbose=FALSE);
mlrModels[i,1] = acc; #1st column
Expand All @@ -78,7 +78,7 @@ for (r in 1:no_bracket) {
maxi=maxIter, maxii=20, verbose=FALSE);
[prob_mlr, Y_mlr, acc] = multiLogRegPredict(X=X, B=beta, Y=y, verbose=FALSE);
mlrModels[i,1] = acc; #1st column
mlrModels[i,2:nrow(beta)+1] = t(beta);
mlrModels[i,2:nrow(beta)+1] = t(beta);*/

i = i + 1;
}
Expand Down
5 changes: 4 additions & 1 deletion src/test/scripts/functions/lineage/FunctionFullReuse1.dml
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ c = 10

X = rand(rows=r, cols=c, seed=42);
y = rand(rows=r, cols=1, seed=43);
R = matrix(0, 1, 2);
R = matrix(0, 1, 3);

beta1 = SimLM(X, y, 0.0001);
R[,1] = beta1;

beta2 = SimLM(X, y, 0.0001);
R[,2] = beta2;

beta2 = SimLM(X, y, 0.0001);
R[,3] = beta2;

write(R, $1, format="text");

0 comments on commit 0d78859

Please sign in to comment.