Skip to content

Commit

Permalink
[SYSTEMDS-2550] Extended parameter server (validation function, stats)
Browse files Browse the repository at this point in the history
Closes #1154.
  • Loading branch information
Tobias Rieger authored and mboehm7 committed Jan 30, 2021
1 parent 11a2334 commit b6640d9
Show file tree
Hide file tree
Showing 13 changed files with 372 additions and 135 deletions.
Expand Up @@ -288,7 +288,7 @@ private void validateParamserv(DataIdentifier output, boolean conditional) {
//check for invalid parameters
Set<String> valid = CollectionUtils.asSet(Statement.PS_MODEL, Statement.PS_FEATURES, Statement.PS_LABELS,
Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
Statement.PS_VAL_FUN, Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING,
Statement.PS_FED_WEIGHING, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING, Statement.PS_SEED);
checkInvalidParameters(getOpCode(), getVarParams(), valid);
Expand All @@ -301,6 +301,7 @@ private void validateParamserv(DataIdentifier output, boolean conditional) {
checkDataValueType(true, fname, Statement.PS_VAL_LABELS, DataType.MATRIX, ValueType.FP64, conditional);
checkDataValueType(false, fname, Statement.PS_UPDATE_FUN, DataType.SCALAR, ValueType.STRING, conditional);
checkDataValueType(false, fname, Statement.PS_AGGREGATION_FUN, DataType.SCALAR, ValueType.STRING, conditional);
checkDataValueType(true, fname, Statement.PS_VAL_FUN, DataType.SCALAR, ValueType.STRING, conditional);
checkStringParam(true, fname, Statement.PS_MODE, conditional);
checkStringParam(true, fname, Statement.PS_UPDATE_TYPE, conditional);
checkStringParam(true, fname, Statement.PS_FREQUENCY, conditional);
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/parser/Statement.java
Expand Up @@ -66,6 +66,7 @@ public abstract class Statement implements ParseInfo
public static final String PS_LABELS = "labels";
public static final String PS_VAL_FEATURES = "val_features";
public static final String PS_VAL_LABELS = "val_labels";
public static final String PS_VAL_FUN = "val";
public static final String PS_UPDATE_FUN = "upd";
public static final String PS_AGGREGATION_FUN = "agg";
public static final String PS_MODE = "mode";
Expand Down Expand Up @@ -117,7 +118,6 @@ public enum PSCheckpointing {
public static final String PS_FED_NAMESPACE = "1701-NCC-namespace";
public static final String PS_FED_GRADIENTS_FNAME = "1701-NCC-gradients_fname";
public static final String PS_FED_AGGREGATION_FNAME = "1701-NCC-aggregation_fname";
public static final String PS_FED_BATCHCOUNTER_VARID = "1701-NCC-batchcounter_varid";
public static final String PS_FED_MODEL_VARID = "1701-NCC-model_varid";


Expand Down
Expand Up @@ -23,6 +23,7 @@
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.parser.Statement.PSFrequency;
Expand All @@ -45,6 +46,7 @@
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.IntObject;
import org.apache.sysds.runtime.instructions.cp.ListObject;
Expand All @@ -53,9 +55,10 @@
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.utils.Statistics;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
Expand All @@ -69,16 +72,15 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>

private FederatedData _featuresData;
private FederatedData _labelsData;
private final long _localStartBatchNumVarID;
private final long _modelVarID;

// runtime balancing
private PSRuntimeBalancing _runtimeBalancing;
private final PSRuntimeBalancing _runtimeBalancing;
private int _numBatchesPerEpoch;
private int _possibleBatchesPerLocalEpoch;
private boolean _weighing;
private final boolean _weighing;
private double _weighingFactor = 1;
private boolean _cycleStartAt0 = false;
private final boolean _cycleStartAt0 = false;

public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq,
PSRuntimeBalancing runtimeBalancing, boolean weighing, int epochs, long batchSize,
Expand All @@ -89,8 +91,7 @@ public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFreque
_numBatchesPerEpoch = numBatchesPerGlobalEpoch;
_runtimeBalancing = runtimeBalancing;
_weighing = weighing;
// generate the IDs for model and batch counter. These get overwritten on the federated worker each time
_localStartBatchNumVarID = FederationUtils.getNextFedDataID();
// generate the ID for the model
_modelVarID = FederationUtils.getNextFedDataID();
}

Expand All @@ -100,6 +101,8 @@ public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFreque
* @param weighingFactor Gradients from this worker will be multiplied by this factor if weighing is enabled
*/
public void setup(double weighingFactor) {
incWorkerNumber();

// prepare features and labels
_featuresData = (FederatedData) _features.getFedMapping().getMap().values().toArray()[0];
_labelsData = (FederatedData) _labels.getFedMapping().getMap().values().toArray()[0];
Expand All @@ -125,22 +128,24 @@ public void setup(double weighingFactor) {
_numBatchesPerEpoch = _possibleBatchesPerLocalEpoch;
}

LOG.info("Setup config for worker " + this.getWorkerName());
LOG.info("Batch size: " + _batchSize + " possible batches: " + _possibleBatchesPerLocalEpoch
+ " batches to run: " + _numBatchesPerEpoch + " weighing factor: " + _weighingFactor);
if( LOG.isInfoEnabled() ) {
LOG.info("Setup config for worker " + this.getWorkerName());
LOG.info("Batch size: " + _batchSize + " possible batches: " + _possibleBatchesPerLocalEpoch
+ " batches to run: " + _numBatchesPerEpoch + " weighing factor: " + _weighingFactor);
}

// serialize program
// create program blocks for the instruction filtering
String programSerialized;
ArrayList<ProgramBlock> pbs = new ArrayList<>();

BasicProgramBlock gradientProgramBlock = new BasicProgramBlock(_ec.getProgram());
gradientProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_inst)));
gradientProgramBlock.setInstructions(new ArrayList<>(Collections.singletonList(_inst)));
pbs.add(gradientProgramBlock);

if(_freq == PSFrequency.EPOCH) {
BasicProgramBlock aggProgramBlock = new BasicProgramBlock(_ec.getProgram());
aggProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_ps.getAggInst())));
aggProgramBlock.setInstructions(new ArrayList<>(Collections.singletonList(_ps.getAggInst())));
pbs.add(aggProgramBlock);
}

Expand All @@ -160,7 +165,6 @@ public void setup(double weighingFactor) {
_inst.getFunctionName(),
_ps.getAggInst().getFunctionName(),
_ec.getListObject("hyperparams"),
_localStartBatchNumVarID,
_modelVarID
)
));
Expand Down Expand Up @@ -188,12 +192,11 @@ private static class SetupFederatedWorker extends FederatedUDF {
private final String _gradientsFunctionName;
private final String _aggregationFunctionName;
private final ListObject _hyperParams;
private final long _batchCounterVarID;
private final long _modelVarID;

protected SetupFederatedWorker(long batchSize, long dataSize, int possibleBatchesPerLocalEpoch,
String programString, String namespace, String gradientsFunctionName, String aggregationFunctionName,
ListObject hyperParams, long batchCounterVarID, long modelVarID)
ListObject hyperParams, long modelVarID)
{
super(new long[]{});
_batchSize = batchSize;
Expand All @@ -204,7 +207,6 @@ protected SetupFederatedWorker(long batchSize, long dataSize, int possibleBatche
_gradientsFunctionName = gradientsFunctionName;
_aggregationFunctionName = aggregationFunctionName;
_hyperParams = hyperParams;
_batchCounterVarID = batchCounterVarID;
_modelVarID = modelVarID;
}

Expand All @@ -221,7 +223,6 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) {
ec.setVariable(Statement.PS_FED_GRADIENTS_FNAME, new StringObject(_gradientsFunctionName));
ec.setVariable(Statement.PS_FED_AGGREGATION_FNAME, new StringObject(_aggregationFunctionName));
ec.setVariable(Statement.PS_HYPER_PARAMS, _hyperParams);
ec.setVariable(Statement.PS_FED_BATCHCOUNTER_VARID, new IntObject(_batchCounterVarID));
ec.setVariable(Statement.PS_FED_MODEL_VARID, new IntObject(_modelVarID));

return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
Expand Down Expand Up @@ -272,7 +273,6 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) {
ec.removeVariable(Statement.PS_FED_NAMESPACE);
ec.removeVariable(Statement.PS_FED_GRADIENTS_FNAME);
ec.removeVariable(Statement.PS_FED_AGGREGATION_FNAME);
ec.removeVariable(Statement.PS_FED_BATCHCOUNTER_VARID);
ec.removeVariable(Statement.PS_FED_MODEL_VARID);
ParamservUtils.cleanupListObject(ec, Statement.PS_HYPER_PARAMS);

Expand Down Expand Up @@ -319,16 +319,18 @@ protected ListObject pullModel() {
return _ps.pull(_workerID);
}

protected void scaleAndPushGradients(ListObject gradients) {
protected void weighAndPushGradients(ListObject gradients) {
// scale gradients - must only include MatrixObjects
if(_weighing && _weighingFactor != 1) {
Timing tWeighing = DMLScript.STATISTICS ? new Timing(true) : null;
gradients.getData().parallelStream().forEach((matrix) -> {
MatrixObject matrixObject = (MatrixObject) matrix;
MatrixBlock input = matrixObject.acquireReadAndRelease().scalarOperations(
new RightScalarOperator(Multiply.getMultiplyFnObject(), _weighingFactor), new MatrixBlock());
matrixObject.acquireModify(input);
matrixObject.release();
});
accFedPSGradientWeighingTime(tWeighing);
}

// Push the gradients to ps
Expand All @@ -350,12 +352,10 @@ protected void computeWithBatchUpdates() {
int localStartBatchNum = getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch);
ListObject model = pullModel();
ListObject gradients = computeGradientsForNBatches(model, 1, localStartBatchNum);
scaleAndPushGradients(gradients);
weighAndPushGradients(gradients);
ParamservUtils.cleanupListObject(model);
ParamservUtils.cleanupListObject(gradients);
LOG.info("[+] " + this.getWorkerName() + " completed BATCH " + localStartBatchNum);
}
LOG.info("[+] " + this.getWorkerName() + " --- completed EPOCH " + epochCounter);
}
}

Expand All @@ -376,9 +376,7 @@ protected void computeWithEpochUpdates() {
// Pull the global parameters from ps
ListObject model = pullModel();
ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum, true);
scaleAndPushGradients(gradients);

LOG.info("[+] " + this.getWorkerName() + " --- completed EPOCH " + epochCounter);
weighAndPushGradients(gradients);
ParamservUtils.cleanupListObject(model);
ParamservUtils.cleanupListObject(gradients);
}
Expand All @@ -401,15 +399,13 @@ protected ListObject computeGradientsForNBatches(ListObject model, int numBatche
protected ListObject computeGradientsForNBatches(ListObject model,
int numBatchesToCompute, int localStartBatchNum, boolean localUpdate)
{
// put local start batch num on federated worker
Future<FederatedResponse> putBatchCounterResponse = _featuresData.executeFederatedOperation(
new FederatedRequest(RequestType.PUT_VAR, _localStartBatchNumVarID, new IntObject(localStartBatchNum)));
Timing tFedCommunication = DMLScript.STATISTICS ? new Timing(true) : null;
// put current model on federated worker
Future<FederatedResponse> putParamsResponse = _featuresData.executeFederatedOperation(
new FederatedRequest(RequestType.PUT_VAR, _modelVarID, model));

try {
if(!putParamsResponse.get().isSuccessful() || !putBatchCounterResponse.get().isSuccessful())
if(!putParamsResponse.get().isSuccessful())
throw new DMLRuntimeException("FederatedLocalPSThread: put was not successful");
}
catch(Exception e) {
Expand All @@ -420,14 +416,22 @@ protected ListObject computeGradientsForNBatches(ListObject model,
Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(
new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(),
new federatedComputeGradientsForNBatches(new long[]{_featuresData.getVarID(), _labelsData.getVarID(),
_localStartBatchNumVarID, _modelVarID}, numBatchesToCompute,localUpdate)
_modelVarID}, numBatchesToCompute, localUpdate, localStartBatchNum)
));

try {
Object[] responseData = udfResponse.get().getData();
if(DMLScript.STATISTICS) {
long total = (long) tFedCommunication.stop();
long workerComputing = ((DoubleObject) responseData[1]).getLongValue();
Statistics.accFedPSWorkerComputing(workerComputing);
Statistics.accFedPSCommunicationTime(total - workerComputing);
}
return (ListObject) responseData[0];
}
catch(Exception e) {
if(DMLScript.STATISTICS)
tFedCommunication.stop();
throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute UDF" + e.getMessage());
}
}
Expand All @@ -439,20 +443,22 @@ private static class federatedComputeGradientsForNBatches extends FederatedUDF {
private static final long serialVersionUID = -3075901536748794832L;
int _numBatchesToCompute;
boolean _localUpdate;
int _localStartBatchNum;

protected federatedComputeGradientsForNBatches(long[] inIDs, int numBatchesToCompute, boolean localUpdate) {
protected federatedComputeGradientsForNBatches(long[] inIDs, int numBatchesToCompute, boolean localUpdate, int localStartBatchNum) {
super(inIDs);
_numBatchesToCompute = numBatchesToCompute;
_localUpdate = localUpdate;
_localStartBatchNum = localStartBatchNum;
}

@Override
public FederatedResponse execute(ExecutionContext ec, Data... data) {
Timing tGradients = new Timing(true);
// read in data by varid
MatrixObject features = (MatrixObject) data[0];
MatrixObject labels = (MatrixObject) data[1];
int localStartBatchNum = (int) ((IntObject) data[2]).getLongValue();
ListObject model = (ListObject) data[3];
ListObject model = (ListObject) data[2];

// get data from execution context
long batchSize = ((IntObject) ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue();
Expand Down Expand Up @@ -493,7 +499,7 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) {
}

ListObject accGradients = null;
int currentLocalBatchNumber = localStartBatchNum;
int currentLocalBatchNumber = _localStartBatchNum;
// prepare execution context
ec.setVariable(Statement.PS_MODEL, model);
for (int batchCounter = 0; batchCounter < _numBatchesToCompute; batchCounter++) {
Expand Down Expand Up @@ -534,14 +540,14 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) {
ParamservUtils.cleanupListObject(ec, gradientsOutput.getName());
ParamservUtils.cleanupData(ec, Statement.PS_FEATURES);
ParamservUtils.cleanupData(ec, Statement.PS_LABELS);
ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
}

// model clean up
ParamservUtils.cleanupListObject(ec, ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL);

return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, accGradients);
// stop timing
DoubleObject gradientsTime = new DoubleObject(tGradients.stop());
return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[]{accGradients, gradientsTime});
}

@Override
Expand All @@ -551,28 +557,34 @@ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
}

// Statistics methods
protected void accFedPSGradientWeighingTime(Timing time) {
if (DMLScript.STATISTICS && time != null)
Statistics.accFedPSGradientWeighingTime((long) time.stop());
}

@Override
public String getWorkerName() {
return String.format("Federated worker_%d", _workerID);
}

@Override
protected void incWorkerNumber() {

if (DMLScript.STATISTICS)
Statistics.incWorkerNumber();
}

@Override
protected void accLocalModelUpdateTime(Timing time) {

throw new NotImplementedException();
}

@Override
protected void accBatchIndexingTime(Timing time) {

throw new NotImplementedException();
}

@Override
protected void accGradientComputeTime(Timing time) {

throw new NotImplementedException();
}
}
Expand Up @@ -21,6 +21,7 @@

import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.ListObject;

Expand All @@ -30,12 +31,19 @@ public LocalParamServer() {
super();
}

public static LocalParamServer create(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
return new LocalParamServer(model, aggFunc, updateType, ec, workerNum);
public static LocalParamServer create(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,
MatrixObject valFeatures, MatrixObject valLabels)
{
return new LocalParamServer(model, aggFunc, updateType, freq, ec,
workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels);
}

private LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
super(model, aggFunc, updateType, ec, workerNum);
private LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType,
Statement.PSFrequency freq, ExecutionContext ec, int workerNum, String valFunc, int numBatchesPerEpoch,
MatrixObject valFeatures, MatrixObject valLabels)
{
super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels);
}

@Override
Expand Down
Expand Up @@ -35,7 +35,6 @@
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;

// TODO use the validate features and labels to calculate the model precision when training
public abstract class PSWorker implements Serializable
{
private static final long serialVersionUID = -3510485051178200118L;
Expand Down

0 comments on commit b6640d9

Please sign in to comment.