Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,21 @@ public String toString() {
return Arrays.toString(_beginDims) + " - " + Arrays.toString(_endDims);
}

@Override public boolean equals(Object o) {
if(this == o)
return true;
if(o == null || getClass() != o.getClass())
return false;
FederatedRange range = (FederatedRange) o;
return Arrays.equals(_beginDims, range._beginDims) && Arrays.equals(_endDims, range._endDims);
}

@Override public int hashCode() {
int result = Arrays.hashCode(_beginDims);
result = 31 * result + Arrays.hashCode(_endDims);
return result;
}

public FederatedRange shift(long rshift, long cshift) {
//row shift
_beginDims[0] += rshift;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,10 @@ public FederationMap copyWithNewID() {
public FederationMap copyWithNewID(long id) {
Map<FederatedRange, FederatedData> map = new TreeMap<>();
//TODO handling of file path, but no danger as never written
for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() )
map.put(new FederatedRange(e.getKey()), e.getValue().copyWithNewID(id));
for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() ) {
if(e.getKey().getSize() != 0)
map.put(new FederatedRange(e.getKey()), e.getValue().copyWithNewID(id));
}
return new FederationMap(id, map, _type);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public enum FEDType {
Tsmm,
MMChain,
Reorg,
MatrixIndexing
}

protected final FEDType _fedType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
Expand Down Expand Up @@ -127,6 +128,15 @@ else if(inst instanceof ReorgCPInstruction && inst.getOpcode().equals("r'")) {
if( mo.isFederated() )
fedinst = ReorgFEDInstruction.parseInstruction(rinst.getInstructionString());
}
else if(inst instanceof MatrixIndexingCPInstruction && inst.getOpcode().equalsIgnoreCase("rightIndex")) {
// matrix indexing
MatrixIndexingCPInstruction minst = (MatrixIndexingCPInstruction) inst;
if(minst.input1.isMatrix()) {
CacheableData<?> fo = ec.getCacheableData(minst.input1);
if(fo.isFederated())
fedinst = MatrixIndexingFEDInstruction.parseInstruction(minst.getInstructionString());
}
}
else if(inst instanceof VariableCPInstruction ){
VariableCPInstruction ins = (VariableCPInstruction) inst;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.instructions.fed;

import org.apache.sysds.common.Types;
import org.apache.sysds.lops.LeftIndex;
import org.apache.sysds.lops.RightIndex;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.util.IndexRange;

public abstract class IndexingFEDInstruction extends UnaryFEDInstruction {
protected final CPOperand rowLower, rowUpper, colLower, colUpper;

protected IndexingFEDInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu,
CPOperand out, String opcode, String istr) {
super(FEDInstruction.FEDType.MatrixIndexing, null, in, out, opcode, istr);
rowLower = rl;
rowUpper = ru;
colLower = cl;
colUpper = cu;
}

protected IndexingFEDInstruction(CPOperand lhsInput, CPOperand rhsInput, CPOperand rl, CPOperand ru, CPOperand cl,
CPOperand cu, CPOperand out, String opcode, String istr) {
super(FEDInstruction.FEDType.MatrixIndexing, null, lhsInput, rhsInput, out, opcode, istr);
rowLower = rl;
rowUpper = ru;
colLower = cl;
colUpper = cu;
}

protected IndexRange getIndexRange(ExecutionContext ec) {
return new IndexRange( // rl, ru, cl, ru
(int) (ec.getScalarInput(rowLower).getLongValue() - 1),
(int) (ec.getScalarInput(rowUpper).getLongValue() - 1),
(int) (ec.getScalarInput(colLower).getLongValue() - 1),
(int) (ec.getScalarInput(colUpper).getLongValue() - 1));
}

public static IndexingFEDInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];

if(opcode.equalsIgnoreCase(RightIndex.OPCODE)) {
if(parts.length == 7) {
CPOperand in, rl, ru, cl, cu, out;
in = new CPOperand(parts[1]);
rl = new CPOperand(parts[2]);
ru = new CPOperand(parts[3]);
cl = new CPOperand(parts[4]);
cu = new CPOperand(parts[5]);
out = new CPOperand(parts[6]);
if(in.getDataType() == Types.DataType.MATRIX)
return new MatrixIndexingFEDInstruction(in, rl, ru, cl, cu, out, opcode, str);
else
throw new DMLRuntimeException("Can index only on matrices, frames, and lists in federated.");
}
else {
throw new DMLRuntimeException("Invalid number of operands in instruction: " + str);
}
}
else if(opcode.equalsIgnoreCase(LeftIndex.OPCODE)) {
throw new DMLRuntimeException("Left indexing not implemented for federated operations.");
}
else {
throw new DMLRuntimeException("Unknown opcode while parsing a MatrixIndexingFEDInstruction: " + str);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.instructions.fed;

import java.util.HashMap;
import java.util.Map;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.IndexRange;

public final class MatrixIndexingFEDInstruction extends IndexingFEDInstruction {
private static final Log LOG = LogFactory.getLog(MatrixIndexingFEDInstruction.class.getName());

public MatrixIndexingFEDInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu,
CPOperand out, String opcode, String istr) {
super(in, rl, ru, cl, cu, out, opcode, istr);
}

@Override
public void processInstruction(ExecutionContext ec) {
rightIndexing(ec);
}

private void rightIndexing(ExecutionContext ec) {
MatrixObject in = ec.getMatrixObject(input1);
FederationMap fedMapping = in.getFedMapping();
IndexRange ixrange = getIndexRange(ec);
// FederationMap.FType fedType;
Map<FederatedRange, IndexRange> ixs = new HashMap<>();

for(int i = 0; i < fedMapping.getFederatedRanges().length; i++) {
FederatedRange curFedRange = fedMapping.getFederatedRanges()[i];
long rs = curFedRange.getBeginDims()[0], re = curFedRange.getEndDims()[0],
cs = curFedRange.getBeginDims()[1], ce = curFedRange.getEndDims()[1];

if((ixrange.colStart <= ce) && (ixrange.colEnd >= cs) && (ixrange.rowStart <= re) && (ixrange.rowEnd >= rs)) {
// If the indexing range contains values that are within the specific federated range.
// change the range.
long rsn = (ixrange.rowStart >= rs) ? (ixrange.rowStart - rs) : 0;
long ren = (ixrange.rowEnd >= rs && ixrange.rowEnd < re) ? (ixrange.rowEnd - rs) : (re - rs - 1);
long csn = (ixrange.colStart >= cs) ? (ixrange.colStart - cs) : 0;
long cen = (ixrange.colEnd >= cs && ixrange.colEnd < ce) ? (ixrange.colEnd - cs) : (ce - cs - 1);
if(LOG.isDebugEnabled()) {
LOG.debug("Ranges for fed location: " + rsn + " " + ren + " " + csn + " " + cen);
LOG.debug("ixRange : " + ixrange);
LOG.debug("Fed Mapping : " + curFedRange);
}
curFedRange.setBeginDim(0, Math.max(rs - ixrange.rowStart, 0));
curFedRange.setBeginDim(1, Math.max(cs - ixrange.colStart, 0));
curFedRange.setEndDim(0,
(ixrange.rowEnd > re ? re - ixrange.rowStart : ixrange.rowEnd - ixrange.rowStart + 1));
curFedRange.setEndDim(1,
(ixrange.colEnd > ce ? ce - ixrange.colStart : ixrange.colEnd - ixrange.colStart + 1));
if(LOG.isDebugEnabled()) {
LOG.debug("Fed Mapping After : " + curFedRange);
}
ixs.put(curFedRange, new IndexRange(rsn, ren, csn, cen));
}
else {
// If not within the range, change the range to become an 0 times 0 big range.
// by setting the end dimensions to the same as the beginning dimensions.
curFedRange.setBeginDim(0, 0);
curFedRange.setBeginDim(1, 0);
curFedRange.setEndDim(0, 0);
curFedRange.setEndDim(1, 0);
}

}

long varID = FederationUtils.getNextFedDataID();
FederationMap slicedMapping = fedMapping.mapParallel(varID, (range, data) -> {
try {
FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(
FederatedRequest.RequestType.EXEC_UDF, -1,
new SliceMatrix(data.getVarID(), varID, ixs.getOrDefault(range, new IndexRange(-1, -1, -1, -1)))))
.get();
if(!response.isSuccessful())
response.throwExceptionFromResponse();
}
catch(Exception e) {
throw new DMLRuntimeException(e);
}
return null;
});

MatrixObject sliced = ec.getMatrixObject(output);
sliced.getDataCharacteristics()
.set(fedMapping.getMaxIndexInRange(0), fedMapping.getMaxIndexInRange(1), (int) in.getBlocksize());
if(ixrange.rowEnd - ixrange.rowStart == 0) {
slicedMapping.setType(FederationMap.FType.COL);
}
else if(ixrange.colEnd - ixrange.colStart == 0) {
slicedMapping.setType(FederationMap.FType.ROW);
}
sliced.setFedMapping(slicedMapping);
LOG.debug(slicedMapping);
LOG.debug(sliced);
}

private static class SliceMatrix extends FederatedUDF {

private static final long serialVersionUID = 5956832933333848772L;
private final long _outputID;
private final IndexRange _ixrange;

private SliceMatrix(long input, long outputID, IndexRange ixrange) {
super(new long[] {input});
_outputID = outputID;
_ixrange = ixrange;
}

@Override
public FederatedResponse execute(ExecutionContext ec, Data... data) {
MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
MatrixBlock res;
if(_ixrange.rowStart != -1)
res = mb.slice(_ixrange, new MatrixBlock());
else
res = new MatrixBlock();
MatrixObject mout = ExecutionContext.createMatrixObject(res);
ec.setVariable(String.valueOf(_outputID), mout);

return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,22 @@
import java.io.File;
import java.util.HashMap;

import org.junit.Assert;
import org.junit.Test;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.Test;

public class CellwiseTmplTest extends AutomatedTestBase
{
private static final Log LOG = LogFactory.getLog(CellwiseTmplTest.class.getName());

private static final String TEST_NAME = "cellwisetmpl";
private static final String TEST_NAME1 = TEST_NAME+1;
private static final String TEST_NAME2 = TEST_NAME+2;
Expand Down Expand Up @@ -539,7 +543,7 @@ else if( testname.equals(TEST_NAME23) || testname.equals(TEST_NAME24) )
protected File getConfigTemplateFile() {
// Instrumentation in this test's output log to show custom configuration file used for template.
File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, TEST_CONF);
System.out.println("This test case overrides default configuration with " + TEST_CONF_FILE.getPath());
LOG.info("This test case overrides default configuration with " + TEST_CONF_FILE.getPath());
return TEST_CONF_FILE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,23 @@
import java.io.File;
import java.util.HashMap;

import org.junit.Assert;
import org.junit.Test;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.Test;

public class DAGCellwiseTmplTest extends AutomatedTestBase
{

private static final Log LOG = LogFactory.getLog(DAGCellwiseTmplTest.class.getName());

private static final String TEST_NAME1 = "DAGcellwisetmpl1";
private static final String TEST_NAME2 = "DAGcellwisetmpl2";
private static final String TEST_NAME3 = "DAGcellwisetmpl3";
Expand Down Expand Up @@ -160,7 +165,7 @@ private void testCodegenIntegration( String testname, boolean rewrites, boolean
@Override
protected File getConfigTemplateFile() {
// Instrumentation in this test's output log to show custom configuration file used for template.
System.out.println("This test case overrides default configuration with " + TEST_CONF_FILE.getPath());
LOG.info("This test case overrides default configuration with " + TEST_CONF_FILE.getPath());
return TEST_CONF_FILE;
}
}
Loading