From 17f46a60706332ae0ab2c0ec401db5f967f059e4 Mon Sep 17 00:00:00 2001 From: ywcb00 Date: Sun, 24 Jan 2021 16:40:42 +0100 Subject: [PATCH 1/7] feat(FEDInstructionUtils.java): add rewrite statement for BinaryFEDInstruction in checkAndReplaceSP() fix(BinaryMatrixMatrixFEDInstruction.java): change broadcast of mo2 to broadcast sliced feat(FedLogical): add tests for federated logical MatrixScalar and MatrixMatrix instructions fix(alsCG): add check for heavy hitter "fed_!=" - now supported for SPARK too :) --- .../instructions/fed/BinaryMatrixMatrixFEDInstruction.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java index 0ba19357980..ba35a01dd73 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java @@ -75,8 +75,8 @@ public void processInstruction(ExecutionContext ec) { else if(mo2.getNumRows() == 1 && mo2.getNumColumns() > 1) { //MV row vector FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2); fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2}, - new long[]{mo1.getFedMapping().getID(), fr1.getID()}); - FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID()); + new long[]{mo1.getFedMapping().getID(), fr1[0].getID()}); + FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID()); //execute federated instruction and cleanup intermediates mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3); } From f15052ad2125e58ed833a029d04f197956ebc2f0 Mon Sep 17 00:00:00 2001 From: ywcb00 Date: Tue, 26 Jan 2021 14:09:57 +0100 Subject: [PATCH 2/7] revert(BinaryMatrixMatrixFEDInstruction.java): revert changes of broadcasting sliced chore(fedLogical): ignore the MatrixMatrix tests - keep the MatrixScalar Tests --- .../fed/BinaryMatrixMatrixFEDInstruction.java | 4 ++-- .../federated/primitives/FederatedLogicalTest.java | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java index ba35a01dd73..0ba19357980 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java @@ -75,8 +75,8 @@ public void processInstruction(ExecutionContext ec) { else if(mo2.getNumRows() == 1 && mo2.getNumColumns() > 1) { //MV row vector FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2); fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2}, - new long[]{mo1.getFedMapping().getID(), fr1[0].getID()}); - FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID()); + new long[]{mo1.getFedMapping().getID(), fr1.getID()}); + FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1.getID()); //execute federated instruction and cleanup intermediates mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java index 53dfb2e5a42..424acb473c1 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java @@ -28,6 +28,7 @@ import org.apache.sysds.test.TestUtils; import org.junit.Assert; import org.junit.BeforeClass; +import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -151,61 +152,73 @@ public void federatedLogicalScalarLessEqualsSpark() { //---------------------------MATRIX MATRIX-------------------------- @Test + @Ignore public void federatedLogicalMatrixGreaterSingleNode() { federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER, ExecMode.SINGLE_NODE); } @Test + @Ignore public void federatedLogicalMatrixGreaterSpark() { federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER, ExecMode.SPARK); } @Test + @Ignore public void federatedLogicalMatrixLessSingleNode() { federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS, ExecMode.SINGLE_NODE); } @Test + @Ignore public void federatedLogicalMatrixLessSpark() { federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS, ExecMode.SPARK); } @Test + @Ignore public void federatedLogicalMatrixEqualsSingleNode() { federatedLogicalTest(MATRIX_TEST_NAME, Type.EQUALS, ExecMode.SINGLE_NODE); } @Test + @Ignore public void federatedLogicalMatrixEqualsSpark() { federatedLogicalTest(MATRIX_TEST_NAME, Type.EQUALS, ExecMode.SPARK); } @Test + @Ignore public void federatedLogicalMatrixNotEqualsSingleNode() { federatedLogicalTest(MATRIX_TEST_NAME, Type.NOT_EQUALS, ExecMode.SINGLE_NODE); } @Test + @Ignore public void federatedLogicalMatrixNotEqualsSpark() { federatedLogicalTest(MATRIX_TEST_NAME, Type.NOT_EQUALS, ExecMode.SPARK); } @Test + @Ignore public void federatedLogicalMatrixGreaterEqualsSingleNode() { federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER_EQUALS, ExecMode.SINGLE_NODE); } @Test + @Ignore public void federatedLogicalMatrixGreaterEqualsSpark() { federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER_EQUALS, ExecMode.SPARK); } @Test + @Ignore public void federatedLogicalMatrixLessEqualsSingleNode() { federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS, ExecMode.SINGLE_NODE); } @Test + @Ignore public void federatedLogicalMatrixLessEqualsSpark() { federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS, ExecMode.SPARK); } From 0889676750e976facce252676c41ba620be1ff0e Mon Sep 17 00:00:00 2001 From: ywcb00 Date: Wed, 27 Jan 2021 18:01:35 +0100 Subject: [PATCH 3/7] fix(BinaryMatrixMatrixFEDInstruction.java): broadcast mo2 sliced if it is a matrix and row partitioned distinguish the case where mo2 is a column vector --> don't broadcast slice even if it is row partioned chore(FedLogical): remove ignores of MatrixMatrix tests --- .../federated/primitives/FederatedLogicalTest.java | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java index 424acb473c1..53dfb2e5a42 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java @@ -28,7 +28,6 @@ import org.apache.sysds.test.TestUtils; import org.junit.Assert; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -152,73 +151,61 @@ public void federatedLogicalScalarLessEqualsSpark() { //---------------------------MATRIX MATRIX-------------------------- @Test - @Ignore public void federatedLogicalMatrixGreaterSingleNode() { federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER, ExecMode.SINGLE_NODE); } @Test - @Ignore public void federatedLogicalMatrixGreaterSpark() { federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER, ExecMode.SPARK); } @Test - @Ignore public void federatedLogicalMatrixLessSingleNode() { federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS, ExecMode.SINGLE_NODE); } @Test - @Ignore public void federatedLogicalMatrixLessSpark() { federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS, ExecMode.SPARK); } @Test - @Ignore public void federatedLogicalMatrixEqualsSingleNode() { federatedLogicalTest(MATRIX_TEST_NAME, Type.EQUALS, ExecMode.SINGLE_NODE); } @Test - @Ignore public void federatedLogicalMatrixEqualsSpark() { federatedLogicalTest(MATRIX_TEST_NAME, Type.EQUALS, ExecMode.SPARK); } @Test - @Ignore public void federatedLogicalMatrixNotEqualsSingleNode() { federatedLogicalTest(MATRIX_TEST_NAME, Type.NOT_EQUALS, ExecMode.SINGLE_NODE); } @Test - @Ignore public void federatedLogicalMatrixNotEqualsSpark() { federatedLogicalTest(MATRIX_TEST_NAME, Type.NOT_EQUALS, ExecMode.SPARK); } @Test - @Ignore public void federatedLogicalMatrixGreaterEqualsSingleNode() { federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER_EQUALS, ExecMode.SINGLE_NODE); } @Test - @Ignore public void federatedLogicalMatrixGreaterEqualsSpark() { federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER_EQUALS, ExecMode.SPARK); } @Test - @Ignore public void federatedLogicalMatrixLessEqualsSingleNode() { federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS, ExecMode.SINGLE_NODE); } @Test - @Ignore public void federatedLogicalMatrixLessEqualsSpark() { federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS, ExecMode.SPARK); } From 4d35bd12a1332dd02406fab8224186dae6ac8c4d Mon Sep 17 00:00:00 2001 From: ywcb00 Date: Tue, 26 Jan 2021 15:40:01 +0100 Subject: [PATCH 4/7] test(FedPNMF): add junit java test and dml test scripts --- .../algorithms/FederatedPNMFTest.java | 156 ++++++++++++++++++ .../functions/federated/FederatedPNMFTest.dml | 32 ++++ .../federated/FederatedPNMFTestReference.dml | 31 ++++ 3 files changed, 219 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java create mode 100644 src/test/scripts/functions/federated/FederatedPNMFTest.dml create mode 100644 src/test/scripts/functions/federated/FederatedPNMFTestReference.dml diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java new file mode 100644 index 00000000000..951e6c8f714 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.algorithms; + +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; + +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class FederatedPNMFTest extends AutomatedTestBase +{ + private final static String TEST_NAME = "FederatedPNMFTest"; + private final static String TEST_DIR = "functions/federated/"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederatedPNMFTest.class.getSimpleName() + "/"; + + private final static String OUTPUT_NAME = "Z"; + private final static double TOLERANCE = 0.2; + private final static int blocksize = 1024; + + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + @Parameterized.Parameter(2) + public int rank; + @Parameterized.Parameter(3) + public int max_iter; + @Parameterized.Parameter(4) + public double sparsity; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{OUTPUT_NAME})); + } + + @Parameterized.Parameters + public static Collection data() { + // rows must be even + return Arrays.asList(new Object[][] { + // {rows, cols, rank, max_iter, sparsity} + {1000, 750, 420, 10, 1} + }); + } + + @BeforeClass + public static void init() { + TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR); + } + + @Test + public void federatedPNMFSingleNode() { + federatedPNMF(TEST_NAME, ExecMode.SINGLE_NODE); + } + + @Test + public void federatedPNMFSpark() { + federatedPNMF(TEST_NAME, ExecMode.SPARK); + } + +// ----------------------------------------------------------------------------- + + public void federatedPNMF(String testname, ExecMode execMode) + { + // store the previous platform config to restore it after the test + ExecMode platform_old = setExecMode(execMode); + + getAndLoadTestConfiguration(testname); + String HOME = SCRIPT_DIR + TEST_DIR; + + int fed_rows = rows / 2; + int fed_cols = cols; + + // generate dataset + // matrix handled by two federated workers + double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 1, 13); + double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 1, 2); + + writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); + writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); + + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); + Thread thread2 = startLocalFedWorkerThread(port2); + + getAndLoadTestConfiguration(testname); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + testname + "Reference.dml"; + programArgs = new String[] {"-stats", "-nvargs", + "in_X1=" + input("X1"), "in_X2=" + input("X2"), "in_rank=" + Integer.toString(rank), "in_max_iter=" + Integer.toString(max_iter), + "out_Z=" + expected(OUTPUT_NAME)}; + runTest(true, false, null, -1); + + // Run actual dml script with federated matrix + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[] {"-stats", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), + "in_rank=" + Integer.toString(rank), + "in_max_iter=" + Integer.toString(max_iter), + "rows=" + fed_rows, "cols=" + fed_cols, + "out_Z=" + output(OUTPUT_NAME)}; + runTest(true, false, null, -1); + + // compare the results via files + HashMap refResults = readDMLMatrixFromExpectedDir(OUTPUT_NAME); + HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME); + TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); + + TestUtils.shutdownThreads(thread1, thread2); + + // check for federated operations + Assert.assertTrue(heavyHittersContainsString("fed_wcemm")); + Assert.assertTrue(heavyHittersContainsString("fed_wdivmm")); + Assert.assertTrue(heavyHittersContainsString("fed_fedinit")); + + // check that federated input files are still existing + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); + resetExecMode(platform_old); + } +} diff --git a/src/test/scripts/functions/federated/FederatedPNMFTest.dml b/src/test/scripts/functions/federated/FederatedPNMFTest.dml new file mode 100644 index 00000000000..e8b01c93ac1 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedPNMFTest.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = federated(addresses=list($in_X1, $in_X2), + ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, $cols))); + +rank = $in_rank; +max_iter = $in_max_iter; + +[W, H] = pnmf(X = X, rnk = rank, maxi = max_iter); + +Z = W %*% H; + +write(Z, $out_Z); diff --git a/src/test/scripts/functions/federated/FederatedPNMFTestReference.dml b/src/test/scripts/functions/federated/FederatedPNMFTestReference.dml new file mode 100644 index 00000000000..b501cf924c4 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedPNMFTestReference.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rbind(read($in_X1), read($in_X2)); + +rank = $in_rank; +max_iter = $in_max_iter; + +[W, H] = pnmf(X = X, rnk = rank, maxi = max_iter); + +Z = W %*% H; + +write(Z, $out_Z); From 3faf002b89fb35d26971d89063ac9ea8494dfc3c Mon Sep 17 00:00:00 2001 From: ywcb00 Date: Sun, 31 Jan 2021 11:00:27 +0100 Subject: [PATCH 5/7] chore(Quaternary**FEDInstruction): add check for row partitioned federated data --- .../fed/QuaternaryWCeMMFEDInstruction.java | 11 ++++++----- .../fed/QuaternaryWSLossFEDInstruction.java | 3 ++- .../fed/QuaternaryWSigmoidFEDInstruction.java | 3 ++- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java index 8566b39572c..cca7350e9c5 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java @@ -29,6 +29,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; +import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType; import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.cp.DoubleObject; @@ -59,17 +60,17 @@ public void processInstruction(ExecutionContext ec) MatrixObject U = ec.getMatrixObject(input2); MatrixObject V = ec.getMatrixObject(input3); ScalarObject eps = null; - + if(qop.hasFourInputs()) { eps = (_input4.getDataType() == DataType.SCALAR) ? ec.getScalarInput(_input4) : new DoubleObject(ec.getMatrixInput(_input4.getName()).quickGetValue(0, 0)); } - if(!(X.isFederated() && !U.isFederated() && !V.isFederated())) + if(!(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated())) throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = (" +X.isFederated()+", "+U.isFederated()+", "+V.isFederated()+")"); - + FederationMap fedMap = X.getFedMapping(); FederatedRequest[] fr1 = fedMap.broadcastSliced(U, false); FederatedRequest fr2 = fedMap.broadcast(V); @@ -90,7 +91,7 @@ public void processInstruction(ExecutionContext ec) new CPOperand[]{input1, input2, input3}, new long[]{fedMap.getID(), fr1[0].getID(), fr2.getID()}); } - + FederatedRequest frGet = new FederatedRequest(RequestType.GET_VAR, frComp.getID()); FederatedRequest frClean1 = fedMap.cleanup(getTID(), frComp.getID()); FederatedRequest frClean2 = fedMap.cleanup(getTID(), fr1[0].getID()); @@ -108,7 +109,7 @@ public void processInstruction(ExecutionContext ec) response = fedMap.execute(getTID(), true, fr1, fr2, frComp, frGet, frClean1, frClean2, frClean3); } - + //aggregate partial results from federated responses AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+"); ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, response)); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java index 664fbdcd5d8..090a329c433 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java @@ -25,6 +25,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; +import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType; import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.instructions.cp.CPOperand; @@ -69,7 +70,7 @@ public void processInstruction(ExecutionContext ec) { W = ec.getMatrixObject(_input4); } - if(!(X.isFederated() && !U.isFederated() && !V.isFederated() && (W == null || !W.isFederated()))) + if(!(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated() && (W == null || !W.isFederated()))) throw new DMLRuntimeException("Unsupported federated inputs (X, U, V, W) = (" + X.isFederated() + ", " + U.isFederated() + ", " + V.isFederated() + ", " + (W != null ? W.isFederated() : "none") + ")"); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java index 9884e3bec73..e7fa5ad63b8 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java @@ -28,6 +28,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; +import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType; import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.matrix.operators.Operator; @@ -58,7 +59,7 @@ public void processInstruction(ExecutionContext ec) { MatrixObject U = ec.getMatrixObject(input2); MatrixObject V = ec.getMatrixObject(input3); - if(!(X.isFederated() && !U.isFederated() && !V.isFederated())) + if(!(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated())) throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = (" + X.isFederated() + ", " + U.isFederated() + ", " + V.isFederated() + ")"); From 4b447def5610b0501979fdb0b3b51b81c5b89e47 Mon Sep 17 00:00:00 2001 From: ywcb00 Date: Tue, 2 Feb 2021 17:00:18 +0100 Subject: [PATCH 6/7] refactor(Quaternary**FEDInstruction.java): move data check if over the code and create else for error chore(FedPNMF): use parameter sparsity for generating random matrices remove parameter test_name (==> only one test name for pnmf, not needed as parameter here) --- .../fed/QuaternaryWCeMMFEDInstruction.java | 80 ++++++----- .../fed/QuaternaryWDivMMFEDInstruction.java | 135 +++++++++--------- .../fed/QuaternaryWSLossFEDInstruction.java | 80 ++++++----- .../fed/QuaternaryWSigmoidFEDInstruction.java | 45 +++--- .../fed/QuaternaryWUMMFEDInstruction.java | 42 +++--- .../algorithms/FederatedPNMFTest.java | 18 +-- 6 files changed, 205 insertions(+), 195 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java index cca7350e9c5..3603929288b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java @@ -67,51 +67,53 @@ public void processInstruction(ExecutionContext ec) new DoubleObject(ec.getMatrixInput(_input4.getName()).quickGetValue(0, 0)); } - if(!(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated())) - throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = (" - +X.isFederated()+", "+U.isFederated()+", "+V.isFederated()+")"); + if(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated()) { + FederationMap fedMap = X.getFedMapping(); + FederatedRequest[] fr1 = fedMap.broadcastSliced(U, false); + FederatedRequest fr2 = fedMap.broadcast(V); + FederatedRequest fr3 = null; + FederatedRequest frComp = null; - FederationMap fedMap = X.getFedMapping(); - FederatedRequest[] fr1 = fedMap.broadcastSliced(U, false); - FederatedRequest fr2 = fedMap.broadcast(V); - FederatedRequest fr3 = null; - FederatedRequest frComp = null; + // broadcast the scalar epsilon if there are four inputs + if(eps != null) { + fr3 = fedMap.broadcast(eps); + // change the is_literal flag from true to false because when broadcasted it is no literal anymore + instString = instString.replace("true", "false"); + frComp = FederationUtils.callInstruction(instString, output, + new CPOperand[]{input1, input2, input3, _input4}, + new long[]{fedMap.getID(), fr1[0].getID(), fr2.getID(), fr3.getID()}); + } + else { + frComp = FederationUtils.callInstruction(instString, output, + new CPOperand[]{input1, input2, input3}, + new long[]{fedMap.getID(), fr1[0].getID(), fr2.getID()}); + } - // broadcast the scalar epsilon if there are four inputs - if(eps != null) { - fr3 = fedMap.broadcast(eps); - // change the is_literal flag from true to false because when broadcasted it is no literal anymore - instString = instString.replace("true", "false"); - frComp = FederationUtils.callInstruction(instString, output, - new CPOperand[]{input1, input2, input3, _input4}, - new long[]{fedMap.getID(), fr1[0].getID(), fr2.getID(), fr3.getID()}); - } - else { - frComp = FederationUtils.callInstruction(instString, output, - new CPOperand[]{input1, input2, input3}, - new long[]{fedMap.getID(), fr1[0].getID(), fr2.getID()}); - } + FederatedRequest frGet = new FederatedRequest(RequestType.GET_VAR, frComp.getID()); + FederatedRequest frClean1 = fedMap.cleanup(getTID(), frComp.getID()); + FederatedRequest frClean2 = fedMap.cleanup(getTID(), fr1[0].getID()); + FederatedRequest frClean3 = fedMap.cleanup(getTID(), fr2.getID()); - FederatedRequest frGet = new FederatedRequest(RequestType.GET_VAR, frComp.getID()); - FederatedRequest frClean1 = fedMap.cleanup(getTID(), frComp.getID()); - FederatedRequest frClean2 = fedMap.cleanup(getTID(), fr1[0].getID()); - FederatedRequest frClean3 = fedMap.cleanup(getTID(), fr2.getID()); + Future[] response; + if(fr3 != null) { + FederatedRequest frClean4 = fedMap.cleanup(getTID(), fr3.getID()); + // execute federated instructions + response = fedMap.execute(getTID(), true, fr1, fr2, fr3, + frComp, frGet, frClean1, frClean2, frClean3, frClean4); + } + else { + // execute federated instructions + response = fedMap.execute(getTID(), true, fr1, fr2, + frComp, frGet, frClean1, frClean2, frClean3); + } - Future[] response; - if(fr3 != null) { - FederatedRequest frClean4 = fedMap.cleanup(getTID(), fr3.getID()); - // execute federated instructions - response = fedMap.execute(getTID(), true, fr1, fr2, fr3, - frComp, frGet, frClean1, frClean2, frClean3, frClean4); + //aggregate partial results from federated responses + AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+"); + ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, response)); } else { - // execute federated instructions - response = fedMap.execute(getTID(), true, fr1, fr2, - frComp, frGet, frClean1, frClean2, frClean3); + throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = (" + +X.isFederated()+", "+U.isFederated()+", "+V.isFederated()+")"); } - - //aggregate partial results from federated responses - AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+"); - ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, response)); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java index 5ba2b594dcc..29a785bac73 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java @@ -86,79 +86,82 @@ public void processInstruction(ExecutionContext ec) } } - if(!(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated())) - throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = (" - +X.isFederated()+", "+U.isFederated()+", "+V.isFederated() + ")"); - - FederationMap fedMap = X.getFedMapping(); - FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false); - FederatedRequest frInit2 = fedMap.broadcast(V); + if(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated()) { + FederationMap fedMap = X.getFedMapping(); + FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false); + FederatedRequest frInit2 = fedMap.broadcast(V); - FederatedRequest frInit3 = null; - FederatedRequest frInit3Arr[] = null; - FederatedRequest frCompute1 = null; - // broadcast scalar epsilon if there are four inputs - if(eps != null) { - frInit3 = fedMap.broadcast(eps); - // change the is_literal flag from true to false because when broadcasted it is no literal anymore - instString = instString.replace("true", "false"); - frCompute1 = FederationUtils.callInstruction(instString, output, - new CPOperand[]{input1, input2, input3, _input4}, - new long[]{fedMap.getID(), frInit1[0].getID(), frInit2.getID(), frInit3.getID()}); - } - else if(MX != null) { - frInit3Arr = fedMap.broadcastSliced(MX, false); - frCompute1 = FederationUtils.callInstruction(instString, output, - new CPOperand[]{input1, input2, input3, _input4}, - new long[]{fedMap.getID(), frInit1[0].getID(), frInit2.getID(), frInit3Arr[0].getID()}); - } - else { - frCompute1 = FederationUtils.callInstruction(instString, output, - new CPOperand[]{input1, input2, input3}, - new long[]{fedMap.getID(), frInit1[0].getID(), frInit2.getID()}); - } + FederatedRequest frInit3 = null; + FederatedRequest frInit3Arr[] = null; + FederatedRequest frCompute1 = null; + // broadcast scalar epsilon if there are four inputs + if(eps != null) { + frInit3 = fedMap.broadcast(eps); + // change the is_literal flag from true to false because when broadcasted it is no literal anymore + instString = instString.replace("true", "false"); + frCompute1 = FederationUtils.callInstruction(instString, output, + new CPOperand[]{input1, input2, input3, _input4}, + new long[]{fedMap.getID(), frInit1[0].getID(), frInit2.getID(), frInit3.getID()}); + } + else if(MX != null) { + frInit3Arr = fedMap.broadcastSliced(MX, false); + frCompute1 = FederationUtils.callInstruction(instString, output, + new CPOperand[]{input1, input2, input3, _input4}, + new long[]{fedMap.getID(), frInit1[0].getID(), frInit2.getID(), frInit3Arr[0].getID()}); + } + else { + frCompute1 = FederationUtils.callInstruction(instString, output, + new CPOperand[]{input1, input2, input3}, + new long[]{fedMap.getID(), frInit1[0].getID(), frInit2.getID()}); + } - // get partial results from federated workers - FederatedRequest frGet1 = new FederatedRequest(RequestType.GET_VAR, frCompute1.getID()); + // get partial results from federated workers + FederatedRequest frGet1 = new FederatedRequest(RequestType.GET_VAR, frCompute1.getID()); - FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), frCompute1.getID()); - FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), frInit1[0].getID()); - FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), frInit2.getID()); + FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), frCompute1.getID()); + FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), frInit1[0].getID()); + FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), frInit2.getID()); - // execute federated instructions - Future[] response; - if(frInit3 != null) { - FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), frInit3.getID()); - response = fedMap.execute(getTID(), true, - frInit1, frInit2, frInit3, - frCompute1, frGet1, - frCleanup1, frCleanup2, frCleanup3, frCleanup4); - } - else if(frInit3Arr != null) { - FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), frInit3Arr[0].getID()); - fedMap.execute(getTID(), true, frInit1, frInit2); - response = fedMap.execute(getTID(), true, frInit3Arr, - frCompute1, frGet1, - frCleanup1, frCleanup2, frCleanup3, frCleanup4); - } - else { - response = fedMap.execute(getTID(), true, - frInit1, frInit2, - frCompute1, frGet1, - frCleanup1, frCleanup2, frCleanup3); - } + // execute federated instructions + Future[] response; + if(frInit3 != null) { + FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), frInit3.getID()); + response = fedMap.execute(getTID(), true, + frInit1, frInit2, frInit3, + frCompute1, frGet1, + frCleanup1, frCleanup2, frCleanup3, frCleanup4); + } + else if(frInit3Arr != null) { + FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), frInit3Arr[0].getID()); + fedMap.execute(getTID(), true, frInit1, frInit2); + response = fedMap.execute(getTID(), true, frInit3Arr, + frCompute1, frGet1, + frCleanup1, frCleanup2, frCleanup3, frCleanup4); + } + else { + response = fedMap.execute(getTID(), true, + frInit1, frInit2, + frCompute1, frGet1, + frCleanup1, frCleanup2, frCleanup3); + } - if(wdivmm_type.isLeft()) { - // aggregate partial results from federated responses - AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+"); - ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, response, fedMap)); - } - else if(wdivmm_type.isRight() || wdivmm_type.isBasic()) { - // bind partial results from federated responses - ec.setMatrixOutput(output.getName(), FederationUtils.bind(response, false)); + if(wdivmm_type.isLeft()) { + // aggregate partial results from federated responses + AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+"); + ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, response, fedMap)); + } + else if(wdivmm_type.isRight() || wdivmm_type.isBasic()) { + // bind partial results from federated responses + ec.setMatrixOutput(output.getName(), FederationUtils.bind(response, false)); + } + else { + throw new DMLRuntimeException("Federated WDivMM only supported for BASIC, LEFT or RIGHT variants."); + } } else { - throw new DMLRuntimeException("Federated WDivMM only supported for BASIC, LEFT or RIGHT variants."); + throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = (" + +X.isFederated()+", "+U.isFederated()+", "+V.isFederated() + ")"); } } } + diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java index 090a329c433..70ba16c83fc 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java @@ -70,51 +70,53 @@ public void processInstruction(ExecutionContext ec) { W = ec.getMatrixObject(_input4); } - if(!(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated() && (W == null || !W.isFederated()))) - throw new DMLRuntimeException("Unsupported federated inputs (X, U, V, W) = (" + X.isFederated() + ", " - + U.isFederated() + ", " + V.isFederated() + ", " + (W != null ? W.isFederated() : "none") + ")"); + if(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated() && (W == null || !W.isFederated())) { + FederationMap fedMap = X.getFedMapping(); + FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false); + FederatedRequest frInit2 = fedMap.broadcast(V); - FederationMap fedMap = X.getFedMapping(); - FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false); - FederatedRequest frInit2 = fedMap.broadcast(V); + FederatedRequest[] frInit3 = null; + FederatedRequest frCompute1 = null; + if(W != null) { + frInit3 = fedMap.broadcastSliced(W, false); + frCompute1 = FederationUtils.callInstruction(instString, + output, + new CPOperand[] {input1, input2, input3, _input4}, + new long[] {fedMap.getID(), frInit1[0].getID(), frInit2.getID(), frInit3[0].getID()}); + } + else { + frCompute1 = FederationUtils.callInstruction(instString, + output, + new CPOperand[] {input1, input2, input3}, + new long[] {fedMap.getID(), frInit1[0].getID(), frInit2.getID()}); + } - FederatedRequest[] frInit3 = null; - FederatedRequest frCompute1 = null; - if(W != null) { - frInit3 = fedMap.broadcastSliced(W, false); - frCompute1 = FederationUtils.callInstruction(instString, - output, - new CPOperand[] {input1, input2, input3, _input4}, - new long[] {fedMap.getID(), frInit1[0].getID(), frInit2.getID(), frInit3[0].getID()}); - } - else { - frCompute1 = FederationUtils.callInstruction(instString, - output, - new CPOperand[] {input1, input2, input3}, - new long[] {fedMap.getID(), frInit1[0].getID(), frInit2.getID()}); - } + FederatedRequest frGet1 = new FederatedRequest(RequestType.GET_VAR, frCompute1.getID()); + FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), frCompute1.getID()); + FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), frInit1[0].getID()); + FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), frInit2.getID()); - FederatedRequest frGet1 = new FederatedRequest(RequestType.GET_VAR, frCompute1.getID()); - FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), frCompute1.getID()); - FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), frInit1[0].getID()); - FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), frInit2.getID()); + Future[] response; + if(frInit3 != null) { + FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), frInit3[0].getID()); + // execute federated instructions + fedMap.execute(getTID(), true, frInit1, frInit2); + response = fedMap + .execute(getTID(), true, frInit3, frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3, frCleanup4); + } + else { + // execute federated instructions + response = fedMap + .execute(getTID(), true, frInit1, frInit2, frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3); + } - Future[] response; - if(frInit3 != null) { - FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), frInit3[0].getID()); - // execute federated instructions - fedMap.execute(getTID(), true, frInit1, frInit2); - response = fedMap - .execute(getTID(), true, frInit3, frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3, frCleanup4); + // aggregate partial results from federated responses + AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+"); + ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, response)); } else { - // execute federated instructions - response = fedMap - .execute(getTID(), true, frInit1, frInit2, frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3); + throw new DMLRuntimeException("Unsupported federated inputs (X, U, V, W) = (" + X.isFederated() + ", " + + U.isFederated() + ", " + V.isFederated() + ", " + (W != null ? W.isFederated() : "none") + ")"); } - - // aggregate partial results from federated responses - AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+"); - ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, response)); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java index e7fa5ad63b8..b3fa44c7d2d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java @@ -59,32 +59,33 @@ public void processInstruction(ExecutionContext ec) { MatrixObject U = ec.getMatrixObject(input2); MatrixObject V = ec.getMatrixObject(input3); - if(!(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated())) - throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = (" + X.isFederated() + ", " - + U.isFederated() + ", " + V.isFederated() + ")"); - - FederationMap fedMap = X.getFedMapping(); - FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false); - FederatedRequest frInit2 = fedMap.broadcast(V); + if(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated()) { + FederationMap fedMap = X.getFedMapping(); + FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false); + FederatedRequest frInit2 = fedMap.broadcast(V); - FederatedRequest frCompute1 = FederationUtils.callInstruction(instString, - output, - new CPOperand[] {input1, input2, input3}, - new long[] {fedMap.getID(), frInit1[0].getID(), frInit2.getID()}); + FederatedRequest frCompute1 = FederationUtils.callInstruction(instString, + output, + new CPOperand[] {input1, input2, input3}, + new long[] {fedMap.getID(), frInit1[0].getID(), frInit2.getID()}); - // get partial results from federated workers - FederatedRequest frGet1 = new FederatedRequest(RequestType.GET_VAR, frCompute1.getID()); + // get partial results from federated workers + FederatedRequest frGet1 = new FederatedRequest(RequestType.GET_VAR, frCompute1.getID()); - FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), frCompute1.getID()); - FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), frInit1[0].getID()); - FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), frInit2.getID()); + FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), frCompute1.getID()); + FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), frInit1[0].getID()); + FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), frInit2.getID()); - // execute federated instructions - Future[] response = fedMap - .execute(getTID(), true, frInit1, frInit2, frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3); - - // bind partial results from federated responses - ec.setMatrixOutput(output.getName(), FederationUtils.bind(response, false)); + // execute federated instructions + Future[] response = fedMap + .execute(getTID(), true, frInit1, frInit2, frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3); + // bind partial results from federated responses + ec.setMatrixOutput(output.getName(), FederationUtils.bind(response, false)); + } + else { + throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = (" + X.isFederated() + ", " + + U.isFederated() + ", " + V.isFederated() + ")"); + } } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java index 82bc9e2371a..b5a8b081d34 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java @@ -60,30 +60,32 @@ public void processInstruction(ExecutionContext ec) { MatrixObject U = ec.getMatrixObject(input2); MatrixObject V = ec.getMatrixObject(input3); - if(!(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated())) - throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = (" + X.isFederated() + ", " - + U.isFederated() + ", " + V.isFederated() + ")"); - - FederationMap fedMap = X.getFedMapping(); - FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false); - FederatedRequest frInit2 = fedMap.broadcast(V); + if(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated()) { + FederationMap fedMap = X.getFedMapping(); + FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false); + FederatedRequest frInit2 = fedMap.broadcast(V); - FederatedRequest frCompute1 = FederationUtils.callInstruction(instString, - output, new CPOperand[] {input1, input2, input3}, - new long[] {fedMap.getID(), frInit1[0].getID(), frInit2.getID()}); + FederatedRequest frCompute1 = FederationUtils.callInstruction(instString, + output, new CPOperand[] {input1, input2, input3}, + new long[] {fedMap.getID(), frInit1[0].getID(), frInit2.getID()}); - // get partial results from federated workers - FederatedRequest frGet1 = new FederatedRequest(RequestType.GET_VAR, frCompute1.getID()); + // get partial results from federated workers + FederatedRequest frGet1 = new FederatedRequest(RequestType.GET_VAR, frCompute1.getID()); - FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), frCompute1.getID()); - FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), frInit1[0].getID()); - FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), frInit2.getID()); + FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), frCompute1.getID()); + FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), frInit1[0].getID()); + FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), frInit2.getID()); - // execute federated instructions - Future[] response = fedMap - .execute(getTID(), true, frInit1, frInit2, frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3); + // execute federated instructions + Future[] response = fedMap + .execute(getTID(), true, frInit1, frInit2, frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3); - // bind partial results from federated responses - ec.setMatrixOutput(output.getName(), FederationUtils.bind(response, false)); + // bind partial results from federated responses + ec.setMatrixOutput(output.getName(), FederationUtils.bind(response, false)); + } + else { + throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = (" + X.isFederated() + ", " + + U.isFederated() + ", " + V.isFederated() + ")"); + } } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java index 951e6c8f714..3c28a9a6a9a 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java @@ -80,22 +80,22 @@ public static void init() { @Test public void federatedPNMFSingleNode() { - federatedPNMF(TEST_NAME, ExecMode.SINGLE_NODE); + federatedPNMF(ExecMode.SINGLE_NODE); } @Test public void federatedPNMFSpark() { - federatedPNMF(TEST_NAME, ExecMode.SPARK); + federatedPNMF(ExecMode.SPARK); } // ----------------------------------------------------------------------------- - public void federatedPNMF(String testname, ExecMode execMode) + public void federatedPNMF(ExecMode execMode) { // store the previous platform config to restore it after the test ExecMode platform_old = setExecMode(execMode); - getAndLoadTestConfiguration(testname); + getAndLoadTestConfiguration(TEST_NAME); String HOME = SCRIPT_DIR + TEST_DIR; int fed_rows = rows / 2; @@ -103,8 +103,8 @@ public void federatedPNMF(String testname, ExecMode execMode) // generate dataset // matrix handled by two federated workers - double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 1, 13); - double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 1, 2); + double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 1, 2, sparsity, 13); + double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 2, sparsity, 2); writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); @@ -116,17 +116,17 @@ public void federatedPNMF(String testname, ExecMode execMode) Thread thread1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); Thread thread2 = startLocalFedWorkerThread(port2); - getAndLoadTestConfiguration(testname); + getAndLoadTestConfiguration(TEST_NAME); // Run reference dml script with normal matrix - fullDMLScriptName = HOME + testname + "Reference.dml"; + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; programArgs = new String[] {"-stats", "-nvargs", "in_X1=" + input("X1"), "in_X2=" + input("X2"), "in_rank=" + Integer.toString(rank), "in_max_iter=" + Integer.toString(max_iter), "out_Z=" + expected(OUTPUT_NAME)}; runTest(true, false, null, -1); // Run actual dml script with federated matrix - fullDMLScriptName = HOME + testname + ".dml"; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; programArgs = new String[] {"-stats", "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), From 125ffb815e6880b52fbcfb2d70d3e2d05bbadefc Mon Sep 17 00:00:00 2001 From: ywcb00 Date: Tue, 9 Feb 2021 16:06:06 +0100 Subject: [PATCH 7/7] refactor(**): change blocksize constant to uppercase name in FedAlsCG test, FedPNMF test, and all quaternary primitive tests --- .../algorithms/FederatedAlsCGTest.java | 6 ++--- .../algorithms/FederatedPNMFTest.java | 6 ++--- .../FederatedWeightedCrossEntropyTest.java | 6 ++--- .../FederatedWeightedDivMatrixMultTest.java | 22 +++++++++---------- .../FederatedWeightedSigmoidTest.java | 6 ++--- .../FederatedWeightedSquaredLossTest.java | 6 ++--- .../FederatedWeightedUnaryMatrixMultTest.java | 20 ++++++++--------- 7 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java index 4909f7c497c..9263bebef1f 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java @@ -46,7 +46,7 @@ public class FederatedAlsCGTest extends AutomatedTestBase private final static String OUTPUT_NAME = "Z"; private final static double TOLERANCE = 0.01; - private final static int blocksize = 1024; + private final static int BLOCKSIZE = 1024; @Parameterized.Parameter() public int rows; @@ -112,9 +112,9 @@ public void federatedAlsCG(String testname, ExecMode execMode) double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 2, sparsity, 2); writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics( - fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); + fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols)); writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics( - fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); + fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols)); // empty script name because we don't execute any script, just start the worker fullDMLScriptName = ""; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java index 3c28a9a6a9a..f877f1ff3a9 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java @@ -46,7 +46,7 @@ public class FederatedPNMFTest extends AutomatedTestBase private final static String OUTPUT_NAME = "Z"; private final static double TOLERANCE = 0.2; - private final static int blocksize = 1024; + private final static int BLOCKSIZE = 1024; @Parameterized.Parameter() public int rows; @@ -106,8 +106,8 @@ public void federatedPNMF(ExecMode execMode) double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 1, 2, sparsity, 13); double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 2, sparsity, 2); - writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); - writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); + writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols)); + writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols)); // empty script name because we don't execute any script, just start the worker fullDMLScriptName = ""; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java index bf676a38ad0..655124dc31d 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java @@ -47,7 +47,7 @@ public class FederatedWeightedCrossEntropyTest extends AutomatedTestBase private final static String OUTPUT_NAME = "Z"; private final static double TOLERANCE = 1e-9; - private final static int blocksize = 1024; + private final static int BLOCKSIZE = 1024; @Parameterized.Parameter() public int rows; @@ -124,8 +124,8 @@ public void federatedWeightedCrossEntropy(String testname, ExecMode execMode) double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512); double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040); - writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); - writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); + writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols)); + writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols)); writeInputMatrixWithMTD("U", U, true); writeInputMatrixWithMTD("V", V, true); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedDivMatrixMultTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedDivMatrixMultTest.java index 39a79bb6407..15c192b63f0 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedDivMatrixMultTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedDivMatrixMultTest.java @@ -60,7 +60,7 @@ public class FederatedWeightedDivMatrixMultTest extends AutomatedTestBase private final static double TOLERANCE = 1e-9; - private final static int blocksize = 1024; + private final static int BLOCKSIZE = 1024; @Parameterized.Parameter() public int rows; @@ -256,11 +256,11 @@ public void federatedWeightedDivMatrixMult(String test_name, ExecMode exec_mode) double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512); double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040); - writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); - writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); + writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols)); + writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols)); - writeInputMatrixWithMTD("U", U, true, new MatrixCharacteristics(rows, rank, blocksize, rows * rank)); - writeInputMatrixWithMTD("V", V, true, new MatrixCharacteristics(cols, rank, blocksize, rows * rank)); + writeInputMatrixWithMTD("U", U, true, new MatrixCharacteristics(rows, rank, BLOCKSIZE, rows * rank)); + writeInputMatrixWithMTD("V", V, true, new MatrixCharacteristics(cols, rank, BLOCKSIZE, rows * rank)); // empty script name because we don't execute any script, just start the worker fullDMLScriptName = ""; @@ -270,7 +270,7 @@ public void federatedWeightedDivMatrixMult(String test_name, ExecMode exec_mode) Thread thread2 = startLocalFedWorkerThread(port2); getAndLoadTestConfiguration(test_name); - + try { // Run reference dml script with normal matrix fullDMLScriptName = HOME + test_name + "Reference.dml"; @@ -278,7 +278,7 @@ public void federatedWeightedDivMatrixMult(String test_name, ExecMode exec_mode) "in_U=" + input("U"), "in_V=" + input("V"), "in_W=" + Double.toString(epsilon), "out_Z=" + expected(OUTPUT_NAME)}; runTest(true, false, null, -1); - + // Run actual dml script with federated matrix fullDMLScriptName = HOME + test_name + ".dml"; programArgs = new String[] {"-stats", "-nvargs", @@ -289,22 +289,22 @@ public void federatedWeightedDivMatrixMult(String test_name, ExecMode exec_mode) "in_W=" + Double.toString(epsilon), "rows=" + fed_rows, "cols=" + fed_cols, "out_Z=" + output(OUTPUT_NAME)}; runTest(true, false, null, -1); - + // compare the results via files HashMap refResults = readDMLMatrixFromExpectedDir(OUTPUT_NAME); HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME); TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - + // check for federated operations Assert.assertTrue(heavyHittersContainsString("fed_wdivmm")); - + // check that federated input files are still existing Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1"))); Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); } finally { TestUtils.shutdownThreads(thread1, thread2); - + resetExecMode(platform_old); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java index e73ce826ea6..ec800b090d2 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java @@ -50,7 +50,7 @@ public class FederatedWeightedSigmoidTest extends AutomatedTestBase { private final static double TOLERANCE = 0; - private final static int blocksize = 1024; + private final static int BLOCKSIZE = 1024; @Parameterized.Parameter() public int rows; @@ -151,11 +151,11 @@ public void federatedWeightedSigmoid(String test_name, ExecMode exec_mode) { writeInputMatrixWithMTD("X1", X1, false, - new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); + new MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols)); writeInputMatrixWithMTD("X2", X2, false, - new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); + new MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols)); writeInputMatrixWithMTD("U", U, true); writeInputMatrixWithMTD("V", V, true); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java index 9b0f7a7adae..782891c9673 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java @@ -50,7 +50,7 @@ public class FederatedWeightedSquaredLossTest extends AutomatedTestBase { private final static double TOLERANCE = 1e-8; - private final static int blocksize = 1024; + private final static int BLOCKSIZE = 1024; @Parameterized.Parameter() public int rows; @@ -138,11 +138,11 @@ public void federatedWeightedSquaredLoss(String test_name, ExecMode exec_mode) { writeInputMatrixWithMTD("X1", X1, false, - new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); + new MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols)); writeInputMatrixWithMTD("X2", X2, false, - new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); + new MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols)); writeInputMatrixWithMTD("U", U, true); writeInputMatrixWithMTD("V", V, true); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java index 581d27dfb87..8cc582a6d36 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java @@ -51,7 +51,7 @@ public class FederatedWeightedUnaryMatrixMultTest extends AutomatedTestBase private final static double TOLERANCE = 0; - private final static int blocksize = 1024; + private final static int BLOCKSIZE = 1024; @Parameterized.Parameter() public int rows; @@ -147,11 +147,11 @@ public void federatedWeightedUnaryMatrixMult(String test_name, ExecMode exec_mod double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512); double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040); - writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); - writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols)); + writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols)); + writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols)); - writeInputMatrixWithMTD("U", U, false, new MatrixCharacteristics(rows, rank, blocksize, rows * rank)); - writeInputMatrixWithMTD("V", V, false, new MatrixCharacteristics(cols, rank, blocksize, rows * rank)); + writeInputMatrixWithMTD("U", U, false, new MatrixCharacteristics(rows, rank, BLOCKSIZE, rows * rank)); + writeInputMatrixWithMTD("V", V, false, new MatrixCharacteristics(cols, rank, BLOCKSIZE, rows * rank)); // empty script name because we don't execute any script, just start the worker fullDMLScriptName = ""; @@ -169,7 +169,7 @@ public void federatedWeightedUnaryMatrixMult(String test_name, ExecMode exec_mod "in_U=" + input("U"), "in_V=" + input("V"), "out_Z=" + expected(OUTPUT_NAME)}; runTest(true, false, null, -1); - + // Run actual dml script with federated matrix fullDMLScriptName = HOME + test_name + ".dml"; programArgs = new String[] {"-stats", "-nvargs", @@ -179,22 +179,22 @@ public void federatedWeightedUnaryMatrixMult(String test_name, ExecMode exec_mod "in_V=" + input("V"), "rows=" + fed_rows, "cols=" + fed_cols, "out_Z=" + output(OUTPUT_NAME)}; runTest(true, false, null, -1); - + // compare the results via files HashMap refResults = readDMLMatrixFromExpectedDir(OUTPUT_NAME); HashMap fedResults = readDMLMatrixFromOutputDir(OUTPUT_NAME); TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, "Fed", "Ref"); - + // check for federated operations Assert.assertTrue(heavyHittersContainsString("fed_wumm")); - + // check that federated input files are still existing Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1"))); Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); } finally { TestUtils.shutdownThreads(thread1, thread2); - + resetExecMode(platform_old); } }