Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,21 @@ public String toString() {
return Arrays.toString(_beginDims) + " - " + Arrays.toString(_endDims);
}

@Override public boolean equals(Object o) {
if(this == o)
return true;
if(o == null || getClass() != o.getClass())
return false;
FederatedRange range = (FederatedRange) o;
return Arrays.equals(_beginDims, range._beginDims) && Arrays.equals(_endDims, range._endDims);
}

@Override public int hashCode() {
int result = Arrays.hashCode(_beginDims);
result = 31 * result + Arrays.hashCode(_endDims);
return result;
}

public FederatedRange shift(long rshift, long cshift) {
//row shift
_beginDims[0] += rshift;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,10 @@ public FederationMap copyWithNewID() {
public FederationMap copyWithNewID(long id) {
Map<FederatedRange, FederatedData> map = new TreeMap<>();
//TODO handling of file path, but no danger as never written
for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() )
map.put(new FederatedRange(e.getKey()), e.getValue().copyWithNewID(id));
for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() ) {
if(e.getKey().getSize() != 0)
map.put(new FederatedRange(e.getKey()), e.getValue().copyWithNewID(id));
}
return new FederationMap(id, map, _type);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public enum FEDType {
Tsmm,
MMChain,
Reorg,
MatrixIndexing
}

protected final FEDType _fedType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
Expand Down Expand Up @@ -127,6 +128,15 @@ else if(inst instanceof ReorgCPInstruction && inst.getOpcode().equals("r'")) {
if( mo.isFederated() )
fedinst = ReorgFEDInstruction.parseInstruction(rinst.getInstructionString());
}
else if(inst instanceof MatrixIndexingCPInstruction && inst.getOpcode().equalsIgnoreCase("rightIndex")) {
// matrix indexing
MatrixIndexingCPInstruction minst = (MatrixIndexingCPInstruction) inst;
if(minst.input1.isMatrix()) {
CacheableData<?> fo = ec.getCacheableData(minst.input1);
if(fo.isFederated())
fedinst = MatrixIndexingFEDInstruction.parseInstruction(minst.getInstructionString());
}
}
else if(inst instanceof VariableCPInstruction ){
VariableCPInstruction ins = (VariableCPInstruction) inst;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.instructions.fed;

import org.apache.sysds.common.Types;
import org.apache.sysds.lops.LeftIndex;
import org.apache.sysds.lops.RightIndex;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.util.IndexRange;

public abstract class IndexingFEDInstruction extends UnaryFEDInstruction {
protected final CPOperand rowLower, rowUpper, colLower, colUpper;

protected IndexingFEDInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu,
CPOperand out, String opcode, String istr) {
super(FEDInstruction.FEDType.MatrixIndexing, null, in, out, opcode, istr);
rowLower = rl;
rowUpper = ru;
colLower = cl;
colUpper = cu;
}

protected IndexingFEDInstruction(CPOperand lhsInput, CPOperand rhsInput, CPOperand rl, CPOperand ru, CPOperand cl,
CPOperand cu, CPOperand out, String opcode, String istr) {
super(FEDInstruction.FEDType.MatrixIndexing, null, lhsInput, rhsInput, out, opcode, istr);
rowLower = rl;
rowUpper = ru;
colLower = cl;
colUpper = cu;
}

protected IndexRange getIndexRange(ExecutionContext ec) {
return new IndexRange( //rl, ru, cl, ru
(int) (ec.getScalarInput(rowLower).getLongValue() - 1),
(int) (ec.getScalarInput(rowUpper).getLongValue() - 1),
(int) (ec.getScalarInput(colLower).getLongValue() - 1),
(int) (ec.getScalarInput(colUpper).getLongValue() - 1));
}

public static IndexingFEDInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];

if(opcode.equalsIgnoreCase(RightIndex.OPCODE)) {
if(parts.length == 7) {
CPOperand in, rl, ru, cl, cu, out;
in = new CPOperand(parts[1]);
rl = new CPOperand(parts[2]);
ru = new CPOperand(parts[3]);
cl = new CPOperand(parts[4]);
cu = new CPOperand(parts[5]);
out = new CPOperand(parts[6]);
if(in.getDataType() == Types.DataType.MATRIX)
return new MatrixIndexingFEDInstruction(in, rl, ru, cl, cu, out, opcode, str);
// else if( in.getDataType() == Types.DataType.FRAME )
// return new FrameIndexingCPInstruction(in, rl, ru, cl, cu, out, opcode, str);
// else if( in.getDataType() == Types.DataType.LIST )
// return new ListIndexingCPInstruction(in, rl, ru, cl, cu, out, opcode, str);
else
throw new DMLRuntimeException("Can index only on matrices, frames, and lists.");
}
else {
throw new DMLRuntimeException("Invalid number of operands in instruction: " + str);
}
}
// else if ( opcode.equalsIgnoreCase(LeftIndex.OPCODE)) {
// if ( parts.length == 8 ) {
// CPOperand lhsInput, rhsInput, rl, ru, cl, cu, out;
// lhsInput = new CPOperand(parts[1]);
// rhsInput = new CPOperand(parts[2]);
// rl = new CPOperand(parts[3]);
// ru = new CPOperand(parts[4]);
// cl = new CPOperand(parts[5]);
// cu = new CPOperand(parts[6]);
// out = new CPOperand(parts[7]);
// if( lhsInput.getDataType()== Types.DataType.MATRIX )
// return new MatrixIndexingFEDInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, str);
// else if (lhsInput.getDataType() == Types.DataType.FRAME)
// return new FrameIndexingFEDInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, str);
// else if( lhsInput.getDataType() == Types.DataType.LIST )
// return new ListIndexingFEDInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, str);
// else
// throw new DMLRuntimeException("Can index only on matrices, frames, and lists.");
// }
// else {
// throw new DMLRuntimeException("Invalid number of operands in instruction: " + str);
// }
// }
else {
throw new DMLRuntimeException("Unknown opcode while parsing a MatrixIndexingFEDInstruction: " + str);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.instructions.fed;

import java.util.HashMap;
import java.util.Map;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.IndexRange;

public final class MatrixIndexingFEDInstruction extends IndexingFEDInstruction {
private static final Log LOG = LogFactory.getLog(MatrixIndexingFEDInstruction.class.getName());

public MatrixIndexingFEDInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu,
CPOperand out, String opcode, String istr) {
super(in, rl, ru, cl, cu, out, opcode, istr);
}

@Override
public void processInstruction(ExecutionContext ec) {
rightIndexing(ec);
}


private void rightIndexing (ExecutionContext ec) {
MatrixObject in = ec.getMatrixObject(input1);
FederationMap fedMapping = in.getFedMapping();
IndexRange ixrange = getIndexRange(ec);
FederationMap.FType fedType;
Map <FederatedRange, IndexRange> ixs = new HashMap<>();

FederatedRange nextDim = new FederatedRange(new long[]{0, 0}, new long[]{0, 0});

for (int i = 0; i < fedMapping.getFederatedRanges().length; i++) {
long rs = fedMapping.getFederatedRanges()[i].getBeginDims()[0], re = fedMapping.getFederatedRanges()[i]
.getEndDims()[0], cs = fedMapping.getFederatedRanges()[i].getBeginDims()[1], ce = fedMapping.getFederatedRanges()[i].getEndDims()[1];

// for OTHER
fedType = ((i + 1) < fedMapping.getFederatedRanges().length &&
fedMapping.getFederatedRanges()[i].getEndDims()[0] == fedMapping.getFederatedRanges()[i+1].getBeginDims()[0]) ?
FederationMap.FType.ROW : FederationMap.FType.COL;

long rsn = 0, ren = 0, csn = 0, cen = 0;

rsn = (ixrange.rowStart >= rs && ixrange.rowStart < re) ? (ixrange.rowStart - rs) : 0;
ren = (ixrange.rowEnd >= rs && ixrange.rowEnd < re) ? (ixrange.rowEnd - rs) : (re - rs - 1);
csn = (ixrange.colStart >= cs && ixrange.colStart < ce) ? (ixrange.colStart - cs) : 0;
cen = (ixrange.colEnd >= cs && ixrange.colEnd < ce) ? (ixrange.colEnd - cs) : (ce - cs - 1);

fedMapping.getFederatedRanges()[i].setBeginDim(0, i != 0 ? nextDim.getBeginDims()[0] : 0);
fedMapping.getFederatedRanges()[i].setBeginDim(1, i != 0 ? nextDim.getBeginDims()[1] : 0);
if((ixrange.colStart < ce) && (ixrange.colEnd >= cs) && (ixrange.rowStart < re) && (ixrange.rowEnd >= rs)) {
fedMapping.getFederatedRanges()[i].setEndDim(0, ren - rsn + 1 + nextDim.getBeginDims()[0]);
fedMapping.getFederatedRanges()[i].setEndDim(1, cen - csn + 1 + nextDim.getBeginDims()[1]);

ixs.put(fedMapping.getFederatedRanges()[i], new IndexRange(rsn, ren, csn, cen));
} else {
fedMapping.getFederatedRanges()[i].setEndDim(0, i != 0 ? nextDim.getBeginDims()[0] : 0);
fedMapping.getFederatedRanges()[i].setEndDim(1, i != 0 ? nextDim.getBeginDims()[1] : 0);
}

if(fedType == FederationMap.FType.ROW) {
nextDim.setBeginDim(0,fedMapping.getFederatedRanges()[i].getEndDims()[0]);
nextDim.setBeginDim(1, fedMapping.getFederatedRanges()[i].getBeginDims()[1]);
} else if(fedType == FederationMap.FType.COL) {
nextDim.setBeginDim(1,fedMapping.getFederatedRanges()[i].getEndDims()[1]);
nextDim.setBeginDim(0, fedMapping.getFederatedRanges()[i].getBeginDims()[0]);
}
}

long varID = FederationUtils.getNextFedDataID();
FederationMap slicedMapping = fedMapping.mapParallel(varID, (range, data) -> {
try {
FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
-1, new SliceMatrix(data.getVarID(), varID, ixs.getOrDefault(range, new IndexRange(-1, -1, -1, -1))))).get();
if(!response.isSuccessful())
response.throwExceptionFromResponse();
}
catch(Exception e) {
throw new DMLRuntimeException(e);
}
return null;
});

MatrixObject sliced = ec.getMatrixObject(output);
sliced.getDataCharacteristics().set(fedMapping.getMaxIndexInRange(0), fedMapping.getMaxIndexInRange(1), (int) in.getBlocksize());
sliced.setFedMapping(slicedMapping);
}

private static class SliceMatrix extends FederatedUDF {

private static final long serialVersionUID = 5956832933333848772L;
private final long _outputID;
private final IndexRange _ixrange;

private SliceMatrix(long input, long outputID, IndexRange ixrange) {
super(new long[] {input});
_outputID = outputID;
_ixrange = ixrange;
}


@Override public FederatedResponse execute(ExecutionContext ec, Data... data) {
MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
MatrixBlock res;
if(_ixrange.rowStart != -1)
res = mb.slice(_ixrange, new MatrixBlock());
else res = new MatrixBlock();
MatrixObject mout = ExecutionContext.createMatrixObject(res);
ec.setVariable(String.valueOf(_outputID), mout);

return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY);
}
}
}
Loading