Skip to content

Commit

Permalink
Various fixes (#141)
Browse files Browse the repository at this point in the history
* #8121 CnnSentenceDataSetIterator fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8120 CnnSentenceDataSetIterator.loadSingleSentence no words UX/exception improvement

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8122 AggregatingSentenceIterator builder - addSentencePreProcessor -> sentencePreProcessor

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8082 Arbiter - fix GridSearchCandidateGenerator search size issue

Signed-off-by: AlexDBlack <blacka101@gmail.com>
  • Loading branch information
AlexDBlack committed Aug 21, 2019
1 parent 0adce9a commit 348d9c5
Show file tree
Hide file tree
Showing 11 changed files with 195 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import lombok.AllArgsConstructor;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;

import java.util.Collections;
import java.util.List;
import java.util.Map;

Expand All @@ -37,6 +38,8 @@ public abstract class ParameterSpaceAdapter<F, T> implements ParameterSpace<T> {

protected abstract ParameterSpace<F> underlying();

protected abstract String underlyingName();


@Override
public T getValue(double[] parameterValues) {
Expand All @@ -50,17 +53,21 @@ public int numParameters() {

@Override
public List<ParameterSpace> collectLeaves() {
ParameterSpace p = underlying();
if(p.isLeaf()){
return Collections.singletonList(p);
}
return underlying().collectLeaves();
}

@Override
public Map<String, ParameterSpace> getNestedSpaces() {
return underlying().getNestedSpaces();
return Collections.singletonMap(underlyingName(), (ParameterSpace)underlying());
}

@Override
public boolean isLeaf() {
return underlying().isLeaf();
return false; //Underlying may be a leaf, however
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
package org.deeplearning4j.arbiter.optimize.generator;

import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.math3.random.RandomAdaptor;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
import org.deeplearning4j.arbiter.util.LeafUtils;
Expand Down Expand Up @@ -65,6 +67,7 @@ public enum Mode {
private final Mode mode;

private int[] numValuesPerParam;
@Getter
private int totalNumCandidates;
private Queue<Integer> order;

Expand Down Expand Up @@ -123,6 +126,8 @@ protected void initialize() {
int max = ips.getMax();
//Discretize, as some integer ranges are much too large to search (i.e., num. neural network units, between 100 and 1000)
numValuesPerParam[i] = Math.min(max - min + 1, discretizationCount);
} else if (ps instanceof FixedValue){
numValuesPerParam[i] = 1;
} else {
numValuesPerParam[i] = discretizationCount;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.core.JsonProcessingException;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.*;

/**
* This is an abstract ParameterSpace for both MultiLayerNetworks (MultiLayerSpace) and ComputationGraph (ComputationGraphSpace)
Expand Down Expand Up @@ -212,18 +209,30 @@ protected NeuralNetConfiguration.Builder randomGlobalConf(double[] values) {
@Override
public List<ParameterSpace> collectLeaves() {
Map<String, ParameterSpace> global = getNestedSpaces();
List<ParameterSpace> list = new ArrayList<>();
list.addAll(global.values());

//Note: Results on previous line does NOT include the LayerSpaces, therefore we need to add these manually...
//This is because the type is a list, not a ParameterSpace
LinkedList<ParameterSpace> stack = new LinkedList<>();
stack.add(this);

for (LayerConf layerConf : layerSpaces) {
LayerSpace ls = layerConf.getLayerSpace();
list.addAll(ls.collectLeaves());
stack.addAll(ls.collectLeaves());
}

List<ParameterSpace> out = new ArrayList<>();
while (!stack.isEmpty()) {
ParameterSpace next = stack.removeLast();
if (next.isLeaf()) {
out.add(next);
} else {
Map<String, ParameterSpace> m = next.getNestedSpaces();
ParameterSpace[] arr = m.values().toArray(new ParameterSpace[m.size()]);
for (int i = arr.length - 1; i >= 0; i--) {
stack.add(arr[i]);
}
}
}

return list;
return out;
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ protected MultiLayerSpace(Builder builder) {
List<ParameterSpace> allLeaves = collectLeaves();
List<ParameterSpace> list = LeafUtils.getUniqueObjects(allLeaves);

for (ParameterSpace ps : list)
for (ParameterSpace ps : list) {
int n = ps.numParameters();
numParameters += ps.numParameters();
}

this.trainingWorkspaceMode = builder.trainingWorkspaceMode;
this.inferenceWorkspaceMode = builder.inferenceWorkspaceMode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,9 @@ public IActivation convertValue(Activation from) {
protected ParameterSpace<Activation> underlying() {
return activation;
}

@Override
protected String underlyingName() {
return "activation";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,9 @@ protected ILossFunction convertValue(LossFunctions.LossFunction from) {
protected ParameterSpace<LossFunctions.LossFunction> underlying() {
return lossFunction;
}

@Override
protected String underlyingName() {
return "lossFunction";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ public int numParameters() {

@Override
public List<ParameterSpace> collectLeaves() {
return Collections.<ParameterSpace>singletonList(dropOut);
return dropOut.collectLeaves();
}



@Override
public boolean isLeaf() {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
Expand Down Expand Up @@ -706,4 +707,63 @@ public void testDropout2(){

MultiLayerConfiguration conf = mls.getValue(new double[nParams]).getMultiLayerConfiguration();
}


@Test
public void testIssue8082(){
ParameterSpace<Double> learningRateHyperparam = new DiscreteParameterSpace<>(0.003, 0.005, 0.01, 0.05);
ParameterSpace<Integer> layerSizeHyperparam1 = new DiscreteParameterSpace<>(32, 64, 96, 128);
ParameterSpace<Integer> layerSizeHyperparam2 = new DiscreteParameterSpace<>(32, 64, 96, 128);
ParameterSpace<Double> dropoutHyperparam = new DiscreteParameterSpace<>(0.8, 0.9);

MultiLayerSpace mls = new MultiLayerSpace.Builder()
.updater(new AdamSpace(learningRateHyperparam))
.weightInit(WeightInit.XAVIER)
.l2(0.0001)
.addLayer(new DenseLayerSpace.Builder()
.nIn(10)
.nOut(layerSizeHyperparam1)
.build())
.addLayer(new BatchNormalizationSpace.Builder()
.nOut(layerSizeHyperparam1)
.activation(Activation.RELU)
.build())
.addLayer(new DropoutLayerSpace.Builder()
.dropOut(dropoutHyperparam)
.build())
.addLayer(new DenseLayerSpace.Builder()
.nOut(layerSizeHyperparam2)
.build())
.addLayer(new BatchNormalizationSpace.Builder()
.nOut(layerSizeHyperparam2)
.activation(Activation.RELU)
.build())
.addLayer(new DropoutLayerSpace.Builder()
.dropOut(dropoutHyperparam)
.build())
.addLayer(new OutputLayerSpace.Builder()
.nOut(10)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunction.MCXENT)
.build())
.build();

assertEquals(4, mls.getNumParameters());

for( int discreteCount : new int[]{1, 5}) {
GridSearchCandidateGenerator generator = new GridSearchCandidateGenerator(mls, discreteCount, GridSearchCandidateGenerator.Mode.Sequential, null);

int expCandidates = 4 * 4 * 4 * 2;
assertEquals(expCandidates, generator.getTotalNumCandidates());

int count = 0;
while (generator.hasMoreCandidates()) {
generator.getCandidate();
count++;
}


assertEquals(expCandidates, count);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,25 +117,32 @@ protected CnnSentenceDataSetIterator(Builder builder) {
List<String> sortedLabels = new ArrayList<>(this.sentenceProvider.allLabels());
Collections.sort(sortedLabels);

this.wordVectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;

for (String s : sortedLabels) {
this.labelClassMap.put(s, count++);
}
if (unknownWordHandling == UnknownWordHandling.UseUnknownVector) {
if (useNormalizedWordVectors) {
wordVectors.getWordVectorMatrixNormalized(wordVectors.getUNK());
unknown = wordVectors.getWordVectorMatrixNormalized(wordVectors.getUNK());
} else {
wordVectors.getWordVectorMatrix(wordVectors.getUNK());
unknown = wordVectors.getWordVectorMatrix(wordVectors.getUNK());
}
}

this.wordVectorSize = wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length;
if(unknown == null){
unknown = wordVectors.getWordVectorMatrix(wordVectors.vocab().wordAtIndex(0)).like();
}
}
}

/**
* Generally used post training time to load a single sentence for predictions
*/
public INDArray loadSingleSentence(String sentence) {
List<String> tokens = tokenizeSentence(sentence);
if(tokens.isEmpty())
throw new IllegalStateException("No tokens available for input sentence - empty string or no words in vocabulary with RemoveWord unknown handling? Sentence = \"" +
sentence + "\"");
if(format == Format.CNN1D || format == Format.RNN){
int[] featuresShape = new int[] {1, wordVectorSize, Math.min(maxSentenceLength, tokens.size())};
INDArray features = Nd4j.create(featuresShape, (format == Format.CNN1D ? 'c' : 'f'));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,15 @@ public Builder addSentenceIterators(@NonNull Collection<SentenceIterator> iterat
return this;
}

/**
* @deprecated Use {@link #sentencePreProcessor(SentencePreProcessor)}
*/
@Deprecated
public Builder addSentencePreProcessor(@NonNull SentencePreProcessor preProcessor) {
return sentencePreProcessor(preProcessor);
}

public Builder sentencePreProcessor(@NonNull SentencePreProcessor preProcessor) {
this.preProcessor = preProcessor;
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,21 @@ public void testCnnSentenceDataSetIteratorNoTokensEdgeCase() throws Exception {
assertEquals(expLabels, ds.getLabels());
assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
assertNull(ds.getLabelsMaskArray());


//Sanity check on single sentence loading:
INDArray allKnownWords = dsi.loadSingleSentence("these balance");
INDArray withUnknown = dsi.loadSingleSentence("these NOVALID");
assertNotNull(allKnownWords);
assertNotNull(withUnknown);

try {
dsi.loadSingleSentence("NOVALID AlsoNotInVocab");
fail("Expected exception");
} catch (Throwable t){
String m = t.getMessage();
assertTrue(m, m.contains("RemoveWord") && m.contains("vocabulary"));
}
}

@Test
Expand Down Expand Up @@ -324,4 +339,56 @@ public void testCnnSentenceDataSetIteratorNoValidTokensNextEdgeCase() throws Exc
assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
assertNull(ds.getLabelsMaskArray());
}


@Test
public void testCnnSentenceDataSetIteratorUseUnknownVector() throws Exception {

WordVectors w2v = WordVectorSerializer
.readWord2VecModel(new ClassPathResource("word2vec/googleload/sample_vec.bin").getFile());

List<String> sentences = new ArrayList<>();
sentences.add("these balance Database model");
sentences.add("into same THISWORDDOESNTEXIST are");
//Last 2 sentences - no valid words
sentences.add("NOVALID WORDSHERE");
sentences.add("!!!");

List<String> labelsForSentences = Arrays.asList("Positive", "Negative", "Positive", "Negative");


LabeledSentenceProvider p = new CollectionLabeledSentenceProvider(sentences, labelsForSentences, null);
CnnSentenceDataSetIterator dsi = new CnnSentenceDataSetIterator.Builder(CnnSentenceDataSetIterator.Format.CNN1D)
.unknownWordHandling(CnnSentenceDataSetIterator.UnknownWordHandling.UseUnknownVector)
.sentenceProvider(p).wordVectors(w2v)
.useNormalizedWordVectors(true)
.maxSentenceLength(256).minibatchSize(4).sentencesAlongHeight(false).build();

assertTrue(dsi.hasNext());
DataSet ds = dsi.next();

assertFalse(dsi.hasNext());

INDArray f = ds.getFeatures();
assertEquals(4, f.size(0));

INDArray unknown = w2v.getWordVectorMatrix(w2v.getUNK());
if(unknown == null)
unknown = Nd4j.create(DataType.FLOAT, f.size(1));

assertEquals(unknown, f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(0)));
assertEquals(unknown, f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(1)));
assertEquals(unknown.like(), f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(3)));

assertEquals(unknown, f.get(NDArrayIndex.point(3), NDArrayIndex.all(), NDArrayIndex.point(0)));
assertEquals(unknown.like(), f.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(1)));

//Sanity check on single sentence loading:
INDArray allKnownWords = dsi.loadSingleSentence("these balance");
INDArray withUnknown = dsi.loadSingleSentence("these NOVALID");
INDArray allUnknown = dsi.loadSingleSentence("NOVALID AlsoNotInVocab");
assertNotNull(allKnownWords);
assertNotNull(withUnknown);
assertNotNull(allUnknown);
}
}

0 comments on commit 348d9c5

Please sign in to comment.