Skip to content

Commit

Permalink
cleanup SDRNN and rnn ops (#238)
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Nett <rnett@skymind.io>
  • Loading branch information
Ryan Nett authored and AlexDBlack committed Sep 5, 2019
1 parent 7d85775 commit 79867f5
Show file tree
Hide file tree
Showing 23 changed files with 931 additions and 265 deletions.
Expand Up @@ -34,7 +34,7 @@ CUSTOM_OP_IMPL(sruCell, 4, 2, false, 0, 0) {
auto xt = INPUT_VARIABLE(0); // input [bS x inSize], bS - batch size, inSize - number of features
auto ct_1 = INPUT_VARIABLE(1); // previous cell state ct [bS x inSize], that is at previous time step t-1
auto w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize]
auto b = INPUT_VARIABLE(3); // biases [1 x 2*inSize]
auto b = INPUT_VARIABLE(3); // biases [2*inSize]

auto ht = OUTPUT_VARIABLE(0); // current cell output [bS x inSize], that is at current time step t
auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x inSize], that is at current time step t
Expand Down
Expand Up @@ -6511,4 +6511,22 @@ public String generateNewVarName(String base, int argIndex, boolean existingOp)
public String generateNewVarName(String base, int argIndex) {
return generateNewVarName(base, argIndex, true);
}

/**
* Returns an unused variable name of the format &lt;base&gt;_#.
*
* Intended to be used for custom variables (like weights), arguments and op outputs should use {@link #generateNewVarName(String, int)}.
*/
public String generateDistinctCustomVariableName(String base){
if(!variables.containsKey(base))
return base;

int inc = 1;

while(variables.containsKey(base + "_" + inc)){
inc++;
}

return base + "_" + inc;
}
}
Expand Up @@ -16,13 +16,23 @@

package org.nd4j.autodiff.samediff.ops;

import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.*;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.*;

import java.util.Arrays;
import java.util.List;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRUCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRULayerOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
import org.nd4j.linalg.primitives.Pair;

/**
* SameDiff Recurrent Neural Network operations<br>
Expand All @@ -39,90 +49,163 @@ public SDRNN(SameDiff sameDiff) {


/**
* The gru cell
*
* @param configuration the configuration to use
* @return
* See {@link #gru(String, SDVariable, SDVariable, GRUWeights)}.
*/
public List<SDVariable> gru(GRUCellConfiguration configuration) {
GRUCell c = new GRUCell(sd, configuration);
return Arrays.asList(c.outputVariables());
public GRUCellOutputs gru(@NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) {
GRUCell c = new GRUCell(sd, x, hLast, weights);
return new GRUCellOutputs(c.outputVariables());
}

/**
* The gru cell
* The GRU cell. Does a single time step operation.
*
* @param baseName the base name for the gru cell
* @param configuration the configuration to use
* @return
* @param baseName The base name for the gru cell
* @param x Input, with shape [batchSize, inSize]
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits]
* @param weights The cell's weights.
* @return The cell's outputs.
*/
public List<SDVariable> gru(String baseName, GRUCellConfiguration configuration) {
GRUCell c = new GRUCell(sd, configuration);
return Arrays.asList(c.outputVariables(baseName));
public GRUCellOutputs gru(String baseName, @NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) {
GRUCell c = new GRUCell(sd, x, hLast, weights);
return new GRUCellOutputs(c.outputVariables(baseName));
}

/**
* See {@link #lstmCell(String, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}.
*/
public LSTMCellOutputs lstmCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
LSTMWeights weights, LSTMConfiguration config){
LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config);
return new LSTMCellOutputs(c.outputVariables());
}

/**
* LSTM unit
* The LSTM cell. Does a single time step operation.
*
* @param baseName the base name for outputs
* @param configuration the configuration to use
* @return
* @param baseName The base name for the lstm cell
* @param x Input, with shape [batchSize, inSize]
* @param cLast Previous cell state, with shape [batchSize, numUnits]
* @param yLast Previous cell output, with shape [batchSize, numUnits]
* @param weights The cell's weights.
* @param config The cell's config.
* @return The cell's outputs.
*/
public SDVariable lstmCell(String baseName, LSTMCellConfiguration configuration) {
return new LSTMCell(sd, configuration).outputVariables(baseName)[0];
public LSTMCellOutputs lstmCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config);
return new LSTMCellOutputs(c.outputVariables(baseName));
}

public List<SDVariable> lstmBlockCell(String name, LSTMBlockCellConfiguration configuration){
SDVariable[] v = new LSTMBlockCell(sd, configuration).outputVariables(name);
return Arrays.asList(v);
/**
* See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
*/
public LSTMLayerOutputs lstmLayer(@NonNull SDVariable maxTSLength,
@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config);
return new LSTMLayerOutputs(c.outputVariables(), config.getDataFormat());
}

public List<SDVariable> lstmLayer(String name, LSTMConfiguration configuration){
SDVariable[] v = new LSTMLayer(sd, configuration).outputVariables(name);
return Arrays.asList(v);
/**
* See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
*/
public LSTMLayerOutputs lstmLayer(int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
return lstmLayer(
sd.scalar("lstm_max_ts_length", maxTSLength),
x, cLast, yLast, weights, config);
}

/**
* Simple recurrent unit
*
* @param configuration the configuration for the sru
* @return
* See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
*/
public SDVariable sru(SRUConfiguration configuration) {
return new SRU(sd, configuration).outputVariables()[0];
public LSTMLayerOutputs lstmLayer(String baseName, int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
if(baseName != null) {
return lstmLayer(baseName,
sd.scalar(sd.generateDistinctCustomVariableName(baseName + "_max_ts_length"), maxTSLength),
x, cLast, yLast, weights, config);
} else {
return lstmLayer(maxTSLength, x, cLast, yLast, weights, config);
}
}

/**
* Simiple recurrent unit
* The LSTM layer. Does multiple time steps.
*
* Input shape depends on data format (in config):<br>
* TNS -> [timeSteps, batchSize, inSize]<br>
* NST -> [batchSize, inSize, timeSteps]<br>
* NTS -> [batchSize, timeSteps, inSize]<br>
*
* @param baseName the base name to use for output variables
* @param configuration the configuration for the sru
* @return
* @param baseName The base name for the lstm layer
* @param x Input, with shape dependent on the data format (in config).
* @param cLast Previous/initial cell state, with shape [batchSize, numUnits]
* @param yLast Previous/initial cell output, with shape [batchSize, numUnits]
* @param weights The layer's weights.
* @param config The layer's config.
* @return The layer's outputs.
*/
public LSTMLayerOutputs lstmLayer(String baseName, @NonNull SDVariable maxTSLength,
@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config);
return new LSTMLayerOutputs(c.outputVariables(baseName), config.getDataFormat());
}

/**
* See {@link #sruCell(String, SDVariable, SDVariable, SRUWeights)}.
*/
public SDVariable sru(String baseName, SRUConfiguration configuration) {
return new SRU(sd, configuration).outputVariables(baseName)[0];
public SRUCellOutputs sruCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) {
return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables());
}

/**
* An sru cell
* The SRU cell. Does a single time step operation.
*
* @param configuration the configuration for the sru cell
* @return
* @param baseName The base name for the sru cell
* @param x Input, with shape [batchSize, inSize]
* @param cLast Previous cell state, with shape [batchSize, inSize]
* @param weights The cell's weights.
* @return The cell's outputs.
*/
public SRUCellOutputs sruCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) {
return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables(baseName));
}

/**
* See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
*/
public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
return sru(x, initialC, null, weights);
}

/**
* See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
*/
public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
return sru(baseName, x, initialC, null, weights);
}

/**
* See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
*/
public SDVariable sruCell(SRUCellConfiguration configuration) {
return new SRUCell(sd, configuration).outputVariables()[0];
public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables());
}

/**
* An sru cell
* The SRU layer. Does a single time step operation.
*
* @param baseName the base name to use for the output variables
* @param configuration the configuration for the sru cell
* @return
* @param baseName The base name for the sru layer
* @param x Input, with shape [batchSize, inSize, timeSeriesLength]
* @param initialC Initial cell state, with shape [batchSize, inSize]
* @param mask An optional dropout mask, with shape [batchSize, inSize]
* @param weights The layer's weights.
* @return The layer's outputs.
*/
public SDVariable sruCell(String baseName, SRUCellConfiguration configuration) {
return new SRUCell(sd, configuration).outputVariables(baseName)[0];
public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables(baseName));
}

}
Expand Up @@ -16,13 +16,15 @@

package org.nd4j.linalg.api.ops.impl.layers.recurrent;

import lombok.Getter;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
Expand All @@ -39,14 +41,15 @@
*/
public class GRUCell extends DynamicCustomOp {

private GRUCellConfiguration configuration;
@Getter
private GRUWeights weights;

public GRUCell() {
}

public GRUCell(SameDiff sameDiff, GRUCellConfiguration configuration) {
super(null, sameDiff, configuration.args());
this.configuration = configuration;
public GRUCell(SameDiff sameDiff, SDVariable x, SDVariable hLast, GRUWeights weights) {
super(null, sameDiff, weights.argsWithInputs(x, hLast));
this.weights = weights;
}

@Override
Expand Down
Expand Up @@ -16,12 +16,15 @@

package org.nd4j.linalg.api.ops.impl.layers.recurrent;

import lombok.Getter;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMBlockCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.primitives.Pair;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
Expand Down Expand Up @@ -49,10 +52,12 @@
* 6: weights - cell peephole (t) connections to output gate, [numUnits]<br>
* 7: biases, shape [4*numUnits]<br>
* <br>
* Input integer arguments: set via {@link LSTMBlockCellConfiguration}<br>
* Weights are set via {@link LSTMWeights}.<br>
* <br>
* Input integer arguments: set via {@link LSTMConfiguration}<br>
* 0: if not zero, provide peephole connections<br>
* <br>
* Input float arguments: set via {@link LSTMBlockCellConfiguration}<br>
* Input float arguments: set via {@link LSTMConfiguration}<br>
* 0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training<br>
* 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped<br>
* <br>
Expand All @@ -69,15 +74,19 @@
*/
public class LSTMBlockCell extends DynamicCustomOp {

private LSTMBlockCellConfiguration configuration;
private LSTMConfiguration configuration;

@Getter
private LSTMWeights weights;

public LSTMBlockCell() {
}

public LSTMBlockCell(SameDiff sameDiff, LSTMBlockCellConfiguration configuration) {
super(null, sameDiff, configuration.args());
public LSTMBlockCell(SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration) {
super(null, sameDiff, weights.argsWithInputs(x, cLast, yLast));
this.configuration = configuration;
addIArgument(configuration.iArgs());
this.weights = weights;
addIArgument(configuration.iArgs(false));
addTArgument(configuration.tArgs());
}

Expand All @@ -97,12 +106,12 @@ public List<SDVariable> doDiff(List<SDVariable> grads) {

@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
configuration = LSTMBlockCellConfiguration.builder()
configuration = LSTMConfiguration.builder()
.forgetBias(attributesForNode.get("forget_bias").getF())
.clippingCellValue(attributesForNode.get("cell_clip").getF())
.peepHole(attributesForNode.get("use_peephole").getB())
.build();
addIArgument(configuration.iArgs());
addIArgument(configuration.iArgs(false));
addTArgument(configuration.tArgs());
}

Expand All @@ -113,7 +122,7 @@ public String opName() {

@Override
public Map<String, Object> propertiesForFunction() {
return configuration.toProperties();
return configuration.toProperties(false);
}

@Override
Expand Down

0 comments on commit 79867f5

Please sign in to comment.