Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

43 changes: 22 additions & 21 deletions src/main/java/org/apache/sysds/parser/Statement.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -27,26 +27,26 @@
public abstract class Statement implements ParseInfo
{
protected static final Log LOG = LogFactory.getLog(Statement.class.getName());

public static final String OUTPUTSTATEMENT = "WRITE";

// parameter names for seq()
public static final String SEQ_FROM = "from";
public static final String SEQ_FROM = "from";
public static final String SEQ_TO = "to";
public static final String SEQ_INCR = "incr";

public static final String SOURCE = "source";
public static final String SETWD = "setwd";

public static final String MATRIX_DATA_TYPE = "matrix";
public static final String FRAME_DATA_TYPE = "frame";
public static final String SCALAR_DATA_TYPE = "scalar";

public static final String DOUBLE_VALUE_TYPE = "double";
public static final String BOOLEAN_VALUE_TYPE = "boolean";
public static final String INT_VALUE_TYPE = "int";
public static final String STRING_VALUE_TYPE = "string";

// String constants related to Grouped Aggregate parameters
public static final String GAGG_TARGET = "target";
public static final String GAGG_GROUPS = "groups";
Expand All @@ -72,6 +72,7 @@ public abstract class Statement implements ParseInfo
public static final String PS_MODE = "mode";
public static final String PS_GRADIENTS = "gradients";
public static final String PS_SEED = "seed";
public static final String PS_NBATCHES = "nbatches";
public enum PSModeType {
FEDERATED, LOCAL, REMOTE_SPARK
}
Expand All @@ -87,7 +88,7 @@ public boolean isASP() {
}
public static final String PS_FREQUENCY = "freq";
public enum PSFrequency {
BATCH, EPOCH
BATCH, EPOCH, NBATCHES
}
public static final String PS_FED_WEIGHTING = "weighting";
public static final String PS_FED_RUNTIME_BALANCING = "runtime_balancing";
Expand Down Expand Up @@ -122,34 +123,34 @@ public enum PSCheckpointing {


public abstract boolean controlStatement();

public abstract VariableSet variablesRead();
public abstract VariableSet variablesUpdated();

public abstract void initializeforwardLV(VariableSet activeIn);
public abstract VariableSet initializebackwardLV(VariableSet lo);

public abstract Statement rewriteStatement(String prefix);

// Used only insider python parser to allow for ignoring newline logic
private boolean isEmptyNewLineStatement = false;
public boolean isEmptyNewLineStatement() {
return isEmptyNewLineStatement;
}
}
public void setEmptyNewLineStatement(boolean isEmptyNewLineStatement) {
this.isEmptyNewLineStatement = isEmptyNewLineStatement;
}

///////////////////////////////////////////////////////////////////////////
// store exception info + position information for statements
///////////////////////////////////////////////////////////////////////////


private String _filename;
private int _beginLine, _beginColumn;
private int _endLine, _endColumn;
private String _text;

@Override
public void setFilename(String passed) { _filename = passed; }
@Override
Expand All @@ -175,10 +176,10 @@ public void setCtxValues(ParserRuleContext ctx) {
setEndColumn(ctx.stop.getCharPositionInLine());
// preserve whitespace if possible
if ((ctx.start != null) && (ctx.stop != null) && (ctx.start.getStartIndex() != -1)
&& (ctx.stop.getStopIndex() != -1) && (ctx.start.getStartIndex() <= ctx.stop.getStopIndex())
&& (ctx.start.getInputStream() != null)) {
&& (ctx.stop.getStopIndex() != -1) && (ctx.start.getStartIndex() <= ctx.stop.getStopIndex())
&& (ctx.start.getInputStream() != null)) {
String text = ctx.start.getInputStream()
.getText(Interval.of(ctx.start.getStartIndex(), ctx.stop.getStopIndex()));
.getText(Interval.of(ctx.start.getStartIndex(), ctx.stop.getStopIndex()));
if (text != null) {
text = text.trim();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,22 @@ public class FederatedPSControlThread extends PSWorker implements Callable<Void>
// runtime balancing
private final PSRuntimeBalancing _runtimeBalancing;
private int _numBatchesPerEpoch;
private int _numBatchesPerNbatch ;
private int _possibleBatchesPerLocalEpoch;
private final boolean _weighting;
private double _weightingFactor = 1;
private boolean _cycleStartAt0 = false;

public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq,
PSRuntimeBalancing runtimeBalancing, boolean weighting, int epochs, long batchSize,
int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps)
int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps, int nbatches)
{
super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
super(workerID, updFunc, freq, epochs, batchSize, ec, ps, nbatches);

_numBatchesPerEpoch = numBatchesPerGlobalEpoch;
_runtimeBalancing = runtimeBalancing;
_weighting = weighting;
_numBatchesPerNbatch = nbatches;
// generate the ID for the model
_modelVarID = FederationUtils.getNextFedDataID();
}
Expand Down Expand Up @@ -150,7 +152,7 @@ public void setup(double weightingFactor) {
aggProgramBlock.setInstructions(new ArrayList<>(Collections.singletonList(_ps.getAggInst())));
pbs.add(aggProgramBlock);
}

programSerialized = InstructionUtils.concatStrings(
PROG_BEGIN, NEWLINE,
ProgramConverter.serializeProgram(_ec.getProgram(), pbs, new HashMap<>()),
Expand All @@ -167,7 +169,8 @@ public void setup(double weightingFactor) {
_inst.getFunctionName(),
_ps.getAggInst().getFunctionName(),
_ec.getListObject("hyperparams"),
_modelVarID
_modelVarID,
_nbatches
)
));

Expand Down Expand Up @@ -195,10 +198,11 @@ private static class SetupFederatedWorker extends FederatedUDF {
private final String _aggregationFunctionName;
private final ListObject _hyperParams;
private final long _modelVarID;
private final int _nbatches;

protected SetupFederatedWorker(long batchSize, long dataSize, int possibleBatchesPerLocalEpoch,
String programString, String namespace, String gradientsFunctionName, String aggregationFunctionName,
ListObject hyperParams, long modelVarID)
ListObject hyperParams, long modelVarID, int nbatches)
{
super(new long[]{});
_batchSize = batchSize;
Expand All @@ -210,6 +214,7 @@ protected SetupFederatedWorker(long batchSize, long dataSize, int possibleBatche
_aggregationFunctionName = aggregationFunctionName;
_hyperParams = hyperParams;
_modelVarID = modelVarID;
_nbatches = nbatches;
}

@Override
Expand All @@ -226,6 +231,7 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) {
ec.setVariable(Statement.PS_FED_AGGREGATION_FNAME, new StringObject(_aggregationFunctionName));
ec.setVariable(Statement.PS_HYPER_PARAMS, _hyperParams);
ec.setVariable(Statement.PS_FED_MODEL_VARID, new IntObject(_modelVarID));
ec.setVariable(Statement.PS_NBATCHES, new IntObject(_nbatches));

return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
}
Expand Down Expand Up @@ -277,7 +283,7 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) {
ec.removeVariable(Statement.PS_FED_AGGREGATION_FNAME);
ec.removeVariable(Statement.PS_FED_MODEL_VARID);
ParamservUtils.cleanupListObject(ec, Statement.PS_HYPER_PARAMS);

return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
}

Expand All @@ -300,9 +306,9 @@ public Void call() throws Exception {
case BATCH:
computeWithBatchUpdates();
break;
/*case NBATCH:
case NBATCHES:
computeWithNBatchUpdates();
break; */
break;
case EPOCH:
computeWithEpochUpdates();
break;
Expand Down Expand Up @@ -344,7 +350,7 @@ protected static int getNextLocalBatchNum(int currentLocalBatchNumber, int possi
}

/**
* Computes all epochs and updates after each batch
* Computes all epochs and updates after each batch
*/
protected void computeWithBatchUpdates() {
for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) {
Expand All @@ -365,7 +371,21 @@ protected void computeWithBatchUpdates() {
* Computes all epochs and updates after N batches
*/
protected void computeWithNBatchUpdates() {
throw new NotImplementedException();
int numSetsPerEpocNbatches = (int) Math.ceil(_batchSize / _numBatchesPerNbatch);

for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) {
int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : _numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;

for (int batchCounter = 0; batchCounter < numSetsPerEpocNbatches; batchCounter++) {
int localStartBatchNum = getNextLocalBatchNum(currentLocalBatchNumber, numSetsPerEpocNbatches);
currentLocalBatchNumber = currentLocalBatchNumber + _numBatchesPerNbatch;
ListObject model = pullModel();
ListObject gradients = computeGradientsForNBatches(model, _numBatchesPerNbatch, localStartBatchNum);
weightAndPushGradients(gradients);
ParamservUtils.cleanupListObject(model);
ParamservUtils.cleanupListObject(gradients);
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
protected LocalPSWorker() {}

public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq,
int epochs, long batchSize, ExecutionContext ec, ParamServer ps)
int epochs, long batchSize, ExecutionContext ec, ParamServer ps, int nbatches)
{
super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
super(workerID, updFunc, freq, epochs, batchSize, ec, ps, nbatches);
}

@Override
public String getWorkerName() {
return String.format("Local worker_%d", _workerID);
}

@Override
public Void call() throws Exception {
incWorkerNumber();
Expand All @@ -67,6 +67,9 @@ public Void call() throws Exception {
case EPOCH:
computeEpoch(dataSize, batchIter);
break;
case NBATCHES:
computeNBatches(dataSize, batchIter);
break;
default:
throw new DMLRuntimeException(String.format("%s not support update frequency %s", getWorkerName(), _freq));
}
Expand All @@ -89,19 +92,19 @@ private void computeEpoch(long dataSize, int batchIter) {
try {
for (int j = 0; j < batchIter; j++) {
ListObject gradients = computeGradients(params, dataSize, batchIter, i, j);

boolean localUpdate = j < batchIter - 1;
// Accumulate the intermediate gradients (async for overlap w/ model updates

// Accumulate the intermediate gradients (async for overlap w/ model updates
// and gradient computation, sequential over gradient matrices to avoid deadlocks)
ListObject accGradientsPrev = accGradients.get();
accGradients = _tpool.submit(() -> ParamservUtils.accrueGradients(
accGradientsPrev, gradients, false, !localUpdate));

// Update the local model with gradients
if(localUpdate)
params = updateModel(params, gradients, i, j, batchIter);

accNumBatches(1);
}

Expand All @@ -112,7 +115,7 @@ private void computeEpoch(long dataSize, int batchIter) {
catch(ExecutionException | InterruptedException ex) {
throw new DMLRuntimeException(ex);
}

accNumEpochs(1);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1));
Expand All @@ -126,7 +129,7 @@ private ListObject updateModel(ListObject globalParams, ListObject gradients, in
globalParams = _ps.updateLocalModel(_ec, gradients, globalParams);

accLocalModelUpdateTime(tUpd);

if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: local global parameter [size:%d kb] updated. "
+ "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]",
Expand All @@ -145,10 +148,10 @@ private void computeBatch(long dataSize, int totalIter) {
// Push the gradients to ps
pushGradients(gradients);
ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL);

accNumBatches(1);
}

accNumEpochs(1);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1));
Expand Down Expand Up @@ -208,7 +211,56 @@ private ListObject computeGradients(ListObject params, long dataSize, int batchI
ParamservUtils.cleanupData(_ec, Statement.PS_LABELS);
return gradients;
}


private void computeNBatches(long dataSize, int batchIter) {
Future<ListObject> accGradients = ConcurrentUtils.constantFuture(null);

for(int i = 0; i < _epochs; i++) {
ListObject params = null;
boolean check = true;
int step = 0;
try {
for(int j = 0; j < batchIter; j++) {
boolean localUpdate = j < batchIter;
boolean localPull = (step % (_nbatches-1) == 0);
if(localPull && check) {
// Pull the global parameters from ps
params = pullModel();
check = false;
localPull = false;
}
ListObject gradients = computeGradients(params, dataSize, batchIter, i, j);
// Accumulate the intermediate gradients (async for overlap w/ model updates
// and gradient computation, sequential over gradient matrices to avoid deadlocks)
ListObject accGradientsPrev = accGradients.get();
accGradients = _tpool
.submit(() -> ParamservUtils.accrueGradients(accGradientsPrev, gradients, false, !localUpdate));
// Update the local model with gradients
if(localUpdate) {
params = updateModel(params, gradients, i, j, batchIter);
}
accNumBatches(1);
step++;
if((localPull & j > 0) || (j == batchIter-1)) {
// Push the gradients to ps
pushGradients(accGradients.get());
accGradients = ConcurrentUtils.constantFuture(null);
check = true;
step = 0;
}
accNumBatches(1);
}
}
catch(ExecutionException | InterruptedException ex) {
throw new DMLRuntimeException(ex);
}
accNumEpochs(1);
if(LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1));
}
}
}

@Override
protected void incWorkerNumber() {
if (DMLScript.STATISTICS)
Expand Down
Loading