Skip to content

Commit

Permalink
Optimize schema construction for UIE Task (#2170)
Browse files Browse the repository at this point in the history
- remove import but unused os lib
- use `_schema_tree` attribute to displace multiple times of schema tree construction
- remove `schema_tree` arg in `_multi_stage_predict`
- add early check & return in `_multi_stage_predict`
- rename `id` var in `_multi_stage_predict` to `idx` to avoid mixing with python built-in funcs
- change `_build_tree` func to python `classmethod`

Co-authored-by: Linjie Chen <40840292+linjieccc@users.noreply.github.com>
  • Loading branch information
Spico197 and linjieccc committed May 16, 2022
1 parent 8294fb4 commit 9765848
Showing 1 changed file with 34 additions and 24 deletions.
58 changes: 34 additions & 24 deletions paddlenlp/taskflow/information_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np
import paddle
from ..datasets import load_dataset
Expand Down Expand Up @@ -121,6 +119,7 @@ class UIETask(Task):

def __init__(self, task, model, schema, **kwargs):
super().__init__(task=task, model=model, **kwargs)
self._schema_tree = None
self.set_schema(schema)
if model not in self.encoding_model_map.keys():
raise ValueError(
Expand All @@ -142,8 +141,7 @@ def __init__(self, task, model, schema, **kwargs):
def set_schema(self, schema):
if isinstance(schema, dict) or isinstance(schema, str):
schema = [schema]
self._schema = schema
self._build_tree(self._schema)
self._schema_tree = self._build_tree(schema)

def _construct_input_spec(self):
"""
Expand Down Expand Up @@ -330,31 +328,42 @@ def _auto_joiner(self, short_results, short_inputs, input_mapping):

def _run_model(self, inputs):
raw_inputs = inputs['text']
schema_tree = self._build_tree(self._schema)
results = self._multi_stage_predict(raw_inputs, schema_tree)
results = self._multi_stage_predict(raw_inputs)
inputs['result'] = results
return inputs

def _multi_stage_predict(self, datas, schema_tree):
def _multi_stage_predict(self, datas):
"""
Traversal the schema tree and do multi-stage prediction.
Args:
datas (list): a list of strings
Returns:
list: a list of predictions, where the list's length
equals to the length of `datas`
"""
results = [{} for i in range(len(datas))]
schema_list = schema_tree.children
results = [{} for _ in range(len(datas))]
# input check to early return
if len(datas) < 1 or self._schema_tree is None:
return results

# copy to stay `self._schema_tree` unchanged
schema_list = self._schema_tree.children[:]
while len(schema_list) > 0:
node = schema_list.pop(0)
examples = []
input_map = {}
cnt = 0
id = 0
idx = 0
if not node.prefix:
for data in datas:
examples.append({
"text": data,
"prompt": dbc2sbc(node.name)
})
input_map[cnt] = [id]
id += 1
input_map[cnt] = [idx]
idx += 1
cnt += 1
else:
for pre, data in zip(node.prefix, datas):
Expand All @@ -366,8 +375,8 @@ def _multi_stage_predict(self, datas, schema_tree):
"text": data,
"prompt": dbc2sbc(p + node.name)
})
input_map[cnt] = [i + id for i in range(len(pre))]
id += len(pre)
input_map[cnt] = [i + idx for i in range(len(pre))]
idx += len(pre)
cnt += 1
if len(examples) == 0:
result_list = []
Expand All @@ -377,13 +386,13 @@ def _multi_stage_predict(self, datas, schema_tree):
if not node.parent_relations:
relations = [[] for i in range(len(datas))]
for k, v in input_map.items():
for id in v:
if len(result_list[id]) == 0:
for idx in v:
if len(result_list[idx]) == 0:
continue
if node.name not in results[k].keys():
results[k][node.name] = result_list[id]
results[k][node.name] = result_list[idx]
else:
results[k][node.name].extend(result_list[id])
results[k][node.name].extend(result_list[idx])
if node.name in results[k].keys():
relations[k].extend(results[k][node.name])
else:
Expand Down Expand Up @@ -415,11 +424,11 @@ def _multi_stage_predict(self, datas, schema_tree):
"relations"][node.name][k])
relations = new_relations

prefix = [[] for i in range(len(datas))]
prefix = [[] for _ in range(len(datas))]
for k, v in input_map.items():
for id in v:
for i in range(len(result_list[id])):
prefix[k].append(result_list[id][i]["text"] + "็š„")
for idx in v:
for i in range(len(result_list[idx])):
prefix[k].append(result_list[idx][i]["text"] + "็š„")

for child in node.children:
child.prefix = prefix
Expand Down Expand Up @@ -459,7 +468,8 @@ def _convert_ids_to_results(self, examples, sentence_ids, probs):
results.append(result_list)
return results

def _build_tree(self, schema, name='root'):
@classmethod
def _build_tree(cls, schema, name='root'):
"""
Build the schema tree.
"""
Expand All @@ -477,7 +487,7 @@ def _build_tree(self, schema, name='root'):
raise TypeError(
"Invalid schema, value for each key:value pairs should be list or string"
"but {} received".format(type(v)))
schema_tree.add_child(self._build_tree(child, name=k))
schema_tree.add_child(cls._build_tree(child, name=k))
else:
raise TypeError(
"Invalid schema, element should be string or dict, "
Expand Down

0 comments on commit 9765848

Please sign in to comment.