Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,11 @@ public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRequ
return ret.toArray(new Future[0]);
}

public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRange[] fedRange1,
FederatedRequest elseFr, FederatedRequest frSlice1, FederatedRequest frSlice2, FederatedRequest fr) {
return execute(tid, wait, fedRange1, elseFr, new FederatedRequest[]{frSlice1}, new FederatedRequest[]{frSlice2}, fr);
}

@SuppressWarnings("unchecked")
public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRange[] fedRange1, FederatedRequest elseFr, FederatedRequest[] frSlices1, FederatedRequest[] frSlices2, FederatedRequest... fr) {
// executes step1[] - step 2 - ... step4 (only first step federated-data-specific)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.LeftIndex;
Expand All @@ -44,6 +45,7 @@
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.IndexRange;
Expand Down Expand Up @@ -204,7 +206,8 @@ private void leftIndexing(ExecutionContext ec)
{
//get input and requested index range
CacheableData<?> in1 = ec.getCacheableData(input1);
CacheableData<?> in2 = ec.getCacheableData(input2);
CacheableData<?> in2 = null; // either in2 or scalar is set
ScalarObject scalar = null;
IndexRange ixrange = getIndexRange(ec);

//check bounds
Expand All @@ -213,11 +216,21 @@ private void leftIndexing(ExecutionContext ec)
throw new DMLRuntimeException("Invalid values for matrix indexing: ["+(ixrange.rowStart+1)+":"+(ixrange.rowEnd+1)+","
+ (ixrange.colStart+1)+":"+(ixrange.colEnd+1)+"] " + "must be within matrix dimensions ["+in1.getNumRows()+","+in1.getNumColumns()+"].");
}
if( (ixrange.rowEnd-ixrange.rowStart+1) != in2.getNumRows() || (ixrange.colEnd-ixrange.colStart+1) != in2.getNumColumns()) {
throw new DMLRuntimeException("Invalid values for matrix indexing: " +
"dimensions of the source matrix ["+in2.getNumRows()+"x" + in2.getNumColumns() + "] " +
"do not match the shape of the matrix specified by indices [" +
(ixrange.rowStart+1) +":" + (ixrange.rowEnd+1) + ", " + (ixrange.colStart+1) + ":" + (ixrange.colEnd+1) + "].");

if(input2.getDataType() == DataType.SCALAR) {
if(!ixrange.isScalar())
throw new DMLRuntimeException("Invalid index range for leftindexing with scalar: " + ixrange.toString() + ".");

scalar = ec.getScalarInput(input2);
}
else {
in2 = ec.getCacheableData(input2);
if( (ixrange.rowEnd-ixrange.rowStart+1) != in2.getNumRows() || (ixrange.colEnd-ixrange.colStart+1) != in2.getNumColumns()) {
throw new DMLRuntimeException("Invalid values for matrix indexing: " +
"dimensions of the source matrix ["+in2.getNumRows()+"x" + in2.getNumColumns() + "] " +
"do not match the shape of the matrix specified by indices [" +
(ixrange.rowStart+1) +":" + (ixrange.rowEnd+1) + ", " + (ixrange.colStart+1) + ":" + (ixrange.colEnd+1) + "].");
}
}

FederationMap fedMap = in1.getFedMapping();
Expand All @@ -226,6 +239,10 @@ private void leftIndexing(ExecutionContext ec)
int[][] sliceIxs = new int[fedMap.getSize()][];
FederatedRange[] ranges = new FederatedRange[fedMap.getSize()];

// instruction string for copying a partition at the federated site
int cpVarInstIx = fedMap.getSize();
String cpVarInstString = createCopyInstString();

// replace old reshape values for each worker
int i = 0, prev = 0, from = fedMap.getSize();
for(org.apache.commons.lang3.tuple.Pair<FederatedRange, FederatedData> e : fedMap.getMap()) {
Expand All @@ -239,29 +256,46 @@ private void leftIndexing(ExecutionContext ec)

long[] newIx = new long[]{(int) rsn, (int) ren, (int) csn, (int) cen};

// find ranges where to apply leftIndex
long to;
if(in1.isFederated(FType.ROW) && (to = (prev + ren - rsn)) >= 0 &&
to < in2.getNumRows() && ixrange.rowStart <= re) {
sliceIxs[i] = new int[] { prev, (int) to, 0, (int) in2.getNumColumns()-1};
prev = (int) (to + 1);

instStrings[i] = modifyIndices(newIx, 4, 8);
ranges[i] = range;
from = Math.min(i, from);
if(in2 != null) { // matrix, frame
// find ranges where to apply leftIndex
long to;
if(in1.isFederated(FType.ROW) && (to = (prev + ren - rsn)) >= 0 &&
to < in2.getNumRows() && ixrange.rowStart <= re) {
sliceIxs[i] = new int[] { prev, (int) to, 0, (int) in2.getNumColumns()-1};
prev = (int) (to + 1);

instStrings[i] = modifyIndices(newIx, 4, 8);
ranges[i] = range;
from = Math.min(i, from);
}
else if(in1.isFederated(FType.COL) && (to = (prev + cen - csn)) >= 0 &&
to < in2.getNumColumns() && ixrange.colStart <= ce) {
sliceIxs[i] = new int[] {0, (int) in2.getNumRows() - 1, prev, (int) to};
prev = (int) (to + 1);

instStrings[i] = modifyIndices(newIx, 4, 8);
ranges[i] = range;
from = Math.min(i, from);
}
else {
// TODO shallow copy, add more advanced update in place for federated
cpVarInstIx = Math.min(i, cpVarInstIx);
instStrings[i] = cpVarInstString;
}
}
else if(in1.isFederated(FType.COL) && (to = (prev + cen - csn)) >= 0 &&
to < in2.getNumColumns() && ixrange.colStart <= ce) {
sliceIxs[i] = new int[] {0, (int) in2.getNumRows() - 1, prev, (int) to};
prev = (int) (to + 1);

instStrings[i] = modifyIndices(newIx, 4, 8);
ranges[i] = range;
from = Math.min(i, from);
else { // scalar
if(ixrange.rowStart >= rs && ixrange.rowEnd < re
&& ixrange.colStart >= cs && ixrange.colEnd < ce) {
instStrings[i] = modifyIndices(newIx, 4, 8);
instStrings[i] = changeScalarLiteralFlag(instStrings[i], 3);
ranges[i] = range;
from = Math.min(i, from);
}
else {
cpVarInstIx = Math.min(i, cpVarInstIx);
instStrings[i] = cpVarInstString;
}
}
else
// TODO shallow copy, add more advanced update in place for federated
instStrings[i] = createCopyInstString();

i++;
}
Expand All @@ -272,32 +306,40 @@ else if(in1.isFederated(FType.COL) && (to = (prev + cen - csn)) >= 0 &&
FederatedRequest tmp = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new MatrixCharacteristics(-1, -1), in1.getDataType());
fedMap.execute(getTID(), true, tmp);

FederatedRequest[] fr1 = fedMap.broadcastSliced(in2, DMLScript.LINEAGE ? ec.getLineageItem(input2) : null,
input2.isFrame(), sliceIxs);
FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, output, id, new CPOperand[]{input1, input2},
new long[]{fedMap.getID(), fr1[0].getID()}, null);
FederatedRequest fr3 = fedMap.cleanup(getTID(), fr1[0].getID());
if(in2 != null) { // matrix, frame
FederatedRequest[] fr1 = fedMap.broadcastSliced(in2, DMLScript.LINEAGE ? ec.getLineageItem(input2) : null,
input2.isFrame(), sliceIxs);
FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, output, id, new CPOperand[]{input1, input2},
new long[]{fedMap.getID(), fr1[0].getID()}, null);
FederatedRequest fr3 = fedMap.cleanup(getTID(), fr1[0].getID());

//execute federated instruction and cleanup intermediates
if(sliceIxs.length == fedMap.getSize())
fedMap.execute(getTID(), true, fr2, fr1, fr3);
else {
// get index of cpvar request
for(i = 0; i < fr2.length; i++)
if(i < from || i >= from + sliceIxs.length)
break;
fedMap.execute(getTID(), true, ranges, (fr2[i]), Arrays.copyOfRange(fr2, from, from + sliceIxs.length), fr1, fr3);
//execute federated instruction and cleanup intermediates
if(sliceIxs.length == fedMap.getSize())
fedMap.execute(getTID(), true, fr2, fr1, fr3);
else
fedMap.execute(getTID(), true, ranges, fr2[cpVarInstIx], Arrays.copyOfRange(fr2, from, from + sliceIxs.length), fr1, fr3);
}
else { // scalar
FederatedRequest fr1 = fedMap.broadcast(scalar);
FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, output, id, new CPOperand[]{input1, input2},
new long[]{fedMap.getID(), fr1.getID()}, null);
FederatedRequest fr3 = fedMap.cleanup(getTID(), fr1.getID());

if(fr2.length == 1)
fedMap.execute(getTID(), true, fr2, fr1, fr3);
else
fedMap.execute(getTID(), true, ranges, fr2[cpVarInstIx], fr2[from], fr1, fr3);
}

if(input1.isFrame()) {
FrameObject out = ec.getFrameObject(output);
out.setSchema(((FrameObject) in1).getSchema());
out.getDataCharacteristics().set(in1.getDataCharacteristics());
out.setFedMapping(fedMap.copyWithNewID(fr2[0].getID()));
out.setFedMapping(fedMap.copyWithNewID(id));
} else {
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(in1.getDataCharacteristics());;
out.setFedMapping(fedMap.copyWithNewID(fr2[0].getID()));
out.getDataCharacteristics().set(in1.getDataCharacteristics());
out.setFedMapping(fedMap.copyWithNewID(id));
}
}

Expand All @@ -309,6 +351,13 @@ private String modifyIndices(long[] newIx, int from, int to) {
return String.join(Lop.OPERAND_DELIMITOR, instParts);
}

private String changeScalarLiteralFlag(String inst, int partIx) {
// change the literal flag of the broadcast scalar
String[] instParts = inst.split(Lop.OPERAND_DELIMITOR);
instParts[partIx] = instParts[partIx].replace("true", "false");
return String.join(Lop.OPERAND_DELIMITOR, instParts);
}

private String createCopyInstString() {
String[] instParts = instString.split(Lop.OPERAND_DELIMITOR);
return VariableCPInstruction.prepareCopyInstruction(instParts[2], instParts[8]).toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public class FederatedLeftIndexTest extends AutomatedTestBase {

private final static String TEST_NAME1 = "FederatedLeftIndexFullTest";
private final static String TEST_NAME2 = "FederatedLeftIndexFrameFullTest";
private final static String TEST_NAME3 = "FederatedLeftIndexScalarTest";

private final static String TEST_DIR = "functions/federated/";
private static final String TEST_CLASS_DIR = TEST_DIR + FederatedLeftIndexTest.class.getSimpleName() + "/";
Expand Down Expand Up @@ -81,14 +82,15 @@ public static Collection<Object[]> data() {
}

private enum DataType {
MATRIX, FRAME
MATRIX, FRAME, SCALAR
}

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"}));
addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"}));
}

@Test
Expand All @@ -109,15 +111,27 @@ public void testLeftIndexFullDenseFrameSP() {
runAggregateOperationTest(DataType.FRAME, ExecMode.SPARK);
}

@Test
public void testLeftIndexScalarCP() {
runAggregateOperationTest(DataType.SCALAR, ExecMode.SINGLE_NODE);
}

@Test
public void testLeftIndexScalarSP() {
runAggregateOperationTest(DataType.SCALAR, ExecMode.SPARK);
}

private void runAggregateOperationTest(DataType dataType, ExecMode execMode) {
setExecMode(execMode);

String TEST_NAME = null;

if(dataType == DataType.MATRIX)
TEST_NAME = TEST_NAME1;
else
else if(dataType == DataType.FRAME)
TEST_NAME = TEST_NAME2;
else
TEST_NAME = TEST_NAME3;


getAndLoadTestConfiguration(TEST_NAME);
Expand All @@ -142,10 +156,12 @@ private void runAggregateOperationTest(DataType dataType, ExecMode execMode) {
writeInputMatrixWithMTD("X3", X3, false, mc);
writeInputMatrixWithMTD("X4", X4, false, mc);

double[][] Y = getRandomMatrix(rows2, cols2, 1, 5, 1, 3);
if(dataType != DataType.SCALAR) {
double[][] Y = getRandomMatrix(rows2, cols2, 1, 5, 1, 3);

MatrixCharacteristics mc2 = new MatrixCharacteristics(rows2, cols2, blocksize, rows2 * cols2);
writeInputMatrixWithMTD("Y", Y, false, mc2);
MatrixCharacteristics mc2 = new MatrixCharacteristics(rows2, cols2, blocksize, rows2 * cols2);
writeInputMatrixWithMTD("Y", Y, false, mc2);
}

// empty script name because we don't execute any script, just start the worker
fullDMLScriptName = "";
Expand Down Expand Up @@ -173,7 +189,7 @@ private void runAggregateOperationTest(DataType dataType, ExecMode execMode) {

// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-explain", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
programArgs = new String[] {"-args", input("X1"), input("X2"), input("X3"), input("X4"),
input("Y"), String.valueOf(from), String.valueOf(to),
String.valueOf(from2), String.valueOf(to2),
Boolean.toString(rowPartitioned).toUpperCase(), expected("S")};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,3 @@ A = as.frame(A)

A[from:to, from2:to2] = B;
write(A, $out_S);

print(toString(A))
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,3 @@ A = as.frame(A)

A[from:to, from2:to2] = B;
write(A, $11);

print(toString(A))
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,3 @@ B = read($in_Y)

A[from:to, from2:to2] = B;
write(A, $out_S);

print(toString(A))
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,3 @@ B = read($5)

A[from:to, from2:to2] = B;
write(A, $11);

print(toString(A))
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

row1 = $from;
row2 = $to;
col1 = $from2;
col2 = $to2;

if ($rP) {
A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
} else {
A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2),
list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
}

b = 13;
c = as.scalar(rand(rows=1, cols=1, seed=456));

A[row1, col1] = b;
A[row2, col2] = c;

write(A, $out_S);

Loading