Skip to content

Commit

Permalink
add support for user-defined POS for all OOV handlers
Browse files Browse the repository at this point in the history
the feature is gated by userPos flag of OOV handlers,
which must be set to true for this feature to work
  • Loading branch information
eiennohito committed Jun 6, 2022
1 parent 9bd6a41 commit 229e3c3
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.worksap.nlp.sudachi.dictionary.CategoryType;
import com.worksap.nlp.sudachi.dictionary.Grammar;
import com.worksap.nlp.sudachi.dictionary.POS;
import com.worksap.nlp.sudachi.dictionary.WordInfo;

import java.io.IOException;
Expand Down Expand Up @@ -76,7 +77,7 @@ public void setUp(Grammar grammar) throws IOException {
Config.Resource<Object> charDef = settings.getResource("charDef");
readCharacterProperty(charDef);
Config.Resource<Object> unkDef = settings.getResource("unkDef");
readOOV(unkDef, grammar);
readOOV(unkDef, grammar, settings.getString("userPos", "forbid"));
}

@Override
Expand Down Expand Up @@ -167,7 +168,7 @@ <T> void readCharacterProperty(Config.Resource<T> charDef) throws IOException {
}
}

<T> void readOOV(Config.Resource<T> unkDef, Grammar grammar) throws IOException {
<T> void readOOV(Config.Resource<T> unkDef, Grammar grammar, String userPosType) throws IOException {
if (unkDef == null) {
unkDef = settings.base.toResource(Paths.get("unk.def"));
}
Expand All @@ -194,8 +195,8 @@ <T> void readOOV(Config.Resource<T> unkDef, Grammar grammar) throws IOException
oov.leftId = Short.parseShort(cols[1]);
oov.rightId = Short.parseShort(cols[2]);
oov.cost = Short.parseShort(cols[3]);
List<String> pos = Arrays.asList(cols[4], cols[5], cols[6], cols[7], cols[8], cols[9]);
oov.posId = grammar.getPartOfSpeechId(pos);
POS pos = new POS(cols[4], cols[5], cols[6], cols[7], cols[8], cols[9]);
oov.posId = posIdOf(grammar, pos, userPosType);

oovList.computeIfAbsent(type, t -> new ArrayList<>()).add(oov);
}
Expand Down
23 changes: 23 additions & 0 deletions src/main/java/com/worksap/nlp/sudachi/OovProviderPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@

import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Objects;

import com.worksap.nlp.sudachi.dictionary.Grammar;
import com.worksap.nlp.sudachi.dictionary.GrammarImpl;
import com.worksap.nlp.sudachi.dictionary.POS;

/**
* A plugin that provides the nodes of out-of-vocabulary morphemes.
Expand Down Expand Up @@ -98,4 +102,23 @@ protected LatticeNodeImpl createNode() {
node.setOOV();
return node;
}

protected short posIdOf(Grammar grammar, POS pos, String userPosMode) {
short posIdPresent = grammar.getPartOfSpeechId(pos);
userPosMode = userPosMode.toLowerCase(Locale.ROOT);

if (Objects.equals(userPosMode, "forbid")) {
if (posIdPresent >= 0) {
return posIdPresent;
}
throw new IllegalArgumentException(String.format(
"POS %s WAS NOT present in dictionary and OOV Plugin %s is forbidden to add new POS tags", pos,
this));
} else if (!Objects.equals(userPosMode, "allow")) {
throw new IllegalArgumentException(
"Unknown user POS mode: " + userPosMode + " allowed values are: forbid, allow");
}
GrammarImpl grammarImpl = (GrammarImpl) grammar;
return grammarImpl.registerPOS(pos);
}
}
3 changes: 2 additions & 1 deletion src/main/java/com/worksap/nlp/sudachi/RegexOovProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ public class RegexOovProvider extends OovProviderPlugin {
public void setUp(Grammar grammar) throws IOException {
super.setUp(grammar);
POS stringPos = new POS(settings.getStringList("pos"));
posId = grammar.getPartOfSpeechId(stringPos);
String userPosType = settings.getString("userPos", "forbid");
posId = posIdOf(grammar, stringPos, userPosType);
if (posId == -1) {
throw new IllegalArgumentException("POS " + stringPos + " was not present in the dictionary");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.worksap.nlp.sudachi;

import com.worksap.nlp.sudachi.dictionary.Grammar;
import com.worksap.nlp.sudachi.dictionary.POS;
import com.worksap.nlp.sudachi.dictionary.WordInfo;

import java.util.List;
Expand Down Expand Up @@ -54,18 +55,12 @@ class SimpleOovProviderPlugin extends OovProviderPlugin {

@Override
public void setUp(Grammar grammar) {
List<String> oovPOSStrings = settings.getStringList("oovPOS");
POS pos = new POS(settings.getStringList("oovPOS"));
leftId = (short) settings.getInt("leftId");
rightId = (short) settings.getInt("rightId");
cost = (short) settings.getInt("cost");

if (oovPOSStrings.isEmpty()) {
throw new IllegalArgumentException("oovPOS is not specified");
}
oovPOSId = grammar.getPartOfSpeechId(oovPOSStrings);
if (oovPOSId < 0) {
throw new IllegalArgumentException("oovPOS is invalid: " + String.join(",", oovPOSStrings));
}
String userPosType = settings.getString("userPos", "forbid");
oovPOSId = posIdOf(grammar, pos, userPosType);
}

@Override
Expand Down
19 changes: 19 additions & 0 deletions src/main/java/com/worksap/nlp/sudachi/dictionary/GrammarImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,23 @@ private String bufferToString(int offset) {
public Connection getConnection() {
return this.matrix;
}

/**
* Registers a POS tag in the grammar definition and returns its id. If the POS
* tag was already present, return its id. Should be called only during the
* setup phase.
*
* @param pos
* tag to register
* @return id of registered tag or id of existing tag
*/
public short registerPOS(POS pos) {
int i = posList.indexOf(pos);
if (i == -1) {
int len = posList.size();
posList.add(pos);
return (short) len;
}
return (short) i;
}
}
3 changes: 2 additions & 1 deletion src/main/java/com/worksap/nlp/sudachi/dictionary/POS.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public POS(String... elems) {
throw new IllegalArgumentException("pos must not be null");
}
if (elems.length != DEPTH) {
throw new IllegalArgumentException("pos must have exactly six elements");
throw new IllegalArgumentException(String.format("pos must have exactly 6 elements, was %s: %s",
elems.length, String.join(",", elems)));
}
for (String e : elems) {
if (e == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ public void readOOV() throws IOException {
Lines oovConfig = new Lines("DEFAULT,1,2,3,補助記号,一般,*,*,*,*", "DEFAULT,3,4,5,補助記号,一般,*,*,*,*");
MeCabOovProviderPlugin plugin = new MeCabOovProviderPlugin();
plugin.categories.put(CategoryType.DEFAULT, new CategoryInfo());
plugin.readOOV(oovConfig, new MockGrammar());
plugin.readOOV(oovConfig, new MockGrammar(), "forbid");
assertThat(plugin.oovList.size(), is(1));
assertThat(plugin.oovList.get(CategoryType.DEFAULT).size(), is(2));
assertThat(plugin.oovList.get(CategoryType.DEFAULT).get(0).leftId, is((short) 1));
Expand All @@ -416,20 +416,20 @@ public void readOOV() throws IOException {
public void readOOVWithTooFewColumns() throws IOException {
MeCabOovProviderPlugin plugin = new MeCabOovProviderPlugin();
plugin.categories.put(CategoryType.DEFAULT, new CategoryInfo());
plugin.readOOV(new Lines("DEFAULT,1,2,3"), new MockGrammar());
plugin.readOOV(new Lines("DEFAULT,1,2,3"), new MockGrammar(), "forbid");
}

@Test(expected = IllegalArgumentException.class)
public void readOOVWithUndefinedType() throws IOException {
MeCabOovProviderPlugin plugin = new MeCabOovProviderPlugin();
plugin.categories.put(CategoryType.DEFAULT, new CategoryInfo());
plugin.readOOV(new Lines("FOO,1,2,3,補助記号,一般,*,*,*,*"), new MockGrammar());
plugin.readOOV(new Lines("FOO,1,2,3,補助記号,一般,*,*,*,*"), new MockGrammar(), "forbid");
}

@Test(expected = IllegalArgumentException.class)
public void readOOVWithCategoryNotInCharacterProperty() throws IOException {
MeCabOovProviderPlugin plugin = new MeCabOovProviderPlugin();
plugin.categories.put(CategoryType.DEFAULT, new CategoryInfo());
plugin.readOOV(new Lines("FOO,1,2,3,補助記号,一般,*,*,*,*"), new MockGrammar());
plugin.readOOV(new Lines("FOO,1,2,3,補助記号,一般,*,*,*,*"), new MockGrammar(), "forbid");
}
}
100 changes: 100 additions & 0 deletions src/test/java/com/worksap/nlp/sudachi/OovProviderPluginTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright (c) 2022 Works Applications Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.worksap.nlp.sudachi

import com.worksap.nlp.sudachi.dictionary.Grammar
import com.worksap.nlp.sudachi.dictionary.POS
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFails
import kotlin.test.assertIs

class OovProviderPluginTest {
class FakeOovProvider : OovProviderPlugin() {
override fun provideOOV(
inputText: InputText?,
offset: Int,
otherWords: Long,
result: MutableList<LatticeNodeImpl>?
): Int {
throw NotImplementedError("not implemented")
}

override fun setUp(grammar: Grammar?) {
val kind = settings.getString("posKind", "forbid")
val pos = POS(settings.getStringList("pos"))
posId = posIdOf(grammar, pos, kind)
}

var posId: Short = -1
}

@Test
fun posIdOfWorks() {
val cfg = TestDictionary.user0Cfg()
cfg.addOovProviderPlugin(FakeOovProvider::class.java)
.addList("pos", "名詞", "普通名詞", "一般", "*", "*", "*")
val inst = DictionaryFactory().create(cfg) as JapaneseDictionary
val plugin = assertIs<FakeOovProvider>(inst.oovProviderPlugins.last())
assertEquals(4, plugin.posId)
}

@Test
fun posIdOfWorksNewPos() {
val cfg = TestDictionary.user0Cfg()
cfg.addOovProviderPlugin(FakeOovProvider::class.java)
.addList("pos", "名詞", "普通名詞", "一般", "*", "*", "new")
.add("posKind", "allow")
val inst = DictionaryFactory().create(cfg) as JapaneseDictionary
val plugin = assertIs<FakeOovProvider>(inst.oovProviderPlugins.last())
assertEquals(8, plugin.posId)
}

@Test
fun failInvalidName() {
val cfg = TestDictionary.user0Cfg()
cfg.addOovProviderPlugin(FakeOovProvider::class.java)
.addList("pos", "名詞", "普通名詞", "一般", "*", "*", "*")
.add("posKind", "test")
assertFails { DictionaryFactory().create(cfg) }
}

@Test
fun failInvalidPos() {
val cfg = TestDictionary.user0Cfg()
cfg.addOovProviderPlugin(FakeOovProvider::class.java)
.addList("pos", "名詞", "普通名詞", "一般", "*", "*", "test")
assertFails { DictionaryFactory().create(cfg) }
}

@Test
fun doubleRegisterReturnsSamePosId() {
val cfg = TestDictionary.user0Cfg()
cfg.addOovProviderPlugin(FakeOovProvider::class.java)
.addList("pos", "名詞", "普通名詞", "一般", "*", "*", "new")
.add("posKind", "allow")
cfg.addOovProviderPlugin(FakeOovProvider::class.java)
.addList("pos", "名詞", "普通名詞", "一般", "*", "*", "new")
.add("posKind", "allow")
val inst = DictionaryFactory().create(cfg) as JapaneseDictionary
val oovPlugins = inst.oovProviderPlugins
val p1 = assertIs<FakeOovProvider>(oovPlugins[oovPlugins.size - 2])
assertEquals(8, p1.posId)
val p2 = assertIs<FakeOovProvider>(oovPlugins[oovPlugins.size - 1])
assertEquals(8, p2.posId)
}
}

0 comments on commit 229e3c3

Please sign in to comment.