From 3d7c483d380bf54b39312cb2f66892fad901ddff Mon Sep 17 00:00:00 2001 From: smarthi Date: Sun, 15 Jan 2017 16:52:03 -0500 Subject: [PATCH] OPENNLP-123: Feature cutoff should only be done by data indexers --- .../java/opennlp/tools/ml/maxent/GIS.java | 37 +- .../opennlp/tools/ml/maxent/GISTrainer.java | 328 ++++++++---------- .../ml/maxent/quasinewton/QNTrainer.java | 4 +- .../tools/ml/maxent/GISTestIndexing.java | 14 +- .../tools/ml/maxent/MaxentPrepAttachTest.java | 30 +- .../tools/ml/maxent/RealValueModelTest.java | 9 +- 6 files changed, 191 insertions(+), 231 deletions(-) diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java index e1aa08cf7..997246515 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GIS.java @@ -50,10 +50,10 @@ public class GIS extends AbstractEventTrainer { * the trainer to imagine that it saw a feature that it actually didn't see. * Defaulted to 0.1. */ - public static double SMOOTHING_OBSERVATION = 0.1; + private static final double SMOOTHING_OBSERVATION = 0.1; - public static final String SMOOTHING_PARAM = "smoothing"; - public static final boolean SMOOTHING_DEFAULT = false; + private static final String SMOOTHING_PARAM = "smoothing"; + private static final boolean SMOOTHING_DEFAULT = false; public GIS() { } @@ -80,10 +80,9 @@ public AbstractModel doTrain(DataIndexer indexer) throws IOException { boolean printMessages = parameters.getBooleanParam(VERBOSE_PARAM, VERBOSE_DEFAULT); boolean smoothing = parameters.getBooleanParam(SMOOTHING_PARAM, SMOOTHING_DEFAULT); - int cutoff = getCutoff(); int threads = parameters.getIntParam(TrainingParameters.THREADS_PARAM, 1); - model = trainModel(iterations, indexer, printMessages, smoothing, null, cutoff, threads); + model = trainModel(iterations, indexer, printMessages, smoothing, null, threads); return model; } @@ -188,8 +187,9 @@ public static GISModel trainModel(ObjectStream eventStream, int iteration public static GISModel trainModel(ObjectStream eventStream, int iterations, int cutoff, double sigma) throws IOException { GISTrainer trainer = new GISTrainer(PRINT_MESSAGES); - if (sigma > 0) + if (sigma > 0) { trainer.setGaussianSigma(sigma); + } return trainer.trainModel(eventStream, iterations, cutoff); } @@ -206,9 +206,8 @@ public static GISModel trainModel(ObjectStream eventStream, int iteration * @return The newly trained model, which can be used immediately or saved to * disk using an opennlp.tools.ml.maxent.io.GISModelWriter object. */ - public static GISModel trainModel(int iterations, DataIndexer indexer, - boolean smoothing) { - return trainModel(iterations, indexer, true, smoothing, null, 0); + public static GISModel trainModel(int iterations, DataIndexer indexer, boolean smoothing) { + return trainModel(iterations, indexer, true, smoothing, null, 1); } /** @@ -222,7 +221,7 @@ public static GISModel trainModel(int iterations, DataIndexer indexer, * disk using an opennlp.tools.ml.maxent.io.GISModelWriter object. */ public static GISModel trainModel(int iterations, DataIndexer indexer) { - return trainModel(iterations, indexer, true, false, null, 0); + return trainModel(iterations, indexer, true, false, null, 1); } /** @@ -257,16 +256,13 @@ public static GISModel trainModel(int iterations, DataIndexer indexer, * training the model. * @param modelPrior * The prior distribution for the model. - * @param cutoff - * The number of times a predicate must occur to be used in a model. * @return The newly trained model, which can be used immediately or saved to * disk using an opennlp.tools.ml.maxent.io.GISModelWriter object. */ public static GISModel trainModel(int iterations, DataIndexer indexer, - boolean printMessagesWhileTraining, boolean smoothing, Prior modelPrior, - int cutoff) { - return trainModel(iterations, indexer, printMessagesWhileTraining, - smoothing, modelPrior, cutoff, 1); + boolean printMessagesWhileTraining, boolean smoothing, + Prior modelPrior) { + return trainModel(iterations, indexer, printMessagesWhileTraining, smoothing, modelPrior, 1); } /** @@ -283,22 +279,19 @@ public static GISModel trainModel(int iterations, DataIndexer indexer, * training the model. * @param modelPrior * The prior distribution for the model. - * @param cutoff - * The number of times a predicate must occur to be used in a model. * @return The newly trained model, which can be used immediately or saved to * disk using an opennlp.tools.ml.maxent.io.GISModelWriter object. */ public static GISModel trainModel(int iterations, DataIndexer indexer, - boolean printMessagesWhileTraining, boolean smoothing, Prior modelPrior, - int cutoff, int threads) { + boolean printMessagesWhileTraining, boolean smoothing, + Prior modelPrior, int threads) { GISTrainer trainer = new GISTrainer(printMessagesWhileTraining); trainer.setSmoothing(smoothing); trainer.setSmoothingObservation(SMOOTHING_OBSERVATION); if (modelPrior == null) { modelPrior = new UniformPrior(); } - - return trainer.trainModel(iterations, indexer, modelPrior, cutoff, threads); + return trainer.trainModel(iterations, indexer, modelPrior, threads); } } diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java index b19870510..19ea58e8e 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/GISTrainer.java @@ -44,118 +44,92 @@ * for this implementation was Adwait Ratnaparkhi's tech report at the * University of Pennsylvania's Institute for Research in Cognitive Science, * and is available at ftp://ftp.cis.upenn.edu/pub/ircs/tr/97-08.ps.Z. - * + *

* The slack parameter used in the above implementation has been removed by default * from the computation and a method for updating with Gaussian smoothing has been * added per Investigating GIS and Smoothing for Maximum Entropy Taggers, Clark and Curran (2002). * http://acl.ldc.upenn.edu/E/E03/E03-1071.pdf * The slack parameter can be used by setting useSlackParameter to true. * Gaussian smoothing can be used by setting useGaussianSmoothing to true. - * + *

* A prior can be used to train models which converge to the distribution which minimizes the * relative entropy between the distribution specified by the empirical constraints of the training * data and the specified prior. By default, the uniform distribution is used as the prior. */ class GISTrainer { + private static final double LLThreshold = 0.0001; + private final boolean printMessages; /** * Specifies whether unseen context/outcome pairs should be estimated as occur very infrequently. */ private boolean useSimpleSmoothing = false; - /** * Specified whether parameter updates should prefer a distribution of parameters which * is gaussian. */ private boolean useGaussianSmoothing = false; - private double sigma = 2.0; - // If we are using smoothing, this is used as the "number" of // times we want the trainer to imagine that it saw a feature that it // actually didn't see. Defaulted to 0.1. private double _smoothingObservation = 0.1; - - private final boolean printMessages; - /** * Number of unique events which occured in the event set. */ private int numUniqueEvents; - /** * Number of predicates. */ private int numPreds; - /** * Number of outcomes. */ private int numOutcomes; - /** * Records the array of predicates seen in each event. */ private int[][] contexts; - /** * The value associated with each context. If null then context values are assumes to be 1. */ private float[][] values; - /** * List of outcomes for each event i, in context[i]. */ private int[] outcomeList; - /** * Records the num of times an event has been seen for each event i, in context[i]. */ private int[] numTimesEventsSeen; - - /** - * The number of times a predicate occured in the training data. - */ - private int[] predicateCounts; - - private int cutoff; - /** * Stores the String names of the outcomes. The GIS only tracks outcomes as * ints, and so this array is needed to save the model to disk and thereby * allow users to know what the outcome was in human understandable terms. */ private String[] outcomeLabels; - /** * Stores the String names of the predicates. The GIS only tracks predicates * as ints, and so this array is needed to save the model to disk and thereby * allow users to know what the outcome was in human understandable terms. */ private String[] predLabels; - /** * Stores the observed expected values of the features based on training data. */ private MutableContext[] observedExpects; - /** * Stores the estimated parameter value of each predicate during iteration */ private MutableContext[] params; - /** * Stores the expected values of the features based on the current models */ private MutableContext[][] modelExpects; - /** * This is the prior distribution that the model uses for training. */ private Prior prior; - - private static final double LLThreshold = 0.0001; - /** * Initial probability for all outcomes. */ @@ -164,7 +138,6 @@ class GISTrainer { /** * Creates a new GISTrainer instance which does not print * progress messages about training to STDOUT. - * */ GISTrainer() { printMessages = false; @@ -207,7 +180,6 @@ public void setSmoothingObservation(double timesSeen) { * Sets whether this trainer will use smoothing while training the model. * This can improve model accuracy, though training will potentially take * longer and use more memory. Model size will also be larger. - * */ public void setGaussianSigma(double sigmaValue) { useGaussianSmoothing = true; @@ -217,43 +189,45 @@ public void setGaussianSigma(double sigmaValue) { /** * Trains a GIS model on the event in the specified event stream, using the specified number * of iterations and the specified count cutoff. + * * @param eventStream A stream of all events. - * @param iterations The number of iterations to use for GIS. - * @param cutoff The number of times a feature must occur to be included. + * @param iterations The number of iterations to use for GIS. + * @param cutoff The number of times a feature must occur to be included. * @return A GIS model trained with specified */ - public GISModel trainModel(ObjectStream eventStream, int iterations, int cutoff) throws IOException { + public GISModel trainModel(ObjectStream eventStream, int iterations, + int cutoff) throws IOException { DataIndexer indexer = new OnePassDataIndexer(); Map params = new HashMap<>(); params.put(GIS.ITERATIONS_PARAM, Integer.toString(iterations)); params.put(GIS.CUTOFF_PARAM, Integer.toString(cutoff)); indexer.init(params, new HashMap<>()); indexer.index(eventStream); - return trainModel(iterations, indexer, cutoff); + return trainModel(iterations, indexer); } /** * Train a model using the GIS algorithm. * - * @param iterations The number of GIS iterations to perform. - * @param di The data indexer used to compress events in memory. + * @param iterations The number of GIS iterations to perform. + * @param di The data indexer used to compress events in memory. * @return The newly trained model, which can be used immediately or saved - * to disk using an opennlp.tools.ml.maxent.io.GISModelWriter object. + * to disk using an opennlp.tools.ml.maxent.io.GISModelWriter object. */ - public GISModel trainModel(int iterations, DataIndexer di, int cutoff) { - return trainModel(iterations,di,new UniformPrior(),cutoff,1); + public GISModel trainModel(int iterations, DataIndexer di) { + return trainModel(iterations, di, new UniformPrior(), 1); } /** * Train a model using the GIS algorithm. * - * @param iterations The number of GIS iterations to perform. - * @param di The data indexer used to compress events in memory. + * @param iterations The number of GIS iterations to perform. + * @param di The data indexer used to compress events in memory. * @param modelPrior The prior distribution used to train this model. * @return The newly trained model, which can be used immediately or saved - * to disk using an opennlp.tools.ml.maxent.io.GISModelWriter object. + * to disk using an opennlp.tools.ml.maxent.io.GISModelWriter object. */ - public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int cutoff, int threads) { + public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int threads) { if (threads <= 0) { throw new IllegalArgumentException("threads must be at least one or greater but is " + threads + "!"); @@ -265,8 +239,10 @@ public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int display("Incorporating indexed data for training... \n"); contexts = di.getContexts(); values = di.getValues(); - this.cutoff = cutoff; - predicateCounts = di.getPredCounts(); + /* + The number of times a predicate occured in the training data. + */ + int[] predicateCounts = di.getPredCounts(); numTimesEventsSeen = di.getNumTimesEventsSeen(); numUniqueEvents = contexts.length; this.prior = modelPrior; @@ -279,8 +255,7 @@ public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int if (contexts[ci].length > correctionConstant) { correctionConstant = contexts[ci].length; } - } - else { + } else { float cl = values[ci][0]; for (int vi = 1; vi < values[ci].length; vi++) { cl += values[ci][vi]; @@ -298,7 +273,7 @@ public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int numOutcomes = outcomeLabels.length; predLabels = di.getPredLabels(); - prior.setLabels(outcomeLabels,predLabels); + prior.setLabels(outcomeLabels, predLabels); numPreds = predLabels.length; display("\tNumber of Event Tokens: " + numUniqueEvents + "\n"); @@ -311,16 +286,12 @@ public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int for (int j = 0; j < contexts[ti].length; j++) { if (values != null && values[ti] != null) { predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti] * values[ti][j]; - } - else { + } else { predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti]; } } } - //printTable(predCount); - di = null; // don't need it anymore - // A fake "observation" to cover features which are not detected in // the data. The default is to assume that we observed "1/10th" of a // feature during training. @@ -332,15 +303,16 @@ public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int // implementation, this is cancelled out when we compute the next // iteration of a parameter, making the extra divisions wasteful. params = new MutableContext[numPreds]; - for (int i = 0; i < modelExpects.length; i++) + for (int i = 0; i < modelExpects.length; i++) { modelExpects[i] = new MutableContext[numPreds]; + } observedExpects = new MutableContext[numPreds]; // The model does need the correction constant and the correction feature. The correction constant // is only needed during training, and the correction feature is not necessary. // For compatibility reasons the model contains form now on a correction constant of 1, // and a correction param 0. - evalParams = new EvalParameters(params,0,1,numOutcomes); + evalParams = new EvalParameters(params, 0, 1, numOutcomes); int[] activeOutcomes = new int[numOutcomes]; int[] outcomePattern; int[] allOutcomesPattern = new int[numOutcomes]; @@ -353,27 +325,26 @@ public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int if (useSimpleSmoothing) { numActiveOutcomes = numOutcomes; outcomePattern = allOutcomesPattern; - } - else { //determine active outcomes + } else { //determine active outcomes for (int oi = 0; oi < numOutcomes; oi++) { - if (predCount[pi][oi] > 0 && predicateCounts[pi] >= cutoff) { + if (predCount[pi][oi] > 0) { activeOutcomes[numActiveOutcomes] = oi; numActiveOutcomes++; } } if (numActiveOutcomes == numOutcomes) { outcomePattern = allOutcomesPattern; - } - else { + } else { outcomePattern = new int[numActiveOutcomes]; System.arraycopy(activeOutcomes, 0, outcomePattern, 0, numActiveOutcomes); } } - params[pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]); - for (int i = 0; i < modelExpects.length; i++) - modelExpects[i][pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]); - observedExpects[pi] = new MutableContext(outcomePattern,new double[numActiveOutcomes]); - for (int aoi = 0;aoi < numActiveOutcomes; aoi++) { + params[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]); + for (int i = 0; i < modelExpects.length; i++) { + modelExpects[i][pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]); + } + observedExpects[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]); + for (int aoi = 0; aoi < numActiveOutcomes; aoi++) { int oi = outcomePattern[aoi]; params[pi].setParameter(aoi, 0.0); for (MutableContext[] modelExpect : modelExpects) { @@ -381,22 +352,20 @@ public GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int } if (predCount[pi][oi] > 0) { observedExpects[pi].setParameter(aoi, predCount[pi][oi]); - } - else if (useSimpleSmoothing) { - observedExpects[pi].setParameter(aoi,smoothingObservation); + } else if (useSimpleSmoothing) { + observedExpects[pi].setParameter(aoi, smoothingObservation); } } } - predCount = null; // don't need it anymore - display("...done.\n"); /* Find the parameters *****/ - if (threads == 1) + if (threads == 1) { display("Computing model parameters ...\n"); - else + } else { display("Computing model parameters in " + threads + " threads...\n"); + } findParameters(iterations, correctionConstant); @@ -411,19 +380,20 @@ else if (useSimpleSmoothing) { private void findParameters(int iterations, double correctionConstant) { int threads = modelExpects.length; ExecutorService executor = Executors.newFixedThreadPool(threads); - CompletionService completionService = + CompletionService completionService = new ExecutorCompletionService<>(executor); double prevLL = 0.0; double currLL; display("Performing " + iterations + " iterations.\n"); for (int i = 1; i <= iterations; i++) { - if (i < 10) + if (i < 10) { display(" " + i + ": "); - else if (i < 100) + } else if (i < 100) { display(" " + i + ": "); - else + } else { display(i + ": "); - currLL = nextIteration(correctionConstant,completionService); + } + currLL = nextIteration(correctionConstant, completionService); if (i > 1) { if (prevLL > currLL) { System.err.println("Model Diverging: loglikelihood decreased"); @@ -445,7 +415,7 @@ else if (i < 100) } //modeled on implementation in Zhang Le's maxent kit - private double gaussianUpdate(int predicate, int oid, int n, double correctionConstant) { + private double gaussianUpdate(int predicate, int oid, double correctionConstant) { double param = params[predicate].getParameters()[oid]; double x0 = 0.0; double modelValue = modelExpects[0][predicate].getParameters()[oid]; @@ -467,98 +437,9 @@ private double gaussianUpdate(int predicate, int oid, int n, double correctionCo return x0; } - private class ModelExpactationComputeTask implements Callable { - - private final int startIndex; - private final int length; - - private double loglikelihood = 0; - - private int numEvents = 0; - private int numCorrect = 0; - - final private int threadIndex; - - // startIndex to compute, number of events to compute - ModelExpactationComputeTask(int threadIndex, int startIndex, int length) { - this.startIndex = startIndex; - this.length = length; - this.threadIndex = threadIndex; - } - - public ModelExpactationComputeTask call() { - - final double[] modelDistribution = new double[numOutcomes]; - - - for (int ei = startIndex; ei < startIndex + length; ei++) { - - // TODO: check interruption status here, if interrupted set a poisoned flag and return - - if (values != null) { - prior.logPrior(modelDistribution, contexts[ei], values[ei]); - GISModel.eval(contexts[ei], values[ei], modelDistribution, evalParams); - } - else { - prior.logPrior(modelDistribution,contexts[ei]); - GISModel.eval(contexts[ei], modelDistribution, evalParams); - } - for (int j = 0; j < contexts[ei].length; j++) { - int pi = contexts[ei][j]; - if (predicateCounts[pi] >= cutoff) { - int[] activeOutcomes = modelExpects[threadIndex][pi].getOutcomes(); - for (int aoi = 0;aoi < activeOutcomes.length; aoi++) { - int oi = activeOutcomes[aoi]; - - // numTimesEventsSeen must also be thread safe - if (values != null && values[ei] != null) { - modelExpects[threadIndex][pi].updateParameter(aoi,modelDistribution[oi] - * values[ei][j] * numTimesEventsSeen[ei]); - } - else { - modelExpects[threadIndex][pi].updateParameter(aoi,modelDistribution[oi] - * numTimesEventsSeen[ei]); - } - } - } - } - - loglikelihood += Math.log(modelDistribution[outcomeList[ei]]) * numTimesEventsSeen[ei]; - - numEvents += numTimesEventsSeen[ei]; - if (printMessages) { - int max = 0; - for (int oi = 1; oi < numOutcomes; oi++) { - if (modelDistribution[oi] > modelDistribution[max]) { - max = oi; - } - } - if (max == outcomeList[ei]) { - numCorrect += numTimesEventsSeen[ei]; - } - } - - } - - return this; - } - - synchronized int getNumEvents() { - return numEvents; - } - - synchronized int getNumCorrect() { - return numCorrect; - } - - synchronized double getLoglikelihood() { - return loglikelihood; - } - } - /* Compute one iteration of GIS and retutn log-likelihood.*/ private double nextIteration(double correctionConstant, - CompletionService completionService) { + CompletionService completionService) { // compute contribution of p(a|b_i) for each feature and the new // correction parameter double loglikelihood = 0.0; @@ -566,7 +447,7 @@ private double nextIteration(double correctionConstant, int numCorrect = 0; // Each thread gets equal number of tasks, if the number of tasks - // is not divisible by the number of threads, the first "leftOver" + // is not divisible by the number of threads, the first "leftOver" // threads have one extra task. int numberOfThreads = modelExpects.length; int taskSize = numUniqueEvents / numberOfThreads; @@ -574,14 +455,17 @@ private double nextIteration(double correctionConstant, // submit all tasks to the completion service. for (int i = 0; i < numberOfThreads; i++) { - if (i < leftOver) - completionService.submit(new ModelExpactationComputeTask(i, i * taskSize + i, taskSize + 1)); - else - completionService.submit(new ModelExpactationComputeTask(i, i * taskSize + leftOver, taskSize)); + if (i < leftOver) { + completionService.submit(new ModelExpectationComputeTask(i, i * taskSize + i, + taskSize + 1)); + } else { + completionService.submit(new ModelExpectationComputeTask(i, + i * taskSize + leftOver, taskSize)); + } } for (int i = 0; i < numberOfThreads; i++) { - ModelExpactationComputeTask finishedTask; + ModelExpectationComputeTask finishedTask; try { finishedTask = completionService.take().get(); } catch (InterruptedException e) { @@ -625,14 +509,13 @@ private double nextIteration(double correctionConstant, int[] activeOutcomes = params[pi].getOutcomes(); for (int aoi = 0; aoi < activeOutcomes.length; aoi++) { if (useGaussianSmoothing) { - params[pi].updateParameter(aoi,gaussianUpdate(pi,aoi,numEvents,correctionConstant)); - } - else { + params[pi].updateParameter(aoi, gaussianUpdate(pi, aoi, correctionConstant)); + } else { if (model[aoi] == 0) { System.err.println("Model expects == 0 for " + predLabels[pi] + " " + outcomeLabels[aoi]); } //params[pi].updateParameter(aoi,(Math.log(observed[aoi]) - Math.log(model[aoi]))); - params[pi].updateParameter(aoi,((Math.log(observed[aoi]) - Math.log(model[aoi])) + params[pi].updateParameter(aoi, ((Math.log(observed[aoi]) - Math.log(model[aoi])) / correctionConstant)); } @@ -649,7 +532,90 @@ private double nextIteration(double correctionConstant, } private void display(String s) { - if (printMessages) + if (printMessages) { System.out.print(s); + } + } + + private class ModelExpectationComputeTask implements Callable { + + private final int startIndex; + private final int length; + final private int threadIndex; + private double loglikelihood = 0; + private int numEvents = 0; + private int numCorrect = 0; + + // startIndex to compute, number of events to compute + ModelExpectationComputeTask(int threadIndex, int startIndex, int length) { + this.startIndex = startIndex; + this.length = length; + this.threadIndex = threadIndex; + } + + public ModelExpectationComputeTask call() { + + final double[] modelDistribution = new double[numOutcomes]; + + + for (int ei = startIndex; ei < startIndex + length; ei++) { + + // TODO: check interruption status here, if interrupted set a poisoned flag and return + + if (values != null) { + prior.logPrior(modelDistribution, contexts[ei], values[ei]); + GISModel.eval(contexts[ei], values[ei], modelDistribution, evalParams); + } else { + prior.logPrior(modelDistribution, contexts[ei]); + GISModel.eval(contexts[ei], modelDistribution, evalParams); + } + for (int j = 0; j < contexts[ei].length; j++) { + int pi = contexts[ei][j]; + int[] activeOutcomes = modelExpects[threadIndex][pi].getOutcomes(); + for (int aoi = 0; aoi < activeOutcomes.length; aoi++) { + int oi = activeOutcomes[aoi]; + + // numTimesEventsSeen must also be thread safe + if (values != null && values[ei] != null) { + modelExpects[threadIndex][pi].updateParameter(aoi, modelDistribution[oi] + * values[ei][j] * numTimesEventsSeen[ei]); + } else { + modelExpects[threadIndex][pi].updateParameter(aoi, modelDistribution[oi] + * numTimesEventsSeen[ei]); + } + } + } + + loglikelihood += Math.log(modelDistribution[outcomeList[ei]]) * numTimesEventsSeen[ei]; + + numEvents += numTimesEventsSeen[ei]; + if (printMessages) { + int max = 0; + for (int oi = 1; oi < numOutcomes; oi++) { + if (modelDistribution[oi] > modelDistribution[max]) { + max = oi; + } + } + if (max == outcomeList[ei]) { + numCorrect += numTimesEventsSeen[ei]; + } + } + + } + + return this; + } + + synchronized int getNumEvents() { + return numEvents; + } + + synchronized int getNumCorrect() { + return numCorrect; + } + + synchronized double getLoglikelihood() { + return loglikelihood; + } } } diff --git a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java index cc4745ab3..44998d7a8 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ml/maxent/quasinewton/QNTrainer.java @@ -182,8 +182,8 @@ public QNModel trainModel(int iterations, DataIndexer indexer) { Context[] params = new Context[nPredLabels]; for (int ci = 0; ci < params.length; ci++) { - List outcomePattern = new ArrayList(nOutcomes); - List alpha = new ArrayList(nOutcomes); + List outcomePattern = new ArrayList<>(nOutcomes); + List alpha = new ArrayList<>(nOutcomes); for (int oi = 0; oi < nOutcomes; oi++) { double val = parameters[oi * nPredLabels + ci]; outcomePattern.add(oi); diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/GISTestIndexing.java b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/GISTestIndexing.java index 509f3ddf9..80a94c3a9 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/GISTestIndexing.java +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/GISTestIndexing.java @@ -29,12 +29,13 @@ public class GISTestIndexing { - static String[][] cntx = new String[][]{ + private static String[][] cntx = new String[][]{ {"dog","cat","mouse"}, {"text", "print", "mouse"}, {"dog", "pig", "cat", "mouse"} }; - static String[] outputs = new String[]{"A","B","A"}; + + private static String[] outputs = new String[]{"A","B","A"}; /* * Test the GIS.trainModel(ObjectStream eventStream) method @@ -74,12 +75,13 @@ public void testGISTrainSignature3() throws Exception { events.add(new Event(outputs[i], cntx[i])); } ObjectStream eventStream = ObjectStreamUtils.createObjectStream(events); - Assert.assertNotNull(GIS.trainModel(eventStream,10,1)); + Assert.assertNotNull(GIS.trainModel(eventStream, 10, 1)); eventStream.close(); } /* - * Test the GIS.trainModel(ObjectStream eventStream, int iterations, int cutoff, double sigma) method + * Test the GIS.trainModel(ObjectStream eventStream, int iterations, int cutoff, + * double sigma) method */ @Test public void testGISTrainSignature4() throws Exception { @@ -88,7 +90,7 @@ public void testGISTrainSignature4() throws Exception { events.add(new Event(outputs[i], cntx[i])); } ObjectStream eventStream = ObjectStreamUtils.createObjectStream(events); - Assert.assertNotNull(GIS.trainModel(eventStream,10,1,0.01)); + Assert.assertNotNull(GIS.trainModel(eventStream, 10, 1, 0.01)); eventStream.close(); } @@ -103,7 +105,7 @@ public void testGISTrainSignature5() throws Exception { events.add(new Event(outputs[i], cntx[i])); } ObjectStream eventStream = ObjectStreamUtils.createObjectStream(events); - Assert.assertNotNull(GIS.trainModel(eventStream,10,1,false,false)); + Assert.assertNotNull(GIS.trainModel(eventStream, 10, 1, false, false)); eventStream.close(); } } diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java index a969ede4c..090c165bc 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/MaxentPrepAttachTest.java @@ -17,69 +17,67 @@ package opennlp.tools.ml.maxent; -import static opennlp.tools.ml.PrepAttachDataUtil.createTrainingStream; -import static opennlp.tools.ml.PrepAttachDataUtil.testModel; - import java.io.IOException; import java.util.HashMap; import java.util.Map; +import org.junit.Test; + import opennlp.tools.ml.AbstractEventTrainer; import opennlp.tools.ml.AbstractTrainer; import opennlp.tools.ml.EventTrainer; +import opennlp.tools.ml.PrepAttachDataUtil; import opennlp.tools.ml.TrainerFactory; import opennlp.tools.ml.model.AbstractModel; import opennlp.tools.ml.model.MaxentModel; import opennlp.tools.ml.model.TwoPassDataIndexer; import opennlp.tools.ml.model.UniformPrior; -import org.junit.Test; - public class MaxentPrepAttachTest { @Test public void testMaxentOnPrepAttachData() throws IOException { AbstractModel model = new GISTrainer(true).trainModel(100, - new TwoPassDataIndexer(createTrainingStream(), 1), 1); + new TwoPassDataIndexer(PrepAttachDataUtil.createTrainingStream(), 1)); - testModel(model, 0.7997028967566229); + PrepAttachDataUtil.testModel(model, 0.7997028967566229); } @Test public void testMaxentOnPrepAttachData2Threads() throws IOException { AbstractModel model = new GISTrainer(true).trainModel(100, - new TwoPassDataIndexer(createTrainingStream(), 1), - new UniformPrior(), 1, 2); + new TwoPassDataIndexer(PrepAttachDataUtil.createTrainingStream(), 1), + new UniformPrior(), 2); - testModel(model, 0.7997028967566229); + PrepAttachDataUtil.testModel(model, 0.7997028967566229); } @Test public void testMaxentOnPrepAttachDataWithParams() throws IOException { - Map trainParams = new HashMap(); + Map trainParams = new HashMap<>(); trainParams.put(AbstractTrainer.ALGORITHM_PARAM, GIS.MAXENT_VALUE); trainParams.put(AbstractEventTrainer.DATA_INDEXER_PARAM, AbstractEventTrainer.DATA_INDEXER_TWO_PASS_VALUE); trainParams.put(AbstractTrainer.CUTOFF_PARAM, Integer.toString(1)); EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams, null); - MaxentModel model = trainer.train(createTrainingStream()); + MaxentModel model = trainer.train(PrepAttachDataUtil.createTrainingStream()); - testModel(model, 0.7997028967566229); + PrepAttachDataUtil.testModel(model, 0.7997028967566229); } @Test public void testMaxentOnPrepAttachDataWithParamsDefault() throws IOException { - Map trainParams = new HashMap(); + Map trainParams = new HashMap<>(); trainParams.put(AbstractTrainer.ALGORITHM_PARAM, GIS.MAXENT_VALUE); EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams, null); - MaxentModel model = trainer.train(createTrainingStream()); + MaxentModel model = trainer.train(PrepAttachDataUtil.createTrainingStream()); - testModel(model, 0.8086159940579352 ); + PrepAttachDataUtil.testModel(model, 0.8086159940579352 ); } } diff --git a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/RealValueModelTest.java b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/RealValueModelTest.java index a8e96c3c8..dbd65c1df 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/RealValueModelTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/ml/maxent/RealValueModelTest.java @@ -17,13 +17,14 @@ package opennlp.tools.ml.maxent; -import opennlp.tools.ml.model.FileEventStream; -import opennlp.tools.ml.model.OnePassRealValueDataIndexer; -import opennlp.tools.ml.model.RealValueFileEventStream; +import java.io.IOException; + import org.junit.Assert; import org.junit.Test; -import java.io.IOException; +import opennlp.tools.ml.model.FileEventStream; +import opennlp.tools.ml.model.OnePassRealValueDataIndexer; +import opennlp.tools.ml.model.RealValueFileEventStream; public class RealValueModelTest {