Skip to content

Commit

Permalink
[SYSTEMDS-2944] Extended Sparsity Propagation in Federated Backend
Browse files Browse the repository at this point in the history
AMLS project SoSe'23, part 2
Closes #1863.
  • Loading branch information
ywcb00 authored and mboehm7 committed Jul 25, 2023
1 parent d3d3911 commit 25b7351
Show file tree
Hide file tree
Showing 14 changed files with 509 additions and 40 deletions.
Expand Up @@ -658,6 +658,14 @@ public FederationMap filter(IndexRange ixrange) {
if(!overlap)
iter.remove();
}

boolean rowPartitioned = this.getType().isType(FType.ROW)
|| Arrays.stream(ret.getFederatedRanges()).allMatch(range -> range.getSize(1) == ret.getMaxIndexInRange(1));
boolean colPartitioned = this.getType().isType(FType.COL)
|| Arrays.stream(ret.getFederatedRanges()).allMatch(range -> range.getSize(0) == ret.getMaxIndexInRange(0));
if(rowPartitioned && colPartitioned)
ret.setType(FType.FULL);

return ret;
}

Expand Down
Expand Up @@ -566,8 +566,8 @@ public static MatrixBlock bindResponses(List<Pair<FederatedRange, Future<Federat
int[] endDimsInt = range.getEndDimsInt();
MatrixBlock multRes = (MatrixBlock) response.getData()[0];
ret.copy(beginDimsInt[0], endDimsInt[0] - 1, beginDimsInt[1], endDimsInt[1] - 1, multRes, false);
ret.setNonZeros(ret.getNonZeros() + multRes.getNonZeros());
}
ret.setNonZeros(totalNNZ);
return ret;
}

Expand Down
Expand Up @@ -85,10 +85,14 @@ public long getNumColumns() {
return left.getNumColumns();
}

public long getBlocksize() {
public int getBlocksize() {
return left.getBlocksize();
}

public long getNnz() {
return left.getNnz();
}

public DataType getDataType() {
return left.getDataType();
}
Expand Down
Expand Up @@ -120,8 +120,9 @@ else if(mo1.isFederated(FType.ROW)) { // MV + MM
}
if((_fedOut.isForcedFederated() || (!isVector && !_fedOut.isForcedLocal()))
&& !isPartOut) { // not creating federated output in the MV case for reasons of performance
mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
setOutputFedMapping(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
Future<FederatedResponse>[] ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
setOutputFedMapping(mo1.getFedMapping(), mo1, mo2,
FederationUtils.sumNonZeros(ffr), fr2.getID(), ec);
}
else {
boolean isDoubleBroadcast = (mo1.isFederated(FType.BROADCAST) && mo2.isFederated(FType.BROADCAST));
Expand Down Expand Up @@ -183,7 +184,8 @@ private void writeInfoLog(MatrixLineagePair mo1, MatrixLineagePair mo2){
private void setPartialOutput(FederationMap federationMap, MatrixLineagePair mo1, MatrixLineagePair mo2,
long outputID, ExecutionContext ec){
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
out.getDataCharacteristics().setDimension(mo1.getNumRows(), mo2.getNumColumns())
.setBlocksize(mo1.getBlocksize());
FederationMap outputFedMap = federationMap
.copyWithNewIDAndRange(mo1.getNumRows(), mo2.getNumColumns(), outputID);
out.setFedMapping(outputFedMap);
Expand All @@ -194,13 +196,16 @@ private void setPartialOutput(FederationMap federationMap, MatrixLineagePair mo1
* @param federationMap federation map to be set in output
* @param mo1 matrix object with number of rows used to set the number of rows of the output
* @param mo2 matrix object with number of columns used to set the number of columns of the output
* @param nnz the number of non-zeros of the output
* @param outputID ID of the output
* @param ec execution context
*/
private void setOutputFedMapping(FederationMap federationMap, MatrixLineagePair mo1, MatrixLineagePair mo2,
long outputID, ExecutionContext ec){
private void setOutputFedMapping(FederationMap federationMap, MatrixLineagePair mo1,
MatrixLineagePair mo2, long nnz, long outputID, ExecutionContext ec){
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
out.getDataCharacteristics()
.setDimension(mo1.getNumRows(), mo2.getNumColumns())
.setBlocksize(mo1.getBlocksize()).setNonZeros(nnz);
out.setFedMapping(federationMap.copyWithNewID(outputID, mo2.getNumColumns()));
}

Expand Down
Expand Up @@ -124,19 +124,21 @@ else if(!_cbind && mo1.getNumColumns() != mo2.getNumColumns()) {
new CPOperand[]{input1, input2},
new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});

Future<FederatedResponse>[] ffr = null;
if(isSpark) {
FederatedRequest frTmp = new FederatedRequest(RequestType.PUT_VAR,
fr2.getID(), new MatrixCharacteristics(-1, -1), mo1.getDataType());
mo1.getFedMapping().execute(getTID(), true, frTmp, fr2);
ffr = mo1.getFedMapping().execute(getTID(), true, frTmp, fr2);
}
else {
mo1.getFedMapping().execute(getTID(), true, fr2);
ffr = mo1.getFedMapping().execute(getTID(), true, fr2);
}

int dim = (_cbind ? 1 : 0);
FederationMap newFedMap = mo1.getFedMapping().copyWithNewID(fr2.getID())
.modifyFedRanges(mo1.getDim(dim) + mo2.getDim(dim), dim);
out.setFedMapping(newFedMap);
out.getDataCharacteristics().setNonZeros(FederationUtils.sumNonZeros(ffr));
}
// federated/federated misaligned, federated/local, local/federated bind
else if( ((mo1.isFederated(FType.ROW) || mo2.isFederated(FType.ROW)) && !_cbind)
Expand All @@ -152,6 +154,8 @@ else if( ((mo1.isFederated(FType.ROW) || mo2.isFederated(FType.ROW)) && !_cbind)

out.setFedMapping(fed1.identCopy(getTID(), id)
.bind(roff, coff, fed2.identCopy(getTID(), id)));
if(mo1.getNnz() != -1 && mo2.getNnz() != -1)
out.getDataCharacteristics().setNonZeros(mo1.getNnz() + mo2.getNnz());
}
// federated/local, local/federated bind
else if( ((mo1.isFederated(FType.ROW) || mo2.isFederated(FType.ROW)) && _cbind)
Expand Down
Expand Up @@ -19,12 +19,15 @@

package org.apache.sysds.runtime.instructions.fed;

import java.util.concurrent.Future;

import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
Expand Down Expand Up @@ -71,28 +74,29 @@ public void processInstruction(ExecutionContext ec) {

//execute federated operation on mo1 or mo2
FederatedRequest fr2 = null;
Future<FederatedResponse>[] ffr = null;
if( mo2.isFederatedExcept(FType.BROADCAST) ) {
if(mo1.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(),
mo1.isFederated(FType.ROW) ? AlignType.ROW : AlignType.COL)) {
mo1.isFederated(FType.ROW) ? AlignType.ROW : AlignType.COL)) {
fr2 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2},
new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()}, true);
mo2.getFedMapping().execute(getTID(), true, fr2);
ffr = mo2.getFedMapping().execute(getTID(), true, fr2);
}
else {
FederatedRequest[] fr1 = mo2.getFedMapping().broadcastSliced(mo1, false);
fr2 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2},
new long[]{fr1[0].getID(), mo2.getFedMapping().getID()}, true);
mo2.getFedMapping().execute(getTID(), true, fr1, fr2);
ffr = mo2.getFedMapping().execute(getTID(), true, fr1, fr2);
}
fedMo = mo2.getMO(); // for setting the output federated mapping afterwards
}
else if ( mo2.isFederated(FType.BROADCAST) && !mo1.isFederated() ){
FederatedRequest fr1 = mo2.getFedMapping().broadcast(mo1);
fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
new long[]{mo2.getFedMapping().getID(), fr1.getID()}, true);
mo2.getFedMapping().execute(getTID(), true, fr1, fr2);
ffr = mo2.getFedMapping().execute(getTID(), true, fr1, fr2);
fedMo = mo2.getMO();
}
else { // matrix-matrix binary operations -> lhs fed input -> fed output
Expand All @@ -103,7 +107,7 @@ else if ( mo2.isFederated(FType.BROADCAST) && !mo1.isFederated() ){
FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
new long[]{mo1.getFedMapping().getID(), fr1.getID()}, true);
mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
}
else {
throw new DMLRuntimeException("Matrix-matrix binary operations with a full partitioned federated input with multiple partitions are not supported yet.");
Expand All @@ -115,33 +119,34 @@ else if((mo1.isFederated(FType.ROW) && mo2.getNumRows() == 1) //matrix-rowV
FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
new long[]{mo1.getFedMapping().getID(), fr1.getID()}, true);
mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
}
else if((mo1.isFederated(FType.ROW) ^ mo1.isFederated(FType.COL))
|| (mo1.isFederated(FType.FULL) && mo1.getFedMapping().getSize() == 1)) {
// row partitioned MM or col partitioned MM
FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
new long[]{mo1.getFedMapping().getID(), fr1[0].getID()}, true);
mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
}
else if ( mo1.isFederated(FType.PART) && !mo2.isFederated() ){
FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
new long[]{mo1.getFedMapping().getID(), fr1.getID()}, true);
mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
}
else {
throw new DMLRuntimeException("Matrix-matrix binary operations are only supported with a row partitioned or column partitioned federated input yet.");
}
fedMo = mo1.getMO(); // for setting the output federated mapping afterwards
}

long nnz = FederationUtils.sumNonZeros(ffr);
if ( mo1.isFederated(FType.PART) && !mo2.isFederated() )
setOutputFedMappingPart(mo1.getMO(), mo2.getMO(), fr2.getID(), ec);
setOutputFedMappingPart(mo1.getMO(), mo2.getMO(), nnz, fr2.getID(), ec);
else if ( fedMo.isFederated() )
setOutputFedMapping(fedMo, Math.max(mo1.getNumRows(), mo2.getNumRows()),
Math.max(mo1.getNumColumns(), mo2.getNumColumns()), fr2.getID(), ec);
Math.max(mo1.getNumColumns(), mo2.getNumColumns()), nnz, fr2.getID(), ec);
else throw new DMLRuntimeException("Input is not federated, so the output FedMapping cannot be set!");
}

Expand All @@ -152,9 +157,11 @@ else if ( fedMo.isFederated() )
* @param outputID ID of output
* @param ec execution context
*/
private void setOutputFedMappingPart(MatrixObject mo1, MatrixObject mo2, long outputID, ExecutionContext ec){
private void setOutputFedMappingPart(MatrixObject mo1, MatrixObject mo2, long nnz,
long outputID, ExecutionContext ec){
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), mo1.getBlocksize());
out.getDataCharacteristics().setDimension(mo1.getNumRows(), mo2.getNumColumns())
.setBlocksize(mo1.getBlocksize()).setNonZeros(nnz);
FederationMap outputFedMap = mo1.getFedMapping()
.copyWithNewIDAndRange(mo1.getNumRows(), mo2.getNumColumns(), outputID);
out.setFedMapping(outputFedMap);
Expand All @@ -167,15 +174,15 @@ private void setOutputFedMappingPart(MatrixObject mo1, MatrixObject mo2, long ou
* @param ec execution context
*/
private void setOutputFedMapping(MatrixObject moFederated, long rowNum, long colNum,
long outputFedmappingID, ExecutionContext ec){
long nnz, long outputFedmappingID, ExecutionContext ec){
MatrixObject out = ec.getMatrixObject(output);
FederationMap fedMap = moFederated.getFedMapping().copyWithNewID(outputFedmappingID);
if(moFederated.getNumRows() != rowNum || moFederated.getNumColumns() != colNum) {
int dim = moFederated.isFederated(FType.COL) ? 0 : 1;
fedMap.modifyFedRanges((dim == 0) ? rowNum : colNum, dim);
}
out.getDataCharacteristics().set(moFederated.getDataCharacteristics())
.setRows(rowNum).setCols(colNum);
.setDimension(rowNum, colNum).setNonZeros(nnz);
out.setFedMapping(fedMap);
}
}
Expand Up @@ -19,9 +19,12 @@

package org.apache.sysds.runtime.instructions.fed;

import java.util.concurrent.Future;

import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.BinaryMatrixScalarCPInstruction;
Expand Down Expand Up @@ -62,16 +65,19 @@ public void processInstruction(ExecutionContext ec) {
new long[]{mo.getFedMapping().getID(), (fr1 != null)?fr1.getID():-1}, true);

//execute federated matrix-scalar operation and cleanups
Future<FederatedResponse>[] ffr = null;
if( fr1 != null ) {
FederatedRequest fr3 = mo.getFedMapping().cleanup(getTID(), fr1.getID());
mo.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
ffr = mo.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
}
else {
ffr = mo.getFedMapping().execute(getTID(), true, fr2);
}
else
mo.getFedMapping().execute(getTID(), true, fr2);

//derive new fed mapping for output
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo.getDataCharacteristics());
out.getDataCharacteristics().set(mo.getDataCharacteristics())
.setNonZeros(FederationUtils.sumNonZeros(ffr));
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr2.getID()));
}
}
Expand Up @@ -210,14 +210,15 @@ else if(reversed && !reversedWeights)

if(fedOutput) {
if(fr2 != null) // broadcasted mo3
fedMap.execute(getTID(), true, fr1, fr2, fr3);
ffr = fedMap.execute(getTID(), true, fr1, fr2, fr3);
else
fedMap.execute(getTID(), true, fr1, fr3);
ffr = fedMap.execute(getTID(), true, fr1, fr3);

MatrixObject out = ec.getMatrixObject(output);
FederationMap newFedMap = modifyFedRanges(fedMap.copyWithNewID(fr3.getID()),
staticDim, dims2, reversed);
setFedOutput(mo1.getMO(), out, newFedMap, staticDim, dims2, reversed);
long nnz = FederationUtils.sumNonZeros(ffr);
setFedOutput(mo1.getMO(), out, newFedMap, staticDim, dims2, nnz, reversed);
} else {
fr4 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr3.getID());
fr5 = fedMap.cleanup(getTID(), fr3.getID());
Expand Down Expand Up @@ -280,16 +281,18 @@ private static boolean isFedOutput(FederationMap fedMap, MatrixObject mo2) {
* @param fedMap the federation map of the federated matrix input mo1
* @param staticDim static non-partitioned dimension of the output
* @param dims2 dimensions of the partial outputs along the federated partitioning
* @param nnz the number of non-zeros of the resulting federated output
* @param reversed boolean indicating if inputs mo1 and mo2 are reversed
*/
private static void setFedOutput(MatrixObject mo1, MatrixObject out, FederationMap fedMap,
long staticDim, Long[] dims2, boolean reversed) {
long staticDim, Long[] dims2, long nnz, boolean reversed) {
// get the final output dimensions
final long d1 = (reversed ? Collections.max(Arrays.asList(dims2)) : staticDim);
final long d2 = (reversed ? staticDim : Collections.max(Arrays.asList(dims2)));

// set output
out.getDataCharacteristics().set(d1, d2, mo1.getBlocksize(), mo1.getNnz());
out.getDataCharacteristics().setDimension(d1, d2)
.setBlocksize(mo1.getBlocksize()).setNonZeros(nnz);
out.setFedMapping(fedMap);

long varID = FederationUtils.getNextFedDataID();
Expand Down
Expand Up @@ -130,12 +130,14 @@ public void processInstruction(ExecutionContext ec) {

FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, id, new CPOperand[] {input1},
new long[] {mo1.getFedMapping().getID()}, isSpark ? Types.ExecType.SPARK : Types.ExecType.CP, true);
mo1.getFedMapping().execute(getTID(), true, fr, fr1);
Future<FederatedResponse>[] ffr = mo1.getFedMapping().execute(getTID(), true, fr, fr1);

if (_fedOut != null && !_fedOut.isForcedLocal()){
//drive output federated mapping
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), mo1.getBlocksize(), mo1.getNnz());
long nnz = (mo1.getNnz() != -1) ? mo1.getNnz() : FederationUtils.sumNonZeros(ffr);
out.getDataCharacteristics().setDimension(mo1.getNumColumns(), mo1.getNumRows())
.setBlocksize(mo1.getBlocksize()).setNonZeros(nnz);
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose());
} else {
FederatedRequest getRequest = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
Expand All @@ -153,14 +155,16 @@ else if(instOpcode.equalsIgnoreCase("rev")) {
//execute transpose at federated site
FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, id, new CPOperand[] {input1},
new long[] {mo1.getFedMapping().getID()}, isSpark ? Types.ExecType.SPARK : Types.ExecType.CP, true);
mo1.getFedMapping().execute(getTID(), true, fr, fr1);
Future<FederatedResponse>[] ffr = mo1.getFedMapping().execute(getTID(), true, fr, fr1);

if(mo1.isFederated(FType.ROW))
mo1.getFedMapping().reverseFedMap();

//derive output federated mapping
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getNumRows(), mo1.getNumColumns(), mo1.getBlocksize(), mo1.getNnz());
long nnz = (mo1.getNnz() != -1) ? mo1.getNnz() : FederationUtils.sumNonZeros(ffr);
out.getDataCharacteristics().setDimension(mo1.getNumRows(), mo1.getNumColumns())
.setBlocksize(mo1.getBlocksize()).setNonZeros(nnz);
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));

optionalForceLocal(out);
Expand Down

0 comments on commit 25b7351

Please sign in to comment.