From 03694c718edc1f6ee65059344586ec7dbcd1044e Mon Sep 17 00:00:00 2001 From: ywcb00 Date: Fri, 13 May 2022 17:08:42 +0200 Subject: [PATCH 1/2] feat(FederatedLeftIndexTest.java): support a scalar rhs for left indexing chore(FederationMap.java): override the execute method to directly call it with single items instead of arrays feat(FederatedLeftIndexScalarTest): add test for scalar left indexing --- .../federated/FederationMap.java | 5 + .../fed/IndexingFEDInstruction.java | 137 ++++++++++++------ .../primitives/FederatedLeftIndexTest.java | 28 +++- .../FederatedLeftIndexScalarTest.dml | 44 ++++++ .../FederatedLeftIndexScalarTestReference.dml | 40 +++++ 5 files changed, 204 insertions(+), 50 deletions(-) create mode 100644 src/test/scripts/functions/federated/FederatedLeftIndexScalarTest.dml create mode 100644 src/test/scripts/functions/federated/FederatedLeftIndexScalarTestReference.dml diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java index 0053a8b2fe2..fcef0d7984e 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java @@ -351,6 +351,11 @@ public Future[] execute(long tid, boolean wait, FederatedRequ return ret.toArray(new Future[0]); } + public Future[] execute(long tid, boolean wait, FederatedRange[] fedRange1, + FederatedRequest elseFr, FederatedRequest frSlice1, FederatedRequest frSlice2, FederatedRequest fr) { + return execute(tid, wait, fedRange1, elseFr, new FederatedRequest[]{frSlice1}, new FederatedRequest[]{frSlice2}, fr); + } + @SuppressWarnings("unchecked") public Future[] execute(long tid, boolean wait, FederatedRange[] fedRange1, FederatedRequest elseFr, FederatedRequest[] frSlices1, FederatedRequest[] frSlices2, FederatedRequest... fr) { // executes step1[] - step 2 - ... step4 (only first step federated-data-specific) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java index bc70b398f94..4697252d71e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java @@ -27,6 +27,7 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.hops.fedplanner.FTypes.FType; import org.apache.sysds.lops.LeftIndex; @@ -44,6 +45,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.util.IndexRange; @@ -204,7 +206,8 @@ private void leftIndexing(ExecutionContext ec) { //get input and requested index range CacheableData in1 = ec.getCacheableData(input1); - CacheableData in2 = ec.getCacheableData(input2); + CacheableData in2 = null; // either in2 or scalar is set + ScalarObject scalar = null; IndexRange ixrange = getIndexRange(ec); //check bounds @@ -213,11 +216,21 @@ private void leftIndexing(ExecutionContext ec) throw new DMLRuntimeException("Invalid values for matrix indexing: ["+(ixrange.rowStart+1)+":"+(ixrange.rowEnd+1)+"," + (ixrange.colStart+1)+":"+(ixrange.colEnd+1)+"] " + "must be within matrix dimensions ["+in1.getNumRows()+","+in1.getNumColumns()+"]."); } - if( (ixrange.rowEnd-ixrange.rowStart+1) != in2.getNumRows() || (ixrange.colEnd-ixrange.colStart+1) != in2.getNumColumns()) { - throw new DMLRuntimeException("Invalid values for matrix indexing: " + - "dimensions of the source matrix ["+in2.getNumRows()+"x" + in2.getNumColumns() + "] " + - "do not match the shape of the matrix specified by indices [" + - (ixrange.rowStart+1) +":" + (ixrange.rowEnd+1) + ", " + (ixrange.colStart+1) + ":" + (ixrange.colEnd+1) + "]."); + + if(input2.getDataType() == DataType.SCALAR) { + if(!ixrange.isScalar()) + throw new DMLRuntimeException("Invalid index range for leftindexing with scalar: " + ixrange.toString() + "."); + + scalar = ec.getScalarInput(input2); + } + else { + in2 = ec.getCacheableData(input2); + if( (ixrange.rowEnd-ixrange.rowStart+1) != in2.getNumRows() || (ixrange.colEnd-ixrange.colStart+1) != in2.getNumColumns()) { + throw new DMLRuntimeException("Invalid values for matrix indexing: " + + "dimensions of the source matrix ["+in2.getNumRows()+"x" + in2.getNumColumns() + "] " + + "do not match the shape of the matrix specified by indices [" + + (ixrange.rowStart+1) +":" + (ixrange.rowEnd+1) + ", " + (ixrange.colStart+1) + ":" + (ixrange.colEnd+1) + "]."); + } } FederationMap fedMap = in1.getFedMapping(); @@ -226,6 +239,10 @@ private void leftIndexing(ExecutionContext ec) int[][] sliceIxs = new int[fedMap.getSize()][]; FederatedRange[] ranges = new FederatedRange[fedMap.getSize()]; + // instruction string for copying a partition at the federated site + int cpVarInstIx = fedMap.getSize(); + String cpVarInstString = createCopyInstString(); + // replace old reshape values for each worker int i = 0, prev = 0, from = fedMap.getSize(); for(org.apache.commons.lang3.tuple.Pair e : fedMap.getMap()) { @@ -239,29 +256,46 @@ private void leftIndexing(ExecutionContext ec) long[] newIx = new long[]{(int) rsn, (int) ren, (int) csn, (int) cen}; - // find ranges where to apply leftIndex - long to; - if(in1.isFederated(FType.ROW) && (to = (prev + ren - rsn)) >= 0 && - to < in2.getNumRows() && ixrange.rowStart <= re) { - sliceIxs[i] = new int[] { prev, (int) to, 0, (int) in2.getNumColumns()-1}; - prev = (int) (to + 1); - - instStrings[i] = modifyIndices(newIx, 4, 8); - ranges[i] = range; - from = Math.min(i, from); + if(in2 != null) { // matrix, frame + // find ranges where to apply leftIndex + long to; + if(in1.isFederated(FType.ROW) && (to = (prev + ren - rsn)) >= 0 && + to < in2.getNumRows() && ixrange.rowStart <= re) { + sliceIxs[i] = new int[] { prev, (int) to, 0, (int) in2.getNumColumns()-1}; + prev = (int) (to + 1); + + instStrings[i] = modifyIndices(newIx, 4, 8); + ranges[i] = range; + from = Math.min(i, from); + } + else if(in1.isFederated(FType.COL) && (to = (prev + cen - csn)) >= 0 && + to < in2.getNumColumns() && ixrange.colStart <= ce) { + sliceIxs[i] = new int[] {0, (int) in2.getNumRows() - 1, prev, (int) to}; + prev = (int) (to + 1); + + instStrings[i] = modifyIndices(newIx, 4, 8); + ranges[i] = range; + from = Math.min(i, from); + } + else { + // TODO shallow copy, add more advanced update in place for federated + cpVarInstIx = Math.min(i, cpVarInstIx); + instStrings[i] = cpVarInstString; + } } - else if(in1.isFederated(FType.COL) && (to = (prev + cen - csn)) >= 0 && - to < in2.getNumColumns() && ixrange.colStart <= ce) { - sliceIxs[i] = new int[] {0, (int) in2.getNumRows() - 1, prev, (int) to}; - prev = (int) (to + 1); - - instStrings[i] = modifyIndices(newIx, 4, 8); - ranges[i] = range; - from = Math.min(i, from); + else { // scalar + if(ixrange.rowStart >= rs && ixrange.rowEnd < re + && ixrange.colStart >= cs && ixrange.colEnd < ce) { + instStrings[i] = modifyIndices(newIx, 4, 8); + instStrings[i] = changeScalarLiteralFlag(instStrings[i], 3); + ranges[i] = range; + from = Math.min(i, from); + } + else { + cpVarInstIx = Math.min(i, cpVarInstIx); + instStrings[i] = cpVarInstString; + } } - else - // TODO shallow copy, add more advanced update in place for federated - instStrings[i] = createCopyInstString(); i++; } @@ -272,32 +306,40 @@ else if(in1.isFederated(FType.COL) && (to = (prev + cen - csn)) >= 0 && FederatedRequest tmp = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), in1.getDataType()); fedMap.execute(getTID(), true, tmp); - FederatedRequest[] fr1 = fedMap.broadcastSliced(in2, DMLScript.LINEAGE ? ec.getLineageItem(input2) : null, - input2.isFrame(), sliceIxs); - FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, output, id, new CPOperand[]{input1, input2}, - new long[]{fedMap.getID(), fr1[0].getID()}, null); - FederatedRequest fr3 = fedMap.cleanup(getTID(), fr1[0].getID()); + if(in2 != null) { // matrix, frame + FederatedRequest[] fr1 = fedMap.broadcastSliced(in2, DMLScript.LINEAGE ? ec.getLineageItem(input2) : null, + input2.isFrame(), sliceIxs); + FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, output, id, new CPOperand[]{input1, input2}, + new long[]{fedMap.getID(), fr1[0].getID()}, null); + FederatedRequest fr3 = fedMap.cleanup(getTID(), fr1[0].getID()); - //execute federated instruction and cleanup intermediates - if(sliceIxs.length == fedMap.getSize()) - fedMap.execute(getTID(), true, fr2, fr1, fr3); - else { - // get index of cpvar request - for(i = 0; i < fr2.length; i++) - if(i < from || i >= from + sliceIxs.length) - break; - fedMap.execute(getTID(), true, ranges, (fr2[i]), Arrays.copyOfRange(fr2, from, from + sliceIxs.length), fr1, fr3); + //execute federated instruction and cleanup intermediates + if(sliceIxs.length == fedMap.getSize()) + fedMap.execute(getTID(), true, fr2, fr1, fr3); + else + fedMap.execute(getTID(), true, ranges, fr2[cpVarInstIx], Arrays.copyOfRange(fr2, from, from + sliceIxs.length), fr1, fr3); + } + else { // scalar + FederatedRequest fr1 = fedMap.broadcast(scalar); + FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, output, id, new CPOperand[]{input1, input2}, + new long[]{fedMap.getID(), fr1.getID()}, null); + FederatedRequest fr3 = fedMap.cleanup(getTID(), fr1.getID()); + + if(fr2.length == 1) + fedMap.execute(getTID(), true, fr2, fr1, fr3); + else + fedMap.execute(getTID(), true, ranges, fr2[cpVarInstIx], fr2[from], fr1, fr3); } if(input1.isFrame()) { FrameObject out = ec.getFrameObject(output); out.setSchema(((FrameObject) in1).getSchema()); out.getDataCharacteristics().set(in1.getDataCharacteristics()); - out.setFedMapping(fedMap.copyWithNewID(fr2[0].getID())); + out.setFedMapping(fedMap.copyWithNewID(id)); } else { MatrixObject out = ec.getMatrixObject(output); - out.getDataCharacteristics().set(in1.getDataCharacteristics());; - out.setFedMapping(fedMap.copyWithNewID(fr2[0].getID())); + out.getDataCharacteristics().set(in1.getDataCharacteristics()); + out.setFedMapping(fedMap.copyWithNewID(id)); } } @@ -309,6 +351,13 @@ private String modifyIndices(long[] newIx, int from, int to) { return String.join(Lop.OPERAND_DELIMITOR, instParts); } + private String changeScalarLiteralFlag(String inst, int partIx) { + // change the literal flag of the broadcast scalar + String[] instParts = inst.split(Lop.OPERAND_DELIMITOR); + instParts[partIx] = instParts[partIx].replace("true", "false"); + return String.join(Lop.OPERAND_DELIMITOR, instParts); + } + private String createCopyInstString() { String[] instParts = instString.split(Lop.OPERAND_DELIMITOR); return VariableCPInstruction.prepareCopyInstruction(instParts[2], instParts[8]).toString(); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java index 3c337286a3a..a2e816fa92c 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java @@ -40,6 +40,7 @@ public class FederatedLeftIndexTest extends AutomatedTestBase { private final static String TEST_NAME1 = "FederatedLeftIndexFullTest"; private final static String TEST_NAME2 = "FederatedLeftIndexFrameFullTest"; + private final static String TEST_NAME3 = "FederatedLeftIndexScalarTest"; private final static String TEST_DIR = "functions/federated/"; private static final String TEST_CLASS_DIR = TEST_DIR + FederatedLeftIndexTest.class.getSimpleName() + "/"; @@ -81,7 +82,7 @@ public static Collection data() { } private enum DataType { - MATRIX, FRAME + MATRIX, FRAME, SCALAR } @Override @@ -89,6 +90,7 @@ public void setUp() { TestUtils.clearAssertionInformation(); addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"})); addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"})); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"})); } @Test @@ -109,6 +111,16 @@ public void testLeftIndexFullDenseFrameSP() { runAggregateOperationTest(DataType.FRAME, ExecMode.SPARK); } + @Test + public void testLeftIndexScalarCP() { + runAggregateOperationTest(DataType.SCALAR, ExecMode.SINGLE_NODE); + } + + @Test + public void testLeftIndexScalarSP() { + runAggregateOperationTest(DataType.SCALAR, ExecMode.SPARK); + } + private void runAggregateOperationTest(DataType dataType, ExecMode execMode) { setExecMode(execMode); @@ -116,8 +128,10 @@ private void runAggregateOperationTest(DataType dataType, ExecMode execMode) { if(dataType == DataType.MATRIX) TEST_NAME = TEST_NAME1; - else + else if(dataType == DataType.FRAME) TEST_NAME = TEST_NAME2; + else + TEST_NAME = TEST_NAME3; getAndLoadTestConfiguration(TEST_NAME); @@ -142,10 +156,12 @@ private void runAggregateOperationTest(DataType dataType, ExecMode execMode) { writeInputMatrixWithMTD("X3", X3, false, mc); writeInputMatrixWithMTD("X4", X4, false, mc); - double[][] Y = getRandomMatrix(rows2, cols2, 1, 5, 1, 3); + if(dataType != DataType.SCALAR) { + double[][] Y = getRandomMatrix(rows2, cols2, 1, 5, 1, 3); - MatrixCharacteristics mc2 = new MatrixCharacteristics(rows2, cols2, blocksize, rows2 * cols2); - writeInputMatrixWithMTD("Y", Y, false, mc2); + MatrixCharacteristics mc2 = new MatrixCharacteristics(rows2, cols2, blocksize, rows2 * cols2); + writeInputMatrixWithMTD("Y", Y, false, mc2); + } // empty script name because we don't execute any script, just start the worker fullDMLScriptName = ""; @@ -173,7 +189,7 @@ private void runAggregateOperationTest(DataType dataType, ExecMode execMode) { // Run reference dml script with normal matrix fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; - programArgs = new String[] {"-explain", "-args", input("X1"), input("X2"), input("X3"), input("X4"), + programArgs = new String[] {"-args", input("X1"), input("X2"), input("X3"), input("X4"), input("Y"), String.valueOf(from), String.valueOf(to), String.valueOf(from2), String.valueOf(to2), Boolean.toString(rowPartitioned).toUpperCase(), expected("S")}; diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexScalarTest.dml b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTest.dml new file mode 100644 index 00000000000..71a9f934908 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTest.dml @@ -0,0 +1,44 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +row1 = $from; +row2 = $to; +col1 = $from2; +col2 = $to2; + +if ($rP) { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); +} else { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2), + list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols))); +} + +b = 13; +c = as.scalar(rand(rows=1, cols=1, seed=456)); + +A[row1, col1] = b; +A[row2, col2] = c; + +write(A, $out_S); + diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexScalarTestReference.dml b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTestReference.dml new file mode 100644 index 00000000000..14ea17fbda2 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTestReference.dml @@ -0,0 +1,40 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +row1 = $6; +row2 = $7; +col1 = $8; +col2 = $9; +if($10) { + A = rbind(read($1), read($2), read($3), read($4)); +} +else { + A = cbind(read($1), read($2), read($3), read($4)); +} + +b = 13; +c = as.scalar(rand(rows=1, cols=1, seed=456)); + +A[row1, col1] = b; +A[row2, col2] = c; + +write(A, $11); + From 3e2707695fd5ac0e45c74294b26ede8b0e6bcc4b Mon Sep 17 00:00:00 2001 From: ywcb00 Date: Sun, 22 May 2022 11:03:47 +0200 Subject: [PATCH 2/2] chore(FederatedLeftIndexTest): remove the prints of the dml scripts --- .../functions/federated/FederatedLeftIndexFrameFullTest.dml | 2 -- .../federated/FederatedLeftIndexFrameFullTestReference.dml | 2 -- .../scripts/functions/federated/FederatedLeftIndexFullTest.dml | 2 -- .../functions/federated/FederatedLeftIndexFullTestReference.dml | 2 -- 4 files changed, 8 deletions(-) diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml index ca9fe81f40a..a10bb72f77c 100644 --- a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml +++ b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml @@ -41,5 +41,3 @@ A = as.frame(A) A[from:to, from2:to2] = B; write(A, $out_S); - -print(toString(A)) diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml index 4b5a85234cb..6589134273c 100644 --- a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml +++ b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml @@ -37,5 +37,3 @@ A = as.frame(A) A[from:to, from2:to2] = B; write(A, $11); - -print(toString(A)) diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml b/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml index a201f7bfe3c..c048cb77c25 100644 --- a/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml +++ b/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml @@ -38,5 +38,3 @@ B = read($in_Y) A[from:to, from2:to2] = B; write(A, $out_S); - -print(toString(A)) diff --git a/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml b/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml index 2cc29f7ca80..ecd123254ef 100644 --- a/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml +++ b/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml @@ -34,5 +34,3 @@ B = read($5) A[from:to, from2:to2] = B; write(A, $11); - -print(toString(A))