Skip to content

Commit

Permalink
[SYSTEMML-2420,2422] New distributed paramserv spark workers and rpc
Browse files Browse the repository at this point in the history
Closes #805.
  • Loading branch information
EdgarLGB authored and mboehm7 committed Jul 22, 2018
1 parent 54dbe9b commit 15ecb72
Show file tree
Hide file tree
Showing 31 changed files with 915 additions and 497 deletions.
Expand Up @@ -35,12 +35,20 @@
public class LocalPSWorker extends PSWorker implements Callable<Void> {

protected static final Log LOG = LogFactory.getLog(LocalPSWorker.class.getName());
private static final long serialVersionUID = 5195390748495357295L;

protected LocalPSWorker() {}

public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize,
MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) {
super(workerID, updFunc, freq, epochs, batchSize, valFeatures, valLabels, ec, ps);
}

@Override
public String getWorkerName() {
return String.format("Local worker_%d", _workerID);
}

@Override
public Void call() throws Exception {
if (DMLScript.STATISTICS)
Expand All @@ -60,10 +68,10 @@ public Void call() throws Exception {
}

if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Local worker_%d: Job finished.", _workerID));
LOG.debug(String.format("%s: job finished.", getWorkerName()));
}
} catch (Exception e) {
throw new DMLRuntimeException(String.format("Local worker_%d failed", _workerID), e);
throw new DMLRuntimeException(String.format("%s failed", getWorkerName()), e);
}
return null;
}
Expand Down Expand Up @@ -93,7 +101,7 @@ private void computeEpoch(long dataSize, int totalIter) {
ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL);

if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1));
LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1));
}
}

Expand All @@ -108,9 +116,9 @@ private ListObject updateModel(ListObject globalParams, ListObject gradients, in
Statistics.accPSLocalModelUpdateTime((long) tUpd.stop());

if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Local worker_%d: Local global parameter [size:%d kb] updated. "
LOG.debug(String.format("%s: local global parameter [size:%d kb] updated. "
+ "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]",
_workerID, globalParams.getDataSize(), i + 1, _epochs, j + 1, totalIter));
getWorkerName(), globalParams.getDataSize(), i + 1, _epochs, j + 1, totalIter));
}
return globalParams;
}
Expand All @@ -129,17 +137,17 @@ private void computeBatch(long dataSize, int totalIter) {
ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL);
}
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1));
LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1));
}
}
}

private ListObject pullModel() {
// Pull the global parameters from ps
ListObject globalParams = (ListObject)_ps.pull(_workerID);
ListObject globalParams = _ps.pull(_workerID);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Local worker_%d: Successfully pull the global parameters "
+ "[size:%d kb] from ps.", _workerID, globalParams.getDataSize() / 1024));
LOG.debug(String.format("%s: successfully pull the global parameters "
+ "[size:%d kb] from ps.", getWorkerName(), globalParams.getDataSize() / 1024));
}
return globalParams;
}
Expand All @@ -148,8 +156,8 @@ private void pushGradients(ListObject gradients) {
// Push the gradients to ps
_ps.push(_workerID, gradients);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Local worker_%d: Successfully push the gradients "
+ "[size:%d kb] to ps.", _workerID, gradients.getDataSize() / 1024));
LOG.debug(String.format("%s: successfully push the gradients "
+ "[size:%d kb] to ps.", getWorkerName(), gradients.getDataSize() / 1024));
}
}

Expand All @@ -168,8 +176,8 @@ private ListObject computeGradients(long dataSize, int totalIter, int i, int j)
_ec.setVariable(Statement.PS_LABELS, bLabels);

if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Local worker_%d: Got batch data [size:%d kb] of index from %d to %d [last index: %d]. "
+ "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", _workerID,
LOG.debug(String.format("%s: got batch data [size:%d kb] of index from %d to %d [last index: %d]. "
+ "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", getWorkerName(),
bFeatures.getDataSize() / 1024 + bLabels.getDataSize() / 1024, begin, end, dataSize, i + 1, _epochs,
j + 1, totalIter));
}
Expand Down
Expand Up @@ -22,11 +22,14 @@
import org.apache.sysml.parser.Statement;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.ListObject;

public class LocalParamServer extends ParamServer {

public LocalParamServer() {
super();
}

public LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
super(model, aggFunc, updateType, ec, workerNum);
}
Expand All @@ -37,7 +40,7 @@ public void push(int workerID, ListObject gradients) {
}

@Override
public Data pull(int workerID) {
public ListObject pull(int workerID) {
ListObject model;
try {
model = _modelMap.get(workerID).take();
Expand Down
Expand Up @@ -21,6 +21,7 @@

import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.PS_FUNC_PREFIX;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.stream.Collectors;

Expand All @@ -34,7 +35,10 @@
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;

public abstract class PSWorker {
public abstract class PSWorker implements Serializable {

private static final long serialVersionUID = -3510485051178200118L;

protected int _workerID;
protected int _epochs;
protected long _batchSize;
Expand All @@ -50,10 +54,8 @@ public abstract class PSWorker {
protected String _updFunc;
protected Statement.PSFrequency _freq;

protected PSWorker() {
protected PSWorker() {}

}

protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize,
MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) {
_workerID = workerID;
Expand All @@ -65,7 +67,10 @@ protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int
_valLabels = valLabels;
_ec = ec;
_ps = ps;
setupUpdateFunction(updFunc, ec);
}

protected void setupUpdateFunction(String updFunc, ExecutionContext ec) {
// Get the update function
String[] cfn = ParamservUtils.getCompleteFuncName(updFunc, PS_FUNC_PREFIX);
String ns = cfn[0];
Expand Down Expand Up @@ -125,4 +130,6 @@ public MatrixObject getFeatures() {
public MatrixObject getLabels() {
return _labels;
}

public abstract String getWorkerName();
}
Expand Up @@ -42,7 +42,6 @@
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysml.runtime.instructions.cp.ListObject;
import org.apache.sysml.utils.Statistics;
Expand All @@ -53,17 +52,19 @@ public abstract class ParamServer
protected static final boolean ACCRUE_BSP_GRADIENTS = true;

// worker input queues and global model
protected final Map<Integer, BlockingQueue<ListObject>> _modelMap;
protected Map<Integer, BlockingQueue<ListObject>> _modelMap;
private ListObject _model;

//aggregation service
protected final ExecutionContext _ec;
private final Statement.PSUpdateType _updateType;
private final FunctionCallCPInstruction _inst;
private final String _outputName;
private final boolean[] _finishedStates; // Workers' finished states
protected ExecutionContext _ec;
private Statement.PSUpdateType _updateType;
private FunctionCallCPInstruction _inst;
private String _outputName;
private boolean[] _finishedStates; // Workers' finished states
private ListObject _accGradients = null;

protected ParamServer() {}

protected ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
// init worker queues and global model
_modelMap = new HashMap<>(workerNum);
Expand All @@ -77,10 +78,22 @@ protected ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType u
_ec = ec;
_updateType = updateType;
_finishedStates = new boolean[workerNum];
setupAggFunc(_ec, aggFunc);

// broadcast initial model
try {
broadcastModel();
}
catch (InterruptedException e) {
throw new DMLRuntimeException("Param server: failed to broadcast the initial model.", e);
}
}

public void setupAggFunc(ExecutionContext ec, String aggFunc) {
String[] cfn = ParamservUtils.getCompleteFuncName(aggFunc, PS_FUNC_PREFIX);
String ns = cfn[0];
String fname = cfn[1];
FunctionProgramBlock func = _ec.getProgram().getFunctionProgramBlock(ns, fname);
FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(ns, fname);
ArrayList<DataIdentifier> inputs = func.getInputParams();
ArrayList<DataIdentifier> outputs = func.getOutputParams();

Expand All @@ -101,19 +114,11 @@ protected ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType u
ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
.collect(Collectors.toCollection(ArrayList::new));
_inst = new FunctionCallCPInstruction(ns, fname, boundInputs, inputNames, outputNames, "aggregate function");

// broadcast initial model
try {
broadcastModel();
}
catch (InterruptedException e) {
throw new DMLRuntimeException("Param server: failed to broadcast the initial model.", e);
}
}

public abstract void push(int workerID, ListObject value);

public abstract Data pull(int workerID);
public abstract ListObject pull(int workerID);

public ListObject getResult() {
// All the model updating work has terminated,
Expand Down
Expand Up @@ -28,8 +28,11 @@
import java.util.stream.IntStream;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.MultiThreadedHop;
Expand Down Expand Up @@ -57,6 +60,7 @@
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkAggregator;
import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkMapper;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.ListObject;
Expand All @@ -68,13 +72,14 @@
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.util.ProgramConverter;
import org.apache.sysml.utils.Statistics;

import scala.Tuple2;

public class ParamservUtils {

protected static final Log LOG = LogFactory.getLog(ParamservUtils.class.getName());
public static final String PS_FUNC_PREFIX = "_ps_";

public static long SEED = -1; // Used for generating permutation

/**
Expand Down Expand Up @@ -140,6 +145,14 @@ public static void cleanupData(ExecutionContext ec, Data data) {
CacheableData<?> cd = (CacheableData<?>) data;
cd.enableCleanup(true);
ec.cleanupCacheableData(cd);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s has been deleted.", cd.getFileName()));
}
}

public static void cleanupMatrixObject(ExecutionContext ec, MatrixObject mo) {
mo.enableCleanup(true);
ec.cleanupCacheableData(mo);
}

public static MatrixObject newMatrixObject(MatrixBlock mb) {
Expand Down Expand Up @@ -365,40 +378,42 @@ public int numPartitions() {

@SuppressWarnings("unchecked")
public static JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> doPartitionOnSpark(SparkExecutionContext sec, MatrixObject features, MatrixObject labels, Statement.PSScheme scheme, int workerNum) {
Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
// Get input RDD
JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD = (JavaPairRDD<MatrixIndexes, MatrixBlock>)
sec.getRDDHandleForMatrixObject(features, InputInfo.BinaryBlockInputInfo);
JavaPairRDD<MatrixIndexes, MatrixBlock> labelsRDD = (JavaPairRDD<MatrixIndexes, MatrixBlock>)
sec.getRDDHandleForMatrixObject(labels, InputInfo.BinaryBlockInputInfo);

DataPartitionerSparkMapper mapper = new DataPartitionerSparkMapper(scheme, workerNum, sec, (int) features.getNumRows());
return ParamservUtils.assembleTrainingData(features.getNumRows(), featuresRDD, labelsRDD) // Combine features and labels into a pair (rowBlockID => (features, labels))
JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> result = ParamservUtils
.assembleTrainingData(features.getNumRows(), featuresRDD, labelsRDD) // Combine features and labels into a pair (rowBlockID => (features, labels))
.flatMapToPair(mapper) // Do the data partitioning on spark (workerID => (rowBlockID, (single row features, single row labels))
// Aggregate the partitioned matrix according to rowID for each worker
// i.e. (workerID => ordered list[(rowBlockID, (single row features, single row labels)]
.aggregateByKey(new LinkedList<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>(),
new Partitioner() {
private static final long serialVersionUID = -7937781374718031224L;
@Override
public int getPartition(Object workerID) {
return (int) workerID;
}
@Override
public int numPartitions() {
return workerNum;
}
},
(list, input) -> {
list.add(input);
return list;
},
(l1, l2) -> {
l1.addAll(l2);
l1.sort((o1, o2) -> o1._1.compareTo(o2._1));
return l1;
})
.mapToPair(new DataPartitionerSparkAggregator(
features.getNumColumns(), labels.getNumColumns()));
.aggregateByKey(new LinkedList<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>(), new Partitioner() {
private static final long serialVersionUID = -7937781374718031224L;
@Override
public int getPartition(Object workerID) {
return (int) workerID;
}
@Override
public int numPartitions() {
return workerNum;
}
}, (list, input) -> {
list.add(input);
return list;
}, (l1, l2) -> {
l1.addAll(l2);
l1.sort((o1, o2) -> o1._1.compareTo(o2._1));
return l1;
})
.mapToPair(new DataPartitionerSparkAggregator(features.getNumColumns(), labels.getNumColumns()));

if (DMLScript.STATISTICS)
Statistics.accPSSetupTime((long) tSetup.stop());
return result;
}

public static ListObject accrueGradients(ListObject accGradients, ListObject gradients) {
Expand Down
Expand Up @@ -28,12 +28,10 @@ public class SparkPSBody {

private ExecutionContext _ec;

public SparkPSBody() {

}
public SparkPSBody() {}

public SparkPSBody(ExecutionContext ec) {
this._ec = ec;
_ec = ec;
}

public ExecutionContext getEc() {
Expand Down

0 comments on commit 15ecb72

Please sign in to comment.