Skip to content

Commit

Permalink
Added n-best and cost limited search
Browse files Browse the repository at this point in the history
  • Loading branch information
emmanuellegedin committed Jul 21, 2016
1 parent 9ce40b6 commit d0ed0eb
Show file tree
Hide file tree
Showing 6 changed files with 379 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.atilika.kuromoji.dict.UserDictionary;
import com.atilika.kuromoji.fst.FST;
import com.atilika.kuromoji.util.ResourceResolver;
import com.atilika.kuromoji.viterbi.MultiSearchResult;
import com.atilika.kuromoji.viterbi.TokenFactory;
import com.atilika.kuromoji.viterbi.ViterbiBuilder;
import com.atilika.kuromoji.viterbi.ViterbiFormatter;
Expand Down Expand Up @@ -115,6 +116,17 @@ public List<? extends TokenBase> tokenize(String text) {
return createTokenList(text);
}

public <T extends TokenBase> List<List<T>> multiTokenize(String text, int maxCount, int maxCost) {
return createMultiTokenList(text, maxCount, maxCost);
}

public <T extends TokenBase> List<List<T>> multiTokenizeAnyCost(String text, int maxCount) {
return multiTokenize(text, maxCount, Integer.MAX_VALUE);
}

public <T extends TokenBase> List<List<T>> multiTokenizeAll(String text, int maxCost) {
return multiTokenize(text, Integer.MAX_VALUE, maxCost);
}

/**
* Tokenizes the provided text and returns a list of tokens with various feature information
Expand Down Expand Up @@ -153,6 +165,21 @@ protected <T extends TokenBase> List<T> createTokenList(String text) {
return result;
}

/**
* Tokenizes the provided text and returns up to maxCount lists of tokens with various feature information. Each list corresponds to a possible tokenization with cost at most maxCost.
* <p>
* This method is thread safe
*
* @param text text to tokenize
* @param maxCount maximum number of different tokenizations
* @param maxCost maximum cost of a tokenization
* @param <T> token type
* @return list of Token, not null
*/
protected <T extends TokenBase> List<List<T>> createMultiTokenList(String text, int maxCount, int maxCost) {
return createMultiTokenList(0, text, maxCount, maxCost);
}

/**
* Tokenizes the provided text and outputs the corresponding Viterbi lattice and the Viterbi path to the provided output stream
* <p>
Expand Down Expand Up @@ -258,6 +285,45 @@ private <T extends TokenBase> List<T> createTokenList(int offset, String text) {
return result;
}

/**
* Tokenize input sentence. Up to maxCount different paths of cost at most maxCost are returned ordered in ascending order by cost.
*
* @param offset offset of sentence in original input text
* @param text sentence to tokenize
* @param maxCount maximum number of paths
* @param maxCost maximum cost of a path
* @return list of Token
*/
private <T extends TokenBase> List<List<T>> createMultiTokenList(int offset, String text, int maxCount, int maxCost) {
List<List<T>> result = new ArrayList<>();

ViterbiLattice lattice = viterbiBuilder.build(text);
MultiSearchResult multiSearchResult = viterbiSearcher.searchMultiple(lattice, maxCount, maxCost);
List<List<ViterbiNode>> paths = multiSearchResult.getTokenizedResultsList();

for (List<ViterbiNode> path : paths) {
ArrayList<T> tokens = new ArrayList<>();
for (ViterbiNode node : path) {
int wordId = node.getWordId();
if (node.getType() == ViterbiNode.Type.KNOWN && wordId == -1) { // Do not include BOS/EOS
continue;
}
@SuppressWarnings("unchecked")
T token = (T) tokenFactory.createToken(
wordId,
node.getSurface(),
node.getType(),
offset + node.getStartIndex(),
dictionaryMap.get(node.getType())
);
tokens.add(token);
}
result.add(tokens);
}

return result;
}

/**
* Abstract Builder shared by all tokenizers
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/**
* Copyright © 2010-2015 Atilika Inc. and contributors (see CONTRIBUTORS.md)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License. A copy of the
* License is distributed with this work in the LICENSE.md file. You may
* also obtain a copy of the License from
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.atilika.kuromoji.viterbi;

import java.util.ArrayList;
import java.util.List;

public class MultiSearchResult {
private List<List<ViterbiNode>> tokenizedResults;
private List<Integer> costs;

public MultiSearchResult() {
tokenizedResults = new ArrayList<>();
costs = new ArrayList<>();
}

public void add(List<ViterbiNode> tokenizedResult, int cost) {
tokenizedResults.add(tokenizedResult);
costs.add(cost);
}

public List<ViterbiNode> getTokenizedResult(int index) {
return tokenizedResults.get(index);
}

public List<List<ViterbiNode>> getTokenizedResultsList() {
return tokenizedResults;
}

public int getCost(int index) {
return costs.get(index);
}

public int size() {
return costs.size();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
/**
* Copyright © 2010-2015 Atilika Inc. and contributors (see CONTRIBUTORS.md)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License. A copy of the
* License is distributed with this work in the LICENSE.md file. You may
* also obtain a copy of the License from
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.atilika.kuromoji.viterbi;

import com.atilika.kuromoji.TokenizerBase;
import com.atilika.kuromoji.dict.ConnectionCosts;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.PriorityQueue;

public class MultiSearcher {

private final ConnectionCosts costs;
private final TokenizerBase.Mode mode;
private final ViterbiSearcher viterbiSearcher;
private int baseCost;
private List<Integer> pathCosts;

public MultiSearcher(ConnectionCosts costs, TokenizerBase.Mode mode, ViterbiSearcher viterbiSearcher) {
this.costs = costs;
this.mode = mode;
this.viterbiSearcher = viterbiSearcher;
pathCosts = new ArrayList<>();
}

/**
* Get up to maxCount shortest paths with cost at most maxCost. The results are ordered in ascending order by cost.
*
* @param lattice an instance of ViterbiLattice prosecced by a ViterbiSearcher
* @param maxCount the maximum number of results
* @param maxCost the maximum cost of a result
* @return the shortest paths and their costs
*/
public MultiSearchResult getShortestPaths(ViterbiLattice lattice, int maxCount, int maxCost) {
MultiSearchResult multiSearchResult = new MultiSearchResult();
buildSidetrackTrees(lattice);
ViterbiNode eos = lattice.getEndIndexArr()[0][0];
baseCost = eos.getPathCost();
List<SidetrackTreeNode> sidetracks = getPaths(eos.getSidetrackTreeNode(), maxCount, maxCost);
int i = 0;
for (SidetrackTreeNode sidetrack : sidetracks) {
List<ViterbiNode> path = generatePath(eos, sidetrack);
multiSearchResult.add(path, pathCosts.get(i));
i += 1;
}
return multiSearchResult;
}

private List<SidetrackTreeNode> getPaths(SidetrackTreeNode root, int maxCount, int maxCost) {
List<SidetrackTreeNode> result = new ArrayList<>();
PriorityQueue<SidetrackTreeNode> sidetrackHeap = new PriorityQueue<>();
sidetrackHeap.add(root);
for (int i = 0; i < maxCount; i++) {
if (sidetrackHeap.isEmpty()) {
break;
}
SidetrackTreeNode node = sidetrackHeap.poll();
if (baseCost + node.getCost() > maxCost) {
break;
}
result.add(node);
pathCosts.add(baseCost + node.getCost());
for (SidetrackTreeNode child : node.getChildren()) {
SidetrackTreeNode modifiedChild = new SidetrackTreeNode(child.getSidetrackEdge());
modifiedChild.addChildren(child.getChildren());
modifiedChild.setParent(node);
sidetrackHeap.add(modifiedChild);
}
}
return result;
}

private List<ViterbiNode> generatePath(ViterbiNode eos, SidetrackTreeNode sidetrackNode) {
LinkedList<ViterbiNode> result = new LinkedList<>();
ViterbiNode node = eos;
result.add(node);
while (true) {
if (node.getLeftNode() == null) {
break;
}
ViterbiNode leftNode = node.getLeftNode();
if (sidetrackNode != null && sidetrackNode.getSidetrackEdge().getHead() == node) {
leftNode = sidetrackNode.getSidetrackEdge().getTail();
sidetrackNode = sidetrackNode.getParent();
}
node = leftNode;
result.addFirst(node);
}
return result;
}

private void buildSidetrackTrees(ViterbiLattice lattice) {
ViterbiNode[][] startIndexArr = lattice.getStartIndexArr();
ViterbiNode[][] endIndexArr = lattice.getEndIndexArr();

for (int i = 1; i < startIndexArr.length; i++) {
if (startIndexArr[i] == null || endIndexArr[i] == null) {
continue;
}

for (ViterbiNode node : startIndexArr[i]) {
if (node == null) {
break;
}

buildSidetrackTreeNode(endIndexArr[i], node);
}
}
}

private void buildSidetrackTreeNode(ViterbiNode[] leftNodes, ViterbiNode node) {
int backwardConnectionId = node.getLeftId();
int wordCost = node.getWordCost();

node.setSidetrackTreeNode(new SidetrackTreeNode(new SidetrackEdge(0, null, null)));

for (ViterbiNode leftNode : leftNodes) {
if (leftNode == null) {
return;
}

if (leftNode.getType() == ViterbiNode.Type.KNOWN && leftNode.getWordId() == -1) { // Ignore BOS
continue;
}

int sideTrackCost = leftNode.getPathCost() - node.getPathCost() + wordCost + costs.get(leftNode.getRightId(), backwardConnectionId);
if (mode == TokenizerBase.Mode.SEARCH || mode == TokenizerBase.Mode.EXTENDED) {
sideTrackCost += viterbiSearcher.getPenaltyCost(node);
}

if (leftNode == node.getLeftNode()) { // Follow optimal path
node.getSidetrackTreeNode().addChildren(leftNode.getSidetrackTreeNode().getChildren());
} else { // Sidetrack
SidetrackEdge sideTrackEdge = new SidetrackEdge(sideTrackCost, leftNode, node);
SidetrackTreeNode sideTrackTreeNode = new SidetrackTreeNode(sideTrackEdge);
sideTrackTreeNode.addChildren(leftNode.getSidetrackTreeNode().getChildren());
node.getSidetrackTreeNode().addChild(sideTrackTreeNode);
}

}
}

private class SidetrackEdge {
private int cost;
private ViterbiNode tail, head;

SidetrackEdge(int cost, ViterbiNode tail, ViterbiNode head) {
this.cost = cost;
this.tail = tail;
this.head = head;
}

public int getCost() {
return cost;
}

ViterbiNode getTail() {
return tail;
}

ViterbiNode getHead() {
return head;
}
}

class SidetrackTreeNode implements Comparable<SidetrackTreeNode> {
private SidetrackEdge sidetrackEdge;
private List<SidetrackTreeNode> children;
private SidetrackTreeNode parent;
private int cost;

SidetrackTreeNode(SidetrackEdge sidetrackEdge) {
this.sidetrackEdge = sidetrackEdge;
cost = sidetrackEdge.getCost();
children = new ArrayList<>();
}

SidetrackEdge getSidetrackEdge() {
return sidetrackEdge;
}

void addChild(SidetrackTreeNode child) {
children.add(child);
}

void addChildren(List<SidetrackTreeNode> children) {
this.children.addAll(children);
}

List<SidetrackTreeNode> getChildren() {
return children;
}

public SidetrackTreeNode getParent() {
return parent;
}

public void setParent(SidetrackTreeNode parent) {
this.cost = parent.getCost() + sidetrackEdge.getCost();
this.parent = parent;
}

public void setCost(int cost) {
this.cost = cost;
}

public int getCost() {
return cost;
}

public int compareTo(SidetrackTreeNode o) {
return cost - o.getCost();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

public class ViterbiLattice {

private static final String BOS = "BOS";
private static final String EOS = "EOS";
static final String BOS = "BOS";
static final String EOS = "EOS";

private final int dimension;
private final ViterbiNode[][] startIndexArr;
Expand Down
Loading

0 comments on commit d0ed0eb

Please sign in to comment.