Skip to content

Commit 58b84d7

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/docs into docsfix
2 parents c15848a + 6beb88b commit 58b84d7

File tree

5 files changed

+1775
-429
lines changed

5 files changed

+1775
-429
lines changed
Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,386 @@
1+
import argparse
2+
import os
3+
import re
4+
import sys
5+
6+
script_path = os.path.abspath(__file__)
7+
tools_dir = os.path.dirname(__file__)
8+
sys.path.append(tools_dir)
9+
10+
cfp_basedir = os.path.join(tools_dir, "..")
11+
12+
from validate_mapping_files import (
13+
DiffMeta,
14+
IndexParserState,
15+
discover_all_metas,
16+
process_mapping_index as reference_mapping_item,
17+
)
18+
19+
accept_index_parser_state_set = {
20+
IndexParserState.normal,
21+
IndexParserState.table_sep_ignore,
22+
IndexParserState.table_sep,
23+
IndexParserState.table_row_ignore,
24+
IndexParserState.table_row,
25+
}
26+
27+
28+
def mapping_type_to_description(mapping_type):
29+
mapping_type_1 = [
30+
"无参数",
31+
"参数完全一致",
32+
"仅参数名不一致",
33+
"paddle 参数更多",
34+
"参数默认值不一致",
35+
]
36+
37+
if mapping_type in mapping_type_1:
38+
return "功能一致," + mapping_type, True
39+
40+
mapping_type_2 = ["torch 参数更多"]
41+
if mapping_type in mapping_type_2:
42+
return "功能一致," + mapping_type, True
43+
44+
mapping_type_3 = [
45+
# "参数不一致",
46+
"返回参数类型不一致",
47+
"输入参数类型不一致",
48+
"输入参数用法不一致",
49+
]
50+
if mapping_type in mapping_type_3:
51+
return "功能一致," + mapping_type, True
52+
53+
mapping_type_4 = ["组合替代实现"]
54+
if mapping_type in mapping_type_4:
55+
return "组合替代实现", True
56+
57+
mapping_type_5 = ["涉及上下文修改"]
58+
if mapping_type in mapping_type_5:
59+
return "功能一致," + mapping_type, True
60+
61+
mapping_type_6 = ["对应 API 不在主框架"]
62+
if mapping_type in mapping_type_6:
63+
return "对应 API 不在主框架【占位】", False
64+
65+
mapping_type_7 = ["功能缺失"]
66+
if mapping_type in mapping_type_7:
67+
return "功能缺失", False
68+
69+
mapping_type_delete = ["可删除"]
70+
if mapping_type in mapping_type_delete:
71+
return "无对应 API,可以直接删除,对网络一般无影响", False
72+
73+
raise ValueError(
74+
f"Unexpected PyTorch-PaddlePaddle api mapping type {mapping_type}, please check "
75+
)
76+
return "【未知类型】", False
77+
78+
79+
# 以后没有 REFERENCE-ITEM 需要维护了,全部从 api_difference/ 目录生成
80+
_REFERENCE_ITEM_PATTERN = re.compile(
81+
r"^\| *REFERENCE-MAPPING-ITEM\( *(?P<src_api>[^,]+) *, *(?P<diff_url>.+) *\) *\|$"
82+
)
83+
REFERENCE_TABLE_PATTERN = re.compile(
84+
r"^\| *REFERENCE-MAPPING-TABLE\( *(?P<api_prefix>[^,]+) *(, *max_depth *= *(?P<max_depth>\d+) *)?\) *\|$"
85+
)
86+
ALIAS_PATTERN = re.compile(
87+
r"^\| *ALIAS-REFERENCE-ITEM\( *(?P<alias_name>[^,]+) *, *(?P<src_api>[^,]+) *\) *\|$"
88+
)
89+
NOT_IMPLEMENTED_PATTERN = re.compile(
90+
r"^\| *NOT-IMPLEMENTED-ITEM\( *(?P<src_api>[^,]+) *, *(?P<src_api_url>.+), *(?P<remark>.+) *\) *\|$"
91+
)
92+
IN_DEVELOPMENT_PATTERN = re.compile(
93+
r"^\| *IN-DEVELOPMENT-PATTERN\( *(?P<src_api>[^,]+) *, *(?P<src_api_url>.+) *\) *\|$"
94+
)
95+
96+
DOCS_REPO_BASEURL = "https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/"
97+
98+
99+
def docs_url_to_relative_page(url):
100+
"""将映射文档的 PaddlePaddle/docs url 转换为网页路径"""
101+
if not url.startswith(DOCS_REPO_BASEURL):
102+
return url
103+
104+
md_path = url[len(DOCS_REPO_BASEURL) :]
105+
if md_path.endswith(".md"):
106+
return md_path[:-3] + ".html"
107+
return md_path
108+
109+
110+
def doc_path_to_relative_page(path):
111+
"""将映射文档的本地路径转换为网页相对路径"""
112+
md_path = os.path.relpath(path, cfp_basedir)
113+
114+
assert md_path.endswith(".md"), f"Unexpected mapping doc path: {path}"
115+
116+
return md_path[:-3] + ".html"
117+
118+
119+
def reference_table_match_to_condition(m):
120+
api_prefix = m["api_prefix"].strip("`")
121+
max_depth = m["max_depth"]
122+
if max_depth is None:
123+
max_depth = 255
124+
else:
125+
max_depth = int(max_depth)
126+
return api_prefix, max_depth
127+
128+
129+
def get_referenced_api_columns(src_api, metadata_dict, alias=None):
130+
assert src_api in metadata_dict, (
131+
f'Error: cannot find mapping doc of api "{src_api}"'
132+
)
133+
api_data: DiffMeta = metadata_dict[src_api]
134+
135+
diff_page_url = doc_path_to_relative_page(api_data["source_file"])
136+
137+
src_api_url = api_data["src_api_url"]
138+
api_disp_name = src_api if alias is None else alias
139+
src_api_column = f"[`{api_disp_name}`]({src_api_url})"
140+
141+
mapping_type = api_data["mapping_type"]
142+
mapping_type_column = mapping_type
143+
144+
_mapping_type_desc, show_diff_url = mapping_type_to_description(
145+
mapping_type
146+
)
147+
desc_column = ""
148+
if show_diff_url:
149+
desc_column = f"[详细对比]({diff_page_url})"
150+
if alias is not None:
151+
desc_column = f"`{src_api}` 别名,{desc_column}"
152+
153+
if "dst_api" not in api_data:
154+
if mapping_type not in ["组合替代实现", "可删除", "功能缺失"]:
155+
print(f"Error: cannot find dst_api for src_api: {src_api}")
156+
dst_api_column = ""
157+
else:
158+
dst_api = api_data["dst_api"]
159+
dst_api_url = api_data["dst_api_url"]
160+
dst_api_column = f"[`{dst_api}`]({dst_api_url})"
161+
162+
return [
163+
src_api_column,
164+
dst_api_column,
165+
mapping_type_column,
166+
desc_column,
167+
]
168+
169+
170+
def apply_reference_to_row_ex(line, metadata_dict, context, line_idx):
171+
line = line.rstrip()
172+
reference_table_match = REFERENCE_TABLE_PATTERN.match(line)
173+
alias_match = ALIAS_PATTERN.match(line)
174+
not_implemented_match = NOT_IMPLEMENTED_PATTERN.match(line)
175+
in_development_match = IN_DEVELOPMENT_PATTERN.match(line)
176+
177+
row_idx_s = str(context["table_row_idx"])
178+
179+
def record_api(api):
180+
if api not in context["api_used_src"]:
181+
context["api_used_src"][api] = []
182+
context["api_used_src"][api].append((line_idx, line))
183+
184+
if reference_table_match:
185+
condition = reference_table_match_to_condition(reference_table_match)
186+
api_list = context["c2a_dict"][
187+
condition
188+
] # 这个键一定存在,否则说明前面出错了
189+
output_lines = []
190+
cur_row_idx = context["table_row_idx"]
191+
for api in api_list:
192+
record_api(api)
193+
194+
content = get_referenced_api_columns(api, metadata_dict)
195+
content.insert(0, str(cur_row_idx))
196+
output = "| " + " | ".join(content) + " |\n"
197+
output_lines.append(output)
198+
cur_row_idx += 1
199+
# 因为外面会给 table_row_idx 自动加 1,所以这里减去 1
200+
context["table_row_idx"] = cur_row_idx - 1
201+
return output_lines
202+
elif alias_match:
203+
alias_name = alias_match["alias_name"].strip("`").replace(r"\_", "_")
204+
205+
record_api(alias_name)
206+
207+
src_api = alias_match["src_api"].strip("`").replace(r"\_", "_")
208+
209+
content = get_referenced_api_columns(
210+
src_api, metadata_dict, alias=alias_name
211+
)
212+
213+
content.insert(0, row_idx_s)
214+
215+
output = "| " + " | ".join(content) + " |\n"
216+
return [output]
217+
elif not_implemented_match:
218+
src_api = (
219+
not_implemented_match["src_api"].strip("`").replace(r"\_", "_")
220+
)
221+
record_api(src_api)
222+
223+
src_api_url = not_implemented_match["src_api_url"].strip()
224+
225+
src_api_column = f"[`{src_api}`]({src_api_url})"
226+
227+
dst_api_column = ""
228+
mapping_column = "功能缺失"
229+
mapping_url_column = not_implemented_match["remark"].strip()
230+
231+
content = [
232+
row_idx_s,
233+
src_api_column,
234+
dst_api_column,
235+
mapping_column,
236+
mapping_url_column,
237+
]
238+
output = "| " + " | ".join(content) + " |\n"
239+
return [output]
240+
elif in_development_match:
241+
src_api = in_development_match["src_api"].strip("`").replace(r"\_", "_")
242+
record_api(src_api)
243+
244+
src_api_url = in_development_match["src_api_url"].strip()
245+
246+
src_api_column = f"[`{src_api}`]({src_api_url})"
247+
248+
dst_api_column = ""
249+
mapping_column = "映射关系开发中"
250+
mapping_url_column = ""
251+
252+
content = [
253+
row_idx_s,
254+
src_api_column,
255+
dst_api_column,
256+
mapping_column,
257+
mapping_url_column,
258+
]
259+
output = "| " + " | ".join(content) + " |\n"
260+
return [output]
261+
else:
262+
raise ValueError(
263+
f"found manual-maintaining row at line [{line_idx}]: {line}"
264+
)
265+
return [line]
266+
267+
268+
def reference_mapping_item_processer(line, line_idx, state, output, context):
269+
if not line.startswith("|"):
270+
output.append(line)
271+
return True
272+
273+
metadata_dict = context.get("metadata_dict", {})
274+
275+
if state == IndexParserState.table_row:
276+
# check content of table to process in common process
277+
output_lines = apply_reference_to_row_ex(
278+
line, metadata_dict, context, line_idx + 1
279+
)
280+
281+
output += output_lines
282+
return True
283+
elif state in accept_index_parser_state_set:
284+
output.append(line)
285+
return True
286+
287+
print(state)
288+
return False
289+
290+
291+
def reference_table_scanner(line, _line_idx, state, output, context):
292+
if not line.startswith("|"):
293+
return True
294+
295+
if state == IndexParserState.table_row:
296+
# check content of table to process in common process
297+
rtm = REFERENCE_TABLE_PATTERN.match(line)
298+
if rtm:
299+
condition = reference_table_match_to_condition(rtm)
300+
context["table_conditions"].append(condition)
301+
return True
302+
elif state in accept_index_parser_state_set:
303+
return True
304+
305+
return False
306+
307+
308+
def get_c2a_dict(conditions, meta_dict):
309+
c2a_dict = {c: [] for c in conditions}
310+
conditions.sort(
311+
key=lambda c: (-len(c[0]), c[1])
312+
) # 先按照字符串长度降序,随后按照最大深度升序
313+
for api in meta_dict:
314+
for api_prefix, max_depth in conditions:
315+
if not api.startswith(api_prefix):
316+
continue
317+
depth = len(api.split(".")) - 1
318+
if depth > max_depth:
319+
continue
320+
c2a_dict[(api_prefix, max_depth)].append(api)
321+
break
322+
else:
323+
print(f"Warning: cannot find a suitable condition for api {api}")
324+
325+
return c2a_dict
326+
327+
328+
if __name__ == "__main__":
329+
parser = argparse.ArgumentParser()
330+
parser.add_argument(
331+
"-c",
332+
"--check_only",
333+
action="store_true",
334+
help="Write back to the source file",
335+
)
336+
args = parser.parse_args()
337+
338+
CHECK_ONLY = args.check_only
339+
340+
# pysrc_api_mapping_cn
341+
mapping_index_file = os.path.join(cfp_basedir, "pytorch_api_mapping_cn.md")
342+
343+
metas = discover_all_metas(cfp_basedir)
344+
345+
meta_dict = {m["src_api"].replace(r"\_", "_"): m for m in metas}
346+
347+
reference_context = {
348+
"metadata_dict": meta_dict,
349+
"ret_code": 0,
350+
"output": [],
351+
"table_conditions": [],
352+
"api_used_src": {},
353+
}
354+
355+
# 第一遍预读,用来分析有哪些表格和匹配条件
356+
ret_code = reference_mapping_item(
357+
mapping_index_file, reference_table_scanner, reference_context
358+
)
359+
assert ret_code == 0
360+
reference_context["output"] = []
361+
362+
# 现在 c2a_dict 包含每个条件对应的 api 列表
363+
c2a_dict = get_c2a_dict(reference_context["table_conditions"], meta_dict)
364+
reference_context["c2a_dict"] = c2a_dict
365+
366+
# 第二遍正式读,读并处理
367+
ret_code = reference_mapping_item(
368+
mapping_index_file, reference_mapping_item_processer, reference_context
369+
)
370+
371+
# 检查是否重复出现
372+
for api, rows in reference_context["api_used_src"].items():
373+
if len(rows) > 1:
374+
row_ids = [r[0] for r in rows]
375+
print(f"Error: {api} used in multiple rows: {row_ids}")
376+
for row_id, line in rows:
377+
print(f" - row [{row_id}]: {line}")
378+
379+
# 如果只检查,就写到临时文件去
380+
output_path = mapping_index_file
381+
if CHECK_ONLY:
382+
output_path = os.path.join(tools_dir, "generated.tmp.md")
383+
384+
with open(output_path, "w", encoding="utf-8") as f:
385+
f.writelines(reference_context["output"])
386+
# 映射关系文件的保存流程移动至 `validate_mapping_in_api_difference.py`

0 commit comments

Comments
 (0)