From a3a436d1229e2c766c8c8d53faabb15f48ca8261 Mon Sep 17 00:00:00 2001 From: Rodrigo Agerri Date: Fri, 3 Feb 2017 16:00:38 +0100 Subject: [PATCH] OPENNLP-904 Harmonize lemmatizer API and function to get multiple lemmas OPENNLP-904 add minor correction after PR comment --- .../cmdline/lemmatizer/LemmatizerMETool.java | 4 +- .../lemmatizer/DictionaryLemmatizer.java | 70 ++++++++++++++----- .../lemmatizer/LemmaSampleEventStream.java | 2 +- .../lemmatizer/LemmaSampleSequenceStream.java | 6 +- .../tools/lemmatizer/LemmaSampleStream.java | 4 +- .../opennlp/tools/lemmatizer/Lemmatizer.java | 16 ++++- .../tools/lemmatizer/LemmatizerME.java | 64 +++++++++++++++-- .../tools/lemmatizer/DummyLemmatizer.java | 7 ++ .../tools/lemmatizer/LemmatizerMETest.java | 3 +- 9 files changed, 139 insertions(+), 37 deletions(-) diff --git a/opennlp-tools/src/main/java/opennlp/tools/cmdline/lemmatizer/LemmatizerMETool.java b/opennlp-tools/src/main/java/opennlp/tools/cmdline/lemmatizer/LemmatizerMETool.java index 13f28b20b..9390376f6 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/cmdline/lemmatizer/LemmatizerMETool.java +++ b/opennlp-tools/src/main/java/opennlp/tools/cmdline/lemmatizer/LemmatizerMETool.java @@ -72,10 +72,8 @@ public void run(String[] args) { continue; } - String[] preds = lemmatizer.lemmatize(posSample.getSentence(), + String[] lemmas = lemmatizer.lemmatize(posSample.getSentence(), posSample.getTags()); - String[] lemmas = lemmatizer.decodeLemmas(posSample.getSentence(), - preds); System.out.println(new LemmaSample(posSample.getSentence(), posSample.getTags(), lemmas).toString()); diff --git a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/DictionaryLemmatizer.java b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/DictionaryLemmatizer.java index b1b04a19c..9f0b0b012 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/DictionaryLemmatizer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/DictionaryLemmatizer.java @@ -37,7 +37,7 @@ public class DictionaryLemmatizer implements Lemmatizer { /** * The hashmap containing the dictionary. */ - private final Map, String> dictMap; + private final Map, List> dictMap; /** * Construct a hashmap from the input tab separated dictionary. @@ -47,26 +47,24 @@ public class DictionaryLemmatizer implements Lemmatizer { * @param dictionary * the input dictionary via inputstream */ - public DictionaryLemmatizer(final InputStream dictionary) { + public DictionaryLemmatizer(final InputStream dictionary) throws IOException { this.dictMap = new HashMap<>(); - final BufferedReader breader = new BufferedReader(new InputStreamReader(dictionary)); + final BufferedReader breader = new BufferedReader( + new InputStreamReader(dictionary)); String line; - try { - while ((line = breader.readLine()) != null) { - final String[] elems = line.split("\t"); - this.dictMap.put(Arrays.asList(elems[0], elems[1]), elems[2]); - } - } catch (final IOException e) { - e.printStackTrace(); + while ((line = breader.readLine()) != null) { + final String[] elems = line.split("\t"); + this.dictMap.put(Arrays.asList(elems[0], elems[1]), Arrays.asList(elems[2])); } } + /** * Get the Map containing the dictionary. * * @return dictMap the Map */ - public Map, String> getDictMap() { + public Map, List> getDictMap() { return this.dictMap; } @@ -85,31 +83,65 @@ private List getDictKeys(final String word, final String postag) { return keys; } + public String[] lemmatize(final String[] tokens, final String[] postags) { List lemmas = new ArrayList<>(); for (int i = 0; i < tokens.length; i++) { - lemmas.add(this.apply(tokens[i], postags[i])); + lemmas.add(this.lemmatize(tokens[i], postags[i])); } return lemmas.toArray(new String[lemmas.size()]); } + public List> lemmatize(final List tokens, final List posTags) { + List> allLemmas = new ArrayList<>(); + for (int i = 0; i < tokens.size(); i++) { + allLemmas.add(this.getAllLemmas(tokens.get(i), posTags.get(i))); + } + return allLemmas; + } + /** * Lookup lemma in a dictionary. Outputs "O" if not found. - * @param word the token - * @param postag the postag + * + * @param word + * the token + * @param postag + * the postag * @return the lemma */ - public String apply(final String word, final String postag) { + private String lemmatize(final String word, final String postag) { String lemma; final List keys = this.getDictKeys(word, postag); // lookup lemma as value of the map - final String keyValue = this.dictMap.get(keys); - if (keyValue != null) { - lemma = keyValue; + final List keyValues = this.dictMap.get(keys); + if (!keyValues.isEmpty()) { + lemma = keyValues.get(0); } else { lemma = "O"; } return lemma; } -} + /** + * Lookup every lemma for a word,pos tag in a dictionary. Outputs "O" if not + * found. + * + * @param word + * the token + * @param postag + * the postag + * @return every lemma + */ + private List getAllLemmas(final String word, final String postag) { + List lemmasList = new ArrayList<>(); + final List keys = this.getDictKeys(word, postag); + // lookup lemma as value of the map + final List keyValues = this.dictMap.get(keys); + if (!keyValues.isEmpty()) { + lemmasList.addAll(keyValues); + } else { + lemmasList.add("O"); + } + return lemmasList; + } +} diff --git a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmaSampleEventStream.java b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmaSampleEventStream.java index fc1a558a4..a8d71e872 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmaSampleEventStream.java +++ b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmaSampleEventStream.java @@ -49,7 +49,7 @@ protected Iterator createEvents(LemmaSample sample) { List events = new ArrayList<>(); String[] toksArray = sample.getTokens(); String[] tagsArray = sample.getTags(); - String[] lemmasArray = sample.getLemmas(); + String[] lemmasArray = LemmatizerME.encodeLemmas(toksArray,sample.getLemmas()); for (int ei = 0, el = sample.getTokens().length; ei < el; ei++) { events.add(new Event(lemmasArray[ei], contextGenerator.getContext(ei,toksArray,tagsArray,lemmasArray))); diff --git a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmaSampleSequenceStream.java b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmaSampleSequenceStream.java index 70565389d..a4d5c8c0b 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmaSampleSequenceStream.java +++ b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmaSampleSequenceStream.java @@ -41,9 +41,9 @@ public Sequence read() throws IOException { LemmaSample sample = samples.read(); if (sample != null) { - String sentence[] = sample.getTokens(); - String tags[] = sample.getTags(); - String preds[] = sample.getLemmas(); + String[] sentence = sample.getTokens(); + String[] tags = sample.getTags(); + String[] preds = sample.getLemmas(); Event[] events = new Event[sentence.length]; for (int i = 0; i < sentence.length; i++) { diff --git a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmaSampleStream.java b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmaSampleStream.java index 0a133c380..9c661a52c 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmaSampleStream.java +++ b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmaSampleStream.java @@ -23,7 +23,6 @@ import opennlp.tools.util.FilterObjectStream; import opennlp.tools.util.ObjectStream; -import opennlp.tools.util.StringUtil; /** @@ -51,8 +50,7 @@ public LemmaSample read() throws IOException { else { toks.add(parts[0]); tags.add(parts[1]); - String ses = StringUtil.getShortestEditScript(parts[0], parts[2]); - preds.add(ses); + preds.add(parts[2]); } } if (toks.size() > 0) { diff --git a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/Lemmatizer.java b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/Lemmatizer.java index ddcaa6a2c..f5cf68873 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/Lemmatizer.java +++ b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/Lemmatizer.java @@ -17,19 +17,31 @@ package opennlp.tools.lemmatizer; +import java.util.List; + /** * The interface for lemmatizers. */ public interface Lemmatizer { /** - * Generates lemma tags for the word and postag returning the result in an array. + * Generates lemmas for the word and postag returning the result in an array. * * @param toks an array of the tokens * @param tags an array of the pos tags * - * @return an array of lemma classes for each token in the sequence. + * @return an array of possible lemmas for each token in the sequence. */ String[] lemmatize(String[] toks, String tags[]); + /** + * Generates a lemma tags for the word and postag returning the result in a list + * of every possible lemma for each token and postag. + * + * @param toks an array of the tokens + * @param tags an array of the pos tags + * @return a list of every possible lemma for each token in the sequence. + */ + List> lemmatize(List toks, List tags); + } diff --git a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmatizerME.java b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmatizerME.java index 98a19f508..34bfa8751 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmatizerME.java +++ b/opennlp-tools/src/main/java/opennlp/tools/lemmatizer/LemmatizerME.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -47,6 +48,7 @@ */ public class LemmatizerME implements Lemmatizer { + public static final int LEMMA_NUMBER = 29; public static final int DEFAULT_BEAM_SIZE = 3; protected int beamSize; private Sequence bestSequence; @@ -86,9 +88,52 @@ public LemmatizerME(LemmatizerModel model) { } public String[] lemmatize(String[] toks, String[] tags) { + String[] ses = predictSES(toks, tags); + String[] lemmas = decodeLemmas(toks, ses); + return lemmas; + } + + @Override public List> lemmatize(List toks, + List tags) { + String[] tokens = toks.toArray(new String[toks.size()]); + String[] posTags = tags.toArray(new String[tags.size()]); + String[][] allLemmas = predictLemmas(LEMMA_NUMBER, tokens, posTags); + List> predictedLemmas = new ArrayList<>(); + for (int i = 0; i < allLemmas.length; i++) { + predictedLemmas.add(Arrays.asList(allLemmas[i])); + } + return predictedLemmas; + } + + /** + * Predict Short Edit Script (automatically induced lemma class). + * @param toks the array of tokens + * @param tags the array of pos tags + * @return an array containing the lemma classes + */ + public String[] predictSES(String[] toks, String[] tags) { bestSequence = model.bestSequence(toks, new Object[] {tags}, contextGenerator, sequenceValidator); - List c = bestSequence.getOutcomes(); - return c.toArray(new String[c.size()]); + List ses = bestSequence.getOutcomes(); + return ses.toArray(new String[ses.size()]); + } + + /** + * Predict all possible lemmas (using a default upper bound). + * @param numLemmas the default number of lemmas + * @param toks the tokens + * @param tags the postags + * @return a double array containing all posible lemmas for each token and postag pair + */ + public String[][] predictLemmas(int numLemmas, String[] toks, String[] tags) { + Sequence[] bestSequences = model.bestSequences(numLemmas, toks, new Object[] {tags}, + contextGenerator, sequenceValidator); + String[][] allLemmas = new String[bestSequences.length][]; + for (int i = 0; i < allLemmas.length; i++) { + List ses = bestSequences[i].getOutcomes(); + String[] sesArray = ses.toArray(new String[ses.size()]); + allLemmas[i] = decodeLemmas(toks,sesArray); + } + return allLemmas; } /** @@ -97,11 +142,10 @@ public String[] lemmatize(String[] toks, String[] tags) { * @param preds the predicted lemma classes * @return the array of decoded lemmas */ - public String[] decodeLemmas(String[] toks, String[] preds) { + public static String[] decodeLemmas(String[] toks, String[] preds) { List lemmas = new ArrayList<>(); for (int i = 0; i < toks.length; i++) { String lemma = StringUtil.decodeShortestEditScript(toks[i].toLowerCase(), preds[i]); - //System.err.println("-> DEBUG: " + toks[i].toLowerCase() + " " + preds[i] + " " + lemma); if (lemma.length() == 0) { lemma = "_"; } @@ -110,6 +154,18 @@ public String[] decodeLemmas(String[] toks, String[] preds) { return lemmas.toArray(new String[lemmas.size()]); } + public static String[] encodeLemmas(String[] toks, String[] lemmas) { + List sesList = new ArrayList<>(); + for (int i = 0; i < toks.length; i++) { + String ses = StringUtil.getShortestEditScript(toks[i], lemmas[i]); + if (ses.length() == 0) { + ses = "_"; + } + sesList.add(ses); + } + return sesList.toArray(new String[sesList.size()]); + } + public Sequence[] topKSequences(String[] sentence, String[] tags) { return model.bestSequences(DEFAULT_BEAM_SIZE, sentence, new Object[] { tags }, contextGenerator, sequenceValidator); diff --git a/opennlp-tools/src/test/java/opennlp/tools/lemmatizer/DummyLemmatizer.java b/opennlp-tools/src/test/java/opennlp/tools/lemmatizer/DummyLemmatizer.java index 489ba38d2..dcfc883f1 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/lemmatizer/DummyLemmatizer.java +++ b/opennlp-tools/src/test/java/opennlp/tools/lemmatizer/DummyLemmatizer.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.List; /** * This dummy lemmatizer implementation simulates a LemmatizerME. The file has @@ -56,4 +57,10 @@ public String[] lemmatize(String[] toks, String[] tags) { } } + @Override + public List> lemmatize(List toks, + List tags) { + return null; + } + } diff --git a/opennlp-tools/src/test/java/opennlp/tools/lemmatizer/LemmatizerMETest.java b/opennlp-tools/src/test/java/opennlp/tools/lemmatizer/LemmatizerMETest.java index 76b4cd5e8..97dcc3c8e 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/lemmatizer/LemmatizerMETest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/lemmatizer/LemmatizerMETest.java @@ -82,8 +82,7 @@ public void startup() throws IOException { @Test public void testLemmasAsArray() throws Exception { - String[] preds = lemmatizer.lemmatize(tokens, postags); - String[] lemmas = lemmatizer.decodeLemmas(tokens, preds); + String[] lemmas = lemmatizer.lemmatize(tokens, postags); Assert.assertArrayEquals(expect, lemmas); }