Skip to content

Commit

Permalink
deeplearning4j#8765 CompGraph+MDS fix for SharedTrainingMaster
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Black <blacka101@gmail.com>
  • Loading branch information
AlexDBlack committed Mar 26, 2020
1 parent 5437022 commit 0bda5e3
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,22 @@
<artifactId>nd4j-aeron</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-parameter-server-node_2.11</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-spark_2.11</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-parameter-server-node_2.11</artifactId>
<version>${nd4j.version}</version>
<exclusions>
<exclusion>
<groupId>net.jpountz.lz4</groupId>
<artifactId>lz4</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

/**
* This MultiDataSetIterator implementation does accumulation of MultiDataSets from different Spark executors, wrt Thread/Device Affinity
Expand All @@ -32,14 +33,16 @@
public class VirtualMultiDataSetIterator implements ParallelMultiDataSetIterator {

protected final List<Iterator<MultiDataSet>> iterators;
protected final AtomicInteger position;

public VirtualMultiDataSetIterator(@NonNull List<Iterator<MultiDataSet>> iterators) {
this.iterators = iterators;
this.position = new AtomicInteger(0);
}

@Override
public MultiDataSet next(int num) {
return null;
return next();
}

@Override
Expand All @@ -59,27 +62,34 @@ public boolean resetSupported() {

@Override
public boolean asyncSupported() {
return false;
return true;
}

@Override
public void reset() {

throw new UnsupportedOperationException();
}

@Override
public boolean hasNext() {
return false;
// just checking if that's not the last iterator, or if that's the last one - check if it has something
boolean ret = position.get() < iterators.size() - 1
|| (position.get() < iterators.size() && iterators.get(position.get()).hasNext());
return ret;
}

@Override
public MultiDataSet next() {
return null;
// TODO: this solution isn't ideal, it assumes non-empty iterators all the time. Would be nice to do something here
if (!iterators.get(position.get()).hasNext())
position.getAndIncrement();

return iterators.get(position.get()).next();
}

@Override
public void remove() {

// no-op
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ protected void init() {

// now we're creating DataSetIterators, to feed ParallelWrapper
iteratorDS = new VirtualDataSetIterator(iteratorsDS);
iteratorMDS = new VirtualMultiDataSetIterator(iteratorsMDS);
}

public static synchronized SharedTrainingWrapper getInstance(long id) {
Expand Down Expand Up @@ -447,17 +448,19 @@ public INDArray getUpdaterParameters() {
throw new DL4JInvalidConfigException("No iterators were defined for training");

try {
while((iteratorDS != null && iteratorDS.hasNext()) || (iteratorMDS != null && iteratorMDS.hasNext())) {
boolean dsNext;
boolean mdsNext;
while((dsNext = iteratorDS != null && iteratorDS.hasNext()) || (mdsNext = iteratorMDS != null && iteratorMDS.hasNext())) {
//Loop as a guard against concurrent modifications and RCs

if (wrapper != null) {
if (iteratorDS != null)
if (dsNext)
wrapper.fit(iteratorDS);
else
wrapper.fit(iteratorMDS);
} else {
// if wrapper is null, we're fitting standalone model then
if (iteratorDS != null) {
if (dsNext) {
if (model instanceof ComputationGraph) {
((ComputationGraph) originalModel).fit(iteratorDS);
} else if (model instanceof MultiLayerNetwork) {
Expand All @@ -472,7 +475,8 @@ public INDArray getUpdaterParameters() {
}
}

consumer.getUpdatesQueue().purge();
if(consumer != null)
consumer.getUpdatesQueue().purge();
}
} catch (Throwable t){
log.warn("Exception encountered during fit operation", t);
Expand Down

0 comments on commit 0bda5e3

Please sign in to comment.