Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Taskflow word_segmentation and ner tasks #1666

Merged
merged 40 commits into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
23cc16d
Add AutoSplitter & AutoJoiner
linjieccc Feb 8, 2022
0ffb746
Merge branch 'develop' into add_autosplitter
linjieccc Feb 8, 2022
95fcf50
codestyle fix
linjieccc Feb 8, 2022
f984f84
unify auto joiner
linjieccc Feb 14, 2022
32ace14
add comments
linjieccc Feb 14, 2022
dd149f7
Merge branch 'develop' into add_autosplitter
linjieccc Feb 14, 2022
af7d758
add sentence split mode
linjieccc Feb 17, 2022
e243ad9
Merge branch 'add_autosplitter' of https://github.com/linjieccc/Paddl…
linjieccc Feb 17, 2022
e31c3f8
update params
linjieccc Feb 21, 2022
18c2bfe
add paddle version check
linjieccc Feb 22, 2022
5c95b71
Merge branch 'develop' into add_autosplitter
linjieccc Feb 28, 2022
8f63909
add wordtag for word_segmentation
linjieccc Mar 1, 2022
1636a8f
add wordtag for word_segmentation
linjieccc Mar 1, 2022
6549aa4
Merge branch 'develop' into add_autosplitter
linjieccc Mar 1, 2022
a35fb2c
add ner-lac and word_segmentation-jieba
linjieccc Mar 7, 2022
643c901
add return entities only for ner
linjieccc Mar 8, 2022
4cc6389
fix ci
linjieccc Mar 9, 2022
6a840dc
fix ci
linjieccc Mar 9, 2022
e202e5d
fix ci
linjieccc Mar 9, 2022
7c9672a
fix ci
linjieccc Mar 9, 2022
2931053
fix ci
linjieccc Mar 9, 2022
dcab431
Update README.md
linjieccc Mar 9, 2022
6182ca0
Update README.md
linjieccc Mar 9, 2022
b5f9e2b
Update README.md
linjieccc Mar 9, 2022
5850a1a
Update README.md
linjieccc Mar 9, 2022
3f87c52
Update README.md
linjieccc Mar 9, 2022
3a932dc
Update README.md
linjieccc Mar 9, 2022
a09abfe
Update README.md
linjieccc Mar 10, 2022
2f62846
Update README.md
linjieccc Mar 10, 2022
8a7caab
Update README.md
linjieccc Mar 10, 2022
72ab3c8
Update README.md
linjieccc Mar 10, 2022
75b3a2c
Update README.md
linjieccc Mar 10, 2022
cb2e174
Update README.md
linjieccc Mar 11, 2022
0fe6977
Merge branch 'develop' into add_autosplitter
linjieccc Mar 11, 2022
7fbebb4
fix bugs of dataloader
linjieccc Mar 13, 2022
d29fefa
remove guard
linjieccc Mar 15, 2022
7a51b43
Merge branch 'develop' into add_autosplitter
linjieccc Mar 15, 2022
5346e88
use fast mode for rnn example
linjieccc Mar 15, 2022
b978773
Update README.md
linjieccc Mar 15, 2022
aa91294
Update README.md
linjieccc Mar 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
606 changes: 253 additions & 353 deletions docs/model_zoo/taskflow.md

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions examples/text_classification/rnn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
import re

from paddlenlp import Taskflow
import jieba
import numpy as np

word_segmenter = Taskflow("word_segmentation")
word_segmenter = Taskflow("word_segmentation", mode="fast")


def convert_example(example, tokenizer, is_test=False):
Expand Down
2 changes: 1 addition & 1 deletion examples/text_to_knowledge/ernie-ctm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ custom_task_path/
└── tags.txt
```

```shell
```python
from paddlenlp import Taskflow

my_wordtag = Taskflow("knowledge_mining", task_path="./custom_task_path/")
Expand Down
148 changes: 32 additions & 116 deletions paddlenlp/taskflow/knowledge_mining.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
import numpy as np
import paddle
import paddle.nn as nn
try:
from paddle.text import ViterbiDecoder
except:
raise ImportError(
"Taskflow requires paddle version >= 2.2.0, but current paddle version is {}".
format(paddle.version.full_version))
from paddlenlp.layers.crf import LinearChainCrf
from paddlenlp.utils.tools import compare_version
if compare_version(paddle.version.full_version, "2.2.0") >= 0:
# paddle.text.ViterbiDecoder is supported by paddle after version 2.2.0
from paddle.text import ViterbiDecoder
else:
from paddlenlp.layers.crf import ViterbiDecoder

from ..datasets import MapDataset, load_dataset
from ..data import Stack, Pad, Tuple
Expand Down Expand Up @@ -217,6 +218,16 @@ def __init__(self,
self._custom.load_customization(self._user_dict)
else:
self._custom = None
self._num_workers = self.kwargs[
'num_workers'] if 'num_workers' in self.kwargs else 0
self._batch_size = self.kwargs[
'batch_size'] if 'batch_size' in self.kwargs else 1
self._lazy_load = self.kwargs[
'lazy_load'] if 'lazy_load' in self.kwargs else False
self._max_seq_len = self.kwargs[
'max_seq_len'] if 'max_seq_len' in self.kwargs else 512
self._split_sentence = self.kwargs[
'split_sentence'] if 'split_sentence' in self.kwargs else False

@property
def summary_num(self):
Expand Down Expand Up @@ -263,127 +274,34 @@ def _load_task_resources(self):
self._termtree = TermTree.from_dir(
self._term_schema_path, self._term_data_path, self._linking)

def _split_long_text_input(self, input_texts, max_text_len):
"""
Split the long text to list of short text, the max_seq_len of input text is 512,
if the text length greater than 512, will this function that spliting the long text.
"""
short_input_texts = []
for text in input_texts:
if len(text) <= max_text_len:
short_input_texts.append(text)
else:
lens = len(text)
temp_text_list = text.split("?。!")
temp_text_list = [
temp_text for temp_text in temp_text_list
if len(temp_text) > 0
]
if len(temp_text_list) <= 1:
temp_text_list = [
text[i:i + max_text_len]
for i in range(0, len(text), max_text_len)
]
short_input_texts.extend(temp_text_list)
else:
list_len = len(temp_text_list)
start = 0
end = 0
for i in range(0, list_len):
if len(temp_text_list[i]) + 1 >= max_text_len:
if start != end:
short_input_texts.extend(
self._split_long_text_input(
[text[start:end]], max_text_len))
short_input_texts.extend(
self._split_long_text_input([
text[end:end + len(temp_text_list[i]) + 1]
], max_text_len))
start = end + len(temp_text_list[i]) + 1
end = start
else:
if start + len(temp_text_list[
i]) + 1 > max_text_len:
short_input_texts.extend(
self._split_long_text_input(
[text[start:end]], max_text_len))
start = end
end = end + len(temp_text_list[i]) + 1
else:
end = len(temp_text_list[i]) + 1
if start != end:
short_input_texts.extend(
self._split_long_text_input([text[start:end]],
max_text_len))
return short_input_texts

def _concat_short_text_reuslts(self, input_texts, results):
"""
Concat the model output of short texts to the total result of long text.
"""
long_text_lens = [len(text) for text in input_texts]
concat_results = []
single_results = {}
count = 0
for text in input_texts:
text_len = len(text)
while True:
if len(single_results) == 0 or len(single_results[
"text"]) < text_len:
if len(single_results) == 0:
single_results = copy.deepcopy(results[count])
else:
single_results["text"] += results[count]["text"]
single_results["items"].extend(results[count]["items"])
count += 1
elif len(single_results["text"]) == text_len:
concat_results.append(single_results)
single_results = {}
break
else:
raise Exception(
"The length of input text and raw text is not equal.")
for result in concat_results:
pred_words = result['items']
pred_words = self._reset_offset(pred_words)
result['items'] = pred_words
return concat_results

def _preprocess_text(self, input_texts):
"""
Create the dataset and dataloader for the predict.
"""
batch_size = self.kwargs[
'batch_size'] if 'batch_size' in self.kwargs else 1
num_workers = self.kwargs[
'num_workers'] if 'num_workers' in self.kwargs else 0

max_seq_length = 512
if 'max_seq_length' in self.kwargs:
max_seq_length = self.kwargs['max_seq_length']
infer_data = []
max_predict_len = max_seq_length - self.summary_num - 1
max_predict_len = self._max_seq_len - self.summary_num - 1
filter_input_texts = []
for input_text in input_texts:
if not (isinstance(input_text, str) and len(input_text) > 0):
continue
filter_input_texts.append(input_text)
input_texts = filter_input_texts

short_input_texts = self._split_long_text_input(input_texts,
max_predict_len)
short_input_texts, self.input_mapping = self._auto_splitter(
input_texts, max_predict_len, split_sentence=self._split_sentence)

def read(inputs):
for text in inputs:
tokenized_output = self._tokenizer(
list(text),
return_length=True,
is_split_into_words=True,
max_seq_len=max_seq_length)
max_seq_len=self._max_seq_len)
yield tokenized_output['input_ids'], tokenized_output[
'token_type_ids'], tokenized_output['seq_len']

infer_ds = load_dataset(read, inputs=short_input_texts, lazy=False)
infer_ds = load_dataset(
read, inputs=short_input_texts, lazy=self._lazy_load)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=self._tokenizer.pad_token_id, dtype='int64'
), # input_ids
Expand All @@ -396,15 +314,14 @@ def read(inputs):
infer_data_loader = paddle.io.DataLoader(
infer_ds,
collate_fn=batchify_fn,
num_workers=num_workers,
batch_size=batch_size,
num_workers=self._num_workers,
batch_size=self._batch_size,
shuffle=False,
return_list=True)

outputs = {}
outputs['data_loader'] = infer_data_loader
outputs['short_input_texts'] = short_input_texts
outputs['inputs'] = input_texts
return outputs

def _reset_offset(self, pred_words):
Expand All @@ -419,11 +336,9 @@ def _decode(self, batch_texts, batch_pred_tags):
batch_results = []
for sent_index in range(len(batch_texts)):
sent = batch_texts[sent_index]
tags = [
self._index_to_tags[index]
for index in batch_pred_tags[sent_index][self.summary_num:len(
sent) + self.summary_num]
]
indexes = batch_pred_tags[sent_index][self.summary_num:len(sent) +
self.summary_num]
tags = [self._index_to_tags[index] for index in indexes]
if self._custom:
self._custom.parse_customization(sent, tags, prefix=True)
sent_out = []
Expand Down Expand Up @@ -543,7 +458,6 @@ def _run_model(self, inputs):
Run the task model from the outputs of the `_tokenize` function.
"""
all_pred_tags = []

for batch in inputs['data_loader']:
input_ids, token_type_ids, seq_len = batch
self.input_handles[0].copy_from_cpu(input_ids.numpy())
Expand All @@ -561,7 +475,11 @@ def _postprocess(self, inputs):
"""
results = self._decode(inputs['short_input_texts'],
inputs['all_pred_tags'])
results = self._concat_short_text_reuslts(inputs['inputs'], results)
results = self._auto_joiner(results, self.input_mapping, is_dict=True)
for result in results:
pred_words = result['items']
pred_words = self._reset_offset(pred_words)
result['items'] = pred_words
if self.linking is True:
for res in results:
self._term_linking(res)
Expand Down Expand Up @@ -804,7 +722,6 @@ def _run_model(self, inputs):
all_scores_can = []
all_preds_can = []
pred_ids = []

for batch in inputs['data_loader']:
input_ids, token_type_ids, label_indices = batch
self.input_handles[0].copy_from_cpu(input_ids.numpy())
Expand All @@ -819,7 +736,6 @@ def _run_model(self, inputs):
all_scores_can.extend([score_can.tolist()])
all_preds_can.extend([pred_id_can.tolist()])
pred_ids.extend([pred_id_can[:, 0].tolist()])

inputs['all_scores_can'] = all_scores_can
inputs['all_preds_can'] = all_preds_can
inputs['pred_ids'] = pred_ids
Expand Down
26 changes: 19 additions & 7 deletions paddlenlp/taskflow/lexical_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import paddle.nn.functional as F
from ..datasets import load_dataset, MapDataset
from ..data import Stack, Pad, Tuple, Vocab, JiebaTokenizer
from .utils import download_file, add_docstrings, dygraph_mode_guard
from .utils import download_file, add_docstrings, static_mode_guard, dygraph_mode_guard
from .utils import Customization
from .task import Task
from .models import BiGruCrf
Expand Down Expand Up @@ -81,6 +81,7 @@ class LacTask(Task):
Args:
task(string): The name of task.
model(string): The model name in the task.
user_dict(string): The user-defined dictionary, default to None.
kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
"""

Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(self, task, model, user_dict=None, **kwargs):
self._check_task_files()
self._construct_vocabs()
self._get_inference_model()
self._max_seq_len = 512
if self._user_dict:
self._custom = Customization()
self._custom.load_customization(self._user_dict)
Expand Down Expand Up @@ -179,17 +181,24 @@ def _preprocess(self, inputs, padding=True, add_special_tokens=True):
'batch_size'] if 'batch_size' in self.kwargs else 1
num_workers = self.kwargs[
'num_workers'] if 'num_workers' in self.kwargs else 0
self._split_sentence = self.kwargs[
'split_sentence'] if 'split_sentence' in self.kwargs else False
infer_data = []
oov_token_id = self._word_vocab.get("OOV")

filter_inputs = []
for input in inputs:
if not (isinstance(input, str) and len(input.strip()) > 0):
continue
filter_inputs.append(input)

short_input_texts, self.input_mapping = self._auto_splitter(
filter_inputs,
self._max_seq_len,
split_sentence=self._split_sentence)

def read(inputs):
for input_tokens in inputs:
if not (isinstance(input_tokens, str) and
len(input_tokens.strip()) > 0):
continue
filter_inputs.append(input_tokens)
ids = []
for token in input_tokens:
token = self._q2b_vocab.get(token, token)
Expand All @@ -198,7 +207,7 @@ def read(inputs):
lens = len(ids)
yield ids, lens

infer_ds = load_dataset(read, inputs=inputs, lazy=False)
infer_ds = load_dataset(read, inputs=short_input_texts, lazy=False)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=0, dtype="int64"), # input_ids
Stack(dtype='int64'), # seq_len
Expand All @@ -211,7 +220,7 @@ def read(inputs):
shuffle=False,
return_list=True)
outputs = {}
outputs['text'] = filter_inputs
outputs['text'] = short_input_texts
outputs['data_loader'] = infer_data_loader
return outputs

Expand All @@ -229,6 +238,7 @@ def _run_model(self, inputs):
tags_ids = self.output_handle[0].copy_to_cpu()
results.extend(tags_ids.tolist())
lens.extend(seq_len.tolist())

inputs['result'] = results
inputs['lens'] = lens
return inputs
Expand Down Expand Up @@ -273,4 +283,6 @@ def _postprocess(self, inputs):
single_result['segs'] = sent_out
single_result['tags'] = tags_out
final_results.append(single_result)
final_results = self._auto_joiner(
final_results, self.input_mapping, is_dict=True)
return final_results
9 changes: 5 additions & 4 deletions paddlenlp/taskflow/models/lexical_analysis_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
from paddlenlp.layers.crf import LinearChainCrf, LinearChainCrfLoss
from paddlenlp.utils.tools import compare_version

if compare_version(paddle.version.full_version, "2.2.0") >= 0:
# paddle.text.ViterbiDecoder is supported by paddle after version 2.2.0
try:
from paddle.text import ViterbiDecoder
else:
from paddlenlp.layers.crf import ViterbiDecoder
except:
raise ImportError(
"Taskflow requires paddle version >= 2.2.0, but current paddle version is {}".
format(paddle.version.full_version))


class BiGruCrf(nn.Layer):
Expand Down
Loading