Skip to content

Commit

Permalink
[SYSTEMDS-3149] Additional RSS impurity measure for regression trees
Browse files Browse the repository at this point in the history
This patch adds an additional impurity measure, beside gini and entropy,
to the decisionTree and randomForest builtin functions. The new
measure is rss (residual sum of squares) for regression in order to
properly learn the tree with regard to the final accuracy metrics.
  • Loading branch information
mboehm7 committed Apr 14, 2023
1 parent 40f2f7a commit 4c97c1c
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 11 deletions.
11 changes: 9 additions & 2 deletions scripts/builtin/decisionTree.dml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
# candidates at tree nodes: m = ceil(num_features^max_features)
# max_values Parameter controlling the number of values per feature used
# as split candidates: nb = ceil(num_values^max_values)
# impurity Impurity measure: entropy, gini (default)
# impurity Impurity measure: entropy, gini (default), rss (regression)
# seed Fixed seed for randomization of samples and split candidates
# verbose Flag indicating verbose debug output
# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -72,7 +72,9 @@ m_decisionTree = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] cty
if( max_depth > 32 )
stop("decisionTree: invalid max_depth > 32: "+max_depth);
if( sum(X<=0) != 0 )
stop("decisionTree: feature matrix X is not properly recoded/binned: "+sum(X<=0));
stop("decisionTree: feature matrix X is not properly recoded/binned (values <= 0): "+sum(X<=0));
if( sum(abs(X-round(X))>1e-14) != 0 )
stop("decisionTree: feature matrix X is not properly recoded/binned (non-integer): "+sum(abs(X-round(X))>1e-14));
if( sum(y<=0) != 0 )
stop("decisionTree: label vector y is not properly recoded/binned: "+sum(y<=0));

Expand Down Expand Up @@ -230,6 +232,11 @@ computeImpurity = function(Matrix[Double] y2, Matrix[Double] I, String impurity)
score = 1 - rowSums(f^2); # sum(f*(1-f));
else if( impurity == "entropy" )
score = rowSums(-f * log(f));
else if( impurity == "rss" ) { # residual sum of squares
yhat = f %*% seq(1,ncol(f)); # yhat
res = outer(yhat, t(rowIndexMax(y2)), "-"); # yhat-y
score = rowSums((I * res)^2); # sum((yhat-y)^2)
}
else
stop("decisionTree: unsupported impurity measure: "+impurity);
}
Expand Down
2 changes: 1 addition & 1 deletion scripts/builtin/lmPredict.dml
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ m_lmPredict = function(Matrix[Double] X, Matrix[Double] B,
yhat = X %*% B[1:ncol(X),] + intercept;

if( verbose )
lmPredictStats(yhat, ytest);
lmPredictStats(yhat, ytest, TRUE);
}
8 changes: 6 additions & 2 deletions scripts/builtin/lmPredictStats.dml
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,25 @@
# ------------------------------------------------------------------------------
# yhat column vector of predicted response values y
# ytest column vector of actual response values y
# lm indicator if used for linear regression model
# ------------------------------------------------------------------------------
#
# OUTPUT:
# ------------------------------------------------------------------------------
# R column vector holding avg_res, ss_avg_res, and R2
# ------------------------------------------------------------------------------

m_lmPredictStats = function(Matrix[Double] yhat, Matrix[Double] ytest)
m_lmPredictStats = function(Matrix[Double] yhat, Matrix[Double] ytest, Boolean lm)
return (Matrix[Double] R)
{
y_residual = ytest - yhat;
avg_res = sum(y_residual) / nrow(ytest);
ss_res = sum(y_residual^2);
ss_avg_res = ss_res - nrow(ytest) * avg_res^2;
R2 = 1 - ss_res / (sum(ytest^2) - nrow(ytest) * (sum(ytest)/nrow(ytest))^2);
if( lm )
R2 = 1 - ss_res / (sum(ytest^2) - nrow(ytest) * (sum(ytest)/nrow(ytest))^2);
else
R2 = sum((yhat - mean(ytest))^2) / sum((ytest - mean(ytest))^2)
print("\nAccuracy:" +
"\n--sum(ytest) = " + sum(ytest) +
"\n--sum(yhat) = " + sum(yhat) +
Expand Down
2 changes: 1 addition & 1 deletion scripts/builtin/randomForest.dml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
# candidates at tree nodes: m = ceil(num_features^max_features)
# max_values Parameter controlling the number of values per feature used
# as split candidates: nb = ceil(num_values^max_values)
# impurity Impurity measure: entropy, gini (default)
# impurity Impurity measure: entropy, gini (default), rss (regression)
# seed Fixed seed for randomization of samples and split candidates
# verbose Flag indicating verbose debug output
# ------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion scripts/builtin/randomForestPredict.dml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ m_randomForestPredict = function(Matrix[Double] X, Matrix[Double] y = matrix(0,0
if( classify )
print("Accuracy (%): " + (sum(yhat == y) / nrow(y) * 100));
else
lmPredictStats(yhat, y);
lmPredictStats(yhat, y, FALSE);
}

if(verbose) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ public class BuiltinDecisionTreeRealDataTest extends AutomatedTestBase {
private final static String WINE_DATA = DATASET_DIR + "wine/winequality-red-white.csv";
private final static String WINE_TFSPEC = DATASET_DIR + "wine/tfspec.json";


@Override
public void setUp() {
for(int i=1; i<=2; i++)
for(int i=1; i<=3; i++)
addTestConfiguration(TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
}

Expand Down Expand Up @@ -86,9 +85,20 @@ public void testDecisionTreeWine_MaxV1() {

@Test
public void testRandomForestWine_MaxV1() {
//one tree with sample_frac=1 should be equivalent to decision tree
runDecisionTree(2, WINE_DATA, WINE_TFSPEC, 0.989, 2, 1.0, ExecType.CP);
}

@Test
public void testDecisionTreeWineReg_MaxV1() {
//for regression we compare R2 and use rss to optimize
runDecisionTree(3, WINE_DATA, WINE_TFSPEC, 0.369, 1, 1.0, ExecType.CP);
}

@Test
public void testRandomForestWineReg_MaxV1() {
//for regression we compare R2 and use rss to optimize
runDecisionTree(3, WINE_DATA, WINE_TFSPEC, 0.369, 2, 1.0, ExecType.CP);
}

private void runDecisionTree(int test, String data, String tfspec, double minAcc, int dt, double maxV, ExecType instType) {
Types.ExecMode platformOld = setExecMode(instType);
Expand Down
2 changes: 1 addition & 1 deletion src/test/resources/datasets/wine/tfspec.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
{"id":8, "method":"equi-width", "numbins":10},
{"id":9, "method":"equi-width", "numbins":10},
{"id":10, "method":"equi-width", "numbins":10},
{"id":11, "method":"equi-width", "numbins":10},
{"id":11, "method":"equi-width", "numbins":50},
{"id":12, "method":"equi-width", "numbins":10},]
}
50 changes: 50 additions & 0 deletions src/test/scripts/functions/builtin/decisionTreeRealData3.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

F = read($1, data_type="frame", format="csv", header=FALSE);
tfspec = read($2, data_type="scalar", value_type="string");

R = matrix("1 1 1 1 1 1 1 1 1 1 1 2 1", rows=1, cols=13)

[X, meta] = transformencode(target=F, spec=tfspec);
Y = X[,ncol(X)-1];
X = cbind(X[,1:ncol(X)-2], X[,ncol(X)]);
X = replace(target=X, pattern=NaN, replacement=5); # 1 val

if( $3==1 ) {
M = decisionTree(X=X, y=Y, ctypes=R, max_features=1, max_values=$4,
impurity="rss", min_split=10, min_leaf=4, seed=7, verbose=TRUE);
yhat = decisionTreePredict(X=X, ctypes=R, M=M)
}
else {
sf = 1.0/($3-1);
M = randomForest(X=X, y=Y, ctypes=R, sample_frac=sf, num_trees=$3-1,
impurity="rss", max_features=1, max_values=$4,
min_split=10, min_leaf=4, seed=7, verbose=TRUE);
yhat = randomForestPredict(X=X, ctypes=R, M=M)
}

jspec="{ids:true,bin:[{id:1,method:equi-width,numbins:10}]}"
yhat2 = as.matrix(transformdecode(target=yhat, spec=jspec, meta=meta[,12]));

R = lmPredictStats(yhat2, as.matrix(F[,ncol(F)-1]), FALSE)
acc = R[3,]
write(acc, $5);

0 comments on commit 4c97c1c

Please sign in to comment.