Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
Expand Down Expand Up @@ -59,58 +60,60 @@ public void processInstruction(ExecutionContext ec)
MatrixObject U = ec.getMatrixObject(input2);
MatrixObject V = ec.getMatrixObject(input3);
ScalarObject eps = null;

if(qop.hasFourInputs()) {
eps = (_input4.getDataType() == DataType.SCALAR) ?
ec.getScalarInput(_input4) :
new DoubleObject(ec.getMatrixInput(_input4.getName()).quickGetValue(0, 0));
}

if(!(X.isFederated() && !U.isFederated() && !V.isFederated()))
throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = ("
+X.isFederated()+", "+U.isFederated()+", "+V.isFederated()+")");

FederationMap fedMap = X.getFedMapping();
FederatedRequest[] fr1 = fedMap.broadcastSliced(U, false);
FederatedRequest fr2 = fedMap.broadcast(V);
FederatedRequest fr3 = null;
FederatedRequest frComp = null;
if(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated()) {
FederationMap fedMap = X.getFedMapping();
FederatedRequest[] fr1 = fedMap.broadcastSliced(U, false);
FederatedRequest fr2 = fedMap.broadcast(V);
FederatedRequest fr3 = null;
FederatedRequest frComp = null;

// broadcast the scalar epsilon if there are four inputs
if(eps != null) {
fr3 = fedMap.broadcast(eps);
// change the is_literal flag from true to false because when broadcasted it is no literal anymore
instString = instString.replace("true", "false");
frComp = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2, input3, _input4},
new long[]{fedMap.getID(), fr1[0].getID(), fr2.getID(), fr3.getID()});
}
else {
frComp = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2, input3},
new long[]{fedMap.getID(), fr1[0].getID(), fr2.getID()});
}
FederatedRequest frGet = new FederatedRequest(RequestType.GET_VAR, frComp.getID());
FederatedRequest frClean1 = fedMap.cleanup(getTID(), frComp.getID());
FederatedRequest frClean2 = fedMap.cleanup(getTID(), fr1[0].getID());
FederatedRequest frClean3 = fedMap.cleanup(getTID(), fr2.getID());
// broadcast the scalar epsilon if there are four inputs
if(eps != null) {
fr3 = fedMap.broadcast(eps);
// change the is_literal flag from true to false because when broadcasted it is no literal anymore
instString = instString.replace("true", "false");
frComp = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2, input3, _input4},
new long[]{fedMap.getID(), fr1[0].getID(), fr2.getID(), fr3.getID()});
}
else {
frComp = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2, input3},
new long[]{fedMap.getID(), fr1[0].getID(), fr2.getID()});
}

FederatedRequest frGet = new FederatedRequest(RequestType.GET_VAR, frComp.getID());
FederatedRequest frClean1 = fedMap.cleanup(getTID(), frComp.getID());
FederatedRequest frClean2 = fedMap.cleanup(getTID(), fr1[0].getID());
FederatedRequest frClean3 = fedMap.cleanup(getTID(), fr2.getID());

Future<FederatedResponse>[] response;
if(fr3 != null) {
FederatedRequest frClean4 = fedMap.cleanup(getTID(), fr3.getID());
// execute federated instructions
response = fedMap.execute(getTID(), true, fr1, fr2, fr3,
frComp, frGet, frClean1, frClean2, frClean3, frClean4);
Future<FederatedResponse>[] response;
if(fr3 != null) {
FederatedRequest frClean4 = fedMap.cleanup(getTID(), fr3.getID());
// execute federated instructions
response = fedMap.execute(getTID(), true, fr1, fr2, fr3,
frComp, frGet, frClean1, frClean2, frClean3, frClean4);
}
else {
// execute federated instructions
response = fedMap.execute(getTID(), true, fr1, fr2,
frComp, frGet, frClean1, frClean2, frClean3);
}

//aggregate partial results from federated responses
AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, response));
}
else {
// execute federated instructions
response = fedMap.execute(getTID(), true, fr1, fr2,
frComp, frGet, frClean1, frClean2, frClean3);
throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = ("
+X.isFederated()+", "+U.isFederated()+", "+V.isFederated()+")");
}

//aggregate partial results from federated responses
AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, response));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,79 +86,82 @@ public void processInstruction(ExecutionContext ec)
}
}

if(!(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated()))
throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = ("
+X.isFederated()+", "+U.isFederated()+", "+V.isFederated() + ")");

FederationMap fedMap = X.getFedMapping();
FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
FederatedRequest frInit2 = fedMap.broadcast(V);
if(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated()) {
FederationMap fedMap = X.getFedMapping();
FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
FederatedRequest frInit2 = fedMap.broadcast(V);

FederatedRequest frInit3 = null;
FederatedRequest frInit3Arr[] = null;
FederatedRequest frCompute1 = null;
// broadcast scalar epsilon if there are four inputs
if(eps != null) {
frInit3 = fedMap.broadcast(eps);
// change the is_literal flag from true to false because when broadcasted it is no literal anymore
instString = instString.replace("true", "false");
frCompute1 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2, input3, _input4},
new long[]{fedMap.getID(), frInit1[0].getID(), frInit2.getID(), frInit3.getID()});
}
else if(MX != null) {
frInit3Arr = fedMap.broadcastSliced(MX, false);
frCompute1 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2, input3, _input4},
new long[]{fedMap.getID(), frInit1[0].getID(), frInit2.getID(), frInit3Arr[0].getID()});
}
else {
frCompute1 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2, input3},
new long[]{fedMap.getID(), frInit1[0].getID(), frInit2.getID()});
}
FederatedRequest frInit3 = null;
FederatedRequest frInit3Arr[] = null;
FederatedRequest frCompute1 = null;
// broadcast scalar epsilon if there are four inputs
if(eps != null) {
frInit3 = fedMap.broadcast(eps);
// change the is_literal flag from true to false because when broadcasted it is no literal anymore
instString = instString.replace("true", "false");
frCompute1 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2, input3, _input4},
new long[]{fedMap.getID(), frInit1[0].getID(), frInit2.getID(), frInit3.getID()});
}
else if(MX != null) {
frInit3Arr = fedMap.broadcastSliced(MX, false);
frCompute1 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2, input3, _input4},
new long[]{fedMap.getID(), frInit1[0].getID(), frInit2.getID(), frInit3Arr[0].getID()});
}
else {
frCompute1 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2, input3},
new long[]{fedMap.getID(), frInit1[0].getID(), frInit2.getID()});
}

// get partial results from federated workers
FederatedRequest frGet1 = new FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
// get partial results from federated workers
FederatedRequest frGet1 = new FederatedRequest(RequestType.GET_VAR, frCompute1.getID());

FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), frCompute1.getID());
FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), frInit1[0].getID());
FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), frInit2.getID());
FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), frCompute1.getID());
FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), frInit1[0].getID());
FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), frInit2.getID());

// execute federated instructions
Future<FederatedResponse>[] response;
if(frInit3 != null) {
FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), frInit3.getID());
response = fedMap.execute(getTID(), true,
frInit1, frInit2, frInit3,
frCompute1, frGet1,
frCleanup1, frCleanup2, frCleanup3, frCleanup4);
}
else if(frInit3Arr != null) {
FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), frInit3Arr[0].getID());
fedMap.execute(getTID(), true, frInit1, frInit2);
response = fedMap.execute(getTID(), true, frInit3Arr,
frCompute1, frGet1,
frCleanup1, frCleanup2, frCleanup3, frCleanup4);
}
else {
response = fedMap.execute(getTID(), true,
frInit1, frInit2,
frCompute1, frGet1,
frCleanup1, frCleanup2, frCleanup3);
}
// execute federated instructions
Future<FederatedResponse>[] response;
if(frInit3 != null) {
FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), frInit3.getID());
response = fedMap.execute(getTID(), true,
frInit1, frInit2, frInit3,
frCompute1, frGet1,
frCleanup1, frCleanup2, frCleanup3, frCleanup4);
}
else if(frInit3Arr != null) {
FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), frInit3Arr[0].getID());
fedMap.execute(getTID(), true, frInit1, frInit2);
response = fedMap.execute(getTID(), true, frInit3Arr,
frCompute1, frGet1,
frCleanup1, frCleanup2, frCleanup3, frCleanup4);
}
else {
response = fedMap.execute(getTID(), true,
frInit1, frInit2,
frCompute1, frGet1,
frCleanup1, frCleanup2, frCleanup3);
}

if(wdivmm_type.isLeft()) {
// aggregate partial results from federated responses
AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
}
else if(wdivmm_type.isRight() || wdivmm_type.isBasic()) {
// bind partial results from federated responses
ec.setMatrixOutput(output.getName(), FederationUtils.bind(response, false));
if(wdivmm_type.isLeft()) {
// aggregate partial results from federated responses
AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
}
else if(wdivmm_type.isRight() || wdivmm_type.isBasic()) {
// bind partial results from federated responses
ec.setMatrixOutput(output.getName(), FederationUtils.bind(response, false));
}
else {
throw new DMLRuntimeException("Federated WDivMM only supported for BASIC, LEFT or RIGHT variants.");
}
}
else {
throw new DMLRuntimeException("Federated WDivMM only supported for BASIC, LEFT or RIGHT variants.");
throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = ("
+X.isFederated()+", "+U.isFederated()+", "+V.isFederated() + ")");
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
Expand Down Expand Up @@ -69,51 +70,53 @@ public void processInstruction(ExecutionContext ec) {
W = ec.getMatrixObject(_input4);
}

if(!(X.isFederated() && !U.isFederated() && !V.isFederated() && (W == null || !W.isFederated())))
throw new DMLRuntimeException("Unsupported federated inputs (X, U, V, W) = (" + X.isFederated() + ", "
+ U.isFederated() + ", " + V.isFederated() + ", " + (W != null ? W.isFederated() : "none") + ")");
if(X.isFederated(FType.ROW) && !U.isFederated() && !V.isFederated() && (W == null || !W.isFederated())) {
FederationMap fedMap = X.getFedMapping();
FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
FederatedRequest frInit2 = fedMap.broadcast(V);

FederationMap fedMap = X.getFedMapping();
FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
FederatedRequest frInit2 = fedMap.broadcast(V);
FederatedRequest[] frInit3 = null;
FederatedRequest frCompute1 = null;
if(W != null) {
frInit3 = fedMap.broadcastSliced(W, false);
frCompute1 = FederationUtils.callInstruction(instString,
output,
new CPOperand[] {input1, input2, input3, _input4},
new long[] {fedMap.getID(), frInit1[0].getID(), frInit2.getID(), frInit3[0].getID()});
}
else {
frCompute1 = FederationUtils.callInstruction(instString,
output,
new CPOperand[] {input1, input2, input3},
new long[] {fedMap.getID(), frInit1[0].getID(), frInit2.getID()});
}

FederatedRequest[] frInit3 = null;
FederatedRequest frCompute1 = null;
if(W != null) {
frInit3 = fedMap.broadcastSliced(W, false);
frCompute1 = FederationUtils.callInstruction(instString,
output,
new CPOperand[] {input1, input2, input3, _input4},
new long[] {fedMap.getID(), frInit1[0].getID(), frInit2.getID(), frInit3[0].getID()});
}
else {
frCompute1 = FederationUtils.callInstruction(instString,
output,
new CPOperand[] {input1, input2, input3},
new long[] {fedMap.getID(), frInit1[0].getID(), frInit2.getID()});
}
FederatedRequest frGet1 = new FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), frCompute1.getID());
FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), frInit1[0].getID());
FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), frInit2.getID());

FederatedRequest frGet1 = new FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), frCompute1.getID());
FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), frInit1[0].getID());
FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), frInit2.getID());
Future<FederatedResponse>[] response;
if(frInit3 != null) {
FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), frInit3[0].getID());
// execute federated instructions
fedMap.execute(getTID(), true, frInit1, frInit2);
response = fedMap
.execute(getTID(), true, frInit3, frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3, frCleanup4);
}
else {
// execute federated instructions
response = fedMap
.execute(getTID(), true, frInit1, frInit2, frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3);
}

Future<FederatedResponse>[] response;
if(frInit3 != null) {
FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), frInit3[0].getID());
// execute federated instructions
fedMap.execute(getTID(), true, frInit1, frInit2);
response = fedMap
.execute(getTID(), true, frInit3, frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3, frCleanup4);
// aggregate partial results from federated responses
AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, response));
}
else {
// execute federated instructions
response = fedMap
.execute(getTID(), true, frInit1, frInit2, frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3);
throw new DMLRuntimeException("Unsupported federated inputs (X, U, V, W) = (" + X.isFederated() + ", "
+ U.isFederated() + ", " + V.isFederated() + ", " + (W != null ? W.isFederated() : "none") + ")");
}

// aggregate partial results from federated responses
AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, response));
}
}
Loading