-
Notifications
You must be signed in to change notification settings - Fork 459
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SYSTEMDS-189] Multiple imputation using chained equation (MICE)
1. Main DML script (mice_linearReg.dml) 2. Java test file (BuiltinMiceLinearRegTest.java) 3. DML test script (mice_linearRegression.dml) Closes #70.
- Loading branch information
1 parent
ec557c1
commit 7a11d1b
Showing
4 changed files
with
212 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
#------------------------------------------------------------- | ||
# | ||
# Copyright 2019 Graz University of Technology | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
#------------------------------------------------------------- | ||
|
||
# Assumptions | ||
# 1. The data is continous/numerical | ||
# 2. The missing values are denoted by zeros | ||
|
||
# Builtin function Implements binary-class SVM with squared slack variables | ||
# | ||
# INPUT PARAMETERS: | ||
# --------------------------------------------------------------------------------------------- | ||
# NAME TYPE DEFAULT MEANING | ||
# --------------------------------------------------------------------------------------------- | ||
# X Double --- matrix X of feature vectors | ||
# iter Integer 3 Number of iteration for multiple imputations | ||
# complete Integer 3 A complete dataset generated though a specific iteration | ||
# --------------------------------------------------------------------------------------------- | ||
|
||
|
||
#Output(s) | ||
# --------------------------------------------------------------------------------------------- | ||
# NAME TYPE DEFAULT MEANING | ||
# --------------------------------------------------------------------------------------------- | ||
# dataset Double --- imputed dataset | ||
# singleSet Double --- A complete dataset generated though a specific iteration | ||
|
||
|
||
|
||
m_mice_lm = function(Matrix[Double] X, Integer iter = 3, Integer complete = 3) | ||
return(Matrix[Double] dataset, Matrix[Double] singleSet) | ||
{ | ||
n = nrow(X) | ||
row = n*complete; | ||
col = ncol(X); | ||
Result = matrix(0, rows = 1, cols = col) | ||
Mask_Result = matrix(0, rows = 1, cols = col) | ||
|
||
# storing the mask/address of missing values | ||
Mask = (X == 0); | ||
# filling the missing data with their means | ||
X2 = X+(Mask*colMeans(X)) | ||
|
||
# slicing non-missing dataset for training columnwise linear regression | ||
inverseMask = 1 - Mask; | ||
|
||
for(k in 1:iter) | ||
{ | ||
Mask_Filled = Mask; | ||
|
||
parfor(i in 1:col) | ||
{ | ||
# construct column selector | ||
sel = cbind(matrix(1,1,i-1), as.matrix(0), matrix(1,1,col-i)); | ||
|
||
# prepare train data set X and Y | ||
slice1 = removeEmpty(target = X2, margin = "rows", select = inverseMask[,i]) | ||
while(FALSE){} | ||
train_X = removeEmpty(target = slice1, margin = "cols", select = sel); | ||
train_Y = slice1[,i] | ||
|
||
# prepare score data set X and Y for imputing Y | ||
slice2 = removeEmpty(target = X2, margin = "rows", select = Mask[,i]) | ||
while(FALSE){} | ||
test_X = removeEmpty(target = slice2, margin = "cols", select = sel); | ||
test_Y = slice2[,i] | ||
|
||
# learning a regression line | ||
beta = lm(X=train_X, y=train_Y, verbose=FALSE); | ||
|
||
# predicting missing values | ||
pred = lmpredict(X=test_X, w=beta) | ||
|
||
# imputing missing column values (assumes Mask_Filled being 0/1-matrix) | ||
R = removeEmpty(target=Mask_Filled[,i] * seq(1,n), margin="rows"); | ||
Mask_Filled[,i] = table(R, 1, pred, n, 1); | ||
} | ||
|
||
# binding results of multiple imputations | ||
Result = rbind(Result, X + Mask_Filled) | ||
Mask_Result = rbind(Mask_Result, Mask_Filled) | ||
Mask_Filled = Mask; | ||
} | ||
# return imputed dataset | ||
Result = Result[2: n*iter+1, ] | ||
Mask_Result = Mask_Result[2: n*iter+1, ] | ||
index = (((complete*n)-n)+1) | ||
|
||
# aggregating the results | ||
Agg_Matrix = Mask_Result[index:row, ] | ||
for(d in 1:(iter-1)) | ||
Agg_Matrix = Agg_Matrix + Mask_Result[(((d-1)*n)+1):(n*d),] | ||
Agg_Matrix =(Agg_Matrix/iter) | ||
|
||
# return imputed data | ||
dataset = X + Agg_Matrix | ||
singleSet = Result[index:row, ] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
78 changes: 78 additions & 0 deletions
78
src/test/java/org/tugraz/sysds/test/functions/builtin/BuiltinMiceLinearRegTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
/* | ||
* Copyright 2019 Graz University of Technology | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.tugraz.sysds.test.functions.builtin; | ||
|
||
import org.junit.Test; | ||
import org.tugraz.sysds.common.Types; | ||
import org.tugraz.sysds.lops.LopProperties; | ||
import org.tugraz.sysds.test.AutomatedTestBase; | ||
import org.tugraz.sysds.test.TestConfiguration; | ||
|
||
public class BuiltinMiceLinearRegTest extends AutomatedTestBase { | ||
private final static String TEST_NAME = "mice_lm"; | ||
private final static String TEST_DIR = "functions/builtin/"; | ||
private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinMiceLinearRegTest.class.getSimpleName() + "/"; | ||
|
||
private final static int rows = 50; | ||
private final static int cols = 30; | ||
private final static int iter = 3; | ||
private final static int com = 2; | ||
|
||
@Override | ||
public void setUp() { | ||
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B"})); | ||
} | ||
|
||
@Test | ||
public void testMatrixSparseCP() { | ||
runLmTest(0.7, LopProperties.ExecType.CP); | ||
} | ||
|
||
@Test | ||
public void testMatrixDenseCP() { | ||
runLmTest(0.3, LopProperties.ExecType.CP); | ||
} | ||
|
||
// @Test | ||
// public void testMatrixSparseSpark() { | ||
// runLmTest(0.7, LopProperties.ExecType.SPARK); | ||
// } | ||
// | ||
// @Test | ||
// public void testMatrixDenseSpark() { | ||
// runLmTest(0.3, LopProperties.ExecType.SPARK); | ||
// } | ||
|
||
private void runLmTest(double sparseVal, LopProperties.ExecType instType) { | ||
Types.ExecMode platformOld = setExecMode(instType); | ||
try { | ||
loadTestConfiguration(getTestConfiguration(TEST_NAME)); | ||
String HOME = SCRIPT_DIR + TEST_DIR; | ||
fullDMLScriptName = HOME + TEST_NAME + ".dml"; | ||
programArgs = new String[]{"-nvargs", "X=" + input("A"), "iteration=" + iter, "com=" + com, "data=" + output("B")}; | ||
|
||
//generate actual dataset | ||
double[][] A = getRandomMatrix(rows, cols, 0, 1, sparseVal, 7); | ||
writeInputMatrixWithMTD("A", A, true); | ||
|
||
runTest(true, false, null, -1); | ||
} | ||
finally { | ||
rtplatform = platformOld; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#------------------------------------------------------------- | ||
# | ||
# Copyright 2019 Graz University of Technology | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
#------------------------------------------------------------- | ||
|
||
X = read($X) | ||
[dataset, singleSet]= mice_lm(X=X, iter=$iteration, complete=$com) | ||
write(dataset, $data) |