Skip to content

Commit

Permalink
[SYSTEMDS-3149] Fix misc issues decisionTree/randomForest training
Browse files Browse the repository at this point in the history
This patch fixes various issues in the new decisionTree and randomForest
built-in functions as well as adds new and stricter tests:

* randomForest validation checks and parameters (consistent to DT)
* randomForest correct feature map with feature_frac=1.0
* decisionTree simplification of leaf label computation
* synchronized deep copy of hop-DAGs to avoid race conditions in parfor
* added missing size propagation on spark rev operations
* new tests with randomForest that check equivalent results to DT
  with num_tree=1 and reasonable results with larger ensembles
  • Loading branch information
mboehm7 committed Apr 11, 2023
1 parent d39f745 commit 75e7e64
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 18 deletions.
4 changes: 2 additions & 2 deletions scripts/builtin/decisionTree.dml
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ computeLeafLabel = function(Matrix[Double] y2, Matrix[Double] I, Boolean classif
return(Double label)
{
f = (I %*% y2) / sum(I);
label = ifelse(classify,
as.scalar(rowIndexMax(f)), sum(t(f)*seq(1,ncol(f))));
label = as.scalar(ifelse(classify,
rowIndexMax(f), f %*% seq(1,ncol(f))));
if(verbose)
print("-- leaf node label: " + label +" ("+sum(I)*max(f)+"/"+sum(I)+")");
}
20 changes: 13 additions & 7 deletions scripts/builtin/randomForest.dml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
# feature_frac Sample fraction of features for each tree in the forest
# max_depth Maximum depth of the learned tree (stopping criterion)
# min_leaf Minimum number of samples in leaf nodes (stopping criterion)
# min_split Minimum number of samples in leaf for attempting a split
# max_features Parameter controlling the number of features used as split
# candidates at tree nodes: m = ceil(num_features^max_features)
# impurity Impurity measure: entropy, gini (default)
Expand Down Expand Up @@ -68,7 +69,7 @@

m_randomForest = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] ctypes,
Int num_trees = 16, Double sample_frac = 0.1, Double feature_frac = 1.0,
Int max_depth = 10, Int min_leaf = 20, Double max_features = 0.5,
Int max_depth = 10, Int min_leaf = 20, Int min_split = 50, Double max_features = 0.5,
String impurity = "gini", Int seed = -1, Boolean verbose = FALSE)
return(Matrix[Double] M)
{
Expand All @@ -81,6 +82,8 @@ m_randomForest = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] cty
}
if(ncol(ctypes) != ncol(X)+1)
stop("randomForest: inconsistent num features (incl. label) and col types: "+ncol(X)+" vs "+ncol(ctypes)+".");
if( sum(X<=0) != 0 )
stop("randomForest: feature matrix X is not properly recoded/binned: "+sum(X<=0));
if(sum(y <= 0) != 0)
stop("randomForest: y is not properly recoded and binned (contiguous positive integers).");
if(max(y) == 1)
Expand All @@ -91,16 +94,19 @@ m_randomForest = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] cty

# training of num_tree decision trees
M = matrix(0, rows=num_trees, cols=2*(2^max_depth-1));
F = matrix(0, rows=num_trees, cols=ncol(X));
F = matrix(1, rows=num_trees, cols=ncol(X));
parfor(i in 1:num_trees) {
if( verbose )
print("randomForest: start training tree "+i+"/"+num_trees+".");

# step 1: sample data
si1 = as.integer(as.scalar(randSeeds[3*(i-1)+1,1]));
I1 = rand(rows=nrow(X), cols=1, seed=si1) <= sample_frac;
Xi = removeEmpty(target=X, margin="rows", select=I1);
yi = removeEmpty(target=y, margin="rows", select=I1);
Xi = X; yi = y;
if( sample_frac < 1.0 ) {
si1 = as.integer(as.scalar(randSeeds[3*(i-1)+1,1]));
I1 = rand(rows=nrow(X), cols=1, seed=si1) <= sample_frac;
Xi = removeEmpty(target=X, margin="rows", select=I1);
yi = removeEmpty(target=y, margin="rows", select=I1);
}

# step 2: sample features
if( feature_frac < 1.0 ) {
Expand All @@ -116,7 +122,7 @@ m_randomForest = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] cty
# step 3: train decision tree
t2 = time();
si3 = as.integer(as.scalar(randSeeds[3*(i-1)+3,1]));
Mtemp = decisionTree(X=Xi, y=yi, ctypes=ctypes, max_depth=max_depth,
Mtemp = decisionTree(X=Xi, y=yi, ctypes=ctypes, max_depth=max_depth, min_split=min_split,
min_leaf=min_leaf, max_features=max_features, impurity=impurity, seed=si3, verbose=verbose);
M[i,1:length(Mtemp)] = matrix(Mtemp, rows=1, cols=length(Mtemp));
if( verbose )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,17 @@ else if ( getOpcode().equalsIgnoreCase("rsort") ) {
boolean ixret = sec.getScalarInput(_ixret).getBooleanValue();
mcOut.set(mc1.getRows(), ixret?1:mc1.getCols(), mc1.getBlocksize(), mc1.getBlocksize());
}
else { //e.g., rev
mcOut.set(mc1);
}
}

//infer initially unknown nnz from input
if( !mcOut.nnzKnown() && mc1.nnzKnown() ){
boolean sortIx = getOpcode().equalsIgnoreCase("rsort") && sec.getScalarInput(_ixret.getName(), _ixret.getValueType(), _ixret.isLiteral()).getBooleanValue();
if( sortIx )
mcOut.setNonZeros(mc1.getRows());
else //default (r', rdiag, rsort data)
else //default (r', rdiag, rev, rsort data)
mcOut.setNonZeros(mc1.getNonZeros());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,10 @@ public static StatementBlock createStatementBlockCopy( StatementBlock sb, long p
ret.setReadVariables( sb.variablesRead() );

//deep copy hops dag for concurrent recompile
ArrayList<Hop> hops = Recompiler.deepCopyHopsDag( sb.getHops() );
ArrayList<Hop> hops = sb.getHops();
synchronized(hops) { // guard concurrent recompile
hops = Recompiler.deepCopyHopsDag( hops );
}
if( !plain )
Recompiler.updateFunctionNames( hops, pid );
ret.setHops( hops );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.utils.Statistics;
import org.junit.Assert;
import org.junit.Test;

Expand All @@ -42,23 +43,36 @@ public void setUp() {

@Test
public void testDecisionTreeTitanic() {
runDecisionTree(TITANIC_DATA, TITANIC_TFSPEC, 0.875, ExecType.CP);
runDecisionTree(TITANIC_DATA, TITANIC_TFSPEC, 0.875, 1, ExecType.CP);
}

@Test
public void testRandomForestTitanic1() {
//one tree with sample_frac=1 should be equivalent to decision tree
runDecisionTree(TITANIC_DATA, TITANIC_TFSPEC, 0.875, 2, ExecType.CP);
}

@Test
public void testRandomForestTitanic8() {
//8 trees with sample fraction 0.125 each, accuracy 0.785 due to randomness
runDecisionTree(TITANIC_DATA, TITANIC_TFSPEC, 0.793, 9, ExecType.CP);
}

private void runDecisionTree(String data, String tfspec, double minAcc, ExecType instType) {
private void runDecisionTree(String data, String tfspec, double minAcc, int dt, ExecType instType) {
Types.ExecMode platformOld = setExecMode(instType);
try {
loadTestConfiguration(getTestConfiguration(TEST_NAME));

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats",
"-args", data, tfspec, output("R")};
"-args", data, tfspec, String.valueOf(dt), output("R")};

runTest(true, false, null, -1);

double acc = readDMLMatrixFromOutputDir("R").get(new CellIndex(1,1));
Assert.assertTrue(acc >= minAcc);
Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst());
}
finally {
rtplatform = platformOld;
Expand Down
17 changes: 13 additions & 4 deletions src/test/scripts/functions/builtin/decisionTreeRealData.dml
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,20 @@ Y = X[, ncol(X)]
X = X[, 1:ncol(X)-1]
X = imputeByMode(X);

M = decisionTree(X=X, y=Y, ctypes=R, max_features=1, min_split=8, min_leaf=5, verbose=TRUE);
yhat = decisionTreePredict(X=X, y=Y, ctypes=R, M=M)
if( $3==1 ) {
M = decisionTree(X=X, y=Y, ctypes=R, max_features=1,
min_split=10, min_leaf=4, seed=7, verbose=TRUE);
yhat = decisionTreePredict(X=X, y=Y, ctypes=R, M=M)
}
else {
sf = 1.0/($3-1);
M = randomForest(X=X, y=Y, ctypes=R, sample_frac=sf, num_trees=$3-1, max_features=1,
min_split=10, min_leaf=4, seed=7, verbose=TRUE);
yhat = randomForestPredict(X=X, y=Y, ctypes=R, M=M)
}

acc = as.matrix(mean(yhat == Y))
err = 1-(acc);
print("accuracy of DT: "+as.scalar(acc))
print("accuracy: "+as.scalar(acc))

write(acc, $3);
write(acc, $4);

0 comments on commit 75e7e64

Please sign in to comment.