-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update Taskflow word_segmentation and ner tasks #1666
Changes from 35 commits
23cc16d
0ffb746
95fcf50
f984f84
32ace14
dd149f7
af7d758
e243ad9
e31c3f8
18c2bfe
5c95b71
8f63909
1636a8f
6549aa4
a35fb2c
643c901
4cc6389
6a840dc
e202e5d
7c9672a
2931053
dcab431
6182ca0
b5f9e2b
5850a1a
3f87c52
3a932dc
a09abfe
2f62846
8a7caab
72ab3c8
75b3a2c
cb2e174
0fe6977
7fbebb4
d29fefa
7a51b43
5346e88
b978773
aa91294
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,13 +25,14 @@ | |
import numpy as np | ||
import paddle | ||
import paddle.nn as nn | ||
try: | ||
from paddle.text import ViterbiDecoder | ||
except: | ||
raise ImportError( | ||
"Taskflow requires paddle version >= 2.2.0, but current paddle version is {}". | ||
format(paddle.version.full_version)) | ||
from paddlenlp.layers.crf import LinearChainCrf | ||
from paddlenlp.utils.tools import compare_version | ||
if compare_version(paddle.version.full_version, "2.2.0") >= 0: | ||
# paddle.text.ViterbiDecoder is supported by paddle after version 2.2.0 | ||
from paddle.text import ViterbiDecoder | ||
else: | ||
from paddlenlp.layers.crf import ViterbiDecoder | ||
|
||
from ..datasets import MapDataset, load_dataset | ||
from ..data import Stack, Pad, Tuple | ||
|
@@ -217,6 +218,16 @@ def __init__(self, | |
self._custom.load_customization(self._user_dict) | ||
else: | ||
self._custom = None | ||
self._num_workers = self.kwargs[ | ||
'num_workers'] if 'num_workers' in self.kwargs else 0 | ||
self._batch_size = self.kwargs[ | ||
'batch_size'] if 'batch_size' in self.kwargs else 1 | ||
self._lazy_load = self.kwargs[ | ||
'lazy_load'] if 'lazy_load' in self.kwargs else False | ||
self._max_seq_len = self.kwargs[ | ||
'max_seq_len'] if 'max_seq_len' in self.kwargs else 512 | ||
self._split_sentence = self.kwargs[ | ||
'split_sentence'] if 'split_sentence' in self.kwargs else False | ||
|
||
@property | ||
def summary_num(self): | ||
|
@@ -263,127 +274,34 @@ def _load_task_resources(self): | |
self._termtree = TermTree.from_dir( | ||
self._term_schema_path, self._term_data_path, self._linking) | ||
|
||
def _split_long_text_input(self, input_texts, max_text_len): | ||
""" | ||
Split the long text to list of short text, the max_seq_len of input text is 512, | ||
if the text length greater than 512, will this function that spliting the long text. | ||
""" | ||
short_input_texts = [] | ||
for text in input_texts: | ||
if len(text) <= max_text_len: | ||
short_input_texts.append(text) | ||
else: | ||
lens = len(text) | ||
temp_text_list = text.split("?。!") | ||
temp_text_list = [ | ||
temp_text for temp_text in temp_text_list | ||
if len(temp_text) > 0 | ||
] | ||
if len(temp_text_list) <= 1: | ||
temp_text_list = [ | ||
text[i:i + max_text_len] | ||
for i in range(0, len(text), max_text_len) | ||
] | ||
short_input_texts.extend(temp_text_list) | ||
else: | ||
list_len = len(temp_text_list) | ||
start = 0 | ||
end = 0 | ||
for i in range(0, list_len): | ||
if len(temp_text_list[i]) + 1 >= max_text_len: | ||
if start != end: | ||
short_input_texts.extend( | ||
self._split_long_text_input( | ||
[text[start:end]], max_text_len)) | ||
short_input_texts.extend( | ||
self._split_long_text_input([ | ||
text[end:end + len(temp_text_list[i]) + 1] | ||
], max_text_len)) | ||
start = end + len(temp_text_list[i]) + 1 | ||
end = start | ||
else: | ||
if start + len(temp_text_list[ | ||
i]) + 1 > max_text_len: | ||
short_input_texts.extend( | ||
self._split_long_text_input( | ||
[text[start:end]], max_text_len)) | ||
start = end | ||
end = end + len(temp_text_list[i]) + 1 | ||
else: | ||
end = len(temp_text_list[i]) + 1 | ||
if start != end: | ||
short_input_texts.extend( | ||
self._split_long_text_input([text[start:end]], | ||
max_text_len)) | ||
return short_input_texts | ||
|
||
def _concat_short_text_reuslts(self, input_texts, results): | ||
""" | ||
Concat the model output of short texts to the total result of long text. | ||
""" | ||
long_text_lens = [len(text) for text in input_texts] | ||
concat_results = [] | ||
single_results = {} | ||
count = 0 | ||
for text in input_texts: | ||
text_len = len(text) | ||
while True: | ||
if len(single_results) == 0 or len(single_results[ | ||
"text"]) < text_len: | ||
if len(single_results) == 0: | ||
single_results = copy.deepcopy(results[count]) | ||
else: | ||
single_results["text"] += results[count]["text"] | ||
single_results["items"].extend(results[count]["items"]) | ||
count += 1 | ||
elif len(single_results["text"]) == text_len: | ||
concat_results.append(single_results) | ||
single_results = {} | ||
break | ||
else: | ||
raise Exception( | ||
"The length of input text and raw text is not equal.") | ||
for result in concat_results: | ||
pred_words = result['items'] | ||
pred_words = self._reset_offset(pred_words) | ||
result['items'] = pred_words | ||
return concat_results | ||
|
||
def _preprocess_text(self, input_texts): | ||
""" | ||
Create the dataset and dataloader for the predict. | ||
""" | ||
batch_size = self.kwargs[ | ||
'batch_size'] if 'batch_size' in self.kwargs else 1 | ||
num_workers = self.kwargs[ | ||
'num_workers'] if 'num_workers' in self.kwargs else 0 | ||
|
||
max_seq_length = 512 | ||
if 'max_seq_length' in self.kwargs: | ||
max_seq_length = self.kwargs['max_seq_length'] | ||
infer_data = [] | ||
max_predict_len = max_seq_length - self.summary_num - 1 | ||
max_predict_len = self._max_seq_len - self.summary_num - 1 | ||
filter_input_texts = [] | ||
for input_text in input_texts: | ||
if not (isinstance(input_text, str) and len(input_text) > 0): | ||
continue | ||
filter_input_texts.append(input_text) | ||
input_texts = filter_input_texts | ||
|
||
short_input_texts = self._split_long_text_input(input_texts, | ||
max_predict_len) | ||
short_input_texts, self.input_mapping = self._auto_splitter( | ||
input_texts, max_predict_len, split_sentence=self._split_sentence) | ||
|
||
def read(inputs): | ||
for text in inputs: | ||
tokenized_output = self._tokenizer( | ||
list(text), | ||
return_length=True, | ||
is_split_into_words=True, | ||
max_seq_len=max_seq_length) | ||
max_seq_len=self._max_seq_len) | ||
yield tokenized_output['input_ids'], tokenized_output[ | ||
'token_type_ids'], tokenized_output['seq_len'] | ||
|
||
infer_ds = load_dataset(read, inputs=short_input_texts, lazy=False) | ||
infer_ds = load_dataset( | ||
read, inputs=short_input_texts, lazy=self._lazy_load) | ||
batchify_fn = lambda samples, fn=Tuple( | ||
Pad(axis=0, pad_val=self._tokenizer.pad_token_id, dtype='int64' | ||
), # input_ids | ||
|
@@ -396,15 +314,14 @@ def read(inputs): | |
infer_data_loader = paddle.io.DataLoader( | ||
infer_ds, | ||
collate_fn=batchify_fn, | ||
num_workers=num_workers, | ||
batch_size=batch_size, | ||
num_workers=self._num_workers, | ||
batch_size=self._batch_size, | ||
shuffle=False, | ||
return_list=True) | ||
|
||
outputs = {} | ||
outputs['data_loader'] = infer_data_loader | ||
outputs['short_input_texts'] = short_input_texts | ||
outputs['inputs'] = input_texts | ||
return outputs | ||
|
||
def _reset_offset(self, pred_words): | ||
|
@@ -419,11 +336,9 @@ def _decode(self, batch_texts, batch_pred_tags): | |
batch_results = [] | ||
for sent_index in range(len(batch_texts)): | ||
sent = batch_texts[sent_index] | ||
tags = [ | ||
self._index_to_tags[index] | ||
for index in batch_pred_tags[sent_index][self.summary_num:len( | ||
sent) + self.summary_num] | ||
] | ||
indexes = batch_pred_tags[sent_index][self.summary_num:len(sent) + | ||
self.summary_num] | ||
tags = [self._index_to_tags[index] for index in indexes] | ||
if self._custom: | ||
self._custom.parse_customization(sent, tags, prefix=True) | ||
sent_out = [] | ||
|
@@ -543,15 +458,15 @@ def _run_model(self, inputs): | |
Run the task model from the outputs of the `_tokenize` function. | ||
""" | ||
all_pred_tags = [] | ||
|
||
for batch in inputs['data_loader']: | ||
input_ids, token_type_ids, seq_len = batch | ||
self.input_handles[0].copy_from_cpu(input_ids.numpy()) | ||
self.input_handles[1].copy_from_cpu(token_type_ids.numpy()) | ||
self.input_handles[2].copy_from_cpu(seq_len.numpy()) | ||
self.predictor.run() | ||
pred_tags = self.output_handle[0].copy_to_cpu() | ||
all_pred_tags.extend(pred_tags.tolist()) | ||
with dygraph_mode_guard(): | ||
for batch in inputs['data_loader']: | ||
input_ids, token_type_ids, seq_len = batch | ||
self.input_handles[0].copy_from_cpu(input_ids.numpy()) | ||
self.input_handles[1].copy_from_cpu(token_type_ids.numpy()) | ||
self.input_handles[2].copy_from_cpu(seq_len.numpy()) | ||
self.predictor.run() | ||
pred_tags = self.output_handle[0].copy_to_cpu() | ||
all_pred_tags.extend(pred_tags.tolist()) | ||
inputs['all_pred_tags'] = all_pred_tags | ||
return inputs | ||
|
||
|
@@ -561,7 +476,11 @@ def _postprocess(self, inputs): | |
""" | ||
results = self._decode(inputs['short_input_texts'], | ||
inputs['all_pred_tags']) | ||
results = self._concat_short_text_reuslts(inputs['inputs'], results) | ||
results = self._auto_joiner(results, self.input_mapping, is_dict=True) | ||
for result in results: | ||
pred_words = result['items'] | ||
pred_words = self._reset_offset(pred_words) | ||
result['items'] = pred_words | ||
if self.linking is True: | ||
for res in results: | ||
self._term_linking(res) | ||
|
@@ -804,22 +723,22 @@ def _run_model(self, inputs): | |
all_scores_can = [] | ||
all_preds_can = [] | ||
pred_ids = [] | ||
|
||
for batch in inputs['data_loader']: | ||
input_ids, token_type_ids, label_indices = batch | ||
self.input_handles[0].copy_from_cpu(input_ids.numpy()) | ||
self.input_handles[1].copy_from_cpu(token_type_ids.numpy()) | ||
self.predictor.run() | ||
logits = self.output_handle[0].copy_to_cpu() | ||
for i, l in zip(label_indices, logits): | ||
score = l[i[0]:i[-1] + 1, self._vocab_ids] | ||
# Find topk candidates of scores and predicted indices. | ||
score_can, pred_id_can = self._find_topk(score, k=4, axis=-1) | ||
|
||
all_scores_can.extend([score_can.tolist()]) | ||
all_preds_can.extend([pred_id_can.tolist()]) | ||
pred_ids.extend([pred_id_can[:, 0].tolist()]) | ||
|
||
with dygraph_mode_guard(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
for batch in inputs['data_loader']: | ||
input_ids, token_type_ids, label_indices = batch | ||
self.input_handles[0].copy_from_cpu(input_ids.numpy()) | ||
self.input_handles[1].copy_from_cpu(token_type_ids.numpy()) | ||
self.predictor.run() | ||
logits = self.output_handle[0].copy_to_cpu() | ||
for i, l in zip(label_indices, logits): | ||
score = l[i[0]:i[-1] + 1, self._vocab_ids] | ||
# Find topk candidates of scores and predicted indices. | ||
score_can, pred_id_can = self._find_topk( | ||
score, k=4, axis=-1) | ||
|
||
all_scores_can.extend([score_can.tolist()]) | ||
all_preds_can.extend([pred_id_can.tolist()]) | ||
pred_ids.extend([pred_id_can[:, 0].tolist()]) | ||
inputs['all_scores_can'] = all_scores_can | ||
inputs['all_preds_can'] = all_preds_can | ||
inputs['pred_ids'] = pred_ids | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ | |
import paddle.nn.functional as F | ||
from ..datasets import load_dataset, MapDataset | ||
from ..data import Stack, Pad, Tuple, Vocab, JiebaTokenizer | ||
from .utils import download_file, add_docstrings, dygraph_mode_guard | ||
from .utils import download_file, add_docstrings, static_mode_guard, dygraph_mode_guard | ||
from .utils import Customization | ||
from .task import Task | ||
from .models import BiGruCrf | ||
|
@@ -81,6 +81,7 @@ class LacTask(Task): | |
Args: | ||
task(string): The name of task. | ||
model(string): The model name in the task. | ||
user_dict(string): The user-defined dictionary, default to None. | ||
kwargs (dict, optional): Additional keyword arguments passed along to the specific task. | ||
""" | ||
|
||
|
@@ -118,6 +119,7 @@ def __init__(self, task, model, user_dict=None, **kwargs): | |
self._check_task_files() | ||
self._construct_vocabs() | ||
self._get_inference_model() | ||
self._max_seq_len = 512 | ||
if self._user_dict: | ||
self._custom = Customization() | ||
self._custom.load_customization(self._user_dict) | ||
|
@@ -179,17 +181,24 @@ def _preprocess(self, inputs, padding=True, add_special_tokens=True): | |
'batch_size'] if 'batch_size' in self.kwargs else 1 | ||
num_workers = self.kwargs[ | ||
'num_workers'] if 'num_workers' in self.kwargs else 0 | ||
self._split_sentence = self.kwargs[ | ||
'split_sentence'] if 'split_sentence' in self.kwargs else False | ||
infer_data = [] | ||
oov_token_id = self._word_vocab.get("OOV") | ||
|
||
filter_inputs = [] | ||
for input in inputs: | ||
if not (isinstance(input, str) and len(input.strip()) > 0): | ||
continue | ||
filter_inputs.append(input) | ||
|
||
short_input_texts, self.input_mapping = self._auto_splitter( | ||
filter_inputs, | ||
self._max_seq_len, | ||
split_sentence=self._split_sentence) | ||
|
||
def read(inputs): | ||
for input_tokens in inputs: | ||
if not (isinstance(input_tokens, str) and | ||
len(input_tokens.strip()) > 0): | ||
continue | ||
filter_inputs.append(input_tokens) | ||
ids = [] | ||
for token in input_tokens: | ||
token = self._q2b_vocab.get(token, token) | ||
|
@@ -198,7 +207,7 @@ def read(inputs): | |
lens = len(ids) | ||
yield ids, lens | ||
|
||
infer_ds = load_dataset(read, inputs=inputs, lazy=False) | ||
infer_ds = load_dataset(read, inputs=short_input_texts, lazy=False) | ||
batchify_fn = lambda samples, fn=Tuple( | ||
Pad(axis=0, pad_val=0, dtype="int64"), # input_ids | ||
Stack(dtype='int64'), # seq_len | ||
|
@@ -211,7 +220,7 @@ def read(inputs): | |
shuffle=False, | ||
return_list=True) | ||
outputs = {} | ||
outputs['text'] = filter_inputs | ||
outputs['text'] = short_input_texts | ||
outputs['data_loader'] = infer_data_loader | ||
return outputs | ||
|
||
|
@@ -221,14 +230,15 @@ def _run_model(self, inputs): | |
""" | ||
results = [] | ||
lens = [] | ||
for batch in inputs['data_loader']: | ||
input_ids, seq_len = batch | ||
self.input_handles[0].copy_from_cpu(input_ids.numpy()) | ||
self.input_handles[1].copy_from_cpu(seq_len.numpy()) | ||
self.predictor.run() | ||
tags_ids = self.output_handle[0].copy_to_cpu() | ||
results.extend(tags_ids.tolist()) | ||
lens.extend(seq_len.tolist()) | ||
with dygraph_mode_guard(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
for batch in inputs['data_loader']: | ||
input_ids, seq_len = batch | ||
self.input_handles[0].copy_from_cpu(input_ids.numpy()) | ||
self.input_handles[1].copy_from_cpu(seq_len.numpy()) | ||
self.predictor.run() | ||
tags_ids = self.output_handle[0].copy_to_cpu() | ||
results.extend(tags_ids.tolist()) | ||
lens.extend(seq_len.tolist()) | ||
inputs['result'] = results | ||
inputs['lens'] = lens | ||
return inputs | ||
|
@@ -273,4 +283,6 @@ def _postprocess(self, inputs): | |
single_result['segs'] = sent_out | ||
single_result['tags'] = tags_out | ||
final_results.append(single_result) | ||
final_results = self._auto_joiner( | ||
final_results, self.input_mapping, is_dict=True) | ||
return final_results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为啥是动态图了? 不应该是静态图吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改