Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYSTEMML-2416] Use synchronized method instead of single thread pool #790

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -19,8 +19,6 @@

package org.apache.sysml.runtime.controlprogram.paramserv;

import java.util.concurrent.ExecutionException;

import org.apache.sysml.parser.Statement;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
Expand All @@ -35,16 +33,7 @@ public LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType

@Override
public void push(int workerID, ListObject gradients) {
try {
_gradientsQueue.put(new Gradient(workerID, gradients));
} catch (InterruptedException e) {
throw new DMLRuntimeException(e);
}
try {
launchService();
} catch (ExecutionException | InterruptedException e) {
throw new DMLRuntimeException("Aggregate service: some error occurred: ", e);
}
launchService(new Gradient(workerID, gradients));
}

@Override
Expand Down
Expand Up @@ -27,16 +27,10 @@
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
Expand All @@ -55,14 +49,11 @@

public abstract class ParamServer {

final BlockingQueue<Gradient> _gradientsQueue;
final Map<Integer, BlockingQueue<ListObject>> _modelMap;
private final AggregationService _aggService;
private final ExecutorService _es;
private ListObject _model;

ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
_gradientsQueue = new LinkedBlockingDeque<>();
_modelMap = new HashMap<>(workerNum);
IntStream.range(0, workerNum).forEach(i -> {
// Create a single element blocking queue for workers to receive the broadcasted model
Expand All @@ -76,21 +67,14 @@ public abstract class ParamServer {
catch (InterruptedException e) {
throw new DMLRuntimeException("Param server: failed to broadcast the initial model.", e);
}
BasicThreadFactory factory = new BasicThreadFactory.Builder()
.namingPattern("agg-service-pool-thread-%d").build();
_es = Executors.newSingleThreadExecutor(factory);
}

public abstract void push(int workerID, ListObject value);

public abstract Data pull(int workerID);

void launchService() throws ExecutionException, InterruptedException {
_es.submit(_aggService).get();
}

public void shutdown() {
_es.shutdownNow();
void launchService(Gradient gradient) {
_aggService.run(gradient);
}

public ListObject getResult() {
Expand All @@ -116,7 +100,7 @@ public Gradient(int workerID, ListObject gradients) {
/**
* Inner aggregation service which is for updating the model
*/
private class AggregationService implements Callable<Void> {
private class AggregationService {

protected final Log LOG = LogFactory.getLog(AggregationService.class.getName());

Expand Down Expand Up @@ -191,15 +175,8 @@ private void broadcastModel(int workerID) throws InterruptedException {
Statistics.accPSModelBroadcastTime((long) tBroad.stop());
}

@Override
public Void call() throws Exception {
public synchronized void run(Gradient grad) {
try {
Gradient grad;
try {
grad = _gradientsQueue.take();
} catch (InterruptedException e) {
throw new DMLRuntimeException("Aggregation service: error when waiting for the coming gradients.", e);
}
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Successfully pulled the gradients [size:%d kb] of worker_%d.",
grad._gradients.getDataSize() / 1024, grad._workerID));
Expand Down Expand Up @@ -235,7 +212,6 @@ public Void call() throws Exception {
catch (Exception e) {
throw new DMLRuntimeException("Aggregation service failed: ", e);
}
return null;
}

private ListObject updateModel(ListObject gradients, ListObject model) {
Expand All @@ -244,6 +220,11 @@ private ListObject updateModel(ListObject gradients, ListObject model) {

/**
* A service method for updating model with gradients
*
* @param ec execution context
* @param gradients list of gradients
* @param model old model
* @return new model
*/
private ListObject updateModel(ExecutionContext ec, ListObject gradients, ListObject model) {
// Populate the variables table with the gradients and model
Expand Down
Expand Up @@ -160,8 +160,6 @@ public void processInstruction(ExecutionContext ec) {
throw new DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e);
} finally {
es.shutdownNow();
// Should shutdown the thread pool in param server
ps.shutdown();
}
}

Expand Down