Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks * Fixes + tests for activation gradient checking * Javadoc
- Loading branch information
1 parent
f75ffa9
commit c1db0e8
Showing
3 changed files
with
342 additions
and
0 deletions.
There are no files selected for viewing
52 changes: 52 additions & 0 deletions
52
.../nd4j-api/src/main/java/org/nd4j/autodiff/validation/ActivationGradientCheckListener.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
package org.nd4j.autodiff.validation; | ||
|
||
import lombok.Getter; | ||
import lombok.NoArgsConstructor; | ||
import lombok.Setter; | ||
import org.nd4j.autodiff.listeners.At; | ||
import org.nd4j.autodiff.listeners.BaseListener; | ||
import org.nd4j.autodiff.samediff.SameDiff; | ||
import org.nd4j.autodiff.samediff.internal.SameDiffOp; | ||
import org.nd4j.base.Preconditions; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* A listener used for debugging and testing purposes - specifically for gradient checking activations internally in | ||
* {@link GradCheckUtil}. It probably isn't useful for anything outside of this. | ||
* | ||
* @author Alex Black | ||
*/ | ||
@NoArgsConstructor | ||
public class ActivationGradientCheckListener extends BaseListener { | ||
|
||
@Getter @Setter | ||
private String variableName; | ||
@Getter @Setter | ||
private long[] idx; | ||
@Getter @Setter | ||
private double eps; | ||
|
||
@Override | ||
public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) { | ||
Preconditions.checkState(variableName != null, "No variable name has been set yet. Variable name must be set before using this listener"); | ||
Preconditions.checkState(eps != 0.0, "Epsilon has not been set"); | ||
|
||
|
||
List<String> outs = op.getOutputsOfOp(); | ||
int i = 0; | ||
for(String s : outs){ | ||
if(variableName.equals(s)){ | ||
Preconditions.checkState(idx != null || outputs[i].isScalar(), | ||
"No index to modify has been set yet. Index must be set before using this listener"); | ||
|
||
double orig = outputs[i].getDouble(idx); | ||
outputs[i].putScalar(idx, orig + eps); | ||
return; | ||
} | ||
i++; | ||
} | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
...ackends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ActivationGradChecks.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
package org.nd4j.autodiff.opvalidation; | ||
|
||
import org.junit.Test; | ||
import org.nd4j.autodiff.samediff.SDVariable; | ||
import org.nd4j.autodiff.samediff.SameDiff; | ||
import org.nd4j.autodiff.validation.GradCheckUtil; | ||
import org.nd4j.linalg.api.buffer.DataType; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.factory.Nd4j; | ||
import org.nd4j.linalg.factory.Nd4jBackend; | ||
|
||
import java.util.Arrays; | ||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import static org.junit.Assert.assertTrue; | ||
|
||
public class ActivationGradChecks extends BaseOpValidation { | ||
|
||
public ActivationGradChecks(Nd4jBackend backend) { | ||
super(backend); | ||
} | ||
|
||
@Test | ||
public void testActivationGradientCheck1(){ | ||
Nd4j.getRandom().setSeed(12345); | ||
SameDiff sd = SameDiff.create(); | ||
SDVariable in = sd.var("x", Nd4j.rand(DataType.DOUBLE, 3, 4)); | ||
SDVariable tanh = sd.math().tanh("tanh", in); | ||
SDVariable loss = tanh.std(true); | ||
|
||
GradCheckUtil.ActGradConfig c = GradCheckUtil.ActGradConfig.builder() | ||
.sd(sd) | ||
.activationGradsToCheck(Collections.singletonList("tanh")) | ||
.build(); | ||
|
||
boolean ok = GradCheckUtil.checkActivationGradients(c); | ||
|
||
assertTrue(ok); | ||
} | ||
|
||
@Test | ||
public void testActivationGradientCheck2(){ | ||
Nd4j.getRandom().setSeed(12345); | ||
SameDiff sd = SameDiff.create(); | ||
SDVariable x = sd.placeHolder("x", DataType.DOUBLE, 3, 4); | ||
SDVariable y = sd.var("y", Nd4j.rand(DataType.DOUBLE, 4, 5)); | ||
SDVariable mmul = x.mmul("mmul", y); | ||
SDVariable sigmoid = sd.math().tanh("sigmoid", mmul); | ||
SDVariable loss = sigmoid.std(true); | ||
|
||
Map<String, INDArray> m = new HashMap<>(); | ||
m.put("x", Nd4j.rand(DataType.DOUBLE, 3, 4)); | ||
|
||
GradCheckUtil.ActGradConfig c = GradCheckUtil.ActGradConfig.builder() | ||
.sd(sd) | ||
.placeholderValues(m) | ||
.activationGradsToCheck(Arrays.asList("sigmoid", "mmul")) | ||
.subset(GradCheckUtil.Subset.RANDOM) | ||
.maxPerParam(10) | ||
.build(); | ||
|
||
boolean ok = GradCheckUtil.checkActivationGradients(c); | ||
|
||
assertTrue(ok); | ||
} | ||
} |