Skip to content

Commit

Permalink
Optimize RDRPosTagger to use integer comparisons instead of string co…
Browse files Browse the repository at this point in the history
…mparisons.

Also reduce the cache-thrashing by deconstructing the tree's nodes into arrays.
  • Loading branch information
vlofgren committed Jun 19, 2023
1 parent 6f2a797 commit 186a02a
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 50 deletions.
1 change: 1 addition & 0 deletions third-party/rdrpostagger/build.gradle
Expand Up @@ -9,6 +9,7 @@ java {
}

dependencies {
implementation libs.trove
}

test {
Expand Down
Expand Up @@ -13,6 +13,9 @@
public class FWObject
{
public String[] context;

int[] objectCtxI = new int[13];

private final static String[] contextPrototype;
static {
contextPrototype = new String[13];
Expand Down
Expand Up @@ -40,26 +40,12 @@ public void setFatherNode(Node node)
this.fatherNode = node;
}

public int countNodes()
{
int count = 1;
if (exceptNode != null) {
count += exceptNode.countNodes();
}
if (ifnotNode != null) {
count += ifnotNode.countNodes();
}
return count;
}

public boolean satisfy(FWObject object)
{
for (int i = 0; i < 13; i++) {
String key = condition.context[i];
if (key != null) {
if (!key.equals(object.context[i])) {
return false;
}
if (key != null && !key.equals(object.context[i])) { // this is not equivalent to Objects.equals(a,b)
return false;
}
}
return true;
Expand Down
@@ -1,8 +1,11 @@
package com.github.datquocnguyen;

import gnu.trove.map.hash.TObjectIntHashMap;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashMap;

/**
Expand All @@ -12,22 +15,59 @@
public class RDRPOSTagger
{
private final HashMap<String, String> FREQDICT;
public final Node root;
final int OUGHT_TO_BE_ENOUGH = 5000;
final int CONTEXT_SIZE = 13;

// Use dense array representation to reduce the level of indirection
// and improve the performance of the tagger
int[] conditions = new int[OUGHT_TO_BE_ENOUGH * CONTEXT_SIZE];
String[] conclusions = new String[OUGHT_TO_BE_ENOUGH];
short[] exceptIdx = new short[OUGHT_TO_BE_ENOUGH];
short[] ifNotIdx = new short[OUGHT_TO_BE_ENOUGH];
short[] fatherIdx = new short[OUGHT_TO_BE_ENOUGH];
byte[] depthL = new byte[OUGHT_TO_BE_ENOUGH];

short size = 0;

private final TObjectIntHashMap<String> tagDict = new TObjectIntHashMap<>(10000, 0.75f, -1);

private short addNode(FWObject condition, String conclusion, byte d) {
short idx = size++;

for (int i = 0; i < CONTEXT_SIZE; i++) {
String context = condition.context[i];
if (context != null) {
tagDict.putIfAbsent(context, tagDict.size());

conditions[idx * CONTEXT_SIZE + i] = tagDict.get(context);
}
else {
conditions[idx * CONTEXT_SIZE + i] = -1;
}
}

conclusions[idx] = conclusion;
exceptIdx[idx] = -1;
ifNotIdx[idx] = -1;
fatherIdx[idx] = -1;
depthL[idx] = d;

return idx;
}

public RDRPOSTagger(Path dictPath, Path rulesFilePath) throws IOException {
this.FREQDICT = Utils.getDictionary(dictPath.toString());
Arrays.fill(conditions, -1);

BufferedReader buffer = new BufferedReader(new InputStreamReader(
new FileInputStream(rulesFilePath.toFile()), StandardCharsets.UTF_8));
String line = buffer.readLine();

this.root = new Node(new FWObject(false), "NN", null, null, null, 0);

Node currentNode = this.root;
int currentDepth = 0;
short currentIdx = addNode(new FWObject(false), "NN", (byte) 0);
byte currentDepth = 0;

while ((line = buffer.readLine()) != null) {
int depth = 0;
byte depth = 0;
for (int i = 0; i <= 6; i++) { // Supposed that the maximum
// exception level is up to 6.
if (line.charAt(i) == '\t')
Expand All @@ -48,53 +88,72 @@ public RDRPOSTagger(Path dictPath, Path rulesFilePath) throws IOException {
String conclusion = Utils.getConcreteValue(line.split(" : ")[1]
.trim());

Node node = new Node(condition, conclusion, null, null, null, depth);
short newIdx = addNode(condition, conclusion, depth);

if (depth > currentDepth) {
currentNode.setExceptNode(node);
exceptIdx[currentIdx] = newIdx;
}
else if (depth == currentDepth) {
currentNode.setIfnotNode(node);
ifNotIdx[currentIdx] = newIdx;
}
else {
while (currentNode.depth != depth)
currentNode = currentNode.fatherNode;
currentNode.setIfnotNode(node);
while (depthL[currentIdx] != depth) {
currentIdx = fatherIdx[currentIdx];
}
ifNotIdx[currentIdx] = newIdx;
}
node.setFatherNode(currentNode);

currentNode = node;
fatherIdx[newIdx] = currentIdx;

currentIdx = newIdx;
currentDepth = depth;
}
buffer.close();
}

public Node findFiredNode(FWObject object)
public String findFiredNode(FWObject object)
{
Node currentN = root;
Node firedN = null;
while (true) {
if (currentN.satisfy(object)) {
firedN = currentN;
if (currentN.exceptNode == null) {
break;
}
else {
currentN = currentN.exceptNode;
}
int currentIdx = 0;
int firedIdx = -1;

int[] objCtxI = object.objectCtxI;

for (int i = 0; i < CONTEXT_SIZE; i++) {
objCtxI[i] = tagDict.get(object.context[i]);
}

int[] conditionsL = conditions;
short[] exceptIdxL = exceptIdx;
short[] ifNotIdxL = ifNotIdx;

while (currentIdx >= 0) {
if (satisfy(objCtxI, conditionsL, currentIdx)) {
firedIdx = currentIdx;
currentIdx = exceptIdxL[currentIdx];
}
else {
if (currentN.ifnotNode == null) {
break;
}
else {
currentN = currentN.ifnotNode;
}
currentIdx = ifNotIdxL[currentIdx];
}
}

if (firedIdx >= 0) {
return conclusions[firedIdx];
}
else {
return "";
}
}

return firedN;
public boolean satisfy(int[] objectCtxI, int[] conditions, int contextIdx)
{
// This is a good candidate for a vector operation
for (int i = 0; i < CONTEXT_SIZE; i++) {
int key = conditions[CONTEXT_SIZE *contextIdx + i];
if (key >= 0 && key != objectCtxI[i]) {
return false;
}
}
return true;
}

public String[] tagsForEnSentence(String[] sentence)
Expand All @@ -107,7 +166,7 @@ public String[] tagsForEnSentence(String[] sentence)

for (int i = 0; i < initialTags.length; i++) {
Utils.getObject(object, sentence, initialTags, initialTags.length, i);
tags[i] = findFiredNode(object).conclusion;
tags[i] = findFiredNode(object);
}

return tags;
Expand Down
Expand Up @@ -9,6 +9,8 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

/**
* @author DatQuocNguyen
Expand Down Expand Up @@ -69,6 +71,7 @@ public static boolean isAbbre(String word)
return true;
}

static Map<String, String> conditionInstancePool = new HashMap<>();
public static FWObject getCondition(String strCondition)
{
FWObject condition = new FWObject(false);
Expand Down Expand Up @@ -120,6 +123,16 @@ else if (key.equals("suffixL4")) {
}
}

// pool the conditions to increase the chances the data is in cache
// when comparing later

for (var i = 0; i < condition.context.length; i++) {
if (condition.context[i] != null) {
condition.context[i] = conditionInstancePool
.computeIfAbsent(condition.context[i], Function.identity());
}
}

return condition;
}

Expand Down

0 comments on commit 186a02a

Please sign in to comment.