Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ private void validateParamserv(DataIdentifier output, boolean conditional) {
raiseValidateError("Should provide more arguments for function " + fname, false, LanguageErrorCodes.INVALID_PARAMETERS);
}
//check for invalid parameters
Set<String> valid = CollectionUtils.asSet(Statement.PS_MODEL, Statement.PS_FEATURES, Statement.PS_LABELS, Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN, Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS, Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING);
Set<String> valid = CollectionUtils.asSet(Statement.PS_MODEL, Statement.PS_FEATURES, Statement.PS_LABELS, Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN, Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS, Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_RUNTIME_BALANCING, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING);
checkInvalidParameters(getOpCode(), getVarParams(), valid);

// check existence and correctness of parameters
Expand All @@ -304,6 +304,7 @@ private void validateParamserv(DataIdentifier output, boolean conditional) {
checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, DataType.SCALAR, ValueType.INT64, conditional);
checkDataValueType(true, fname, Statement.PS_PARALLELISM, DataType.SCALAR, ValueType.INT64, conditional);
checkStringParam(true, fname, Statement.PS_SCHEME, conditional);
checkStringParam(true, fname, Statement.PS_RUNTIME_BALANCING, conditional);
checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, DataType.LIST, ValueType.UNKNOWN, conditional);
checkStringParam(true, fname, Statement.PS_CHECKPOINTING, conditional);

Expand Down
8 changes: 6 additions & 2 deletions src/main/java/org/apache/sysds/parser/Statement.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ public boolean isASP() {
public enum PSFrequency {
BATCH, EPOCH
}
public static final String PS_RUNTIME_BALANCING = "runtime_balancing";
public enum PSRuntimeBalancing {
NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH, SCALE_BATCH_AND_WEIGH
}
public static final String PS_EPOCHS = "epochs";
public static final String PS_BATCH_SIZE = "batchsize";
public static final String PS_PARALLELISM = "k";
Expand All @@ -95,7 +99,7 @@ public enum PSScheme {
DISJOINT_CONTIGUOUS, DISJOINT_ROUND_ROBIN, DISJOINT_RANDOM, OVERLAP_RESHUFFLE
}
public enum FederatedPSScheme {
KEEP_DATA_ON_WORKER, SHUFFLE
KEEP_DATA_ON_WORKER, SHUFFLE, REPLICATE_TO_MAX, SUBSAMPLE_TO_MIN, BALANCE_TO_AVG
}
public static final String PS_HYPER_PARAMS = "hyperparams";
public static final String PS_CHECKPOINTING = "checkpointing";
Expand All @@ -107,7 +111,7 @@ public enum PSCheckpointing {
// prefixed with code: "1701-NCC-" to not overwrite anything
public static final String PS_FED_BATCH_SIZE = "1701-NCC-batch_size";
public static final String PS_FED_DATA_SIZE = "1701-NCC-data_size";
public static final String PS_FED_NUM_BATCHES = "1701-NCC-num_batches";
public static final String PS_FED_POSS_BATCHES_LOCAL = "1701-NCC-poss_batches_local";
public static final String PS_FED_NAMESPACE = "1701-NCC-namespace";
public static final String PS_FED_GRADIENTS_FNAME = "1701-NCC-gradients_fname";
public static final String PS_FED_AGGREGATION_FNAME = "1701-NCC-aggregation_fname";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ private static void checkFederated( Instruction lastInst, LocalVariableMap vars

CacheableData<?> mo = (CacheableData<?>)dat;
if( mo.isFederated() ) {
if( mo.getFedMapping().getFedMapping().isEmpty() )
if( mo.getFedMapping().getFRangeFDataMap().isEmpty() )
throw new DMLRuntimeException("Invalid empty FederationMap for: "+mo);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public FederatedRange[] getFederatedRanges() {
return _fedMap.keySet().toArray(new FederatedRange[0]);
}

public Map<FederatedRange, FederatedData> getFedMapping() {
public Map<FederatedRange, FederatedData> getFRangeFDataMap() {
return _fedMap;
}

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,36 @@ public static MatrixBlock generatePermutation(int numEntries, long seed) {
new MatrixBlock(numEntries, numEntries, true));
}

/**
* Generates a matrix which when left multiplied with the input matrix will subsample
* @param nsamples number of samples
* @param nrows number of rows in input matrix
* @param seed seed used to generate random number
* @return subsample matrix
*/
public static MatrixBlock generateSubsampleMatrix(int nsamples, int nrows, long seed) {
MatrixBlock seq = new MatrixBlock(nsamples, nrows, false);
// No replacement to preserve as much of the original data as possible
MatrixBlock sample = MatrixBlock.sampleOperations(nrows, nsamples, false, seed);
return seq.ctableSeqOperations(sample, 1.0,
new MatrixBlock(nsamples, nrows, true), false);
}

/**
* Generates a matrix which when left multiplied with the input matrix will replicate n data rows
* @param nsamples number of samples
* @param nrows number of rows in input matrix
* @param seed seed used to generate random number
* @return replication matrix
*/
public static MatrixBlock generateReplicationMatrix(int nsamples, int nrows, long seed) {
MatrixBlock seq = new MatrixBlock(nsamples, nrows, false);
// Replacement set to true to provide random replication
MatrixBlock sample = MatrixBlock.sampleOperations(nrows, nsamples, true, seed);
return seq.ctableSeqOperations(sample, 1.0,
new MatrixBlock(nsamples, nrows, true), false);
}

public static ExecutionContext createExecutionContext(ExecutionContext ec,
LocalVariableMap varsMap, String updFunc, String aggFunc, int k)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.controlprogram.paramserv.dp;

import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;

import java.util.List;
import java.util.concurrent.Future;

public class BalanceToAvgFederatedScheme extends DataPartitionFederatedScheme {
@Override
public Result doPartitioning(MatrixObject features, MatrixObject labels) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);

int average_num_rows = (int) Math.round(pFeatures.stream().map(CacheableData::getNumRows).mapToInt(Long::intValue).average().orElse(Double.NaN));

for(int i = 0; i < pFeatures.size(); i++) {
// Works, because the map contains a single entry
FederatedData featuresData = (FederatedData) pFeatures.get(i).getFedMapping().getFRangeFDataMap().values().toArray()[0];
FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getFRangeFDataMap().values().toArray()[0];

Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
featuresData.getVarID(), new balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, average_num_rows)));

try {
FederatedResponse response = udfResponse.get();
if(!response.isSuccessful())
throw new DMLRuntimeException("FederatedDataPartitioner BalanceFederatedScheme: balance UDF returned fail");
}
catch(Exception e) {
throw new DMLRuntimeException("FederatedDataPartitioner BalanceFederatedScheme: executing balance UDF failed" + e.getMessage());
}

DataCharacteristics update = pFeatures.get(i).getDataCharacteristics().setRows(average_num_rows);
pFeatures.get(i).updateDataCharacteristics(update);
update = pLabels.get(i).getDataCharacteristics().setRows(average_num_rows);
pLabels.get(i).updateDataCharacteristics(update);
}

return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures));
}

/**
* Balance UDF executed on the federated worker
*/
private static class balanceDataOnFederatedWorker extends FederatedUDF {
int _average_num_rows;
protected balanceDataOnFederatedWorker(long[] inIDs, int average_num_rows) {
super(inIDs);
_average_num_rows = average_num_rows;
}

@Override
public FederatedResponse execute(ExecutionContext ec, Data... data) {
MatrixObject features = (MatrixObject) data[0];
MatrixObject labels = (MatrixObject) data[1];

if(features.getNumRows() > _average_num_rows) {
// generate subsampling matrix
MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(_average_num_rows, Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
subsampleTo(features, subsampleMatrixBlock);
subsampleTo(labels, subsampleMatrixBlock);
}
else if(features.getNumRows() < _average_num_rows) {
int num_rows_needed = _average_num_rows - Math.toIntExact(features.getNumRows());
// generate replication matrix
MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
replicateTo(features, replicateMatrixBlock);
replicateTo(labels, replicateMatrixBlock);
}

return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;

Expand All @@ -37,14 +42,17 @@
public abstract class DataPartitionFederatedScheme {

public static final class Result {
public final List<MatrixObject> pFeatures;
public final List<MatrixObject> pLabels;
public final int workerNum;

public Result(List<MatrixObject> pFeatures, List<MatrixObject> pLabels, int workerNum) {
this.pFeatures = pFeatures;
this.pLabels = pLabels;
this.workerNum = workerNum;
public final List<MatrixObject> _pFeatures;
public final List<MatrixObject> _pLabels;
public final int _workerNum;
public final BalanceMetrics _balanceMetrics;


public Result(List<MatrixObject> pFeatures, List<MatrixObject> pLabels, int workerNum, BalanceMetrics balanceMetrics) {
this._pFeatures = pFeatures;
this._pLabels = pLabels;
this._workerNum = workerNum;
this._balanceMetrics = balanceMetrics;
}
}

Expand All @@ -57,12 +65,10 @@ public Result(List<MatrixObject> pFeatures, List<MatrixObject> pLabels, int work
*/
static List<MatrixObject> sliceFederatedMatrix(MatrixObject fedMatrix) {
if (fedMatrix.isFederated(FederationMap.FType.ROW)) {

List<MatrixObject> slices = Collections.synchronizedList(new ArrayList<>());
fedMatrix.getFedMapping().forEachParallel((range, data) -> {
// Create sliced matrix object
MatrixObject slice = new MatrixObject(fedMatrix.getValueType(), Dag.getNextUniqueVarname(Types.DataType.MATRIX));
// Warning needs MetaDataFormat instead of MetaData
slice.setMetaData(new MetaDataFormat(
new MatrixCharacteristics(range.getSize(0), range.getSize(1)),
Types.FileFormat.BINARY)
Expand All @@ -85,4 +91,81 @@ static List<MatrixObject> sliceFederatedMatrix(MatrixObject fedMatrix) {
"currently only supports row federated data");
}
}

static BalanceMetrics getBalanceMetrics(List<MatrixObject> slices) {
if (slices == null || slices.size() == 0)
return new BalanceMetrics(0, 0, 0);

long minRows = slices.get(0).getNumRows();
long maxRows = minRows;
long sum = 0;

for (MatrixObject slice : slices) {
if (slice.getNumRows() < minRows)
minRows = slice.getNumRows();
else if (slice.getNumRows() > maxRows)
maxRows = slice.getNumRows();

sum += slice.getNumRows();
}

return new BalanceMetrics(minRows, sum / slices.size(), maxRows);
}

public static final class BalanceMetrics {
public final long _minRows;
public final long _avgRows;
public final long _maxRows;

public BalanceMetrics(long minRows, long avgRows, long maxRows) {
this._minRows = minRows;
this._avgRows = avgRows;
this._maxRows = maxRows;
}
}

/**
* Just a mat multiply used to shuffle with a provided shuffle matrixBlock
*
* @param m the input matrix object
* @param permutationMatrixBlock the shuffle matrix block
*/
static void shuffle(MatrixObject m, MatrixBlock permutationMatrixBlock) {
// matrix multiply
m.acquireModify(permutationMatrixBlock.aggregateBinaryOperations(
permutationMatrixBlock, m.acquireReadAndRelease(), new MatrixBlock(),
new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject()))
));
m.release();
}

/**
* Takes a MatrixObjects and extends it to the chosen number of rows by random replication
*
* @param m the input matrix object
*/
static void replicateTo(MatrixObject m, MatrixBlock replicateMatrixBlock) {
// matrix multiply and append
MatrixBlock replicatedFeatures = replicateMatrixBlock.aggregateBinaryOperations(
replicateMatrixBlock, m.acquireReadAndRelease(), new MatrixBlock(),
new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject())));

m.acquireModify(m.acquireReadAndRelease().append(replicatedFeatures, new MatrixBlock(), false));
m.release();
}

/**
* Just a mat multiply used to subsample with a provided subsample matrixBlock
*
* @param m the input matrix object
* @param subsampleMatrixBlock the subsample matrix block
*/
static void subsampleTo(MatrixObject m, MatrixBlock subsampleMatrixBlock) {
// matrix multiply
m.acquireModify(subsampleMatrixBlock.aggregateBinaryOperations(
subsampleMatrixBlock, m.acquireReadAndRelease(), new MatrixBlock(),
new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0, Plus.getPlusFnObject()))
));
m.release();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ public FederatedDataPartitioner(Statement.FederatedPSScheme scheme) {
case SHUFFLE:
_scheme = new ShuffleFederatedScheme();
break;
case REPLICATE_TO_MAX:
_scheme = new ReplicateToMaxFederatedScheme();
break;
case SUBSAMPLE_TO_MIN:
_scheme = new SubsampleToMinFederatedScheme();
break;
case BALANCE_TO_AVG:
_scheme = new BalanceToAvgFederatedScheme();
break;
default:
throw new DMLRuntimeException(String.format("FederatedDataPartitioner: not support data partition scheme '%s'", scheme));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ public class KeepDataOnWorkerFederatedScheme extends DataPartitionFederatedSchem
public Result doPartitioning(MatrixObject features, MatrixObject labels) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
return new Result(pFeatures, pLabels, pFeatures.size());
return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures));
}
}
Loading