Skip to content

Commit

Permalink
[SYSTEMML-2344,48,49,52] Various improvements local paramserv backend
Browse files Browse the repository at this point in the history
Closes #777.
  • Loading branch information
EdgarLGB authored and mboehm7 committed Jun 4, 2018
1 parent 2b86a4d commit d44b328
Show file tree
Hide file tree
Showing 16 changed files with 1,117 additions and 319 deletions.
8 changes: 7 additions & 1 deletion src/main/java/org/apache/sysml/parser/Statement.java
Expand Up @@ -77,7 +77,13 @@ public enum PSModeType {
}
public static final String PS_UPDATE_TYPE = "utype";
public enum PSUpdateType {
BSP, ASP, SSP
BSP, ASP, SSP;
public boolean isBSP() {
return this == BSP;
}
public boolean isASP() {
return this == ASP;
}
}
public static final String PS_FREQUENCY = "freq";
public enum PSFrequency {
Expand Down
Expand Up @@ -19,79 +19,82 @@

package org.apache.sysml.runtime.controlprogram.paramserv;

import java.util.concurrent.Callable;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.parser.Statement;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.instructions.cp.ListObject;

public class LocalPSWorker extends PSWorker implements Runnable {
public class LocalPSWorker extends PSWorker implements Callable<Void> {

protected static final Log LOG = LogFactory.getLog(LocalPSWorker.class.getName());

public LocalPSWorker(long workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize,
ListObject hyperParams, ExecutionContext ec, ParamServer ps) {
super(workerID, updFunc, freq, epochs, batchSize, hyperParams, ec, ps);
public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize,
ExecutionContext ec, ParamServer ps) {
super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
}

@Override
public void run() {
public Void call() throws Exception {
try {
long dataSize = _features.getNumRows();
for (int i = 0; i < _epochs; i++) {
int totalIter = (int) Math.ceil(dataSize / _batchSize);
for (int j = 0; j < totalIter; j++) {
// Pull the global parameters from ps
ListObject globalParams = (ListObject)_ps.pull(_workerID);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Local worker_%d: Successfully pull the global parameters "
+ "[size:%d kb] from ps.", _workerID, globalParams.getDataSize() / 1024));
}
_ec.setVariable(Statement.PS_MODEL, globalParams);

long dataSize = _features.getNumRows();
long begin = j * _batchSize + 1;
long end = Math.min(begin + _batchSize, dataSize);

for (int i = 0; i < _epochs; i++) {
int totalIter = (int) Math.ceil(dataSize / _batchSize);
for (int j = 0; j < totalIter; j++) {
// Pull the global parameters from ps
// Need to copy the global parameter
ListObject globalParams = ParamservUtils.copyList((ListObject) _ps.pull(_workerID));
if (LOG.isDebugEnabled()) {
LOG.debug(String.format(
"Local worker_%d: Successfully pull the global parameters [size:%d kb] from ps.", _workerID,
globalParams.getDataSize() / 1024));
}
_ec.setVariable(Statement.PS_MODEL, globalParams);
// Get batch features and labels
MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, begin, end);
MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, begin, end);
_ec.setVariable(Statement.PS_FEATURES, bFeatures);
_ec.setVariable(Statement.PS_LABELS, bLabels);

long begin = j * _batchSize + 1;
long end = Math.min(begin + _batchSize, dataSize);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Local worker_%d: Got batch data [size:%d kb] of index from %d to %d. "
+ "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", _workerID, bFeatures.getDataSize()
/ 1024 + bLabels.getDataSize() / 1024, begin, end, i + 1, _epochs, j + 1, totalIter));
}

// Get batch features and labels
MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, begin, end);
MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, begin, end);
_ec.setVariable(Statement.PS_FEATURES, bFeatures);
_ec.setVariable(Statement.PS_LABELS, bLabels);
// Invoke the update function
_inst.processInstruction(_ec);

if (LOG.isDebugEnabled()) {
LOG.debug(String.format(
"Local worker_%d: Got batch data [size:%d kb] of index from %d to %d. [Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]",
_workerID, bFeatures.getDataSize() / 1024 + bLabels.getDataSize() / 1024, begin, end, i + 1,
_epochs, j + 1, totalIter));
}
// Get the gradients
ListObject gradients = (ListObject) _ec.getVariable(_output.getName());

// Invoke the update function
_inst.processInstruction(_ec);
// Push the gradients to ps
_ps.push(_workerID, gradients);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Local worker_%d: Successfully push the gradients "
+ "[size:%d kb] to ps.", _workerID, gradients.getDataSize() / 1024));
}

// Get the gradients
ListObject gradients = (ListObject) _ec.getVariable(_outputs.get(0).getName());

// Push the gradients to ps
_ps.push(_workerID, gradients);
ParamservUtils.cleanupListObject(_ec, globalParams);
ParamservUtils.cleanupData(bFeatures);
ParamservUtils.cleanupData(bLabels);
}
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Local worker_%d: Successfully push the gradients [size:%d kb] to ps.",
_workerID, gradients.getDataSize() / 1024));
LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1));
}

ParamservUtils.cleanupListObject(_ec, globalParams);
ParamservUtils.cleanupData(bFeatures);
ParamservUtils.cleanupData(bLabels);
}
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1));
LOG.debug(String.format("Local worker_%d: Job finished.", _workerID));
}
} catch (Exception e) {
throw new DMLRuntimeException(String.format("Local worker_%d failed", _workerID), e);
}
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Local worker_%d: Job finished.", _workerID));
}
return null;
}
}
Expand Up @@ -19,6 +19,8 @@

package org.apache.sysml.runtime.controlprogram.paramserv;

import java.util.concurrent.ExecutionException;

import org.apache.sysml.parser.Statement;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
Expand All @@ -28,32 +30,32 @@
public class LocalParamServer extends ParamServer {

public LocalParamServer(ListObject model, String aggFunc, Statement.PSFrequency freq,
Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum,
ListObject hyperParams) {
super(model, aggFunc, freq, updateType, ec, workerNum, hyperParams);
Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
super(model, aggFunc, freq, updateType, ec, workerNum);
}

@Override
public void push(long workerID, ListObject gradients) {
synchronized (_lock) {
_queue.add(new Gradient(workerID, gradients));
_lock.notifyAll();
public void push(int workerID, ListObject gradients) {
try {
_gradientsQueue.put(new Gradient(workerID, gradients));
} catch (InterruptedException e) {
throw new DMLRuntimeException(e);
}
try {
launchService();
} catch (ExecutionException | InterruptedException e) {
throw new DMLRuntimeException("Aggregate service: some error occurred: ", e);
}
}

@Override
public Data pull(long workerID) {
synchronized (_lock) {
while (getPulledState((int) workerID)) {
try {
_lock.wait();
} catch (InterruptedException e) {
throw new DMLRuntimeException(
String.format("Local worker_%d: failed to pull the global parameters.", workerID), e);
}
}
setPulledState((int) workerID, true);
public Data pull(int workerID) {
ListObject model;
try {
model = _modelMap.get((int) workerID).take();
} catch (InterruptedException e) {
throw new DMLRuntimeException(e);
}
return getResult();
return model;
}
}
Expand Up @@ -30,102 +30,99 @@
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysml.runtime.instructions.cp.ListObject;

@SuppressWarnings("unused")
public abstract class PSWorker {

long _workerID = -1;
int _epochs;
long _batchSize;
MatrixObject _features;
MatrixObject _labels;
ExecutionContext _ec;
ParamServer _ps;
private String _updFunc;
private Statement.PSFrequency _freq;
protected final int _workerID;
protected final int _epochs;
protected final long _batchSize;
protected final ExecutionContext _ec;
protected final ParamServer _ps;
protected final DataIdentifier _output;
protected final FunctionCallCPInstruction _inst;
protected MatrixObject _features;
protected MatrixObject _labels;

private MatrixObject _valFeatures;
private MatrixObject _valLabels;

ArrayList<DataIdentifier> _outputs;
FunctionCallCPInstruction _inst;

public PSWorker(long workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize,
ListObject hyperParams, ExecutionContext ec, ParamServer ps) {
this._workerID = workerID;
this._updFunc = updFunc;
this._freq = freq;
this._epochs = epochs;
this._batchSize = batchSize;
this._ec = ExecutionContextFactory.createContext(ec.getProgram());
if (hyperParams != null) {
this._ec.setVariable(Statement.PS_HYPER_PARAMS, hyperParams);
}
this._ps = ps;
private final String _updFunc;
private final Statement.PSFrequency _freq;

protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq,
int epochs, long batchSize, ExecutionContext ec, ParamServer ps) {
_workerID = workerID;
_updFunc = updFunc;
_freq = freq;
_epochs = epochs;
_batchSize = batchSize;
_ec = ec;
_ps = ps;

// Get the update function
String[] keys = DMLProgram.splitFunctionKey(updFunc);
String _funcName = keys[0];
String _funcNS = null;
String funcName = keys[0];
String funcNS = null;
if (keys.length == 2) {
_funcNS = keys[0];
_funcName = keys[1];
funcNS = keys[0];
funcName = keys[1];
}
FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(_funcNS, _funcName);
ArrayList<DataIdentifier> _inputs = func.getInputParams();
_outputs = func.getOutputParams();
CPOperand[] _boundInputs = _inputs.stream()
FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(funcNS, funcName);
ArrayList<DataIdentifier> inputs = func.getInputParams();
ArrayList<DataIdentifier> outputs = func.getOutputParams();
CPOperand[] boundInputs = inputs.stream()
.map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
.toArray(CPOperand[]::new);
ArrayList<String> _inputNames = _inputs.stream().map(DataIdentifier::getName)
ArrayList<String> inputNames = inputs.stream().map(DataIdentifier::getName)
.collect(Collectors.toCollection(ArrayList::new));
ArrayList<String> _outputNames = _outputs.stream().map(DataIdentifier::getName)
ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
.collect(Collectors.toCollection(ArrayList::new));
_inst = new FunctionCallCPInstruction(_funcNS, _funcName, _boundInputs, _inputNames, _outputNames,
_inst = new FunctionCallCPInstruction(funcNS, funcName, boundInputs, inputNames, outputNames,
"update function");

// Check the inputs of the update function
checkInput(_inputs, Expression.DataType.MATRIX, Statement.PS_FEATURES);
checkInput(_inputs, Expression.DataType.MATRIX, Statement.PS_LABELS);
checkInput(_inputs, Expression.DataType.LIST, Statement.PS_MODEL);
if (hyperParams != null) {
checkInput(_inputs, Expression.DataType.LIST, Statement.PS_HYPER_PARAMS);
}
checkInput(false, inputs, Expression.DataType.MATRIX, Statement.PS_FEATURES);
checkInput(false, inputs, Expression.DataType.MATRIX, Statement.PS_LABELS);
checkInput(false, inputs, Expression.DataType.LIST, Statement.PS_MODEL);
checkInput(true, inputs, Expression.DataType.LIST, Statement.PS_HYPER_PARAMS);

// Check the output of the update function
if (_outputs.size() != 1) {
throw new DMLRuntimeException(
String.format("The output of the '%s' function should provide one list containing the gradients.", updFunc));
if (outputs.size() != 1) {
throw new DMLRuntimeException(String.format("The output of the '%s' function "
+ "should provide one list containing the gradients.", updFunc));
}
if (_outputs.get(0).getDataType() != Expression.DataType.LIST) {
throw new DMLRuntimeException(
String.format("The output of the '%s' function should be type of list.", updFunc));
if (outputs.get(0).getDataType() != Expression.DataType.LIST) {
throw new DMLRuntimeException(String.format("The output of the '%s' function should be type of list.", updFunc));
}
_output = outputs.get(0);
}

private void checkInput(ArrayList<DataIdentifier> _inputs, Expression.DataType dt, String pname) {
if (_inputs.stream().filter(input -> input.getDataType() == dt && pname.equals(input.getName())).count() != 1) {
throw new DMLRuntimeException(
String.format("The '%s' function should provide an input of '%s' type named '%s'.", _updFunc, dt, pname));
private void checkInput(boolean optional, ArrayList<DataIdentifier> inputs, Expression.DataType dt, String pname) {
if (optional && inputs.stream().noneMatch(input -> pname.equals(input.getName()))) {
// We do not need to check if the input is optional and is not provided
return;
}
if (inputs.stream().filter(input -> input.getDataType() == dt && pname.equals(input.getName())).count() != 1) {
throw new DMLRuntimeException(String.format("The '%s' function should provide "
+ "an input of '%s' type named '%s'.", _updFunc, dt, pname));
}
}

public void setFeatures(MatrixObject features) {
this._features = features;
_features = features;
}

public void setLabels(MatrixObject labels) {
this._labels = labels;
_labels = labels;
}

public void setValFeatures(MatrixObject valFeatures) {
this._valFeatures = valFeatures;
_valFeatures = valFeatures;
}

public void setValLabels(MatrixObject valLabels) {
this._valLabels = valLabels;
_valLabels = valLabels;
}
}

0 comments on commit d44b328

Please sign in to comment.