-
Notifications
You must be signed in to change notification settings - Fork 4
/
utils.py
59 lines (44 loc) · 2.23 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import re
def construct_prompt(
data: dict,
language: str = "python",
tokenizer= None,
max_token_nums: int = 15800
) -> str:
"""
Construct the prompt for next line prediction.
:param data: data point from the dataset
:param language: the language of the code
:param tokenizer: the tokenizer of the evaluation model
:param max_token_nums: the maximum number of tokens constraint for the prompt
:return: the constructed prompt
"""
# comment symbol for different languages
comment_symbol = "#" if language == "python" else "//"
# construct the cross-file prompt and in-file prompt separately
# cross-file prompt
cross_file_prompt = f"{comment_symbol} Repo Name: {data['repo_name']}\n"
for snippet in data['context']:
cross_file_prompt += f"{comment_symbol} Path: {snippet['path']}\n{snippet['snippet']}" + "\n\n"
# in-file prompt
in_file_prompt = f"{comment_symbol} Path: {data['file_path']}\n{data['import_statement']}\n{data['cropped_code']}\n"
# if we assign the tokenizer and the max_token_nums, we will truncate the cross-file prompt to meet the constraint
if tokenizer is not None and max_token_nums is not None:
cross_file_prompt_token_nums = len(tokenizer.encode(cross_file_prompt))
in_file_prompt_token_nums = len(tokenizer.encode(in_file_prompt))
exceed_token_nums = cross_file_prompt_token_nums + in_file_prompt_token_nums - max_token_nums
if exceed_token_nums > 0:
# split the cross-file prompt into lines
cross_file_prompt_lines = cross_file_prompt.split("\n")
# drop lines from end until the extra token number is less than 0
for i in range(len(cross_file_prompt_lines)-1, -1, -1):
exceed_token_nums -= len(tokenizer.encode(cross_file_prompt_lines[i]))
if exceed_token_nums < 0:
break
# join the lines back
cross_file_prompt = "\n".join(cross_file_prompt_lines[:i]) + "\n\n"
# combine the cross-file prompt and in-file prompt
prompt = cross_file_prompt + in_file_prompt
# normalize some empty lines
prompt = re.sub(r'\n{4,}', '\n\n', prompt)
return prompt