Skip to content

Commit

Permalink
OPENNLP-1442: Sentence transformers (#523)
Browse files Browse the repository at this point in the history
* OPENNLP-1442: Adding sentence transformer model support via ONNX Runtime.
  • Loading branch information
jzonthemtn committed Mar 31, 2023
1 parent 14e0e47 commit 9b62d58
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 116 deletions.
56 changes: 31 additions & 25 deletions opennlp-dl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,50 @@ This module provides OpenNLP interface implementations for ONNX models using the

**Important**: This does not provide the ability to train models. Model training is done outside of OpenNLP. This code provides the ability to use ONNX models from OpenNLP.

To build with example models, download the models to the `/src/test/resources` directory. (These are the exported models described below.)
Models used in the tests are available in the opennlp evaluation test data.

## NameFinderDL

Export a Huggingface NER model to ONNX, e.g.:

```bash
python -m transformers.onnx --model=dslim/bert-base-NER --feature token-classification exported
```

export OPENNLP_DATA=/tmp/
mkdir /tmp/dl-doccat /tmp/dl-namefinder
## DocumentCategorizerDL

# Document categorizer model
wget https://www.dropbox.com/s/n9uzs8r4xm9rhxb/model.onnx?dl=0 -O $OPENNLP_DATA/dl-doccat/model.onnx
wget https://www.dropbox.com/s/aw6yjc68jw0jts6/vocab.txt?dl=0 -O $OPENNLP_DATA/dl-doccat/vocab.txt
Export a Huggingface classification (e.g. sentiment) model to ONNX, e.g.:

# Namefinder model
wget https://www.dropbox.com/s/zgogq65gs9tyfm1/model.onnx?dl=0 -O $OPENNLP_DATA/dl-namefinder/model.onnx
wget https://www.dropbox.com/s/3byt1jggly1dg98/vocab.txt?dl=0 -O $OPENNLP_DATA/dl-/namefinder/vocab.txt
```bash
python -m transformers.onnx --model=nlptown/bert-base-multilingual-uncased-sentiment --feature sequence-classification exported
```

## TokenNameFinder
## SentenceVectors

* Export a Huggingface NER model to ONNX, e.g.:
Convert a sentence vectors model to ONNX, e.g.:

```
python -m transformers.onnx --model=dslim/bert-base-NER --feature token-classification exported
```
Install dependencies:

* Copy the exported model to `src/test/resources/namefinder/model.onnx`.
* Copy the model's [vocab.txt](https://huggingface.co/dslim/bert-base-NER/tree/main) to `src/test/resources/namefinder/vocab.txt`.
```bash
python3 -m pip install optimum onnx onnxruntime
```

Now you can run the tests in `NameFinderDLTest`.
Convert the model:

## DocumentCategorizer
```python
from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import AutoTokenizer
from pathlib import Path

* Export a Huggingface classification (e.g. sentiment) model to ONNX, e.g.:

```
python -m transformers.onnx --model=nlptown/bert-base-multilingual-uncased-sentiment --feature sequence-classification exported
```
model_id="sentence-transformers/all-MiniLM-L6-v2"
onnx_path = Path("onnx")

* Copy the exported model to `src/test/resources/doccat/model.onnx`.
* Copy the model's [vocab.txt](https://huggingface.co/nlptown/bert-base-multilingual-uncased-sentiment/tree/main) to `src/test/resources/namefinder/vocab.txt`.
# load vanilla transformers and convert to onnx
model = ORTModelForFeatureExtraction.from_pretrained(model_id, from_transformers=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Now you can run the tests in `DocumentCategorizerDLTest`.
# save onnx checkpoint and tokenizer
model.save_pretrained(onnx_path)
tokenizer.save_pretrained(onnx_path)
```
72 changes: 72 additions & 0 deletions opennlp-dl/src/main/java/opennlp/dl/AbstractDL.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package opennlp.dl;

import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;

import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;

import opennlp.tools.tokenize.Tokenizer;

/**
* Base class for OpenNLP deep-learning classes using ONNX Runtime.
*/
public abstract class AbstractDL {

public static final String INPUT_IDS = "input_ids";
public static final String ATTENTION_MASK = "attention_mask";
public static final String TOKEN_TYPE_IDS = "token_type_ids";

protected OrtEnvironment env;
protected OrtSession session;
protected Tokenizer tokenizer;
protected Map<String, Integer> vocab;

/**
* Loads a vocabulary file from disk.
* @param vocabFile The vocabulary file.
* @return A map of vocabulary words to integer IDs.
* @throws IOException Thrown if the vocabulary file cannot be opened and read.
*/
public Map<String, Integer> loadVocab(final File vocabFile) throws IOException {

final Map<String, Integer> vocab = new HashMap<>();

final AtomicInteger counter = new AtomicInteger(0);

try (Stream<String> lines = Files.lines(Path.of(vocabFile.getPath()))) {

lines.forEach(line -> {
vocab.put(line, counter.getAndIncrement());
});

}

return vocab;

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

package opennlp.dl.doccat;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.nio.LongBuffer;
import java.util.Arrays;
Expand All @@ -40,43 +38,36 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import opennlp.dl.AbstractDL;
import opennlp.dl.InferenceOptions;
import opennlp.dl.Tokens;
import opennlp.dl.doccat.scoring.ClassificationScoringStrategy;
import opennlp.tools.doccat.DocumentCategorizer;
import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WordpieceTokenizer;

/**
* An implementation of {@link DocumentCategorizer} that performs document classification
* using ONNX models.
*/
public class DocumentCategorizerDL implements DocumentCategorizer {
public class DocumentCategorizerDL extends AbstractDL implements DocumentCategorizer {

private static final Logger logger = LoggerFactory.getLogger(DocumentCategorizerDL.class);
public static final String INPUT_IDS = "input_ids";
public static final String ATTENTION_MASK = "attention_mask";
public static final String TOKEN_TYPE_IDS = "token_type_ids";

private final Tokenizer tokenizer;
private final Map<String, Integer> vocabulary;
private final Map<Integer, String> categories;
private final ClassificationScoringStrategy classificationScoringStrategy;
private final InferenceOptions inferenceOptions;
protected final OrtEnvironment env;
protected final OrtSession session;

/**
* Creates a new document categorizer using ONNX models.
* @param model The ONNX model file.
* @param vocab The model's vocabulary file.
* @param modelFile The ONNX modelFile file.
* @param vocabFile The modelFile's vocabulary file.
* @param categories The categories.
* @param classificationScoringStrategy Implementation of {@link ClassificationScoringStrategy} used
* to calculate the classification scores given the score of each
* individual document part.
* @param inferenceOptions {@link InferenceOptions} to control the inference.
*/
public DocumentCategorizerDL(File model, File vocab, Map<Integer, String> categories,
public DocumentCategorizerDL(File modelFile, File vocabFile, Map<Integer, String> categories,
ClassificationScoringStrategy classificationScoringStrategy,
InferenceOptions inferenceOptions)
throws IOException, OrtException {
Expand All @@ -88,9 +79,9 @@ public DocumentCategorizerDL(File model, File vocab, Map<Integer, String> catego
sessionOptions.addCUDA(inferenceOptions.getGpuDeviceId());
}

this.session = env.createSession(model.getPath(), sessionOptions);
this.vocabulary = loadVocab(vocab);
this.tokenizer = new WordpieceTokenizer(vocabulary.keySet());
this.session = env.createSession(modelFile.getPath(), sessionOptions);
this.vocab = loadVocab(vocabFile);
this.tokenizer = new WordpieceTokenizer(vocab.keySet());
this.categories = categories;
this.classificationScoringStrategy = classificationScoringStrategy;
this.inferenceOptions = inferenceOptions;
Expand Down Expand Up @@ -223,41 +214,14 @@ private int getKey(String value) {

}

/**
* Loads a vocabulary file from disk.
* @param vocab The vocabulary file.
* @return A map of vocabulary words to integer IDs.
* @throws IOException Thrown if the vocabulary file cannot be opened and read.
*/
private Map<String, Integer> loadVocab(File vocab) throws IOException {

final Map<String, Integer> v = new HashMap<>();

BufferedReader br = new BufferedReader(new FileReader(vocab.getPath()));
String line = br.readLine();
int x = 0;

while (line != null) {

line = br.readLine();
x++;

v.put(line, x);

}

return v;

}

private Tokens oldTokenize(String text) {

final String[] tokens = tokenizer.tokenize(text);

final int[] ids = new int[tokens.length];

for (int x = 0; x < tokens.length; x++) {
ids[x] = vocabulary.get(tokens[x]);
ids[x] = vocab.get(tokens[x]);
}

final long[] lids = Arrays.stream(ids).mapToLong(i -> i).toArray();
Expand Down Expand Up @@ -306,7 +270,7 @@ private List<Tokens> tokenize(final String text) {
final int[] ids = new int[tokens.length];

for (int x = 0; x < tokens.length; x++) {
ids[x] = vocabulary.get(tokens[x]);
ids[x] = vocab.get(tokens[x]);
}

final long[] lids = Arrays.stream(ids).mapToLong(i -> i).toArray();
Expand Down
46 changes: 2 additions & 44 deletions opennlp-dl/src/main/java/opennlp/dl/namefinder/NameFinderDL.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@

package opennlp.dl.namefinder;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.HashMap;
Expand All @@ -35,38 +32,29 @@
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;

import opennlp.dl.AbstractDL;
import opennlp.dl.InferenceOptions;
import opennlp.dl.SpanEnd;
import opennlp.dl.Tokens;
import opennlp.tools.namefind.TokenNameFinder;
import opennlp.tools.sentdetect.SentenceDetector;
import opennlp.tools.tokenize.Tokenizer;
import opennlp.tools.tokenize.WordpieceTokenizer;
import opennlp.tools.util.Span;

/**
* An implementation of {@link TokenNameFinder} that uses ONNX models.
*/
public class NameFinderDL implements TokenNameFinder {

public static final String INPUT_IDS = "input_ids";
public static final String ATTENTION_MASK = "attention_mask";
public static final String TOKEN_TYPE_IDS = "token_type_ids";
public class NameFinderDL extends AbstractDL implements TokenNameFinder {

public static final String I_PER = "I-PER";
public static final String B_PER = "B-PER";
public static final String SEPARATOR = "[SEP]";

private static final String CHARS_TO_REPLACE = "##";

protected final OrtSession session;

private final SentenceDetector sentenceDetector;
private final Map<Integer, String> ids2Labels;
private final Tokenizer tokenizer;
private final Map<String, Integer> vocab;
private final InferenceOptions inferenceOptions;
protected final OrtEnvironment env;

public NameFinderDL(File model, File vocabulary, Map<Integer, String> ids2Labels,
SentenceDetector sentenceDetector) throws Exception {
Expand Down Expand Up @@ -384,34 +372,4 @@ private List<Tokens> tokenize(final String text) {

}

/**
* Loads a vocabulary file from disk.
* @param vocab The vocabulary file.
* @return A map of vocabulary words to integer IDs.
* @throws IOException Thrown if the vocabulary file cannot be opened and read.
*/
private Map<String, Integer> loadVocab(File vocab) throws IOException {

final Map<String, Integer> v = new HashMap<>();

try (final BufferedReader br = new BufferedReader(new FileReader(vocab.getPath()))) {

String line = br.readLine();
int x = 0;

while (line != null) {

line = br.readLine();
x++;

v.put(line, x);

}

}

return v;

}

}
Loading

0 comments on commit 9b62d58

Please sign in to comment.