-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add text similarity task for Taskflow #1345
Changes from 4 commits
478cf12
15d90bf
78a9649
0924f23
abbc852
6643178
b5efa9d
3eea9bd
e34d519
8a0ace8
35069d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
- [文本纠错](#文本纠错) | ||
- [句法分析](#句法分析) | ||
- [情感分析](#情感分析) | ||
- [文本匹配](#文本匹配) | ||
- [知识挖掘-词类知识标注](#知识挖掘-词类知识标注) | ||
- [知识挖掘-名词短语标注](#知识挖掘-名词短语标注) | ||
- [生成式问答](#生成式问答) | ||
|
@@ -31,6 +32,7 @@ | |
| 文本纠错 | 开放域对话(TODO) | | ||
| 句法分析 | 自动对联(TODO) | | ||
| 情感分析 | | | ||
| 文本匹配 | | | ||
| 知识挖掘-词类知识标注 | | | ||
| 知识挖掘-名词短语标注 | | | ||
|
||
|
@@ -174,6 +176,20 @@ senta("作为老的四星酒店,房间依然很整洁,相当不错。机场 | |
>>> [{'text': '作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。', 'label': 'positive', 'score': 0.984320878982544}] | ||
``` | ||
|
||
### 文本匹配 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 文本相似度计算更为直观 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
|
||
```python | ||
from paddlenlp import Taskflow | ||
|
||
matcher = Taskflow("text_matching") | ||
matcher([["世界上什么东西最小", "世界上什么东西最小?"]]) | ||
>>> [{'query': '世界上什么东西最小', 'title': '世界上什么东西最小?', 'similarity': 0.992725}] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 输入的key可能采用text1,text2 更加准确。如果用query和title会被倾向于认为是短文本与长文本匹配 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
|
||
matcher = Taskflow("text_matching", batch_size=2) | ||
matcher([["光眼睛大就好看吗", "眼睛好看吗?"], ["小蝌蚪找妈妈怎么样", "小蝌蚪找妈妈是谁画的"]]) | ||
>>> [{'query': '光眼睛大就好看吗', 'title': '眼睛好看吗?', 'similarity': 0.7450271}, {'query': '小蝌蚪找妈妈怎么样', 'title': '小蝌蚪找妈妈是谁画的', 'similarity': 0.8192149}] | ||
``` | ||
|
||
### 知识挖掘-词类知识标注 | ||
|
||
```python | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
from paddlenlp.transformers import BertModel, BertTokenizer | ||
|
||
from ..data import Pad, Tuple | ||
from .utils import static_mode_guard | ||
from .task import Task | ||
|
||
usage = r""" | ||
from paddlenlp import Taskflow | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. taskname改为text similarity会不会更为表意? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
matcher = Taskflow("text_matching") | ||
matcher([["世界上什么东西最小", "世界上什么东西最小?"]]) | ||
''' | ||
[{'query': '世界上什么东西最小', 'title': '世界上什么东西最小?', 'similarity': 0.992725}] | ||
''' | ||
|
||
matcher = Taskflow("text_matching", batch_size=2) | ||
matcher([["光眼睛大就好看吗", "眼睛好看吗?"], ["小蝌蚪找妈妈怎么样", "小蝌蚪找妈妈是谁画的"]]) | ||
''' | ||
[{'query': '光眼睛大就好看吗', 'title': '眼睛好看吗?', 'similarity': 0.7450271}, {'query': '小蝌蚪找妈妈怎么样', 'title': '小蝌蚪找妈妈是谁画的', 'similarity': 0.8192149}] | ||
''' | ||
""" | ||
|
||
class SimBERTTask(Task): | ||
""" | ||
Text matching task using SimBERT to predict the similarity of sentence pair. | ||
Args: | ||
task(string): The name of task. | ||
model(string): The model name in the task. | ||
kwargs (dict, optional): Additional keyword arguments passed along to the specific task. | ||
""" | ||
|
||
def __init__(self, | ||
task, | ||
model, | ||
batch_size=1, | ||
max_seq_len=128, | ||
**kwargs): | ||
super().__init__(task=task, model=model, **kwargs) | ||
self._static_mode = True | ||
self._construct_tokenizer(model) | ||
self._get_inference_model() | ||
self._batch_size = batch_size | ||
self._max_seq_len = max_seq_len | ||
self._usage = usage | ||
|
||
def _construct_input_spec(self): | ||
""" | ||
Construct the input spec for the convert dygraph model to static model. | ||
""" | ||
self._input_spec = [ | ||
paddle.static.InputSpec( | ||
shape=[None, None], dtype="int64", name='input_ids'), | ||
paddle.static.InputSpec( | ||
shape=[None], dtype="int64", name='token_type_ids'), | ||
] | ||
|
||
def _construct_model(self, model): | ||
""" | ||
Construct the inference model for the predictor. | ||
""" | ||
self._model = BertModel.from_pretrained(model, pool_act='linear') | ||
self._model.eval() | ||
|
||
def _construct_tokenizer(self, model): | ||
""" | ||
Construct the tokenizer for the predictor. | ||
""" | ||
self._tokenizer = BertTokenizer.from_pretrained(model) | ||
|
||
def _check_input_text(self, inputs): | ||
inputs = inputs[0] | ||
if not all([isinstance(i, list) and i \ | ||
and all(i) and len(i) == 2 for i in inputs]): | ||
raise TypeError( | ||
"Invalid input format.") | ||
return inputs | ||
|
||
def _preprocess(self, inputs): | ||
""" | ||
Transform the raw text to the model inputs, two steps involved: | ||
1) Transform the raw text to token ids. | ||
2) Generate the other model inputs from the raw text and token ids. | ||
""" | ||
inputs = self._check_input_text(inputs) | ||
num_workers = self.kwargs[ | ||
'num_workers'] if 'num_workers' in self.kwargs else 0 | ||
lazy_load = self.kwargs[ | ||
'lazy_load'] if 'lazy_load' in self.kwargs else False | ||
|
||
examples = [] | ||
|
||
for data in inputs: | ||
query, title = data[0], data[1] | ||
|
||
query_encoded_inputs = self._tokenizer( | ||
text=query, max_seq_len=self._max_seq_len) | ||
query_input_ids = query_encoded_inputs["input_ids"] | ||
query_token_type_ids = query_encoded_inputs["token_type_ids"] | ||
|
||
title_encoded_inputs = self._tokenizer( | ||
text=title, max_seq_len=self._max_seq_len) | ||
title_input_ids = title_encoded_inputs["input_ids"] | ||
title_token_type_ids = title_encoded_inputs["token_type_ids"] | ||
|
||
examples.append((query_input_ids, query_token_type_ids, | ||
title_input_ids, title_token_type_ids)) | ||
|
||
batches = [ | ||
examples[idx:idx + self._batch_size] | ||
for idx in range(0, len(examples), self._batch_size) | ||
] | ||
|
||
batchify_fn = lambda samples, fn=Tuple( | ||
Pad(axis=0, pad_val=self._tokenizer.pad_token_id), # query_input | ||
Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id), # query_segment | ||
Pad(axis=0, pad_val=self._tokenizer.pad_token_id), # title_input | ||
Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id), # tilte_segment | ||
): [data for data in fn(samples)] | ||
|
||
outputs = {} | ||
outputs['data_loader'] = batches | ||
outputs['text'] = inputs | ||
self._batchify_fn = batchify_fn | ||
return outputs | ||
|
||
def _run_model(self, inputs): | ||
""" | ||
Run the task model from the outputs of the `_tokenize` function. | ||
""" | ||
results = [] | ||
with static_mode_guard(): | ||
for batch in inputs['data_loader']: | ||
q_ids, q_segment_ids, t_ids, t_segment_ids = self._batchify_fn(batch) | ||
self.input_handles[0].copy_from_cpu(q_ids) | ||
self.input_handles[1].copy_from_cpu(q_segment_ids) | ||
self.predictor.run() | ||
vecs_query = self.output_handle[1].copy_to_cpu() | ||
|
||
self.input_handles[0].copy_from_cpu(t_ids) | ||
self.input_handles[1].copy_from_cpu(t_segment_ids) | ||
self.predictor.run() | ||
vecs_title = self.output_handle[1].copy_to_cpu() | ||
|
||
vecs_query = vecs_query / (vecs_query**2).sum(axis=1, | ||
keepdims=True)**0.5 | ||
vecs_title = vecs_title / (vecs_title**2).sum(axis=1, | ||
keepdims=True)**0.5 | ||
similarity = (vecs_query * vecs_title).sum(axis=1) | ||
results.extend(similarity) | ||
inputs['result'] = results | ||
return inputs | ||
|
||
def _postprocess(self, inputs): | ||
""" | ||
The model output is tag ids, this function will convert the model output to raw text. | ||
""" | ||
final_results = [] | ||
for text, similarity in zip(inputs['text'], inputs['result']): | ||
result = {} | ||
result['query'] = text[0] | ||
result['title'] = text[1] | ||
result['similarity'] = similarity | ||
final_results.append(result) | ||
return final_results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文本相似度
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed