diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java index d1fa57829ad..4db4f2b8b2d 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java @@ -658,6 +658,14 @@ public FederationMap filter(IndexRange ixrange) { if(!overlap) iter.remove(); } + + boolean rowPartitioned = this.getType().isType(FType.ROW) + || Arrays.stream(ret.getFederatedRanges()).allMatch(range -> range.getSize(1) == ret.getMaxIndexInRange(1)); + boolean colPartitioned = this.getType().isType(FType.COL) + || Arrays.stream(ret.getFederatedRanges()).allMatch(range -> range.getSize(0) == ret.getMaxIndexInRange(0)); + if(rowPartitioned && colPartitioned) + ret.setType(FType.FULL); + return ret; } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java index cabf4887a6b..8afaf9db01f 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java @@ -566,8 +566,8 @@ public static MatrixBlock bindResponses(List[] ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2); + setOutputFedMapping(mo1.getFedMapping(), mo1, mo2, + FederationUtils.sumNonZeros(ffr), fr2.getID(), ec); } else { boolean isDoubleBroadcast = (mo1.isFederated(FType.BROADCAST) && mo2.isFederated(FType.BROADCAST)); @@ -183,7 +184,8 @@ private void writeInfoLog(MatrixLineagePair mo1, MatrixLineagePair mo2){ private void setPartialOutput(FederationMap federationMap, MatrixLineagePair mo1, MatrixLineagePair mo2, long outputID, ExecutionContext ec){ MatrixObject out = ec.getMatrixObject(output); - out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize()); + out.getDataCharacteristics().setDimension(mo1.getNumRows(), mo2.getNumColumns()) + .setBlocksize(mo1.getBlocksize()); FederationMap outputFedMap = federationMap .copyWithNewIDAndRange(mo1.getNumRows(), mo2.getNumColumns(), outputID); out.setFedMapping(outputFedMap); @@ -194,13 +196,16 @@ private void setPartialOutput(FederationMap federationMap, MatrixLineagePair mo1 * @param federationMap federation map to be set in output * @param mo1 matrix object with number of rows used to set the number of rows of the output * @param mo2 matrix object with number of columns used to set the number of columns of the output + * @param nnz the number of non-zeros of the output * @param outputID ID of the output * @param ec execution context */ - private void setOutputFedMapping(FederationMap federationMap, MatrixLineagePair mo1, MatrixLineagePair mo2, - long outputID, ExecutionContext ec){ + private void setOutputFedMapping(FederationMap federationMap, MatrixLineagePair mo1, + MatrixLineagePair mo2, long nnz, long outputID, ExecutionContext ec){ MatrixObject out = ec.getMatrixObject(output); - out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize()); + out.getDataCharacteristics() + .setDimension(mo1.getNumRows(), mo2.getNumColumns()) + .setBlocksize(mo1.getBlocksize()).setNonZeros(nnz); out.setFedMapping(federationMap.copyWithNewID(outputID, mo2.getNumColumns())); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java index d126cae1dcc..948d7f3443d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java @@ -124,19 +124,21 @@ else if(!_cbind && mo1.getNumColumns() != mo2.getNumColumns()) { new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()}); + Future[] ffr = null; if(isSpark) { FederatedRequest frTmp = new FederatedRequest(RequestType.PUT_VAR, fr2.getID(), new MatrixCharacteristics(-1, -1), mo1.getDataType()); - mo1.getFedMapping().execute(getTID(), true, frTmp, fr2); + ffr = mo1.getFedMapping().execute(getTID(), true, frTmp, fr2); } else { - mo1.getFedMapping().execute(getTID(), true, fr2); + ffr = mo1.getFedMapping().execute(getTID(), true, fr2); } int dim = (_cbind ? 1 : 0); FederationMap newFedMap = mo1.getFedMapping().copyWithNewID(fr2.getID()) .modifyFedRanges(mo1.getDim(dim) + mo2.getDim(dim), dim); out.setFedMapping(newFedMap); + out.getDataCharacteristics().setNonZeros(FederationUtils.sumNonZeros(ffr)); } // federated/federated misaligned, federated/local, local/federated bind else if( ((mo1.isFederated(FType.ROW) || mo2.isFederated(FType.ROW)) && !_cbind) @@ -152,6 +154,8 @@ else if( ((mo1.isFederated(FType.ROW) || mo2.isFederated(FType.ROW)) && !_cbind) out.setFedMapping(fed1.identCopy(getTID(), id) .bind(roff, coff, fed2.identCopy(getTID(), id))); + if(mo1.getNnz() != -1 && mo2.getNnz() != -1) + out.getDataCharacteristics().setNonZeros(mo1.getNnz() + mo2.getNnz()); } // federated/local, local/federated bind else if( ((mo1.isFederated(FType.ROW) || mo2.isFederated(FType.ROW)) && _cbind) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java index e11092785ba..1f49496cf76 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java @@ -19,12 +19,15 @@ package org.apache.sysds.runtime.instructions.fed; +import java.util.concurrent.Future; + import org.apache.sysds.hops.fedplanner.FTypes.AlignType; import org.apache.sysds.hops.fedplanner.FTypes.FType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair; @@ -71,20 +74,21 @@ public void processInstruction(ExecutionContext ec) { //execute federated operation on mo1 or mo2 FederatedRequest fr2 = null; + Future[] ffr = null; if( mo2.isFederatedExcept(FType.BROADCAST) ) { if(mo1.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), - mo1.isFederated(FType.ROW) ? AlignType.ROW : AlignType.COL)) { + mo1.isFederated(FType.ROW) ? AlignType.ROW : AlignType.COL)) { fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()}, true); - mo2.getFedMapping().execute(getTID(), true, fr2); + ffr = mo2.getFedMapping().execute(getTID(), true, fr2); } else { FederatedRequest[] fr1 = mo2.getFedMapping().broadcastSliced(mo1, false); fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2}, new long[]{fr1[0].getID(), mo2.getFedMapping().getID()}, true); - mo2.getFedMapping().execute(getTID(), true, fr1, fr2); + ffr = mo2.getFedMapping().execute(getTID(), true, fr1, fr2); } fedMo = mo2.getMO(); // for setting the output federated mapping afterwards } @@ -92,7 +96,7 @@ else if ( mo2.isFederated(FType.BROADCAST) && !mo1.isFederated() ){ FederatedRequest fr1 = mo2.getFedMapping().broadcast(mo1); fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2}, new long[]{mo2.getFedMapping().getID(), fr1.getID()}, true); - mo2.getFedMapping().execute(getTID(), true, fr1, fr2); + ffr = mo2.getFedMapping().execute(getTID(), true, fr1, fr2); fedMo = mo2.getMO(); } else { // matrix-matrix binary operations -> lhs fed input -> fed output @@ -103,7 +107,7 @@ else if ( mo2.isFederated(FType.BROADCAST) && !mo1.isFederated() ){ FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2); fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()}, true); - mo1.getFedMapping().execute(getTID(), true, fr1, fr2); + ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2); } else { throw new DMLRuntimeException("Matrix-matrix binary operations with a full partitioned federated input with multiple partitions are not supported yet."); @@ -115,7 +119,7 @@ else if((mo1.isFederated(FType.ROW) && mo2.getNumRows() == 1) //matrix-rowV FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2); fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()}, true); - mo1.getFedMapping().execute(getTID(), true, fr1, fr2); + ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2); } else if((mo1.isFederated(FType.ROW) ^ mo1.isFederated(FType.COL)) || (mo1.isFederated(FType.FULL) && mo1.getFedMapping().getSize() == 1)) { @@ -123,13 +127,13 @@ else if((mo1.isFederated(FType.ROW) ^ mo1.isFederated(FType.COL)) FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false); fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1[0].getID()}, true); - mo1.getFedMapping().execute(getTID(), true, fr1, fr2); + ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2); } else if ( mo1.isFederated(FType.PART) && !mo2.isFederated() ){ FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2); fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()}, true); - mo1.getFedMapping().execute(getTID(), true, fr1, fr2); + ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2); } else { throw new DMLRuntimeException("Matrix-matrix binary operations are only supported with a row partitioned or column partitioned federated input yet."); @@ -137,11 +141,12 @@ else if ( mo1.isFederated(FType.PART) && !mo2.isFederated() ){ fedMo = mo1.getMO(); // for setting the output federated mapping afterwards } + long nnz = FederationUtils.sumNonZeros(ffr); if ( mo1.isFederated(FType.PART) && !mo2.isFederated() ) - setOutputFedMappingPart(mo1.getMO(), mo2.getMO(), fr2.getID(), ec); + setOutputFedMappingPart(mo1.getMO(), mo2.getMO(), nnz, fr2.getID(), ec); else if ( fedMo.isFederated() ) setOutputFedMapping(fedMo, Math.max(mo1.getNumRows(), mo2.getNumRows()), - Math.max(mo1.getNumColumns(), mo2.getNumColumns()), fr2.getID(), ec); + Math.max(mo1.getNumColumns(), mo2.getNumColumns()), nnz, fr2.getID(), ec); else throw new DMLRuntimeException("Input is not federated, so the output FedMapping cannot be set!"); } @@ -152,9 +157,11 @@ else if ( fedMo.isFederated() ) * @param outputID ID of output * @param ec execution context */ - private void setOutputFedMappingPart(MatrixObject mo1, MatrixObject mo2, long outputID, ExecutionContext ec){ + private void setOutputFedMappingPart(MatrixObject mo1, MatrixObject mo2, long nnz, + long outputID, ExecutionContext ec){ MatrixObject out = ec.getMatrixObject(output); - out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), mo1.getBlocksize()); + out.getDataCharacteristics().setDimension(mo1.getNumRows(), mo2.getNumColumns()) + .setBlocksize(mo1.getBlocksize()).setNonZeros(nnz); FederationMap outputFedMap = mo1.getFedMapping() .copyWithNewIDAndRange(mo1.getNumRows(), mo2.getNumColumns(), outputID); out.setFedMapping(outputFedMap); @@ -167,7 +174,7 @@ private void setOutputFedMappingPart(MatrixObject mo1, MatrixObject mo2, long ou * @param ec execution context */ private void setOutputFedMapping(MatrixObject moFederated, long rowNum, long colNum, - long outputFedmappingID, ExecutionContext ec){ + long nnz, long outputFedmappingID, ExecutionContext ec){ MatrixObject out = ec.getMatrixObject(output); FederationMap fedMap = moFederated.getFedMapping().copyWithNewID(outputFedmappingID); if(moFederated.getNumRows() != rowNum || moFederated.getNumColumns() != colNum) { @@ -175,7 +182,7 @@ private void setOutputFedMapping(MatrixObject moFederated, long rowNum, long col fedMap.modifyFedRanges((dim == 0) ? rowNum : colNum, dim); } out.getDataCharacteristics().set(moFederated.getDataCharacteristics()) - .setRows(rowNum).setCols(colNum); + .setDimension(rowNum, colNum).setNonZeros(nnz); out.setFedMapping(fedMap); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java index 9330f6dfd29..e0aed7be117 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java @@ -19,9 +19,12 @@ package org.apache.sysds.runtime.instructions.fed; +import java.util.concurrent.Future; + import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.BinaryMatrixScalarCPInstruction; @@ -62,16 +65,19 @@ public void processInstruction(ExecutionContext ec) { new long[]{mo.getFedMapping().getID(), (fr1 != null)?fr1.getID():-1}, true); //execute federated matrix-scalar operation and cleanups + Future[] ffr = null; if( fr1 != null ) { FederatedRequest fr3 = mo.getFedMapping().cleanup(getTID(), fr1.getID()); - mo.getFedMapping().execute(getTID(), true, fr1, fr2, fr3); + ffr = mo.getFedMapping().execute(getTID(), true, fr1, fr2, fr3); + } + else { + ffr = mo.getFedMapping().execute(getTID(), true, fr2); } - else - mo.getFedMapping().execute(getTID(), true, fr2); //derive new fed mapping for output MatrixObject out = ec.getMatrixObject(output); - out.getDataCharacteristics().set(mo.getDataCharacteristics()); + out.getDataCharacteristics().set(mo.getDataCharacteristics()) + .setNonZeros(FederationUtils.sumNonZeros(ffr)); out.setFedMapping(mo.getFedMapping().copyWithNewID(fr2.getID())); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java index 517a01562a2..e953aa543af 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java @@ -210,14 +210,15 @@ else if(reversed && !reversedWeights) if(fedOutput) { if(fr2 != null) // broadcasted mo3 - fedMap.execute(getTID(), true, fr1, fr2, fr3); + ffr = fedMap.execute(getTID(), true, fr1, fr2, fr3); else - fedMap.execute(getTID(), true, fr1, fr3); + ffr = fedMap.execute(getTID(), true, fr1, fr3); MatrixObject out = ec.getMatrixObject(output); FederationMap newFedMap = modifyFedRanges(fedMap.copyWithNewID(fr3.getID()), staticDim, dims2, reversed); - setFedOutput(mo1.getMO(), out, newFedMap, staticDim, dims2, reversed); + long nnz = FederationUtils.sumNonZeros(ffr); + setFedOutput(mo1.getMO(), out, newFedMap, staticDim, dims2, nnz, reversed); } else { fr4 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr3.getID()); fr5 = fedMap.cleanup(getTID(), fr3.getID()); @@ -280,16 +281,18 @@ private static boolean isFedOutput(FederationMap fedMap, MatrixObject mo2) { * @param fedMap the federation map of the federated matrix input mo1 * @param staticDim static non-partitioned dimension of the output * @param dims2 dimensions of the partial outputs along the federated partitioning + * @param nnz the number of non-zeros of the resulting federated output * @param reversed boolean indicating if inputs mo1 and mo2 are reversed */ private static void setFedOutput(MatrixObject mo1, MatrixObject out, FederationMap fedMap, - long staticDim, Long[] dims2, boolean reversed) { + long staticDim, Long[] dims2, long nnz, boolean reversed) { // get the final output dimensions final long d1 = (reversed ? Collections.max(Arrays.asList(dims2)) : staticDim); final long d2 = (reversed ? staticDim : Collections.max(Arrays.asList(dims2))); // set output - out.getDataCharacteristics().set(d1, d2, mo1.getBlocksize(), mo1.getNnz()); + out.getDataCharacteristics().setDimension(d1, d2) + .setBlocksize(mo1.getBlocksize()).setNonZeros(nnz); out.setFedMapping(fedMap); long varID = FederationUtils.getNextFedDataID(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java index b9763aaf2b2..c10ca272593 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java @@ -130,12 +130,14 @@ public void processInstruction(ExecutionContext ec) { FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, id, new CPOperand[] {input1}, new long[] {mo1.getFedMapping().getID()}, isSpark ? Types.ExecType.SPARK : Types.ExecType.CP, true); - mo1.getFedMapping().execute(getTID(), true, fr, fr1); + Future[] ffr = mo1.getFedMapping().execute(getTID(), true, fr, fr1); if (_fedOut != null && !_fedOut.isForcedLocal()){ //drive output federated mapping MatrixObject out = ec.getMatrixObject(output); - out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), mo1.getBlocksize(), mo1.getNnz()); + long nnz = (mo1.getNnz() != -1) ? mo1.getNnz() : FederationUtils.sumNonZeros(ffr); + out.getDataCharacteristics().setDimension(mo1.getNumColumns(), mo1.getNumRows()) + .setBlocksize(mo1.getBlocksize()).setNonZeros(nnz); out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose()); } else { FederatedRequest getRequest = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID()); @@ -153,14 +155,16 @@ else if(instOpcode.equalsIgnoreCase("rev")) { //execute transpose at federated site FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, id, new CPOperand[] {input1}, new long[] {mo1.getFedMapping().getID()}, isSpark ? Types.ExecType.SPARK : Types.ExecType.CP, true); - mo1.getFedMapping().execute(getTID(), true, fr, fr1); + Future[] ffr = mo1.getFedMapping().execute(getTID(), true, fr, fr1); if(mo1.isFederated(FType.ROW)) mo1.getFedMapping().reverseFedMap(); //derive output federated mapping MatrixObject out = ec.getMatrixObject(output); - out.getDataCharacteristics().set(mo1.getNumRows(), mo1.getNumColumns(), mo1.getBlocksize(), mo1.getNnz()); + long nnz = (mo1.getNnz() != -1) ? mo1.getNnz() : FederationUtils.sumNonZeros(ffr); + out.getDataCharacteristics().setDimension(mo1.getNumRows(), mo1.getNumColumns()) + .setBlocksize(mo1.getBlocksize()).setNonZeros(nnz); out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID())); optionalForceLocal(out); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java index 73de9261a31..3e0281af7f0 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java @@ -20,6 +20,7 @@ package org.apache.sysds.runtime.instructions.fed; import java.util.Arrays; +import java.util.concurrent.Future; import java.util.stream.Collectors; import org.apache.commons.lang3.ArrayUtils; @@ -30,6 +31,7 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; import org.apache.sysds.runtime.instructions.InstructionUtils; @@ -116,7 +118,7 @@ public void processInstruction(ExecutionContext ec) { FederatedRequest[] fr1 = FederationUtils.callInstruction(newInstString, output, id, new CPOperand[] {input1}, new long[] {mo1.getFedMapping().getID()}, InstructionUtils.getExecType(instString)); mo1.getFedMapping().execute(getTID(), true, tmp); - mo1.getFedMapping().execute(getTID(), true, fr1, new FederatedRequest[0]); + Future[] ffr = mo1.getFedMapping().execute(getTID(), true, fr1, new FederatedRequest[0]); // set new fed map FederationMap reshapedFedMap = mo1.getFedMapping().copyWithNewID(fr1[0].getID()); @@ -139,7 +141,9 @@ public void processInstruction(ExecutionContext ec) { //derive output federated mapping MatrixObject out = ec.getMatrixObject(output); - out.getDataCharacteristics().set(rows, cols, mo1.getBlocksize(), mo1.getNnz()); + long nnz = (mo1.getNnz() != -1) ? mo1.getNnz() : FederationUtils.sumNonZeros(ffr); + out.getDataCharacteristics().setDimension(rows, cols) + .setBlocksize(mo1.getBlocksize()).setNonZeros(nnz); out.setFedMapping(reshapedFedMap); } else { diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index 096107228df..08d7478f07a 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -883,6 +883,10 @@ public HashMap readRMatrixFromExpectedDir(String fileName) { return TestUtils.readRMatrixFromFS(baseDirectory + EXPECTED_DIR + cacheDir + fileName); } + protected static HashMap readDMLScalarFromExpectedDir(String fileName) { + return TestUtils.readDMLScalarFromHDFS(baseDirectory + EXPECTED_DIR + fileName); + } + protected static HashMap readDMLScalarFromOutputDir(String fileName) { return TestUtils.readDMLScalarFromHDFS(baseDirectory + OUTPUT_DIR + fileName); } @@ -973,9 +977,17 @@ public static void checkDMLMetaDataFile(String fileName, MatrixCharacteristics m Assert.assertEquals(mc.getBlocksize(), rmc.getBlocksize()); } + public static MatrixCharacteristics readDMLMetaDataFileFromExpectedDir(String fileName) { + return readDMLMetaDataFile(fileName, EXPECTED_DIR); + } + public static MatrixCharacteristics readDMLMetaDataFile(String fileName) { + return readDMLMetaDataFile(fileName, OUTPUT_DIR); + } + + public static MatrixCharacteristics readDMLMetaDataFile(String fileName, String outputDir) { try { - MetaDataAll meta = getMetaData(fileName); + MetaDataAll meta = getMetaData(fileName, outputDir); return new MatrixCharacteristics( meta.getDim1(), meta.getDim2(), meta.getBlocksize(), -1); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java new file mode 100644 index 00000000000..ef61cb7904c --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.test.functions.federated.io; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + +import org.apache.sysds.api.DMLOptions; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.parser.DMLTranslator; +import org.apache.sysds.parser.ParserFactory; +import org.apache.sysds.parser.ParserWrapper; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory; +import org.apache.sysds.runtime.controlprogram.Program; +import org.apache.sysds.runtime.controlprogram.ProgramBlock; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class FederatedSparsityPropagationTest extends AutomatedTestBase { + + private final static String TEST_DIR = "functions/federated/io/"; + private final static String TEST_NAME = "FederatedSparsityPropagationTest"; + private final static int NUM_MATRICES = 15; + private final static String TEST_CLASS_DIR = TEST_DIR + FederatedSparsityPropagationTest.class.getSimpleName() + "/"; + + private final static int blocksize = 1024; + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + @Parameterized.Parameter(2) + public boolean rowPartitioned; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + @Parameterized.Parameters + public static Collection data() { + // number of rows or cols has to be >= number of federated workers. + return Arrays.asList(new Object[][] {{100, 130, true}}); + } + + @Test + public void federatedGetSparseSingleNode() { + federatedGet(ExecMode.SINGLE_NODE, 0.01); + } + + @Test + public void federatedGetDenseSingleNode() { + federatedGet(ExecMode.SINGLE_NODE, 0.5); + } + + public void federatedGet(ExecMode execMode, double sparsity) { + ExecMode platform_old = setExecMode(execMode); + String HOME = SCRIPT_DIR + TEST_DIR; + getAndLoadTestConfiguration(TEST_NAME); + + // write input matrices + int fed_rows = rows / 2; + int fed_cols = cols; + + MatrixCharacteristics mc = new MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols); + double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 1, 3, sparsity, 3); + double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 3, sparsity, 7); + writeInputMatrixWithMTD("X1", X1, false, mc); + writeInputMatrixWithMTD("X2", X2, false, mc); + + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); + Thread t2 = startLocalFedWorkerThread(port2); + + getAndLoadTestConfiguration(TEST_NAME); + + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] {"-nvargs", "in_X1=" + input("X1"), "in_X2=" + input("X2"), + "sparsity=" + Double.toString(sparsity), "out_Dir=" + expectedDir()}; + runTest(true, false, null, -1); + + Map refNNZ = getRefNNZ(); + + // Obtain nnz from actual dml script with federated matrix + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + Map argVals = new HashMap<>(); + argVals.put("$in_X1", TestUtils.federatedAddress(port1, input("X1"))); + argVals.put("$in_X2", TestUtils.federatedAddress(port2, input("X2"))); + argVals.put("$rows", Integer.toString(fed_rows)); + argVals.put("$cols", Integer.toString(fed_cols)); + argVals.put("$sparsity", Double.toString(sparsity)); + + Map fedNNZ = null; + try { + fedNNZ = executeFedAndGetNNZ(fullDMLScriptName, argVals); + } catch(IOException ioe) { + DMLScript.errorPrint(ioe); + Assert.fail("IOException when executing federated test script."); + } + + System.out.println("RefNNZ: " + refNNZ); + System.out.println("FedNNZ: " + fedNNZ); + + compareNNZ(refNNZ, fedNNZ); + + TestUtils.shutdownThreads(t1, t2); + + resetExecMode(platform_old); + } + + // NOTE: the body of this function is copied from DMLScript.execute + private Map executeFedAndGetNNZ(String dmlScriptPath, Map argVals) + throws IOException { + String dmlScriptStr = ""; + String DML_FILE_PATH_ANTLR_PARSER = DMLOptions.defaultOptions.filePath; + dmlScriptStr = DMLScript.readDMLScript(true, fullDMLScriptName); + + ParserWrapper parser = ParserFactory.createParser(); + DMLProgram prog = parser.parse(DML_FILE_PATH_ANTLR_PARSER, dmlScriptStr, argVals); + + DMLTranslator dmlt = new DMLTranslator(prog); + dmlt.liveVariableAnalysis(prog); + dmlt.validateParseTree(prog); + dmlt.constructHops(prog); + dmlt.constructLops(prog); + Program rtprog = dmlt.getRuntimeProgram(prog, ConfigurationManager.getDMLConfig()); + ArrayList progBlocks = rtprog.getProgramBlocks(); + + ExecutionContext ec = ExecutionContextFactory.createContext(rtprog); + + // execute the first program block and obtain the nnz from the federation maps + progBlocks.get(0).execute(ec); + Map fedNNZ = getFedNNZ(ec); + // no need to execute the remaining program blocks + + return fedNNZ; + } + + private Map getRefNNZ() { + Map refNNZ = new HashMap<>(); + for(int counter = 0; counter < NUM_MATRICES; counter++) { + String varName = "NNZ_M" + Integer.toString(counter+1); + refNNZ.put(varName, readDMLScalarFromExpectedDir(varName) + .entrySet().stream().findAny().get().getValue().longValue()); + } + return refNNZ; + } + + private Map getFedNNZ(ExecutionContext ec) { + Map fedNNZ = new HashMap<>(); + for(String varName : ec.getVariables().keySet()) { + if(ec.isMatrixObject(varName)) { + MatrixObject mo = ec.getMatrixObject(varName); + fedNNZ.put("NNZ_" + varName, mo.getNnz()); + } + } + return fedNNZ; + } + + private void compareNNZ(Map ref, Map fed) { + for(Map.Entry re : ref.entrySet()) { + Assert.assertEquals("NNZs of " + re.getKey() + " differ.", re.getValue(), fed.get(re.getKey())); + } + } +} diff --git a/src/test/scripts/functions/federated/io/FederatedSparsityPropagationTest.dml b/src/test/scripts/functions/federated/io/FederatedSparsityPropagationTest.dml new file mode 100644 index 00000000000..515f5d564b7 --- /dev/null +++ b/src/test/scripts/functions/federated/io/FederatedSparsityPropagationTest.dml @@ -0,0 +1,91 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# create federated matrix +X = federated(addresses=list($in_X1, $in_X2), +ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, $cols))); +# construct additional matrices: +# - X2 (federated, aligned with X) +# - Y (federated, transpose aligned with X) +# - L (local, same dimension as X) +X2 = X^2; +L = rand(rows=nrow(X), cols=ncol(X), sparsity=$sparsity, seed=13); +Y = t(X + (0.1 * L)); + +# right indexing +M1 = X[ , 1:10]; + +# matrix multiplication +M2 = X %*% Y; + +# elementwise minus +M3 = X - X2; + +# elementwise multiplication +M4 = X[1:20, ] * L[1:20, ]; + +# elementwise division +M5 = X / L; + +# matrix vector addition +M6 = X + L[1, ]; + +# column bind, federated row partitioned / local +M7 = cbind(X, L); + +# row bind, federated row partitioned / local +M8 = rbind(M3, t(L) %*% L); + +# column bind, federated row partitioned / federated row partitioned +M9 = cbind(X2, X); + +# binary matrix scalar (literal) multiplication +M10 = X * 0.2; + +# binary matrix scalar subtraction +M11 = X - max(X); + +# ctable +TMP12 = floor(matrix(seq(0, (nrow(X)*ncol(X))-1), nrow(X), ncol(X)) / (nrow(X)*ncol(X)/4)) + 1; +M12 = table(floor(abs(X) + 1), TMP12); + +# transpose +M13 = t(X); + +# rev +M14 = rev(X); + +# reshape +M15 = matrix(X, rows=1, cols=(ncol(X)*nrow(X)), byrow=FALSE); + +# FIXME: wrong nnz from cumsum instruction +# # cumulative sum +# M16 = cumsum(X); + +# FIXME: wrong nnz from ternary instruction +# # ternary ifelse +# M17 = ifelse(X > X2, X, X2); + +while(FALSE) {} +Z = sum(M1) + sum(M2) + sum(M3) + sum(M4) + sum(M5) + sum(M6) + sum(M7) + sum(M8) + + sum(M9) + sum(M10) + sum(M11) + sum(M12) + sum(M13) + sum(M14) + sum(M15) + /*+ sum(M16) + sum(M17)*/; +# NOTE: when adding tests, please remember to increment the number of matrices in the java test class diff --git a/src/test/scripts/functions/federated/io/FederatedSparsityPropagationTestReference.dml b/src/test/scripts/functions/federated/io/FederatedSparsityPropagationTestReference.dml new file mode 100644 index 00000000000..aca684dce7b --- /dev/null +++ b/src/test/scripts/functions/federated/io/FederatedSparsityPropagationTestReference.dml @@ -0,0 +1,120 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rbind(read($in_X1), read($in_X2)); + +X2 = X^2; +L = rand(rows=nrow(X), cols=ncol(X), sparsity=$sparsity, seed=13); +Y = t(X + (0.1 * L)); + +# right indexing +M1 = X[ , 1:10]; + +# matrix multiplication +M2 = X %*% Y; + +# elementwise minus +M3 = X - X2; + +# elementwise multiplication +M4 = X[1:20, ] * L[1:20, ]; + +# elementwise division +M5 = X / L; + +# matrix vector addition +M6 = X + L[1, ]; + +# column bind +M7 = cbind(X, L); + +# row bind, federated row partitioned / local +M8 = rbind(M3, t(L) %*% L); + +# column bind, federated row partitioned / federated row partitioned +M9 = cbind(X2, X); + +# binary matrix scalar (literal) multiplication +M10 = X * 0.2; + +# binary matrix scalar subtraction +M11 = X - max(X); + +# ctable +TMP12 = floor(matrix(seq(0, (nrow(X)*ncol(X))-1), nrow(X), ncol(X)) / (nrow(X)*ncol(X)/4)) + 1; +M12 = table(floor(abs(X) + 1), TMP12); + +# transpose +M13 = t(X); + +# rev +M14 = rev(X); + +# reshape +M15 = matrix(X, rows=1, cols=(ncol(X)*nrow(X)), byrow=FALSE); + +# FIXME: wrong nnz from cumsum instruction +# # cumulative sum +# M16 = cumsum(X); + +# FIXME: wrong nnz from ternary instruction +# # ternary ifelse +# M17 = ifelse(X > X2, X, X2); + +while(FALSE) { } + +NNZ_M1 = sum(M1 != 0); +write(NNZ_M1, $out_Dir + "NNZ_M1"); +NNZ_M2 = sum(M2 != 0); +write(NNZ_M2, $out_Dir + "NNZ_M2"); +NNZ_M3 = sum(M3 != 0); +write(NNZ_M3, $out_Dir + "NNZ_M3"); +NNZ_M4 = sum(M4 != 0); +write(NNZ_M4, $out_Dir + "NNZ_M4"); +NNZ_M5 = sum(M5 != 0); +write(NNZ_M5, $out_Dir + "NNZ_M5"); +NNZ_M6 = sum(M6 != 0); +write(NNZ_M6, $out_Dir + "NNZ_M6"); +# FIXME: the not equal operation returns incorrect results +# DMLTranslator.rewriteHopsDAG rewrites the instruction sum(X != s) and introduces incorrect results +NNZ_M7 = (nrow(M7) * ncol(M7)) - sum(M7 == 0); +write(NNZ_M7, $out_Dir + "NNZ_M7"); +NNZ_M8 = sum(M8 != 0); +write(NNZ_M8, $out_Dir + "NNZ_M8"); +NNZ_M9 = (nrow(M9) * ncol(M9)) - sum(M9 == 0); # FIXME: the not equal operation returns incorrect results +write(NNZ_M9, $out_Dir + "NNZ_M9"); +NNZ_M10 = sum(M10 != 0); +write(NNZ_M10, $out_Dir + "NNZ_M10"); +NNZ_M11 = sum(M11 != 0); +write(NNZ_M11, $out_Dir + "NNZ_M11"); +NNZ_M12 = sum(M12 != 0); +write(NNZ_M12, $out_Dir + "NNZ_M12"); +NNZ_M13 = (nrow(M13) * ncol(M13)) - sum(M13 == 0); # FIXME: the not equal operation returns incorrect results +write(NNZ_M13, $out_Dir + "NNZ_M13"); +NNZ_M14 = (nrow(M14) * ncol(M14)) - sum(M14 == 0); # FIXME: the not equal operation returns incorrect results +write(NNZ_M14, $out_Dir + "NNZ_M14"); +NNZ_M15 = (nrow(M15) * ncol(M15)) - sum(M15 == 0); # FIXME: the not equal operation returns incorrect results +write(NNZ_M15, $out_Dir + "NNZ_M15"); +# NNZ_M16 = sum(M16 != 0); +# write(NNZ_M16, $out_Dir + "NNZ_M16"); +# NNZ_M17 = sum(M17 != 0); +# write(NNZ_M17, $out_Dir + "NNZ_M17"); +# NOTE: when adding tests, please remember to increment the number of matrices in the java test class