-
Notifications
You must be signed in to change notification settings - Fork 3k
/
Copy pathtest_information_extraction.py
139 lines (124 loc) · 5.63 KB
/
test_information_extraction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
from paddlenlp import Taskflow
from ..testing_utils import get_tests_dir
class TestUIETask(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.uie = Taskflow(
task="information_extraction",
model="__internal_testing__/tiny-random-uie",
)
cls.uie_m = Taskflow(
task="information_extraction",
task_path="PaddleCI/tiny-random-uie-m",
from_hf_hub=True,
convert_from_torch=False,
)
cls.uie_x = Taskflow(
task="information_extraction",
task_path="PaddleCI/tiny-random-uie-x",
from_hf_hub=True,
convert_from_torch=False,
)
def test_entity_extraction(self):
schema = ["时间", "选手", "赛事名称"]
self.uie.set_schema(schema)
outputs = self.uie("2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!")
self.assertIsNotNone(outputs)
for output in outputs:
for field in output:
self.assertIn(field, schema)
for entity in output[field]:
self.assertIn("text", entity)
self.assertIn("probability", entity)
self.uie_m.set_schema(schema)
outputs = self.uie_m("2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!")
self.assertIsNotNone(outputs)
for output in outputs:
for field in output:
self.assertIn(field, schema)
for entity in output[field]:
self.assertIn("text", entity)
self.assertIn("probability", entity)
def test_relation_extraction(self):
schema = [{"歌曲名称": ["歌手", "所属专辑"]}]
entity_type = "歌曲名称"
relation_types = ["歌手", "所属专辑"]
self.uie.set_schema(schema)
outputs = self.uie("《告别了》是孙耀威在专辑爱的故事里面的歌曲")
self.assertIsNotNone(outputs)
for output in outputs:
self.assertIn(entity_type, output)
for entity in output[entity_type]:
self.assertIn("text", entity)
self.assertIn("probability", entity)
self.assertIn("relations", entity)
for relation_type, relations in entity["relations"].items():
self.assertIn(relation_type, relation_types)
for relation in relations:
self.assertIn("text", relation)
self.assertIn("probability", relation)
def test_opinion_extraction(self):
schema = [{"评价维度": ["观点词", "情感倾向[正向,负向]"]}]
entity_type = "评价维度"
relation_types = ["观点词", "情感倾向[正向,负向]"]
self.uie.set_schema(schema)
outputs = self.uie("店面干净,很清静,服务员服务热情,性价比很高,发现收银台有排队")
self.assertIsNotNone(outputs)
for output in outputs:
self.assertIn(entity_type, output)
for entity in output[entity_type]:
self.assertIn("text", entity)
self.assertIn("probability", entity)
self.assertIn("relations", entity)
for relation_type, relations in entity["relations"].items():
self.assertIn(relation_type, relation_types)
for relation in relations:
self.assertIn("text", relation)
self.assertIn("probability", relation)
@pytest.mark.skip(reason="todo, fix it")
def test_doc_entity_extraction(self):
doc_path = get_tests_dir("fixtures/tests_samples/OCR/custom.jpeg")
schema = ["进口日期", "申报日期"]
self.uie_x.set_schema(schema)
outputs = self.uie_x(
{"doc": doc_path},
{"text": "进口日期: 2023年3月2日, 申报日期: 2023年3月2日"},
)
self.assertIsNotNone(outputs)
for output in outputs:
for field in output:
self.assertIn(field, schema)
for entity in output[field]:
self.assertIn("text", entity)
self.assertIn("probability", entity)
self.assertIn("bbox", entity)
# Enable layout analysis
self.uie_x.set_argument({"layout_analysis": True})
outputs = self.uie_x(
{"doc": doc_path},
{"text": "进口日期: 2023年3月2日, 申报日期: 2023年3月2日"},
)
self.assertIsNotNone(outputs)
for output in outputs:
for field in output:
self.assertIn(field, schema)
for entity in output[field]:
self.assertIn("text", entity)
self.assertIn("probability", entity)
# fixme @ZHUI
# self.assertIn("bbox", entity)