diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java index 22dea76bd8d..ddefb8e9e78 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java @@ -523,8 +523,6 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) { ParamservUtils.cleanupData(ec, Statement.PS_FEATURES); ParamservUtils.cleanupData(ec, Statement.PS_LABELS); ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString()); - /*if( LOG.isInfoEnabled() ) - LOG.info("[+]" + " completed batch " + localBatchNum);*/ } // model clean up diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java index e6a1bdf7dd9..34e94f06a44 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java @@ -34,6 +34,16 @@ import java.util.List; import java.util.concurrent.Future; +/** + * Balance to Avg Federated scheme + * + * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers. + * Therefore, a UDF is sent to manipulate the data locally. In this case the global average number of examples is taken + * and the worker subsamples or replicates data to match that number of examples. See the other federated schemes. + * + * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list. + * Only supports row federated matrices atm. + */ public class BalanceToAvgFederatedScheme extends DataPartitionFederatedScheme { @Override public Result partition(MatrixObject features, MatrixObject labels, int seed) { diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java index b5f34860153..afbaf4db637 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java @@ -22,6 +22,13 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import java.util.List; +/** + * Keep Data on Worker Federated scheme + * + * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers. + * All entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list. + * Only supports row federated matrices atm. + */ public class KeepDataOnWorkerFederatedScheme extends DataPartitionFederatedScheme { @Override public Result partition(MatrixObject features, MatrixObject labels, int seed) { diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java index 4d1dae546ea..a1b8f6c9d62 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java @@ -34,6 +34,17 @@ import java.util.List; import java.util.concurrent.Future; +/** + * Replicate to Max Federated scheme + * + * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers. + * Therefore, a UDF is sent to manipulate the data locally. In this case the global maximum number of examples is taken + * and the worker replicates data to match that number of examples. The generation is done by multiplying with a + * Permutation Matrix with a global seed. These selected examples are appended to the original data. + * + * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list. + * Only supports row federated matrices atm. + */ public class ReplicateToMaxFederatedScheme extends DataPartitionFederatedScheme { @Override public Result partition(MatrixObject features, MatrixObject labels, int seed) { diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java index 3a0928f6f6a..1920593cbb0 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java @@ -33,6 +33,16 @@ import java.util.List; import java.util.concurrent.Future; +/** + * Shuffle Federated scheme + * + * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers. + * Therefore, a UDF is sent to manipulate the data locally. In this case it is shuffled by generating a permutation + * matrix with a global seed and doing a mat mult. + * + * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list. + * Only supports row federated matrices atm. + */ public class ShuffleFederatedScheme extends DataPartitionFederatedScheme { @Override public Result partition(MatrixObject features, MatrixObject labels, int seed) { diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java index a5ff640e5ae..937c37e4098 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java @@ -34,6 +34,17 @@ import java.util.List; import java.util.concurrent.Future; +/** + * Subsample to Min Federated scheme + * + * When the parameter server runs in federated mode it cannot pull in the data which is already on the workers. + * Therefore, a UDF is sent to manipulate the data locally. In this case the global minimum number of examples is taken + * and the worker subsamples data to match that number of examples. The subsampling is done by multiplying with a + * Permutation Matrix with a global seed. + * + * Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list. + * Only supports row federated matrices atm. + */ public class SubsampleToMinFederatedScheme extends DataPartitionFederatedScheme { @Override public Result partition(MatrixObject features, MatrixObject labels, int seed) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 817747067f4..dc327e82368 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -123,8 +123,8 @@ public void processInstruction(ExecutionContext ec) { } private void runFederated(ExecutionContext ec) { - System.out.println("PARAMETER SERVER"); - System.out.println("[+] Running in federated mode"); + LOG.info("PARAMETER SERVER"); + LOG.info("[+] Running in federated mode"); // get inputs String updFunc = getParam(PS_UPDATE_FUN); @@ -146,21 +146,8 @@ private void runFederated(ExecutionContext ec) { // partition federated data DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(federatedPSScheme, seed) .doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), ec.getMatrixObject(getParam(PS_LABELS))); - List pFeatures = result._pFeatures; - List pLabels = result._pLabels; int workerNum = result._workerNum; - // calculate runtime balancing - int numBatchesPerEpoch = 0; - if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) { - numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._minRows / (float) getBatchSize()); - } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG - || runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) { - numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._avgRows / (float) getBatchSize()); - } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) { - numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._maxRows / (float) getBatchSize()); - } - // setup threading BasicThreadFactory factory = new BasicThreadFactory.Builder() .namingPattern("workers-pool-thread-%d").build(); @@ -178,7 +165,7 @@ private void runFederated(ExecutionContext ec) { ListObject model = ec.getListObject(getParam(PS_MODEL)); ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, updateType, workerNum, model, aggServiceEC); // Create the local workers - int finalNumBatchesPerEpoch = numBatchesPerEpoch; + int finalNumBatchesPerEpoch = getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics); List threads = IntStream.range(0, workerNum) .mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, weighing, getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps)) @@ -190,8 +177,8 @@ private void runFederated(ExecutionContext ec) { // Set features and lables for the control threads and write the program and instructions and hyperparams to the federated workers for (int i = 0; i < threads.size(); i++) { - threads.get(i).setFeatures(pFeatures.get(i)); - threads.get(i).setLabels(pLabels.get(i)); + threads.get(i).setFeatures(result._pFeatures.get(i)); + threads.get(i).setLabels(result._pLabels.get(i)); threads.get(i).setup(result._weighingFactors.get(i)); } @@ -521,6 +508,26 @@ private FederatedPSScheme getFederatedScheme() { return federated_scheme; } + /** + * Calculates the number of batches per epoch depending on the balance metrics and the runtime balancing + * + * @param runtimeBalancing the runtime balancing + * @param balanceMetrics the balance metrics calculated during data partitioning + * @return numBatchesPerEpoch + */ + private int getNumBatchesPerEpoch(PSRuntimeBalancing runtimeBalancing, DataPartitionFederatedScheme.BalanceMetrics balanceMetrics) { + int numBatchesPerEpoch = 0; + if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) { + numBatchesPerEpoch = (int) Math.ceil(balanceMetrics._minRows / (float) getBatchSize()); + } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG + || runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) { + numBatchesPerEpoch = (int) Math.ceil(balanceMetrics._avgRows / (float) getBatchSize()); + } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) { + numBatchesPerEpoch = (int) Math.ceil(balanceMetrics._maxRows / (float) getBatchSize()); + } + return numBatchesPerEpoch; + } + private boolean getWeighing() { return getParameterMap().containsKey(PS_FED_WEIGHING) && Boolean.parseBoolean(getParam(PS_FED_WEIGHING)); }