Skip to content

Commit

Permalink
[SYSTEMDS-2550] Added changes proposed by Sebastian W.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tobias Rieger committed Jan 5, 2021
1 parent 4553481 commit 87b1824
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 20 deletions.
Expand Up @@ -523,8 +523,6 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) {
ParamservUtils.cleanupData(ec, Statement.PS_FEATURES);
ParamservUtils.cleanupData(ec, Statement.PS_LABELS);
ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
/*if( LOG.isInfoEnabled() )
LOG.info("[+]" + " completed batch " + localBatchNum);*/
}

// model clean up
Expand Down
Expand Up @@ -34,6 +34,16 @@
import java.util.List;
import java.util.concurrent.Future;

/**
* Balance to Avg Federated scheme
*
* When the parameter server runs in federated mode it cannot pull in the data which is already on the workers.
* Therefore, a UDF is sent to manipulate the data locally. In this case the global average number of examples is taken
* and the worker subsamples or replicates data to match that number of examples. See the other federated schemes.
*
* Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list.
* Only supports row federated matrices atm.
*/
public class BalanceToAvgFederatedScheme extends DataPartitionFederatedScheme {
@Override
public Result partition(MatrixObject features, MatrixObject labels, int seed) {
Expand Down
Expand Up @@ -22,6 +22,13 @@
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import java.util.List;

/**
* Keep Data on Worker Federated scheme
*
* When the parameter server runs in federated mode it cannot pull in the data which is already on the workers.
* All entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list.
* Only supports row federated matrices atm.
*/
public class KeepDataOnWorkerFederatedScheme extends DataPartitionFederatedScheme {
@Override
public Result partition(MatrixObject features, MatrixObject labels, int seed) {
Expand Down
Expand Up @@ -34,6 +34,17 @@
import java.util.List;
import java.util.concurrent.Future;

/**
* Replicate to Max Federated scheme
*
* When the parameter server runs in federated mode it cannot pull in the data which is already on the workers.
* Therefore, a UDF is sent to manipulate the data locally. In this case the global maximum number of examples is taken
* and the worker replicates data to match that number of examples. The generation is done by multiplying with a
* Permutation Matrix with a global seed. These selected examples are appended to the original data.
*
* Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list.
* Only supports row federated matrices atm.
*/
public class ReplicateToMaxFederatedScheme extends DataPartitionFederatedScheme {
@Override
public Result partition(MatrixObject features, MatrixObject labels, int seed) {
Expand Down
Expand Up @@ -33,6 +33,16 @@
import java.util.List;
import java.util.concurrent.Future;

/**
* Shuffle Federated scheme
*
* When the parameter server runs in federated mode it cannot pull in the data which is already on the workers.
* Therefore, a UDF is sent to manipulate the data locally. In this case it is shuffled by generating a permutation
* matrix with a global seed and doing a mat mult.
*
* Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list.
* Only supports row federated matrices atm.
*/
public class ShuffleFederatedScheme extends DataPartitionFederatedScheme {
@Override
public Result partition(MatrixObject features, MatrixObject labels, int seed) {
Expand Down
Expand Up @@ -34,6 +34,17 @@
import java.util.List;
import java.util.concurrent.Future;

/**
* Subsample to Min Federated scheme
*
* When the parameter server runs in federated mode it cannot pull in the data which is already on the workers.
* Therefore, a UDF is sent to manipulate the data locally. In this case the global minimum number of examples is taken
* and the worker subsamples data to match that number of examples. The subsampling is done by multiplying with a
* Permutation Matrix with a global seed.
*
* Then all entries in the federation map of the input matrix are separated into MatrixObjects and returned as a list.
* Only supports row federated matrices atm.
*/
public class SubsampleToMinFederatedScheme extends DataPartitionFederatedScheme {
@Override
public Result partition(MatrixObject features, MatrixObject labels, int seed) {
Expand Down
Expand Up @@ -123,8 +123,8 @@ public void processInstruction(ExecutionContext ec) {
}

private void runFederated(ExecutionContext ec) {
System.out.println("PARAMETER SERVER");
System.out.println("[+] Running in federated mode");
LOG.info("PARAMETER SERVER");
LOG.info("[+] Running in federated mode");

// get inputs
String updFunc = getParam(PS_UPDATE_FUN);
Expand All @@ -146,21 +146,8 @@ private void runFederated(ExecutionContext ec) {
// partition federated data
DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(federatedPSScheme, seed)
.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), ec.getMatrixObject(getParam(PS_LABELS)));
List<MatrixObject> pFeatures = result._pFeatures;
List<MatrixObject> pLabels = result._pLabels;
int workerNum = result._workerNum;

// calculate runtime balancing
int numBatchesPerEpoch = 0;
if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._minRows / (float) getBatchSize());
} else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG
|| runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) {
numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._avgRows / (float) getBatchSize());
} else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) {
numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._maxRows / (float) getBatchSize());
}

// setup threading
BasicThreadFactory factory = new BasicThreadFactory.Builder()
.namingPattern("workers-pool-thread-%d").build();
Expand All @@ -178,7 +165,7 @@ private void runFederated(ExecutionContext ec) {
ListObject model = ec.getListObject(getParam(PS_MODEL));
ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, updateType, workerNum, model, aggServiceEC);
// Create the local workers
int finalNumBatchesPerEpoch = numBatchesPerEpoch;
int finalNumBatchesPerEpoch = getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics);
List<FederatedPSControlThread> threads = IntStream.range(0, workerNum)
.mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, weighing,
getEpochs(), getBatchSize(), finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
Expand All @@ -190,8 +177,8 @@ private void runFederated(ExecutionContext ec) {

// Set features and lables for the control threads and write the program and instructions and hyperparams to the federated workers
for (int i = 0; i < threads.size(); i++) {
threads.get(i).setFeatures(pFeatures.get(i));
threads.get(i).setLabels(pLabels.get(i));
threads.get(i).setFeatures(result._pFeatures.get(i));
threads.get(i).setLabels(result._pLabels.get(i));
threads.get(i).setup(result._weighingFactors.get(i));
}

Expand Down Expand Up @@ -521,6 +508,26 @@ private FederatedPSScheme getFederatedScheme() {
return federated_scheme;
}

/**
* Calculates the number of batches per epoch depending on the balance metrics and the runtime balancing
*
* @param runtimeBalancing the runtime balancing
* @param balanceMetrics the balance metrics calculated during data partitioning
* @return numBatchesPerEpoch
*/
private int getNumBatchesPerEpoch(PSRuntimeBalancing runtimeBalancing, DataPartitionFederatedScheme.BalanceMetrics balanceMetrics) {
int numBatchesPerEpoch = 0;
if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
numBatchesPerEpoch = (int) Math.ceil(balanceMetrics._minRows / (float) getBatchSize());
} else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG
|| runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) {
numBatchesPerEpoch = (int) Math.ceil(balanceMetrics._avgRows / (float) getBatchSize());
} else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) {
numBatchesPerEpoch = (int) Math.ceil(balanceMetrics._maxRows / (float) getBatchSize());
}
return numBatchesPerEpoch;
}

private boolean getWeighing() {
return getParameterMap().containsKey(PS_FED_WEIGHING) && Boolean.parseBoolean(getParam(PS_FED_WEIGHING));
}
Expand Down

0 comments on commit 87b1824

Please sign in to comment.