Skip to content

Commit

Permalink
[SYSTEMDS-2550] Federated parameter server scaling and weight handling
Browse files Browse the repository at this point in the history
Closes #1141.
  • Loading branch information
Tobias Rieger authored and mboehm7 committed Jan 9, 2021
1 parent cbdfb92 commit 460c394
Show file tree
Hide file tree
Showing 16 changed files with 342 additions and 197 deletions.
Expand Up @@ -289,8 +289,8 @@ private void validateParamserv(DataIdentifier output, boolean conditional) {
Set<String> valid = CollectionUtils.asSet(Statement.PS_MODEL, Statement.PS_FEATURES, Statement.PS_LABELS,
Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
Statement.PS_MODE, Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_RUNTIME_BALANCING,
Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING);
Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING,
Statement.PS_FED_WEIGHING, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING, Statement.PS_SEED);
checkInvalidParameters(getOpCode(), getVarParams(), valid);

// check existence and correctness of parameters
Expand All @@ -308,9 +308,11 @@ private void validateParamserv(DataIdentifier output, boolean conditional) {
checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, DataType.SCALAR, ValueType.INT64, conditional);
checkDataValueType(true, fname, Statement.PS_PARALLELISM, DataType.SCALAR, ValueType.INT64, conditional);
checkStringParam(true, fname, Statement.PS_SCHEME, conditional);
checkStringParam(true, fname, Statement.PS_RUNTIME_BALANCING, conditional);
checkStringParam(true, fname, Statement.PS_FED_RUNTIME_BALANCING, conditional);
checkStringParam(true, fname, Statement.PS_FED_WEIGHING, conditional);
checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, DataType.LIST, ValueType.UNKNOWN, conditional);
checkStringParam(true, fname, Statement.PS_CHECKPOINTING, conditional);
checkDataValueType(true, fname, Statement.PS_SEED, DataType.SCALAR, ValueType.INT64, conditional);

// set output characteristics
output.setDataType(DataType.LIST);
Expand Down
6 changes: 4 additions & 2 deletions src/main/java/org/apache/sysds/parser/Statement.java
Expand Up @@ -70,6 +70,7 @@ public abstract class Statement implements ParseInfo
public static final String PS_AGGREGATION_FUN = "agg";
public static final String PS_MODE = "mode";
public static final String PS_GRADIENTS = "gradients";
public static final String PS_SEED = "seed";
public enum PSModeType {
FEDERATED, LOCAL, REMOTE_SPARK
}
Expand All @@ -87,9 +88,10 @@ public boolean isASP() {
public enum PSFrequency {
BATCH, EPOCH
}
public static final String PS_RUNTIME_BALANCING = "runtime_balancing";
public static final String PS_FED_WEIGHING = "weighing";
public static final String PS_FED_RUNTIME_BALANCING = "runtime_balancing";
public enum PSRuntimeBalancing {
NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH, SCALE_BATCH_AND_WEIGH
NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH
}
public static final String PS_EPOCHS = "epochs";
public static final String PS_BATCH_SIZE = "batchsize";
Expand Down
Expand Up @@ -228,7 +228,7 @@ protected final double sumValues(int valIx, double[] b, double[] dictVals, int o
return val;
}

protected final double sumValuesSparse(int valIx, SparseRow[] rows, double[] dictVals, int rowsIndex) {
protected static double sumValuesSparse(int valIx, SparseRow[] rows, double[] dictVals, int rowsIndex) {
throw new NotImplementedException("This Method was implemented incorrectly");
// final int numCols = getNumCols();
// final int valOff = valIx * numCols;
Expand Down
Expand Up @@ -24,6 +24,8 @@
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.parser.Statement.PSFrequency;
import org.apache.sysds.parser.Statement.PSRuntimeBalancing;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
Expand All @@ -37,13 +39,17 @@
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.IntObject;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.StringObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.util.ProgramConverter;

import java.util.ArrayList;
Expand All @@ -58,87 +64,102 @@
public class FederatedPSControlThread extends PSWorker implements Callable<Void> {
private static final long serialVersionUID = 6846648059569648791L;
protected static final Log LOG = LogFactory.getLog(ParamServer.class.getName());

Statement.PSRuntimeBalancing _runtimeBalancing;
FederatedData _featuresData;
FederatedData _labelsData;
final long _localStartBatchNumVarID;
final long _modelVarID;
int _numBatchesPerGlobalEpoch;
int _possibleBatchesPerLocalEpoch;
boolean _cycleStartAt0 = false;

public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, Statement.PSRuntimeBalancing runtimeBalancing, int epochs, long batchSize, int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps) {

private FederatedData _featuresData;
private FederatedData _labelsData;
private final long _localStartBatchNumVarID;
private final long _modelVarID;

// runtime balancing
private PSRuntimeBalancing _runtimeBalancing;
private int _numBatchesPerEpoch;
private int _possibleBatchesPerLocalEpoch;
private boolean _weighing;
private double _weighingFactor = 1;
private boolean _cycleStartAt0 = false;

public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq,
PSRuntimeBalancing runtimeBalancing, boolean weighing, int epochs, long batchSize,
int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps)
{
super(workerID, updFunc, freq, epochs, batchSize, ec, ps);

_numBatchesPerGlobalEpoch = numBatchesPerGlobalEpoch;
_numBatchesPerEpoch = numBatchesPerGlobalEpoch;
_runtimeBalancing = runtimeBalancing;
_weighing = weighing;
// generate the IDs for model and batch counter. These get overwritten on the federated worker each time
_localStartBatchNumVarID = FederationUtils.getNextFedDataID();
_modelVarID = FederationUtils.getNextFedDataID();
}

/**
* Sets up the federated worker and control thread
*
* @param weighingFactor Gradients from this worker will be multiplied by this factor if weighing is enabled
*/
public void setup() {
public void setup(double weighingFactor) {
// prepare features and labels
_featuresData = (FederatedData) _features.getFedMapping().getMap().values().toArray()[0];
_labelsData = (FederatedData) _labels.getFedMapping().getMap().values().toArray()[0];

// calculate number of batches and get data size
// weighing factor is always set, but only used when weighing is specified
_weighingFactor = weighingFactor;

// different runtime balancing calculations
long dataSize = _features.getNumRows();
_possibleBatchesPerLocalEpoch = (int) Math.ceil((double) dataSize / _batchSize);
if(!(_runtimeBalancing == Statement.PSRuntimeBalancing.RUN_MIN
|| _runtimeBalancing == Statement.PSRuntimeBalancing.CYCLE_AVG
|| _runtimeBalancing == Statement.PSRuntimeBalancing.CYCLE_MAX)) {
_numBatchesPerGlobalEpoch = _possibleBatchesPerLocalEpoch;

// calculate scaled batch size if balancing via batch size.
// In some cases there will be some cycling
if(_runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) {
_batchSize = (int) Math.ceil((double) dataSize / _numBatchesPerEpoch);
}

if(_runtimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH
|| _runtimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH_AND_WEIGH) {
throw new NotImplementedException();
// Calculate possible batches with batch size
_possibleBatchesPerLocalEpoch = (int) Math.ceil((double) dataSize / _batchSize);

// If no runtime balancing is specified, just run possible number of batches
// WARNING: Will get stuck on miss match
if(_runtimeBalancing == PSRuntimeBalancing.NONE) {
_numBatchesPerEpoch = _possibleBatchesPerLocalEpoch;
}

LOG.info("Setup config for worker " + this.getWorkerName());
LOG.info("Batch size: " + _batchSize + " possible batches: " + _possibleBatchesPerLocalEpoch
+ " batches to run: " + _numBatchesPerEpoch + " weighing factor: " + _weighingFactor);

// serialize program
// create program blocks for the instruction filtering
String programSerialized;
ArrayList<ProgramBlock> programBlocks = new ArrayList<>();
ArrayList<ProgramBlock> pbs = new ArrayList<>();

BasicProgramBlock gradientProgramBlock = new BasicProgramBlock(_ec.getProgram());
gradientProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_inst)));
programBlocks.add(gradientProgramBlock);
pbs.add(gradientProgramBlock);

if(_freq == Statement.PSFrequency.EPOCH) {
if(_freq == PSFrequency.EPOCH) {
BasicProgramBlock aggProgramBlock = new BasicProgramBlock(_ec.getProgram());
aggProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_ps.getAggInst())));
programBlocks.add(aggProgramBlock);
pbs.add(aggProgramBlock);
}

StringBuilder sb = new StringBuilder();
sb.append(PROG_BEGIN);
sb.append( NEWLINE );
sb.append(ProgramConverter.serializeProgram(_ec.getProgram(),
programBlocks,
new HashMap<>(),
false
));
sb.append(PROG_END);
programSerialized = sb.toString();
programSerialized = InstructionUtils.concatStrings(
PROG_BEGIN, NEWLINE,
ProgramConverter.serializeProgram(_ec.getProgram(), pbs, new HashMap<>(), false),
PROG_END);

// write program and meta data to worker
Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(
new FederatedRequest(RequestType.EXEC_UDF, _featuresData.getVarID(),
new SetupFederatedWorker(_batchSize,
dataSize,
_possibleBatchesPerLocalEpoch,
programSerialized,
_inst.getNamespace(),
_inst.getFunctionName(),
_ps.getAggInst().getFunctionName(),
_ec.getListObject("hyperparams"),
_localStartBatchNumVarID,
_modelVarID
dataSize,
_possibleBatchesPerLocalEpoch,
programSerialized,
_inst.getNamespace(),
_inst.getFunctionName(),
_ps.getAggInst().getFunctionName(),
_ec.getListObject("hyperparams"),
_localStartBatchNumVarID,
_modelVarID
)
));

Expand Down Expand Up @@ -286,12 +307,23 @@ protected ListObject pullModel() {
return _ps.pull(_workerID);
}

protected void pushGradients(ListObject gradients) {
protected void scaleAndPushGradients(ListObject gradients) {
// scale gradients - must only include MatrixObjects
if(_weighing && _weighingFactor != 1) {
gradients.getData().parallelStream().forEach((matrix) -> {
MatrixObject matrixObject = (MatrixObject) matrix;
MatrixBlock input = matrixObject.acquireReadAndRelease().scalarOperations(
new RightScalarOperator(Multiply.getMultiplyFnObject(), _weighingFactor), new MatrixBlock());
matrixObject.acquireModify(input);
matrixObject.release();
});
}

// Push the gradients to ps
_ps.push(_workerID, gradients);
}

static protected int getNextLocalBatchNum(int currentLocalBatchNumber, int possibleBatchesPerLocalEpoch) {
protected static int getNextLocalBatchNum(int currentLocalBatchNumber, int possibleBatchesPerLocalEpoch) {
return currentLocalBatchNumber % possibleBatchesPerLocalEpoch;
}

Expand All @@ -300,18 +332,18 @@ static protected int getNextLocalBatchNum(int currentLocalBatchNumber, int possi
*/
protected void computeWithBatchUpdates() {
for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) {
int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : _numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : _numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;

for (int batchCounter = 0; batchCounter < _numBatchesPerGlobalEpoch; batchCounter++) {
for (int batchCounter = 0; batchCounter < _numBatchesPerEpoch; batchCounter++) {
int localStartBatchNum = getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch);
ListObject model = pullModel();
ListObject gradients = computeGradientsForNBatches(model, 1, localStartBatchNum);
pushGradients(gradients);
scaleAndPushGradients(gradients);
ParamservUtils.cleanupListObject(model);
ParamservUtils.cleanupListObject(gradients);
LOG.info("[+] " + this.getWorkerName() + " completed BATCH " + localStartBatchNum);
}
if( LOG.isInfoEnabled() )
LOG.info("[+] " + this.getWorkerName() + " completed epoch " + epochCounter);
LOG.info("[+] " + this.getWorkerName() + " --- completed EPOCH " + epochCounter);
}
}

Expand All @@ -327,15 +359,14 @@ protected void computeWithNBatchUpdates() {
*/
protected void computeWithEpochUpdates() {
for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) {
int localStartBatchNum = (_cycleStartAt0) ? 0 : _numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
int localStartBatchNum = (_cycleStartAt0) ? 0 : _numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;

// Pull the global parameters from ps
ListObject model = pullModel();
ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerGlobalEpoch, localStartBatchNum, true);
pushGradients(gradients);

if( LOG.isInfoEnabled() )
LOG.info("[+] " + this.getWorkerName() + " completed epoch " + epochCounter);
ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum, true);
scaleAndPushGradients(gradients);

LOG.info("[+] " + this.getWorkerName() + " --- completed EPOCH " + epochCounter);
ParamservUtils.cleanupListObject(model);
ParamservUtils.cleanupListObject(gradients);
}
Expand Down Expand Up @@ -424,12 +455,12 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) {
ArrayList<DataIdentifier> inputs = func.getInputParams();
ArrayList<DataIdentifier> outputs = func.getOutputParams();
CPOperand[] boundInputs = inputs.stream()
.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
.toArray(CPOperand[]::new);
.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
.toArray(CPOperand[]::new);
ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
.collect(Collectors.toCollection(ArrayList::new));
.collect(Collectors.toCollection(ArrayList::new));
Instruction gradientsInstruction = new FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs,
func.getInputParamNames(), outputNames, "gradient function");
func.getInputParamNames(), outputNames, "gradient function");
DataIdentifier gradientsOutput = outputs.get(0);

// recreate aggregation instruction and output if needed
Expand All @@ -440,12 +471,12 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) {
inputs = func.getInputParams();
outputs = func.getOutputParams();
boundInputs = inputs.stream()
.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
.toArray(CPOperand[]::new);
.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
.toArray(CPOperand[]::new);
outputNames = outputs.stream().map(DataIdentifier::getName)
.collect(Collectors.toCollection(ArrayList::new));
.collect(Collectors.toCollection(ArrayList::new));
aggregationInstruction = new FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs,
func.getInputParamNames(), outputNames, "aggregation function");
func.getInputParamNames(), outputNames, "aggregation function");
aggregationOutput = outputs.get(0);
}

Expand Down Expand Up @@ -492,8 +523,6 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) {
ParamservUtils.cleanupData(ec, Statement.PS_FEATURES);
ParamservUtils.cleanupData(ec, Statement.PS_LABELS);
ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
if( LOG.isInfoEnabled() )
LOG.info("[+]" + " completed batch " + localBatchNum);
}

// model clean up
Expand Down

0 comments on commit 460c394

Please sign in to comment.