Skip to content

Commit

Permalink
[SYSTEMDS-189] Multiple imputation using chained equation (MICE)
Browse files Browse the repository at this point in the history
1. Main DML script (mice_linearReg.dml)
2. Java test file (BuiltinMiceLinearRegTest.java)
3. DML test script (mice_linearRegression.dml)

Closes #70.
  • Loading branch information
Shafaq-Siddiqi authored and mboehm7 committed Dec 7, 2019
1 parent ec557c1 commit 7a11d1b
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 0 deletions.
112 changes: 112 additions & 0 deletions scripts/builtin/mice_lm.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#-------------------------------------------------------------
#
# Copyright 2019 Graz University of Technology
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#-------------------------------------------------------------

# Assumptions
# 1. The data is continous/numerical
# 2. The missing values are denoted by zeros

# Builtin function Implements binary-class SVM with squared slack variables
#
# INPUT PARAMETERS:
# ---------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ---------------------------------------------------------------------------------------------
# X Double --- matrix X of feature vectors
# iter Integer 3 Number of iteration for multiple imputations
# complete Integer 3 A complete dataset generated though a specific iteration
# ---------------------------------------------------------------------------------------------


#Output(s)
# ---------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ---------------------------------------------------------------------------------------------
# dataset Double --- imputed dataset
# singleSet Double --- A complete dataset generated though a specific iteration



m_mice_lm = function(Matrix[Double] X, Integer iter = 3, Integer complete = 3)
return(Matrix[Double] dataset, Matrix[Double] singleSet)
{
n = nrow(X)
row = n*complete;
col = ncol(X);
Result = matrix(0, rows = 1, cols = col)
Mask_Result = matrix(0, rows = 1, cols = col)

# storing the mask/address of missing values
Mask = (X == 0);
# filling the missing data with their means
X2 = X+(Mask*colMeans(X))

# slicing non-missing dataset for training columnwise linear regression
inverseMask = 1 - Mask;

for(k in 1:iter)
{
Mask_Filled = Mask;

parfor(i in 1:col)
{
# construct column selector
sel = cbind(matrix(1,1,i-1), as.matrix(0), matrix(1,1,col-i));

# prepare train data set X and Y
slice1 = removeEmpty(target = X2, margin = "rows", select = inverseMask[,i])
while(FALSE){}
train_X = removeEmpty(target = slice1, margin = "cols", select = sel);
train_Y = slice1[,i]

# prepare score data set X and Y for imputing Y
slice2 = removeEmpty(target = X2, margin = "rows", select = Mask[,i])
while(FALSE){}
test_X = removeEmpty(target = slice2, margin = "cols", select = sel);
test_Y = slice2[,i]

# learning a regression line
beta = lm(X=train_X, y=train_Y, verbose=FALSE);

# predicting missing values
pred = lmpredict(X=test_X, w=beta)

# imputing missing column values (assumes Mask_Filled being 0/1-matrix)
R = removeEmpty(target=Mask_Filled[,i] * seq(1,n), margin="rows");
Mask_Filled[,i] = table(R, 1, pred, n, 1);
}

# binding results of multiple imputations
Result = rbind(Result, X + Mask_Filled)
Mask_Result = rbind(Mask_Result, Mask_Filled)
Mask_Filled = Mask;
}
# return imputed dataset
Result = Result[2: n*iter+1, ]
Mask_Result = Mask_Result[2: n*iter+1, ]
index = (((complete*n)-n)+1)

# aggregating the results
Agg_Matrix = Mask_Result[index:row, ]
for(d in 1:(iter-1))
Agg_Matrix = Agg_Matrix + Mask_Result[(((d-1)*n)+1):(n*d),]
Agg_Matrix =(Agg_Matrix/iter)

# return imputed data
dataset = X + Agg_Matrix
singleSet = Result[index:row, ]
}
1 change: 1 addition & 0 deletions src/main/java/org/tugraz/sysds/common/Builtins.java
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ public enum Builtins {
MAX_POOL_BACKWARD("max_pool_backward", false),
MEDIAN("median", false),
MOMENT("moment", "centralMoment", false),
MICE_LM("mice_lm", true),
NCOL("ncol", false),
NORMALIZE("normalize", true),
NROW("nrow", false),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright 2019 Graz University of Technology
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tugraz.sysds.test.functions.builtin;

import org.junit.Test;
import org.tugraz.sysds.common.Types;
import org.tugraz.sysds.lops.LopProperties;
import org.tugraz.sysds.test.AutomatedTestBase;
import org.tugraz.sysds.test.TestConfiguration;

public class BuiltinMiceLinearRegTest extends AutomatedTestBase {
private final static String TEST_NAME = "mice_lm";
private final static String TEST_DIR = "functions/builtin/";
private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinMiceLinearRegTest.class.getSimpleName() + "/";

private final static int rows = 50;
private final static int cols = 30;
private final static int iter = 3;
private final static int com = 2;

@Override
public void setUp() {
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B"}));
}

@Test
public void testMatrixSparseCP() {
runLmTest(0.7, LopProperties.ExecType.CP);
}

@Test
public void testMatrixDenseCP() {
runLmTest(0.3, LopProperties.ExecType.CP);
}

// @Test
// public void testMatrixSparseSpark() {
// runLmTest(0.7, LopProperties.ExecType.SPARK);
// }
//
// @Test
// public void testMatrixDenseSpark() {
// runLmTest(0.3, LopProperties.ExecType.SPARK);
// }

private void runLmTest(double sparseVal, LopProperties.ExecType instType) {
Types.ExecMode platformOld = setExecMode(instType);
try {
loadTestConfiguration(getTestConfiguration(TEST_NAME));
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[]{"-nvargs", "X=" + input("A"), "iteration=" + iter, "com=" + com, "data=" + output("B")};

//generate actual dataset
double[][] A = getRandomMatrix(rows, cols, 0, 1, sparseVal, 7);
writeInputMatrixWithMTD("A", A, true);

runTest(true, false, null, -1);
}
finally {
rtplatform = platformOld;
}
}
}
21 changes: 21 additions & 0 deletions src/test/scripts/functions/builtin/mice_lm.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#-------------------------------------------------------------
#
# Copyright 2019 Graz University of Technology
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#-------------------------------------------------------------

X = read($X)
[dataset, singleSet]= mice_lm(X=X, iter=$iteration, complete=$com)
write(dataset, $data)

0 comments on commit 7a11d1b

Please sign in to comment.