diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java index 5171f219ba1..05bfc4884bf 100644 --- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java @@ -289,8 +289,8 @@ private void validateParamserv(DataIdentifier output, boolean conditional) { Set valid = CollectionUtils.asSet(Statement.PS_MODEL, Statement.PS_FEATURES, Statement.PS_LABELS, Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN, Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS, - Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_RUNTIME_BALANCING, - Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING); + Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING, + Statement.PS_FED_WEIGHING, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING, Statement.PS_SEED); checkInvalidParameters(getOpCode(), getVarParams(), valid); // check existence and correctness of parameters @@ -308,9 +308,11 @@ private void validateParamserv(DataIdentifier output, boolean conditional) { checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, DataType.SCALAR, ValueType.INT64, conditional); checkDataValueType(true, fname, Statement.PS_PARALLELISM, DataType.SCALAR, ValueType.INT64, conditional); checkStringParam(true, fname, Statement.PS_SCHEME, conditional); - checkStringParam(true, fname, Statement.PS_RUNTIME_BALANCING, conditional); + checkStringParam(true, fname, Statement.PS_FED_RUNTIME_BALANCING, conditional); + checkStringParam(true, fname, Statement.PS_FED_WEIGHING, conditional); checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, DataType.LIST, ValueType.UNKNOWN, conditional); checkStringParam(true, fname, Statement.PS_CHECKPOINTING, conditional); + checkDataValueType(true, fname, Statement.PS_SEED, DataType.SCALAR, ValueType.INT64, conditional); // set output characteristics output.setDataType(DataType.LIST); diff --git a/src/main/java/org/apache/sysds/parser/Statement.java b/src/main/java/org/apache/sysds/parser/Statement.java index 6767d857e35..9104246b8f7 100644 --- a/src/main/java/org/apache/sysds/parser/Statement.java +++ b/src/main/java/org/apache/sysds/parser/Statement.java @@ -70,6 +70,7 @@ public abstract class Statement implements ParseInfo public static final String PS_AGGREGATION_FUN = "agg"; public static final String PS_MODE = "mode"; public static final String PS_GRADIENTS = "gradients"; + public static final String PS_SEED = "seed"; public enum PSModeType { FEDERATED, LOCAL, REMOTE_SPARK } @@ -87,9 +88,10 @@ public boolean isASP() { public enum PSFrequency { BATCH, EPOCH } - public static final String PS_RUNTIME_BALANCING = "runtime_balancing"; + public static final String PS_FED_WEIGHING = "weighing"; + public static final String PS_FED_RUNTIME_BALANCING = "runtime_balancing"; public enum PSRuntimeBalancing { - NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH, SCALE_BATCH_AND_WEIGH + NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH } public static final String PS_EPOCHS = "epochs"; public static final String PS_BATCH_SIZE = "batchsize"; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java index 54a45d0026c..f09b5c2ce0b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java @@ -228,7 +228,7 @@ protected final double sumValues(int valIx, double[] b, double[] dictVals, int o return val; } - protected final double sumValuesSparse(int valIx, SparseRow[] rows, double[] dictVals, int rowsIndex) { + protected static double sumValuesSparse(int valIx, SparseRow[] rows, double[] dictVals, int rowsIndex) { throw new NotImplementedException("This Method was implemented incorrectly"); // final int numCols = getNumCols(); // final int valOff = valIx * numCols; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java index 393b131a518..48249db65df 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java @@ -24,6 +24,8 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.parser.DataIdentifier; import org.apache.sysds.parser.Statement; +import org.apache.sysds.parser.Statement.PSFrequency; +import org.apache.sysds.parser.Statement.PSRuntimeBalancing; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.BasicProgramBlock; import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock; @@ -37,13 +39,17 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; +import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.cp.Data; import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction; import org.apache.sysds.runtime.instructions.cp.IntObject; import org.apache.sysds.runtime.instructions.cp.ListObject; import org.apache.sysds.runtime.instructions.cp.StringObject; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; import org.apache.sysds.runtime.util.ProgramConverter; import java.util.ArrayList; @@ -58,21 +64,29 @@ public class FederatedPSControlThread extends PSWorker implements Callable { private static final long serialVersionUID = 6846648059569648791L; protected static final Log LOG = LogFactory.getLog(ParamServer.class.getName()); - - Statement.PSRuntimeBalancing _runtimeBalancing; - FederatedData _featuresData; - FederatedData _labelsData; - final long _localStartBatchNumVarID; - final long _modelVarID; - int _numBatchesPerGlobalEpoch; - int _possibleBatchesPerLocalEpoch; - boolean _cycleStartAt0 = false; - - public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, Statement.PSRuntimeBalancing runtimeBalancing, int epochs, long batchSize, int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps) { + + private FederatedData _featuresData; + private FederatedData _labelsData; + private final long _localStartBatchNumVarID; + private final long _modelVarID; + + // runtime balancing + private PSRuntimeBalancing _runtimeBalancing; + private int _numBatchesPerEpoch; + private int _possibleBatchesPerLocalEpoch; + private boolean _weighing; + private double _weighingFactor = 1; + private boolean _cycleStartAt0 = false; + + public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, + PSRuntimeBalancing runtimeBalancing, boolean weighing, int epochs, long batchSize, + int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps) + { super(workerID, updFunc, freq, epochs, batchSize, ec, ps); - _numBatchesPerGlobalEpoch = numBatchesPerGlobalEpoch; + _numBatchesPerEpoch = numBatchesPerGlobalEpoch; _runtimeBalancing = runtimeBalancing; + _weighing = weighing; // generate the IDs for model and batch counter. These get overwritten on the federated worker each time _localStartBatchNumVarID = FederationUtils.getNextFedDataID(); _modelVarID = FederationUtils.getNextFedDataID(); @@ -80,65 +94,72 @@ public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFreque /** * Sets up the federated worker and control thread + * + * @param weighingFactor Gradients from this worker will be multiplied by this factor if weighing is enabled */ - public void setup() { + public void setup(double weighingFactor) { // prepare features and labels _featuresData = (FederatedData) _features.getFedMapping().getMap().values().toArray()[0]; _labelsData = (FederatedData) _labels.getFedMapping().getMap().values().toArray()[0]; - // calculate number of batches and get data size + // weighing factor is always set, but only used when weighing is specified + _weighingFactor = weighingFactor; + + // different runtime balancing calculations long dataSize = _features.getNumRows(); - _possibleBatchesPerLocalEpoch = (int) Math.ceil((double) dataSize / _batchSize); - if(!(_runtimeBalancing == Statement.PSRuntimeBalancing.RUN_MIN - || _runtimeBalancing == Statement.PSRuntimeBalancing.CYCLE_AVG - || _runtimeBalancing == Statement.PSRuntimeBalancing.CYCLE_MAX)) { - _numBatchesPerGlobalEpoch = _possibleBatchesPerLocalEpoch; + + // calculate scaled batch size if balancing via batch size. + // In some cases there will be some cycling + if(_runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) { + _batchSize = (int) Math.ceil((double) dataSize / _numBatchesPerEpoch); } - if(_runtimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH - || _runtimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH_AND_WEIGH) { - throw new NotImplementedException(); + // Calculate possible batches with batch size + _possibleBatchesPerLocalEpoch = (int) Math.ceil((double) dataSize / _batchSize); + + // If no runtime balancing is specified, just run possible number of batches + // WARNING: Will get stuck on miss match + if(_runtimeBalancing == PSRuntimeBalancing.NONE) { + _numBatchesPerEpoch = _possibleBatchesPerLocalEpoch; } + LOG.info("Setup config for worker " + this.getWorkerName()); + LOG.info("Batch size: " + _batchSize + " possible batches: " + _possibleBatchesPerLocalEpoch + + " batches to run: " + _numBatchesPerEpoch + " weighing factor: " + _weighingFactor); + // serialize program // create program blocks for the instruction filtering String programSerialized; - ArrayList programBlocks = new ArrayList<>(); + ArrayList pbs = new ArrayList<>(); BasicProgramBlock gradientProgramBlock = new BasicProgramBlock(_ec.getProgram()); gradientProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_inst))); - programBlocks.add(gradientProgramBlock); + pbs.add(gradientProgramBlock); - if(_freq == Statement.PSFrequency.EPOCH) { + if(_freq == PSFrequency.EPOCH) { BasicProgramBlock aggProgramBlock = new BasicProgramBlock(_ec.getProgram()); aggProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_ps.getAggInst()))); - programBlocks.add(aggProgramBlock); + pbs.add(aggProgramBlock); } - StringBuilder sb = new StringBuilder(); - sb.append(PROG_BEGIN); - sb.append( NEWLINE ); - sb.append(ProgramConverter.serializeProgram(_ec.getProgram(), - programBlocks, - new HashMap<>(), - false - )); - sb.append(PROG_END); - programSerialized = sb.toString(); + programSerialized = InstructionUtils.concatStrings( + PROG_BEGIN, NEWLINE, + ProgramConverter.serializeProgram(_ec.getProgram(), pbs, new HashMap<>(), false), + PROG_END); // write program and meta data to worker Future udfResponse = _featuresData.executeFederatedOperation( new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(), new SetupFederatedWorker(_batchSize, - dataSize, - _possibleBatchesPerLocalEpoch, - programSerialized, - _inst.getNamespace(), - _inst.getFunctionName(), - _ps.getAggInst().getFunctionName(), - _ec.getListObject("hyperparams"), - _localStartBatchNumVarID, - _modelVarID + dataSize, + _possibleBatchesPerLocalEpoch, + programSerialized, + _inst.getNamespace(), + _inst.getFunctionName(), + _ps.getAggInst().getFunctionName(), + _ec.getListObject("hyperparams"), + _localStartBatchNumVarID, + _modelVarID ) )); @@ -286,12 +307,23 @@ protected ListObject pullModel() { return _ps.pull(_workerID); } - protected void pushGradients(ListObject gradients) { + protected void scaleAndPushGradients(ListObject gradients) { + // scale gradients - must only include MatrixObjects + if(_weighing && _weighingFactor != 1) { + gradients.getData().parallelStream().forEach((matrix) -> { + MatrixObject matrixObject = (MatrixObject) matrix; + MatrixBlock input = matrixObject.acquireReadAndRelease().scalarOperations( + new RightScalarOperator(Multiply.getMultiplyFnObject(), _weighingFactor), new MatrixBlock()); + matrixObject.acquireModify(input); + matrixObject.release(); + }); + } + // Push the gradients to ps _ps.push(_workerID, gradients); } - static protected int getNextLocalBatchNum(int currentLocalBatchNumber, int possibleBatchesPerLocalEpoch) { + protected static int getNextLocalBatchNum(int currentLocalBatchNumber, int possibleBatchesPerLocalEpoch) { return currentLocalBatchNumber % possibleBatchesPerLocalEpoch; } @@ -300,18 +332,18 @@ static protected int getNextLocalBatchNum(int currentLocalBatchNumber, int possi */ protected void computeWithBatchUpdates() { for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) { - int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : _numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch; + int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : _numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch; - for (int batchCounter = 0; batchCounter < _numBatchesPerGlobalEpoch; batchCounter++) { + for (int batchCounter = 0; batchCounter < _numBatchesPerEpoch; batchCounter++) { int localStartBatchNum = getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch); ListObject model = pullModel(); ListObject gradients = computeGradientsForNBatches(model, 1, localStartBatchNum); - pushGradients(gradients); + scaleAndPushGradients(gradients); ParamservUtils.cleanupListObject(model); ParamservUtils.cleanupListObject(gradients); + LOG.info("[+] " + this.getWorkerName() + " completed BATCH " + localStartBatchNum); } - if( LOG.isInfoEnabled() ) - LOG.info("[+] " + this.getWorkerName() + " completed epoch " + epochCounter); + LOG.info("[+] " + this.getWorkerName() + " --- completed EPOCH " + epochCounter); } } @@ -327,15 +359,14 @@ protected void computeWithNBatchUpdates() { */ protected void computeWithEpochUpdates() { for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) { - int localStartBatchNum = (_cycleStartAt0) ? 0 : _numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch; + int localStartBatchNum = (_cycleStartAt0) ? 0 : _numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch; // Pull the global parameters from ps ListObject model = pullModel(); - ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerGlobalEpoch, localStartBatchNum, true); - pushGradients(gradients); - - if( LOG.isInfoEnabled() ) - LOG.info("[+] " + this.getWorkerName() + " completed epoch " + epochCounter); + ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum, true); + scaleAndPushGradients(gradients); + + LOG.info("[+] " + this.getWorkerName() + " --- completed EPOCH " + epochCounter); ParamservUtils.cleanupListObject(model); ParamservUtils.cleanupListObject(gradients); } @@ -424,12 +455,12 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { ArrayList inputs = func.getInputParams(); ArrayList outputs = func.getOutputParams(); CPOperand[] boundInputs = inputs.stream() - .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) - .toArray(CPOperand[]::new); + .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) + .toArray(CPOperand[]::new); ArrayList outputNames = outputs.stream().map(DataIdentifier::getName) - .collect(Collectors.toCollection(ArrayList::new)); + .collect(Collectors.toCollection(ArrayList::new)); Instruction gradientsInstruction = new FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs, - func.getInputParamNames(), outputNames, "gradient function"); + func.getInputParamNames(), outputNames, "gradient function"); DataIdentifier gradientsOutput = outputs.get(0); // recreate aggregation instruction and output if needed @@ -440,12 +471,12 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { inputs = func.getInputParams(); outputs = func.getOutputParams(); boundInputs = inputs.stream() - .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) - .toArray(CPOperand[]::new); + .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) + .toArray(CPOperand[]::new); outputNames = outputs.stream().map(DataIdentifier::getName) - .collect(Collectors.toCollection(ArrayList::new)); + .collect(Collectors.toCollection(ArrayList::new)); aggregationInstruction = new FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs, - func.getInputParamNames(), outputNames, "aggregation function"); + func.getInputParamNames(), outputNames, "aggregation function"); aggregationOutput = outputs.get(0); } @@ -492,8 +523,6 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { ParamservUtils.cleanupData(ec, Statement.PS_FEATURES); ParamservUtils.cleanupData(ec, Statement.PS_LABELS); ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString()); - if( LOG.isInfoEnabled() ) - LOG.info("[+]" + " completed batch " + localBatchNum); } // model clean up diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java index 460fabaa217..34e94f06a44 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java @@ -20,7 +20,6 @@ package org.apache.sysds.runtime.controlprogram.paramserv.dp; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.federated.FederatedData; @@ -35,13 +34,25 @@ import java.util.List; import java.util.concurrent.Future; +/** + * Balance to Avg Federated scheme + * + * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers. + * Therefore, a UDF is sent to manipulate the data locally. In this case the global average number of examples is taken + * and the worker subsamples or replicates data to match that number of examples. See the other federated schemes. + * + * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list. + * Only supports row federated matrices atm. + */ public class BalanceToAvgFederatedScheme extends DataPartitionFederatedScheme { @Override - public Result doPartitioning(MatrixObject features, MatrixObject labels) { + public Result partition(MatrixObject features, MatrixObject labels, int seed) { List pFeatures = sliceFederatedMatrix(features); List pLabels = sliceFederatedMatrix(labels); + BalanceMetrics balanceMetricsBefore = getBalanceMetrics(pFeatures); + List weighingFactors = getWeighingFactors(pFeatures, balanceMetricsBefore); - int average_num_rows = (int) Math.round(pFeatures.stream().map(CacheableData::getNumRows).mapToInt(Long::intValue).average().orElse(Double.NaN)); + int average_num_rows = (int) balanceMetricsBefore._avgRows; for(int i = 0; i < pFeatures.size(); i++) { // Works, because the map contains a single entry @@ -49,7 +60,7 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0]; Future udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, - featuresData.getVarID(), new balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, average_num_rows))); + featuresData.getVarID(), new balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed, average_num_rows))); try { FederatedResponse response = udfResponse.get(); @@ -66,7 +77,7 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { pLabels.get(i).updateDataCharacteristics(update); } - return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures)); + return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures), weighingFactors); } /** @@ -74,10 +85,12 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { */ private static class balanceDataOnFederatedWorker extends FederatedUDF { private static final long serialVersionUID = 6631958250346625546L; + private final int _seed; private final int _average_num_rows; - - protected balanceDataOnFederatedWorker(long[] inIDs, int average_num_rows) { + + protected balanceDataOnFederatedWorker(long[] inIDs, int seed, int average_num_rows) { super(inIDs); + _seed = seed; _average_num_rows = average_num_rows; } @@ -88,14 +101,14 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { if(features.getNumRows() > _average_num_rows) { // generate subsampling matrix - MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(_average_num_rows, Math.toIntExact(features.getNumRows()), System.currentTimeMillis()); + MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(_average_num_rows, Math.toIntExact(features.getNumRows()), _seed); subsampleTo(features, subsampleMatrixBlock); subsampleTo(labels, subsampleMatrixBlock); } else if(features.getNumRows() < _average_num_rows) { int num_rows_needed = _average_num_rows - Math.toIntExact(features.getNumRows()); // generate replication matrix - MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), System.currentTimeMillis()); + MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), _seed); replicateTo(features, replicateMatrixBlock); replicateTo(labels, replicateMatrixBlock); } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java index f5c963873aa..e00923e8ced 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java @@ -45,16 +45,31 @@ public static final class Result { public final List _pLabels; public final int _workerNum; public final BalanceMetrics _balanceMetrics; + public final List _weighingFactors; - public Result(List pFeatures, List pLabels, int workerNum, BalanceMetrics balanceMetrics) { - this._pFeatures = pFeatures; - this._pLabels = pLabels; - this._workerNum = workerNum; - this._balanceMetrics = balanceMetrics; + + public Result(List pFeatures, List pLabels, int workerNum, BalanceMetrics balanceMetrics, List weighingFactors) { + _pFeatures = pFeatures; + _pLabels = pLabels; + _workerNum = workerNum; + _balanceMetrics = balanceMetrics; + _weighingFactors = weighingFactors; } } - public abstract Result doPartitioning(MatrixObject features, MatrixObject labels); + public static final class BalanceMetrics { + public final long _minRows; + public final long _avgRows; + public final long _maxRows; + + public BalanceMetrics(long minRows, long avgRows, long maxRows) { + _minRows = minRows; + _avgRows = avgRows; + _maxRows = maxRows; + } + } + + public abstract Result partition(MatrixObject features, MatrixObject labels, int seed); /** * Takes a row federated Matrix and slices it into a matrix for each worker @@ -110,16 +125,12 @@ else if (slice.getNumRows() > maxRows) return new BalanceMetrics(minRows, sum / slices.size(), maxRows); } - public static final class BalanceMetrics { - public final long _minRows; - public final long _avgRows; - public final long _maxRows; - - public BalanceMetrics(long minRows, long avgRows, long maxRows) { - this._minRows = minRows; - this._avgRows = avgRows; - this._maxRows = maxRows; - } + static List getWeighingFactors(List pFeatures, BalanceMetrics balanceMetrics) { + List weighingFactors = new ArrayList<>(); + pFeatures.forEach((feature) -> { + weighingFactors.add((double) feature.getNumRows() / balanceMetrics._avgRows); + }); + return weighingFactors; } /** diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java index d1ebb6cac5d..ce2f954ca45 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java @@ -24,10 +24,11 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; public class FederatedDataPartitioner { - private final DataPartitionFederatedScheme _scheme; + private final int _seed; - public FederatedDataPartitioner(Statement.FederatedPSScheme scheme) { + public FederatedDataPartitioner(Statement.FederatedPSScheme scheme, int seed) { + _seed = seed; switch (scheme) { case KEEP_DATA_ON_WORKER: _scheme = new KeepDataOnWorkerFederatedScheme(); @@ -50,6 +51,6 @@ public FederatedDataPartitioner(Statement.FederatedPSScheme scheme) { } public DataPartitionFederatedScheme.Result doPartitioning(MatrixObject features, MatrixObject labels) { - return _scheme.doPartitioning(features, labels); + return _scheme.partition(features, labels, _seed); } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java index e306f25d29c..afbaf4db637 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java @@ -22,11 +22,20 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import java.util.List; +/** + * Keep Data on Worker Federated scheme + * + * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers. + * All entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list. + * Only supports row federated matrices atm. + */ public class KeepDataOnWorkerFederatedScheme extends DataPartitionFederatedScheme { @Override - public Result doPartitioning(MatrixObject features, MatrixObject labels) { + public Result partition(MatrixObject features, MatrixObject labels, int seed) { List pFeatures = sliceFederatedMatrix(features); List pLabels = sliceFederatedMatrix(labels); - return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures)); + BalanceMetrics balanceMetrics = getBalanceMetrics(pFeatures); + List weighingFactors = getWeighingFactors(pFeatures, balanceMetrics); + return new Result(pFeatures, pLabels, pFeatures.size(), balanceMetrics, weighingFactors); } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java index 068cfa9d7a8..a1b8f6c9d62 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java @@ -34,11 +34,23 @@ import java.util.List; import java.util.concurrent.Future; +/** + * Replicate to Max Federated scheme + * + * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers. + * Therefore, a UDF is sent to manipulate the data locally. In this case the global maximum number of examples is taken + * and the worker replicates data to match that number of examples. The generation is done by multiplying with a + * Permutation Matrix with a global seed. These selected examples are appended to the original data. + * + * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list. + * Only supports row federated matrices atm. + */ public class ReplicateToMaxFederatedScheme extends DataPartitionFederatedScheme { @Override - public Result doPartitioning(MatrixObject features, MatrixObject labels) { + public Result partition(MatrixObject features, MatrixObject labels, int seed) { List pFeatures = sliceFederatedMatrix(features); List pLabels = sliceFederatedMatrix(labels); + List weighingFactors = getWeighingFactors(pFeatures, getBalanceMetrics(pFeatures)); int max_rows = 0; for (MatrixObject pFeature : pFeatures) { @@ -51,7 +63,7 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0]; Future udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, - featuresData.getVarID(), new replicateDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, max_rows))); + featuresData.getVarID(), new replicateDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed, max_rows))); try { FederatedResponse response = udfResponse.get(); @@ -68,7 +80,7 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { pLabels.get(i).updateDataCharacteristics(update); } - return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures)); + return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures), weighingFactors); } /** @@ -76,10 +88,12 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { */ private static class replicateDataOnFederatedWorker extends FederatedUDF { private static final long serialVersionUID = -6930898456315100587L; + private final int _seed; private final int _max_rows; - - protected replicateDataOnFederatedWorker(long[] inIDs, int max_rows) { + + protected replicateDataOnFederatedWorker(long[] inIDs, int seed, int max_rows) { super(inIDs); + _seed = seed; _max_rows = max_rows; } @@ -92,7 +106,7 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { if(features.getNumRows() < _max_rows) { int num_rows_needed = _max_rows - Math.toIntExact(features.getNumRows()); // generate replication matrix - MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), System.currentTimeMillis()); + MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), _seed); replicateTo(features, replicateMatrixBlock); replicateTo(labels, replicateMatrixBlock); } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java index 65ef69d83c9..1920593cbb0 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java @@ -33,11 +33,23 @@ import java.util.List; import java.util.concurrent.Future; +/** + * Shuffle Federated scheme + * + * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers. + * Therefore, a UDF is sent to manipulate the data locally. In this case it is shuffled by generating a permutation + * matrix with a global seed and doing a mat mult. + * + * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list. + * Only supports row federated matrices atm. + */ public class ShuffleFederatedScheme extends DataPartitionFederatedScheme { @Override - public Result doPartitioning(MatrixObject features, MatrixObject labels) { + public Result partition(MatrixObject features, MatrixObject labels, int seed) { List pFeatures = sliceFederatedMatrix(features); List pLabels = sliceFederatedMatrix(labels); + BalanceMetrics balanceMetrics = getBalanceMetrics(pFeatures); + List weighingFactors = getWeighingFactors(pFeatures, balanceMetrics); for(int i = 0; i < pFeatures.size(); i++) { // Works, because the map contains a single entry @@ -45,7 +57,7 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0]; Future udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, - featuresData.getVarID(), new shuffleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}))); + featuresData.getVarID(), new shuffleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed))); try { FederatedResponse response = udfResponse.get(); @@ -57,7 +69,7 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { } } - return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures)); + return new Result(pFeatures, pLabels, pFeatures.size(), balanceMetrics, weighingFactors); } /** @@ -65,9 +77,11 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { */ private static class shuffleDataOnFederatedWorker extends FederatedUDF { private static final long serialVersionUID = 3228664618781333325L; + private final int _seed; - protected shuffleDataOnFederatedWorker(long[] inIDs) { + protected shuffleDataOnFederatedWorker(long[] inIDs, int seed) { super(inIDs); + _seed = seed; } @Override @@ -76,7 +90,7 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { MatrixObject labels = (MatrixObject) data[1]; // generate permutation matrix - MatrixBlock permutationMatrixBlock = ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()), System.currentTimeMillis()); + MatrixBlock permutationMatrixBlock = ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()), _seed); shuffle(features, permutationMatrixBlock); shuffle(labels, permutationMatrixBlock); return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java index 9b62cc80fb8..937c37e4098 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java @@ -34,11 +34,23 @@ import java.util.List; import java.util.concurrent.Future; +/** + * Subsample to Min Federated scheme + * + * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers. + * Therefore, a UDF is sent to manipulate the data locally. In this case the global minimum number of examples is taken + * and the worker subsamples data to match that number of examples. The subsampling is done by multiplying with a + * Permutation Matrix with a global seed. + * + * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list. + * Only supports row federated matrices atm. + */ public class SubsampleToMinFederatedScheme extends DataPartitionFederatedScheme { @Override - public Result doPartitioning(MatrixObject features, MatrixObject labels) { + public Result partition(MatrixObject features, MatrixObject labels, int seed) { List pFeatures = sliceFederatedMatrix(features); List pLabels = sliceFederatedMatrix(labels); + List weighingFactors = getWeighingFactors(pFeatures, getBalanceMetrics(pFeatures)); int min_rows = Integer.MAX_VALUE; for (MatrixObject pFeature : pFeatures) { @@ -51,7 +63,7 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { FederatedData labelsData = (FederatedData) pLabels.get(i).getFedMapping().getMap().values().toArray()[0]; Future udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, - featuresData.getVarID(), new subsampleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, min_rows))); + featuresData.getVarID(), new subsampleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed, min_rows))); try { FederatedResponse response = udfResponse.get(); @@ -68,7 +80,7 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { pLabels.get(i).updateDataCharacteristics(update); } - return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures)); + return new Result(pFeatures, pLabels, pFeatures.size(), getBalanceMetrics(pFeatures), weighingFactors); } /** @@ -76,10 +88,12 @@ public Result doPartitioning(MatrixObject features, MatrixObject labels) { */ private static class subsampleDataOnFederatedWorker extends FederatedUDF { private static final long serialVersionUID = 2213790859544004286L; + private final int _seed; private final int _min_rows; - - protected subsampleDataOnFederatedWorker(long[] inIDs, int min_rows) { + + protected subsampleDataOnFederatedWorker(long[] inIDs, int seed, int min_rows) { super(inIDs); + _seed = seed; _min_rows = min_rows; } @@ -91,7 +105,7 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { // subsample down to minimum if(features.getNumRows() > _min_rows) { // generate subsampling matrix - MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(_min_rows, Math.toIntExact(features.getNumRows()), System.currentTimeMillis()); + MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(_min_rows, Math.toIntExact(features.getNumRows()), _seed); subsampleTo(features, subsampleMatrixBlock); subsampleTo(labels, subsampleMatrixBlock); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index a2b8d9fbc22..a66e03992d6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -19,6 +19,17 @@ package org.apache.sysds.runtime.instructions.cp; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + import static org.apache.sysds.parser.Statement.PS_AGGREGATION_FUN; import static org.apache.sysds.parser.Statement.PS_BATCH_SIZE; import static org.apache.sysds.parser.Statement.PS_EPOCHS; @@ -32,18 +43,9 @@ import static org.apache.sysds.parser.Statement.PS_SCHEME; import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN; import static org.apache.sysds.parser.Statement.PS_UPDATE_TYPE; -import static org.apache.sysds.parser.Statement.PS_RUNTIME_BALANCING; - -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; -import java.util.stream.Collectors; -import java.util.stream.IntStream; +import static org.apache.sysds.parser.Statement.PS_FED_RUNTIME_BALANCING; +import static org.apache.sysds.parser.Statement.PS_FED_WEIGHING; +import static org.apache.sysds.parser.Statement.PS_SEED; import org.apache.commons.lang3.concurrent.BasicThreadFactory; import org.apache.commons.logging.Log; @@ -121,37 +123,36 @@ public void processInstruction(ExecutionContext ec) { } private void runFederated(ExecutionContext ec) { - System.out.println("PARAMETER SERVER"); - System.out.println("[+] Running in federated mode"); + LOG.info("PARAMETER SERVER"); + LOG.info("[+] Running in federated mode"); // get inputs - PSFrequency freq = getFrequency(); - PSUpdateType updateType = getUpdateType(); - PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing(); - FederatedPSScheme federatedPSScheme = getFederatedScheme(); String updFunc = getParam(PS_UPDATE_FUN); String aggFunc = getParam(PS_AGGREGATION_FUN); - + PSUpdateType updateType = getUpdateType(); + PSFrequency freq = getFrequency(); + FederatedPSScheme federatedPSScheme = getFederatedScheme(); + PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing(); + boolean weighing = getWeighing(); + int seed = getSeed(); + + if( LOG.isInfoEnabled() ) { + LOG.info("[+] Update Type: " + updateType); + LOG.info("[+] Frequency: " + freq); + LOG.info("[+] Data Partitioning: " + federatedPSScheme); + LOG.info("[+] Runtime Balancing: " + runtimeBalancing); + LOG.info("[+] Weighing: " + weighing); + LOG.info("[+] Seed: " + seed); + } + // partition federated data - DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(federatedPSScheme) - .doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), ec.getMatrixObject(getParam(PS_LABELS))); - List pFeatures = result._pFeatures; - List pLabels = result._pLabels; + DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(federatedPSScheme, seed) + .doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), ec.getMatrixObject(getParam(PS_LABELS))); int workerNum = result._workerNum; - // calculate runtime balancing - int numBatchesPerEpoch = 0; - if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) { - numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._minRows / (float) getBatchSize()); - } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG) { - numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._avgRows / (float) getBatchSize()); - } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) { - numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._maxRows / (float) getBatchSize()); - } - // setup threading BasicThreadFactory factory = new BasicThreadFactory.Builder() - .namingPattern("workers-pool-thread-%d").build(); + .namingPattern("workers-pool-thread-%d").build(); ExecutorService es = Executors.newFixedThreadPool(workerNum, factory); // Get the compiled execution context @@ -166,10 +167,11 @@ private void runFederated(ExecutionContext ec) { ListObject model = ec.getListObject(getParam(PS_MODEL)); ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, updateType, workerNum, model, aggServiceEC); // Create the local workers - int finalNumBatchesPerEpoch = numBatchesPerEpoch; + int finalNumBatchesPerEpoch = getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics); List threads = IntStream.range(0, workerNum) - .mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps)) - .collect(Collectors.toList()); + .mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, weighing, + getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps)) + .collect(Collectors.toList()); if(workerNum != threads.size()) { throw new DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning does not match threads!"); @@ -177,9 +179,9 @@ private void runFederated(ExecutionContext ec) { // Set features and lables for the control threads and write the program and instructions and hyperparams to the federated workers for (int i = 0; i < threads.size(); i++) { - threads.get(i).setFeatures(pFeatures.get(i)); - threads.get(i).setLabels(pLabels.get(i)); - threads.get(i).setup(); + threads.get(i).setFeatures(result._pFeatures.get(i)); + threads.get(i).setLabels(result._pLabels.get(i)); + threads.get(i).setup(result._weighingFactors.get(i)); } try { @@ -395,14 +397,14 @@ private PSFrequency getFrequency() { } private PSRuntimeBalancing getRuntimeBalancing() { - if (!getParameterMap().containsKey(PS_RUNTIME_BALANCING)) { + if (!getParameterMap().containsKey(PS_FED_RUNTIME_BALANCING)) { return DEFAULT_RUNTIME_BALANCING; } try { - return PSRuntimeBalancing.valueOf(getParam(PS_RUNTIME_BALANCING)); + return PSRuntimeBalancing.valueOf(getParam(PS_FED_RUNTIME_BALANCING)); } catch (IllegalArgumentException e) { throw new DMLRuntimeException(String.format("Paramserv function: " - + "not support '%s' runtime balancing.", getParam(PS_RUNTIME_BALANCING))); + + "not support '%s' runtime balancing.", getParam(PS_FED_RUNTIME_BALANCING))); } } @@ -507,4 +509,32 @@ private FederatedPSScheme getFederatedScheme() { } return federated_scheme; } + + /** + * Calculates the number of batches per epoch depending on the balance metrics and the runtime balancing + * + * @param runtimeBalancing the runtime balancing + * @param balanceMetrics the balance metrics calculated during data partitioning + * @return numBatchesPerEpoch + */ + private int getNumBatchesPerEpoch(PSRuntimeBalancing runtimeBalancing, DataPartitionFederatedScheme.BalanceMetrics balanceMetrics) { + int numBatchesPerEpoch = 0; + if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) { + numBatchesPerEpoch = (int) Math.ceil(balanceMetrics._minRows / (float) getBatchSize()); + } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG + || runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) { + numBatchesPerEpoch = (int) Math.ceil(balanceMetrics._avgRows / (float) getBatchSize()); + } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) { + numBatchesPerEpoch = (int) Math.ceil(balanceMetrics._maxRows / (float) getBatchSize()); + } + return numBatchesPerEpoch; + } + + private boolean getWeighing() { + return getParameterMap().containsKey(PS_FED_WEIGHING) && Boolean.parseBoolean(getParam(PS_FED_WEIGHING)); + } + + private int getSeed() { + return (getParameterMap().containsKey(PS_SEED)) ? Integer.parseInt(getParam(PS_SEED)) : (int) System.currentTimeMillis(); + } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java index 6a52fc45ce9..a00e8dc613e 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java @@ -54,44 +54,45 @@ public class FederatedParamservTest extends AutomatedTestBase { private final String _freq; private final String _scheme; private final String _runtime_balancing; + private final String _weighing; private final String _data_distribution; + private final int _seed; // parameters @Parameterized.Parameters public static Collection parameters() { return Arrays.asList(new Object[][] { // Network type, number of federated workers, data set size, batch size, epochs, learning rate, update type, update frequency - // basic functionality - {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG", "IMBALANCED"}, - {"CNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "SHUFFLE", "NONE" , "IMBALANCED"}, - {"CNN", 2, 4, 1, 4, 0.01, "ASP", "BATCH", "REPLICATE_TO_MAX", "RUN_MIN" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 4, 0.01, "ASP", "EPOCH", "BALANCE_TO_AVG", "CYCLE_MAX", "IMBALANCED"}, - {"TwoNN", 5, 1000, 100, 2, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}, - - /* - // runtime balancing - {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"}, - - // data partitioning - {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SHUFFLE", "CYCLE_AVG" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "REPLICATE_TO_MAX", "NONE" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SUBSAMPLE_TO_MIN", "NONE" , "IMBALANCED"}, - {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "BALANCE_TO_AVG", "NONE" , "IMBALANCED"}, - - // balanced tests - {"CNN", 5, 1000, 100, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"} - */ + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED", 200}, + {"CNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "SHUFFLE", "NONE" , "true", "IMBALANCED", 200}, + {"CNN", 2, 4, 1, 4, 0.01, "ASP", "BATCH", "REPLICATE_TO_MAX", "RUN_MIN" , "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "ASP", "EPOCH", "BALANCE_TO_AVG", "CYCLE_MAX" , "true", "IMBALANCED", 200}, + {"TwoNN", 5, 1000, 100, 2, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "NONE" , "true", "BALANCED", 200}, + + /* // runtime balancing + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "true", "IMBALANCED", 200}, + + // data partitioning + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SHUFFLE", "CYCLE_AVG" , "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "REPLICATE_TO_MAX", "NONE" , "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "SUBSAMPLE_TO_MIN", "NONE" , "true", "IMBALANCED", 200}, + {"TwoNN", 2, 4, 1, 1, 0.01, "BSP", "BATCH", "BALANCE_TO_AVG", "NONE" , "true", "IMBALANCED", 200}, + + // balanced tests + {"CNN", 5, 1000, 100, 2, 0.01, "BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "true", "BALANCED", 200} */ + }); } public FederatedParamservTest(String networkType, int numFederatedWorkers, int dataSetSize, int batch_size, - int epochs, double eta, String utype, String freq, String scheme, String runtime_balancing, String data_distribution) { + int epochs, double eta, String utype, String freq, String scheme, String runtime_balancing, String weighing, String data_distribution, int seed) { + _networkType = networkType; _numFederatedWorkers = numFederatedWorkers; _dataSetSize = dataSetSize; @@ -102,7 +103,9 @@ public FederatedParamservTest(String networkType, int numFederatedWorkers, int d _freq = freq; _scheme = scheme; _runtime_balancing = runtime_balancing; + _weighing = weighing; _data_distribution = data_distribution; + _seed = seed; } @Override @@ -185,11 +188,12 @@ private void federatedParamserv(ExecMode mode) { "freq=" + _freq, "scheme=" + _scheme, "runtime_balancing=" + _runtime_balancing, + "weighing=" + _weighing, "network_type=" + _networkType, "channels=" + C, "hin=" + Hin, "win=" + Win, - "seed=" + 25)); + "seed=" + _seed)); programArgs = programArgsList.toArray(new String[0]); LOG.debug(runTest(null)); diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml b/src/test/scripts/functions/federated/paramserv/CNN.dml index 69c7e760442..0f9ae6305dc 100644 --- a/src/test/scripts/functions/federated/paramserv/CNN.dml +++ b/src/test/scripts/functions/federated/paramserv/CNN.dml @@ -163,7 +163,7 @@ train = function(matrix[double] X, matrix[double] y, */ train_paramserv = function(matrix[double] X, matrix[double] y, matrix[double] X_val, matrix[double] y_val, - int num_workers, int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing, + int num_workers, int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing, string weighing, double eta, int C, int Hin, int Win, int seed = -1) return (list[unknown] model) { @@ -211,7 +211,7 @@ train_paramserv = function(matrix[double] X, matrix[double] y, upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients", agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation", k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, - scheme=scheme, runtime_balancing=runtime_balancing, hyperparams=hyperparams) + scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing, hyperparams=hyperparams, seed=seed) } /* diff --git a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml index 10d2cc7f028..5176ccab8f3 100644 --- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml +++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml @@ -26,10 +26,12 @@ source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN features = read($features) labels = read($labels) +print($weighing) + if($network_type == "TwoNN") { - model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $eta, $seed) + model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighing, $eta, $seed) } else { - model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $eta, $channels, $hin, $win, $seed) + model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, $runtime_balancing, $weighing, $eta, $channels, $hin, $win, $seed) } print(toString(model)) \ No newline at end of file diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml b/src/test/scripts/functions/federated/paramserv/TwoNN.dml index 9bd49d85917..a6dc6f2dc9e 100644 --- a/src/test/scripts/functions/federated/paramserv/TwoNN.dml +++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml @@ -125,7 +125,7 @@ train = function(matrix[double] X, matrix[double] y, */ train_paramserv = function(matrix[double] X, matrix[double] y, matrix[double] X_val, matrix[double] y_val, - int num_workers, int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing, + int num_workers, int epochs, string utype, string freq, int batch_size, string scheme, string runtime_balancing, string weighing, double eta, int seed = -1) return (list[unknown] model) { @@ -155,7 +155,7 @@ train_paramserv = function(matrix[double] X, matrix[double] y, upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients", agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation", k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, - scheme=scheme, runtime_balancing=runtime_balancing, hyperparams=hyperparams) + scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing, hyperparams=hyperparams, seed=seed) } /*