Skip to content

Commit

Permalink
SameDiff: add activation gradient checking support for debugging (#19)
Browse files Browse the repository at this point in the history
* SameDiff gradient checker: first pass on activation gradient checks

* Fixes + tests for activation gradient checking

* Javadoc
  • Loading branch information
AlexDBlack committed Jun 20, 2019
1 parent f75ffa9 commit c1db0e8
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 0 deletions.
@@ -0,0 +1,52 @@
package org.nd4j.autodiff.validation;

import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.util.List;

/**
* A listener used for debugging and testing purposes - specifically for gradient checking activations internally in
* {@link GradCheckUtil}. It probably isn't useful for anything outside of this.
*
* @author Alex Black
*/
@NoArgsConstructor
public class ActivationGradientCheckListener extends BaseListener {

@Getter @Setter
private String variableName;
@Getter @Setter
private long[] idx;
@Getter @Setter
private double eps;

@Override
public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) {
Preconditions.checkState(variableName != null, "No variable name has been set yet. Variable name must be set before using this listener");
Preconditions.checkState(eps != 0.0, "Epsilon has not been set");


List<String> outs = op.getOutputsOfOp();
int i = 0;
for(String s : outs){
if(variableName.equals(s)){
Preconditions.checkState(idx != null || outputs[i].isScalar(),
"No index to modify has been set yet. Index must be set before using this listener");

double orig = outputs[i].getDouble(idx);
outputs[i].putScalar(idx, orig + eps);
return;
}
i++;
}
}

}
Expand Up @@ -16,6 +16,8 @@

package org.nd4j.autodiff.validation;

import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
Expand All @@ -24,6 +26,7 @@
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
Expand Down Expand Up @@ -329,6 +332,225 @@ public static boolean checkGradients(SameDiff sd, Map<String,INDArray> placehold
}


/**
* Gradient check the ACTIVATIONS (i.e., ARRAY type SDVariables) as opposed to the parameters of a network (as
* are tested in {@link #checkGradients(SameDiff, Map, double, double, double, boolean, boolean, boolean, boolean, Set, Map, int, Subset)}
* @param config Configuration for gradient check
* @return True if gradient checks pass
*/
public static boolean checkActivationGradients(ActGradConfig config){
SameDiff sd = config.getSd();
List<String> actGrads = config.getActivationGradsToCheck();
double maxRelError = config.getMaxRelError();
double minAbsError = config.getMinAbsError();

Preconditions.checkState(sd != null, "SameDiff instance was not set in configuration");
Preconditions.checkState(actGrads != null && !actGrads.isEmpty(), "No activation gradients were specified to gradient check");
Preconditions.checkState(config.getEps() > 0.0, "Epsilon has not been set");
Preconditions.checkState(maxRelError > 0.0, "Max relative error must be set (is 0.0)");

for(String s : actGrads){
SDVariable v = sd.getVariables().get(s).getVariable();
Preconditions.checkState(v != null, "No variable with name \"%s\" was found", s);
Preconditions.checkState(v.getVariableType() == VariableType.ARRAY, "Only variables with type ARRAY may be " +
"gradient checked using this method. Variable \"%s\" has type %s", s, v.getVariableType());
Preconditions.checkState(v.dataType().isFPType(), "Cannot gradient check activation variable \"%s\": must be floating point type. Is type: %s", s, v.dataType());
if(v.dataType() != DataType.DOUBLE){
log.warn("Floating point variable {} is not double precision - this may result in spurious failures due to limited precision. Variable is type: {}", s, v.dataType());
}
}

boolean debugBefore = sd.isDebugMode();
if(config.isDebugMode()){
sd.enableDebugMode();
}

//Validation sanity checks:
if(!config.isSkipValidation()){
validateInternalState(sd, true);
}

//Loss function variables
List<String> lossFnVariables = sd.getLossVariables();
Preconditions.checkState(lossFnVariables != null && !lossFnVariables.isEmpty(), "Expected 1 or more loss function variables for gradient check, got %s", lossFnVariables);

//TODO also check that all inputs are non-zero (otherwise: consider out = sum(x * y) with all x and y being 0
// in this case, gradients of x and y are all 0 too

//Collect names of variables to get gradients for - i.e., the names of the GRADIENT variables for the specified activations
sd.createGradFunction();
Set<String> gradVarNames = new HashSet<>();
for(String s : actGrads){
SDVariable grad = sd.getVariable(s).gradient();
Preconditions.checkState( grad != null,"Could not get gradient for activation \"%s\": gradient variable is null", s);
gradVarNames.add(grad.getVarName());
}

//Calculate analytical gradients
sd.execBackwards(config.getPlaceholderValues(), new ArrayList<>(gradVarNames));
Map<String,INDArray> gradientsForAct = new HashMap<>();
for(String s : actGrads){
INDArray arr = sd.getVariable(s).gradient().getArr();
Preconditions.checkState(arr != null, "No activation gradient array for variable \"%s\"", s);
gradientsForAct.put(s, arr.dup());
}


//Now, check gradients
int totalNFailures = 0;
int totalCount = 0;
double maxError = 0.0;
ActivationGradientCheckListener listener = new ActivationGradientCheckListener();
sd.setListeners(listener);
Random r = new Random(12345);
int maxPerParam = config.getMaxPerParam();
for(String s : actGrads){

long n = gradientsForAct.get(s).length();
if(config.isPrint()){
log.info("Starting test for variable \"{}\" with {} values", s, n);
}

Iterator<long[]> iter;
if(maxPerParam > 0 && config.getSubset() != null && maxPerParam < n){
//Subset case
long[] shape = gradientsForAct.get(s).shape();
List<long[]> l = new ArrayList<>();
if(config.getSubset() == Subset.RANDOM){
Set<Integer> set = new HashSet<>();
while(set.size() < maxPerParam){
int next = r.nextInt((int)n);
set.add(next);
}
List<Integer> sorted = new ArrayList<>(set);
Collections.sort(sorted);

for(Integer i : sorted){
long[] pos = Shape.ind2subC(shape, i);
l.add(pos);
}
} else {
//Every N
long everyN = n / maxPerParam;
long curr = 0;
while(curr < n){
long[] pos = Shape.ind2subC(shape, curr);
l.add(pos);
curr += everyN;
}
}
iter = l.iterator();
} else {
//Standard case: do all parameters
iter = new NdIndexIterator('c',gradientsForAct.get(s).shape());
}

INDArray varMask = (config.getGradCheckMask() == null ? null : config.getGradCheckMask().get(s));

listener.setVariableName(s);

int i=0;
while(iter.hasNext()){
long[] idx = iter.next();

String strIdx = null;
if(config.isPrint()){
strIdx = Arrays.toString(idx).replaceAll(" ","");
}

boolean maskValue = (varMask == null || (varMask.getDouble(idx) != 0));
if(!maskValue){
//Skip this specific entry (masked out)
continue;
}

//Set listener to apply eps, then do forward pass:
listener.setIdx(idx);
listener.setEps(config.getEps());
double scorePlus = 0.0;
Map<String,INDArray> m = sd.exec(config.getPlaceholderValues(), lossFnVariables);
for(INDArray arr : m.values()){
scorePlus += arr.sumNumber().doubleValue();
}
listener.setEps(-config.getEps());
m = sd.exec(config.getPlaceholderValues(), lossFnVariables);
double scoreMinus = 0.0;
for(INDArray arr : m.values()){
scoreMinus += arr.sumNumber().doubleValue();
}

double numericalGrad = (scorePlus - scoreMinus) / (2 * config.getEps());
double analyticGrad = gradientsForAct.get(s).getDouble(idx);

if (Double.isInfinite(numericalGrad) || Double.isNaN(numericalGrad)) {
throw new IllegalStateException("Numerical gradient was " + numericalGrad + " for variable \"" + s
+ "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
}
if (Double.isInfinite(analyticGrad) || Double.isNaN(analyticGrad)) {
throw new IllegalStateException("Analytic (SameDiff) gradient was " + analyticGrad + " for variable \"" + s
+ "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
}

double relError;
if(numericalGrad == 0.0 && analyticGrad == 0.0){
relError = 0.0;
} else {
relError = Math.abs(analyticGrad - numericalGrad) / (Math.abs(Math.abs(analyticGrad) + Math.abs(numericalGrad)));
}

if (relError > maxError)
maxError = relError;

if (relError > maxRelError || Double.isNaN(relError)) {
double absError = Math.abs(analyticGrad - numericalGrad);
if (absError < minAbsError) {
if(config.isPrint()) {
log.info("Param " + i + " (" + s + strIdx + ") passed: grad= " + analyticGrad
+ ", numericalGrad= " + numericalGrad + ", relError= " + relError
+ "; absolute error = " + absError + " < minAbsoluteError = " + minAbsError);
}
} else {
if (config.isPrint())
log.info("Param " + i + " (" + s + strIdx + ") FAILED: grad= " + analyticGrad
+ ", numericalGrad= " + numericalGrad + ", relError= " + relError
+ ", absError=" + absError
+ ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
if (config.isExitOnFirstFailure())
return false;
totalNFailures++;
}
} else if (config.isPrint()) {
log.info("Param " + i + " (" + s + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= "
+ numericalGrad + ", relError= " + relError);
}
i++;

}
}

return totalNFailures == 0;
}

@Builder
@Data
public static class ActGradConfig {
private SameDiff sd;
private Map<String,INDArray> placeholderValues;
private List<String> activationGradsToCheck;
@Builder.Default private double eps = DEFAULT_EPS;
@Builder.Default private double maxRelError = DEFAULT_MAX_REL_ERROR;
@Builder.Default private double minAbsError = DEFAULT_MIN_ABS_ERROR;
@Builder.Default private boolean print = DEFAULT_PRINT;
@Builder.Default boolean exitOnFirstFailure = DEFAULT_EXIT_FIRST_FAILURE;
@Builder.Default private boolean skipValidation = false;
@Builder.Default private boolean debugMode = DEFAULT_DEBUG_MODE;
private Set<String> skipVariables;
private Map<String,INDArray> gradCheckMask;
int maxPerParam;
private Subset subset;
}


public static void validateInternalState(SameDiff sd, boolean generateAndCheckGradFn){

/*
Expand Down
@@ -0,0 +1,68 @@
package org.nd4j.autodiff.opvalidation;

import org.junit.Test;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.validation.GradCheckUtil;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static org.junit.Assert.assertTrue;

public class ActivationGradChecks extends BaseOpValidation {

public ActivationGradChecks(Nd4jBackend backend) {
super(backend);
}

@Test
public void testActivationGradientCheck1(){
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("x", Nd4j.rand(DataType.DOUBLE, 3, 4));
SDVariable tanh = sd.math().tanh("tanh", in);
SDVariable loss = tanh.std(true);

GradCheckUtil.ActGradConfig c = GradCheckUtil.ActGradConfig.builder()
.sd(sd)
.activationGradsToCheck(Collections.singletonList("tanh"))
.build();

boolean ok = GradCheckUtil.checkActivationGradients(c);

assertTrue(ok);
}

@Test
public void testActivationGradientCheck2(){
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
SDVariable x = sd.placeHolder("x", DataType.DOUBLE, 3, 4);
SDVariable y = sd.var("y", Nd4j.rand(DataType.DOUBLE, 4, 5));
SDVariable mmul = x.mmul("mmul", y);
SDVariable sigmoid = sd.math().tanh("sigmoid", mmul);
SDVariable loss = sigmoid.std(true);

Map<String, INDArray> m = new HashMap<>();
m.put("x", Nd4j.rand(DataType.DOUBLE, 3, 4));

GradCheckUtil.ActGradConfig c = GradCheckUtil.ActGradConfig.builder()
.sd(sd)
.placeholderValues(m)
.activationGradsToCheck(Arrays.asList("sigmoid", "mmul"))
.subset(GradCheckUtil.Subset.RANDOM)
.maxPerParam(10)
.build();

boolean ok = GradCheckUtil.checkActivationGradients(c);

assertTrue(ok);
}
}

0 comments on commit c1db0e8

Please sign in to comment.