Skip to content

Commit

Permalink
[SYSTEMML-2403] Fix accuracy issue paramserv BSP batch updates
Browse files Browse the repository at this point in the history
Closes #791.
  • Loading branch information
EdgarLGB authored and mboehm7 committed Jul 8, 2018
1 parent eb179b1 commit 63a1e2a
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 25 deletions.
Expand Up @@ -20,7 +20,6 @@
package org.apache.sysml.runtime.controlprogram.paramserv;

import java.util.concurrent.Callable;
import java.util.stream.IntStream;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand All @@ -30,10 +29,7 @@
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.cp.ListObject;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.utils.Statistics;

public class LocalPSWorker extends PSWorker implements Callable<Void> {
Expand Down Expand Up @@ -84,13 +80,12 @@ private void computeEpoch(long dataSize, int totalIter) {
ListObject gradients = computeGradients(dataSize, totalIter, i, j);

// Accumulate the intermediate gradients
accGradients = (accGradients==null) ?
ParamservUtils.copyList(gradients) :
accrueGradients(accGradients, gradients);
accGradients = ParamservUtils.accrueGradients(accGradients, gradients);

// Update the local model with gradients
if( j < totalIter - 1 )
params = updateModel(params, gradients, i, j, totalIter);
ParamservUtils.cleanupListObject(gradients);
}

// Push the gradients to ps
Expand Down Expand Up @@ -193,14 +188,4 @@ private ListObject computeGradients(long dataSize, int totalIter, int i, int j)
return gradients;
}

private ListObject accrueGradients(ListObject accGradients, ListObject gradients) {
IntStream.range(0, accGradients.getLength()).forEach(i -> {
MatrixBlock mb1 = ((MatrixObject) accGradients.getData().get(i)).acquireRead();
MatrixBlock mb2 = ((MatrixObject) gradients.getData().get(i)).acquireRead();
mb1.binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), mb2);
((MatrixObject) accGradients.getData().get(i)).release();
((MatrixObject) gradients.getData().get(i)).release();
});
return accGradients;
}
}
Expand Up @@ -49,7 +49,8 @@

public abstract class ParamServer
{
protected final Log LOG = LogFactory.getLog(ParamServer.class.getName());
protected static final Log LOG = LogFactory.getLog(ParamServer.class.getName());
protected static final boolean ACCRUE_BSP_GRADIENTS = true;

// worker input queues and global model
protected final Map<Integer, BlockingQueue<ListObject>> _modelMap;
Expand All @@ -61,6 +62,7 @@ public abstract class ParamServer
private final FunctionCallCPInstruction _inst;
private final String _outputName;
private final boolean[] _finishedStates; // Workers' finished states
private ListObject _accGradients = null;

protected ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
// init worker queues and global model
Expand Down Expand Up @@ -126,17 +128,25 @@ protected synchronized void updateGlobalModel(int workerID, ListObject gradients
gradients.getDataSize() / 1024, workerID));
}

// Update and redistribute the model
Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
_model = updateLocalModel(_ec, gradients, _model);
if (DMLScript.STATISTICS)
Statistics.accPSAggregationTime((long) tAgg.stop());

// Redistribute model according to update type
switch(_updateType) {
case BSP: {
setFinishedState(workerID);

// Accumulate the intermediate gradients
if( ACCRUE_BSP_GRADIENTS )
_accGradients = ParamservUtils.accrueGradients(
_accGradients, gradients, true);
else
updateGlobalModel(gradients);
ParamservUtils.cleanupListObject(gradients);

if (allFinished()) {
// Update the global model with accrued gradients
if( ACCRUE_BSP_GRADIENTS ) {
updateGlobalModel(_accGradients);
_accGradients = null;
}

// Broadcast the updated model
resetFinishedStates();
broadcastModel();
Expand All @@ -146,6 +156,7 @@ protected synchronized void updateGlobalModel(int workerID, ListObject gradients
break;
}
case ASP: {
updateGlobalModel(gradients);
broadcastModel(workerID);
break;
}
Expand All @@ -158,6 +169,13 @@ protected synchronized void updateGlobalModel(int workerID, ListObject gradients
}
}

private void updateGlobalModel(ListObject gradients) {
Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
_model = updateLocalModel(_ec, gradients, _model);
if (DMLScript.STATISTICS)
Statistics.accPSAggregationTime((long) tAgg.stop());
}

/**
* A service method for updating model with gradients
*
Expand Down
Expand Up @@ -50,13 +50,15 @@
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.ListObject;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.MetaDataFormat;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;

public class ParamservUtils {

Expand Down Expand Up @@ -88,6 +90,10 @@ public static ListObject copyList(ListObject lo) {

public static void cleanupListObject(ExecutionContext ec, String lName) {
ListObject lo = (ListObject) ec.removeVariable(lName);
cleanupListObject(lo);
}

public static void cleanupListObject(ListObject lo) {
lo.getData().forEach(ParamservUtils::cleanupData);
}

Expand Down Expand Up @@ -258,4 +264,22 @@ private static FunctionProgramBlock getFunctionBlock(ExecutionContext ec, String
String fname = cfn[1];
return ec.getProgram().getFunctionProgramBlock(ns, fname);
}

public static ListObject accrueGradients(ListObject accGradients, ListObject gradients) {
return accrueGradients(accGradients, gradients, false);
}

public static ListObject accrueGradients(ListObject accGradients, ListObject gradients, boolean par) {
if (accGradients == null)
return ParamservUtils.copyList(gradients);
IntStream range = IntStream.range(0, accGradients.getLength());
(par ? range.parallel() : range).forEach(i -> {
MatrixBlock mb1 = ((MatrixObject) accGradients.getData().get(i)).acquireRead();
MatrixBlock mb2 = ((MatrixObject) gradients.getData().get(i)).acquireRead();
mb1.binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), mb2);
((MatrixObject) accGradients.getData().get(i)).release();
((MatrixObject) gradients.getData().get(i)).release();
});
return accGradients;
}
}

0 comments on commit 63a1e2a

Please sign in to comment.