Skip to content

Commit

Permalink
SameDiff: Listener changes and training api update (#99)
Browse files Browse the repository at this point in the history
* example api

Signed-off-by: Ryan Nett <rnett@skymind.io>

* Lambda based evaluation

Signed-off-by: Ryan Nett <rnett@skymind.io>

* lambda test

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* partial fixes, use get-variable listener framework, example EvaluationListener

Signed-off-by: Ryan Nett <rnett@skymind.io>

* javadoc fix and newInstance implementations

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fit and evaluate methods with validation data (for fit) and listeners

Signed-off-by: Ryan Nett <rnett@skymind.io>

* output method overloads + listener args

Signed-off-by: Ryan Nett <rnett@skymind.io>

* history and evaluation helpers

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* more fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* FitConfig and added getters and setters

Signed-off-by: Ryan Nett <rnett@skymind.io>

* javadocs

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes, javadoc, added activations to history, added latest activation listener

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes, start of tests

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes and updates

Signed-off-by: Ryan Nett <rnett@skymind.io>

* newInstance fixes, tests

Signed-off-by: Ryan Nett <rnett@skymind.io>

* test fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* javadocs, getters with SDVariable overrides, CustomEvaluation fix

Signed-off-by: Ryan Nett <rnett@skymind.io>

* more operation config classes (evaluation, output, exec/single batch output), fix custom eval tests

Signed-off-by: Ryan Nett <rnett@skymind.io>

* merge fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fix

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes, most old fit/evaluate/output methods use the builders

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* numerous fixes/cleanup

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* javadoc

Signed-off-by: Ryan Nett <rnett@skymind.io>

* Polish round 1

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Round 2

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Formatting + round 3

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Round 4

Signed-off-by: AlexDBlack <blacka101@gmail.com>
  • Loading branch information
Ryan Nett authored and AlexDBlack committed Aug 10, 2019
1 parent 6ed0321 commit 11bddb3
Show file tree
Hide file tree
Showing 55 changed files with 5,159 additions and 767 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/**
* Score function for regression (including multi-label regression) for a MultiLayerNetwork or ComputationGraph
* on a test set. Supports all regression metrics: {@link RegressionEvaluation.Metric}
* on a test set. Supports all regression metrics: {@link Metric}
*
* @author Alex Black
*/
Expand All @@ -35,13 +36,13 @@
@NoArgsConstructor(access = AccessLevel.PROTECTED) //For JSON
public class RegressionScoreFunction extends BaseNetScoreFunction {

protected RegressionEvaluation.Metric metric;
protected Metric metric;

public RegressionScoreFunction(@NonNull org.deeplearning4j.eval.RegressionEvaluation.Metric metric) {
this(metric.toNd4j());
}

public RegressionScoreFunction(@NonNull RegressionEvaluation.Metric metric) {
public RegressionScoreFunction(@NonNull Metric metric) {
this.metric = metric;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
import org.junit.rules.TemporaryFolder;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
Expand Down Expand Up @@ -107,7 +107,7 @@ public void testEarlyStoppingIris() {
min = false;
break;
case 3:
sc = new RegressionScoreCalculator(RegressionEvaluation.Metric.MSE, irisIter);
sc = new RegressionScoreCalculator(Metric.MSE, irisIter);
min = true;
break;
case 4:
Expand Down Expand Up @@ -561,8 +561,8 @@ public void onCompletion(EarlyStoppingResult esResult) {
@Test
public void testRegressionScoreFunctionSimple() throws Exception {

for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
RegressionEvaluation.Metric.MAE}) {
for(Metric metric : new Metric[]{Metric.MSE,
Metric.MAE}) {
log.info("Metric: " + metric);

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
Expand Down Expand Up @@ -604,8 +604,8 @@ public void testRegressionScoreFunctionSimple() throws Exception {
@Test
public void testAEScoreFunctionSimple() throws Exception {

for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
RegressionEvaluation.Metric.MAE}) {
for(Metric metric : new Metric[]{Metric.MSE,
Metric.MAE}) {
log.info("Metric: " + metric);

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
Expand Down Expand Up @@ -647,8 +647,8 @@ public void testAEScoreFunctionSimple() throws Exception {
@Test
public void testVAEScoreFunctionSimple() throws Exception {

for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
RegressionEvaluation.Metric.MAE}) {
for(Metric metric : new Metric[]{Metric.MSE,
Metric.MAE}) {
log.info("Metric: " + metric);

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.DataSet;
Expand Down Expand Up @@ -289,8 +289,8 @@ public void onCompletion(EarlyStoppingResult esResult) {
@Test
public void testRegressionScoreFunctionSimple() throws Exception {

for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
RegressionEvaluation.Metric.MAE}) {
for(Metric metric : new Metric[]{Metric.MSE,
Metric.MAE}) {
log.info("Metric: " + metric);

ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
Expand Down Expand Up @@ -335,8 +335,8 @@ public void testRegressionScoreFunctionSimple() throws Exception {
public void testAEScoreFunctionSimple() throws Exception {
DataType dt = Nd4j.defaultFloatingPointType();

for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
RegressionEvaluation.Metric.MAE}) {
for(Metric metric : new Metric[]{Metric.MSE,
Metric.MAE}) {
log.info("Metric: " + metric);

ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
Expand Down Expand Up @@ -380,8 +380,8 @@ public void testAEScoreFunctionSimple() throws Exception {
@Test
public void testVAEScoreFunctionSimple() throws Exception {

for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
RegressionEvaluation.Metric.MAE}) {
for(Metric metric : new Metric[]{Metric.MSE,
Metric.MAE}) {
log.info("Metric: " + metric);

ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,24 @@
import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;

/**
* Score function for a MultiLayerNetwork or ComputationGraph with a single
* {@link org.deeplearning4j.nn.conf.layers.AutoEncoder} layer.
* Calculates the specified {@link RegressionEvaluation.Metric} on the layer's reconstructions.
* Calculates the specified {@link Metric} on the layer's reconstructions.
*
* @author Alex Black
*/
public class AutoencoderScoreCalculator extends BaseScoreCalculator<Model> {

protected final RegressionEvaluation.Metric metric;
protected final Metric metric;
protected RegressionEvaluation evaluation;

public AutoencoderScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator){
public AutoencoderScoreCalculator(Metric metric, DataSetIterator iterator){
super(iterator);
this.metric = metric;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@
import org.deeplearning4j.earlystopping.scorecalc.base.BaseIEvaluationScoreCalculator;
import org.deeplearning4j.nn.api.Model;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

/**
* Calculate the regression score of the network (MultiLayerNetwork or ComputationGraph) on a test set, using the
* specified regression metric - {@link RegressionEvaluation.Metric}
* specified regression metric - {@link Metric}
*
* @author Alex Black
*/
public class RegressionScoreCalculator extends BaseIEvaluationScoreCalculator<Model, RegressionEvaluation> {

protected final RegressionEvaluation.Metric metric;
protected final Metric metric;

public RegressionScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator){
public RegressionScoreCalculator(Metric metric, DataSetIterator iterator){
super(iterator);
this.metric = metric;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
Expand All @@ -35,7 +36,7 @@
*/
public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<Model> {

protected final RegressionEvaluation.Metric metric;
protected final Metric metric;
protected RegressionEvaluation evaluation;

/**
Expand All @@ -44,7 +45,7 @@ public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<Model> {
* @param metric
* @param iterator
*/
public VAEReconErrorScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator) {
public VAEReconErrorScoreCalculator(Metric metric, DataSetIterator iterator) {
super(iterator);
this.metric = metric;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,22 @@ public class At {
private int iteration;
private int trainingThreadNum;
private long javaThreadNum;
private Operation operation;

/**
* @return A new instance with everything set to 0, and operation set to INFERENCE
*/
public static At defaultAt(){
return new At(0, 0, 0, 0, Operation.INFERENCE);
}

/**
* @param op Operation
* @return A new instance with everything set to 0, except for the specified operation
*/
public static At defaultAt(@NonNull Operation op){
return new At(0, 0, 0, 0, op);
}

/**
* @return The current training epoch
Expand Down Expand Up @@ -48,4 +64,26 @@ public int trainingThreadNum(){
public long javaThreadNum(){
return javaThreadNum;
}

/**
* @return The current operation
*/
public Operation operation(){
return operation;
}

/**
* @return A copy of the current At instance
*/
public At copy(){
return new At(epoch, iteration, trainingThreadNum, javaThreadNum, operation);
}

/**
* @param operation Operation to set in the new instance
* @return A copy of the current instance, but with the specified operation
*/
public At copy(Operation operation){
return new At(epoch, iteration, trainingThreadNum, javaThreadNum, operation);
}
}
Loading

0 comments on commit 11bddb3

Please sign in to comment.