From f66fb13c24dd128d3994387e816a560910c5d098 Mon Sep 17 00:00:00 2001 From: koji Date: Tue, 2 May 2017 17:11:34 +0900 Subject: [PATCH] OPENNLP-1044: Add validate() which checks validity of parameters in the process of the framework --- .../ml/AbstractEventModelSequenceTrainer.java | 5 +-- .../tools/ml/AbstractEventTrainer.java | 17 +++++----- .../tools/ml/AbstractSequenceTrainer.java | 5 +-- .../opennlp/tools/ml/AbstractTrainer.java | 26 +++++++++++--- .../ml/maxent/quasinewton/QNTrainer.java | 34 ++++++++++++------- .../ml/naivebayes/NaiveBayesTrainer.java | 4 --- .../ml/perceptron/PerceptronTrainer.java | 20 ++++++++--- .../SimplePerceptronSequenceTrainer.java | 26 ++++++++++---- 8 files changed, 88 insertions(+), 49 deletions(-) diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java index fdcb4b65e..362a0d699 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventModelSequenceTrainer.java @@ -32,10 +32,7 @@ public abstract MaxentModel doTrain(SequenceStream events) throws IOException; public final MaxentModel train(SequenceStream events) throws IOException { - - if (!isValid()) { - throw new IllegalArgumentException("trainParams are not valid!"); - } + validate(); MaxentModel model = doTrain(events); addToReport(AbstractTrainer.TRAINER_TYPE_PARAM, diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java index 330307a73..dc75ffe27 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractEventTrainer.java @@ -42,7 +42,13 @@ public AbstractEventTrainer() { public AbstractEventTrainer(TrainingParameters parameters) { super(parameters); } - + + @Override + public void validate() { + super.validate(); + } + + @Deprecated @Override public boolean isValid() { return super.isValid(); @@ -66,9 +72,7 @@ public DataIndexer getDataIndexer(ObjectStream events) throws IOException public abstract MaxentModel doTrain(DataIndexer indexer) throws IOException; public final MaxentModel train(DataIndexer indexer) throws IOException { - if (!isValid()) { - throw new IllegalArgumentException("trainParams are not valid!"); - } + validate(); if (indexer.getOutcomeLabels().length <= 1) { throw new InsufficientTrainingDataException("Training data must contain more than one outcome"); @@ -80,10 +84,7 @@ public final MaxentModel train(DataIndexer indexer) throws IOException { } public final MaxentModel train(ObjectStream events) throws IOException { - - if (!isValid()) { - throw new IllegalArgumentException("trainParams are not valid!"); - } + validate(); HashSumEventStream hses = new HashSumEventStream(events); DataIndexer indexer = getDataIndexer(hses); diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractSequenceTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractSequenceTrainer.java index 2d4862415..19ecc4b79 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractSequenceTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractSequenceTrainer.java @@ -32,10 +32,7 @@ public abstract SequenceClassificationModel doTrain(SequenceStream event throws IOException; public final SequenceClassificationModel train(SequenceStream events) throws IOException { - - if (!isValid()) { - throw new IllegalArgumentException("trainParams are not valid!"); - } + validate(); SequenceClassificationModel model = doTrain(events); addToReport(AbstractTrainer.TRAINER_TYPE_PARAM, SequenceTrainer.SEQUENCE_VALUE); diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java index 070b96c0a..32c5df68a 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/AbstractTrainer.java @@ -74,20 +74,36 @@ public int getIterations() { return trainingParameters.getIntParameter(ITERATIONS_PARAM, ITERATIONS_DEFAULT); } - public boolean isValid() { - + /** + * Check parameters. If subclass overrides this, it should call super.validate(); + * + * @throws java.lang.IllegalArgumentException + */ + public void validate() { // TODO: Need to validate all parameters correctly ... error prone?! - // should validate if algorithm is set? What about the Parser? try { trainingParameters.getIntParameter(CUTOFF_PARAM, CUTOFF_DEFAULT); trainingParameters.getIntParameter(ITERATIONS_PARAM, ITERATIONS_DEFAULT); } catch (NumberFormatException e) { + throw new IllegalArgumentException(e); + } + } + + /** + * @deprecated Use {@link #validate()} instead. + * @return + */ + @Deprecated + public boolean isValid() { + try { + validate(); + return true; + } + catch (IllegalArgumentException e) { return false; } - - return true; } /** diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java index 7a1a74f7b..daa90a420 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java @@ -115,42 +115,52 @@ public void init(Map trainParams, Map reportMap) init(new TrainingParameters(trainParams),reportMap); } - public boolean isValid() { - - if (!super.isValid()) { - return false; - } + @Override + public void validate() { + super.validate(); String algorithmName = getAlgorithm(); if (algorithmName != null && !(MAXENT_QN_VALUE.equals(algorithmName))) { - return false; + throw new IllegalArgumentException("algorithmName must be MAXENT_QN"); } // Number of Hessian updates to remember if (m < 0) { - return false; + throw new IllegalArgumentException( + "Number of Hessian updates to remember must be >= 0"); } // Maximum number of function evaluations if (maxFctEval < 0) { - return false; + throw new IllegalArgumentException( + "Maximum number of function evaluations must be >= 0"); } // Number of threads must be >= 1 if (threads < 1) { - return false; + throw new IllegalArgumentException("Number of threads must be >= 1"); } // Regularization costs must be >= 0 if (l1Cost < 0) { - return false; + throw new IllegalArgumentException("Regularization costs must be >= 0"); } if (l2Cost < 0) { - return false; + throw new IllegalArgumentException("Regularization costs must be >= 0"); } + } - return true; + @Deprecated + @Override + public boolean isValid() { + try { + validate(); + return true; + } + catch (IllegalArgumentException e) { + return false; + } } public boolean isSortAndMerge() { diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesTrainer.java index 629c2225f..69ef44e13 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/naivebayes/NaiveBayesTrainer.java @@ -102,10 +102,6 @@ public boolean isSortAndMerge() { } public AbstractModel doTrain(DataIndexer indexer) throws IOException { - if (!isValid()) { - throw new IllegalArgumentException("trainParams are not valid!"); - } - return this.trainModel(indexer); } diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java index 129c57613..b73eacaf7 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/PerceptronTrainer.java @@ -84,7 +84,21 @@ public PerceptronTrainer() { public PerceptronTrainer(TrainingParameters parameters) { super(parameters); } - + + @Override + public void validate() { + super.validate(); + + String algorithmName = getAlgorithm(); + if (algorithmName != null) { + if (!PERCEPTRON_VALUE.equals(algorithmName)) { + throw new IllegalArgumentException("algorithmName must be PERCEPTRON"); + } + } + } + + @Deprecated + @Override public boolean isValid() { if (!super.isValid()) { return false; @@ -104,10 +118,6 @@ public boolean isSortAndMerge() { } public AbstractModel doTrain(DataIndexer indexer) throws IOException { - if (!isValid()) { - throw new IllegalArgumentException("trainParams are not valid!"); - } - int iterations = getIterations(); int cutoff = getCutoff(); diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java index 5fc4bbe13..a9ac51680 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/perceptron/SimplePerceptronSequenceTrainer.java @@ -83,16 +83,28 @@ public class SimplePerceptronSequenceTrainer extends AbstractEventModelSequenceT public SimplePerceptronSequenceTrainer() { } - public boolean isValid() { - - if (!super.isValid()) { - return false; - } + @Override + public void validate() { + super.validate(); String algorithmName = getAlgorithm(); + if (algorithmName != null) { + if (!PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName)) { + throw new IllegalArgumentException("algorithmName must be PERCEPTRON_SEQUENCE"); + } + } + } - return !(algorithmName != null - && !(PERCEPTRON_SEQUENCE_VALUE.equals(algorithmName))); + @Deprecated + @Override + public boolean isValid() { + try { + validate(); + return true; + } + catch (IllegalArgumentException e) { + return false; + } } public AbstractModel doTrain(SequenceStream events) throws IOException {