Skip to content

Commit

Permalink
[SYSTEMML-2446] Fix paramserv model list cleanup for partial updates
Browse files Browse the repository at this point in the history
Closes #802.
  • Loading branch information
EdgarLGB authored and mboehm7 committed Jul 19, 2018
1 parent 9593b7f commit bca1f1c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 12 deletions.
Expand Up @@ -85,7 +85,7 @@ private void computeEpoch(long dataSize, int totalIter) {
// Update the local model with gradients
if( j < totalIter - 1 )
params = updateModel(params, gradients, i, j, totalIter);
ParamservUtils.cleanupListObject(gradients);
ParamservUtils.cleanupListObject(_ec, gradients);
}

// Push the gradients to ps
Expand Down Expand Up @@ -183,8 +183,8 @@ private ListObject computeGradients(long dataSize, int totalIter, int i, int j)
// Get the gradients
ListObject gradients = (ListObject) _ec.getVariable(_output.getName());

ParamservUtils.cleanupData(bFeatures);
ParamservUtils.cleanupData(bLabels);
ParamservUtils.cleanupData(_ec, bFeatures);
ParamservUtils.cleanupData(_ec, bLabels);
return gradients;
}
}
Expand Up @@ -138,7 +138,7 @@ protected synchronized void updateGlobalModel(int workerID, ListObject gradients
_accGradients, gradients, true);
else
updateGlobalModel(gradients);
ParamservUtils.cleanupListObject(gradients);
ParamservUtils.cleanupListObject(_ec, gradients);

if (allFinished()) {
// Update the global model with accrued gradients
Expand Down Expand Up @@ -192,11 +192,11 @@ protected ListObject updateLocalModel(ExecutionContext ec, ListObject gradients,
// Invoke the aggregate function
_inst.processInstruction(ec);

// Get the output
// Get the new model
ListObject newModel = (ListObject) ec.getVariable(_outputName);

// Update the model with the new output
ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL);
// Clean up the list according to the data referencing status
ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL, newModel.getStatus());
ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS);
return newModel;
}
Expand Down
Expand Up @@ -101,21 +101,45 @@ public static ListObject copyList(ListObject lo) {
return new ListObject(newData, lo.getNames());
}

/**
* Clean up the list object according to its own data status
* @param ec execution context
* @param lName list var name
*/
public static void cleanupListObject(ExecutionContext ec, String lName) {
ListObject lo = (ListObject) ec.removeVariable(lName);
cleanupListObject(lo);
cleanupListObject(ec, lo, lo.getStatus());
}

/**
* Clean up the list object according to the given array of data status (i.e., false => not be removed)
* @param ec execution context
* @param lName list var name
* @param status data status
*/
public static void cleanupListObject(ExecutionContext ec, String lName, boolean[] status) {
ListObject lo = (ListObject) ec.removeVariable(lName);
cleanupListObject(ec, lo, status);
}

public static void cleanupListObject(ListObject lo) {
lo.getData().forEach(ParamservUtils::cleanupData);
public static void cleanupListObject(ExecutionContext ec, ListObject lo) {
cleanupListObject(ec, lo, lo.getStatus());
}

public static void cleanupListObject(ExecutionContext ec, ListObject lo, boolean[] status) {
for (int i = 0; i < lo.getLength(); i++) {
if (status != null && !status[i])
continue; // data ref by other object must not be cleaned up
ParamservUtils.cleanupData(ec, lo.getData().get(i));
}
}

public static void cleanupData(Data data) {
public static void cleanupData(ExecutionContext ec, Data data) {
if (!(data instanceof CacheableData))
return;
CacheableData<?> cd = (CacheableData<?>) data;
cd.enableCleanup(true);
cd.clearData();
ec.cleanupCacheableData(cd);
}

public static MatrixObject newMatrixObject(MatrixBlock mb) {
Expand Down

0 comments on commit bca1f1c

Please sign in to comment.