-
Notifications
You must be signed in to change notification settings - Fork 8
/
data_collator.py
138 lines (115 loc) · 5.89 KB
/
data_collator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import logging
import torch
from transformers.data.data_collator import *
from inference.ICL import TASK_PROMT, Constrained_PROMPT
logger = logging.getLogger(__name__)
@dataclass
class DataCollator:
tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True # ‘longest’
max_prompt_len: Optional[int] = None
max_ans_len: Optional[int] = None
pad_to_multiple_of: Optional[int] = 1
label_pad_token_id: int = -100
return_tensors: str = "pt"
inference: bool = False
demonstrations: Optional[Any] = None
task: str = None
def __call__(self, batch, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
model_inputs = self.decoder_call(batch, self.return_tensors)
return model_inputs
# only support left padding for now
def tokenize(self, sentence, cutoff_len, add_bos_token=True, add_eos_token=True):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = self.tokenizer(
sentence,
truncation=True,
max_length=cutoff_len,
add_special_tokens=False,
padding=False,
return_tensors=None,
)
if (
len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if (
len(result["input_ids"]) < cutoff_len
and add_bos_token
):
result["input_ids"] = [self.tokenizer.bos_token_id] + result["input_ids"]
result["attention_mask"] = [1] + result["attention_mask"]
result["labels"] = result["input_ids"].copy()
return result
# support decoder-only models for left padding
def decoder_call(self, batch, return_tensors):
# to fix the bug
sources = []
gts = []
tokenized_sources = []
label_lens = [] # 用于存储每个label的长度
actual_max_len = 0 # 用于存储batch中的实际最大长度
limit_len = self.max_prompt_len + self.max_ans_len if not self.inference else self.max_prompt_len
for instance in batch:
instruction = instance['prompt']
label = instance['answer']
sources.append(instruction)
gts.append(label)
if not self.inference:
tokenized_label = self.tokenize(label, limit_len, add_bos_token=False, add_eos_token=True)
tokenize_source = self.tokenize(instruction + label, limit_len, add_bos_token=True, add_eos_token=True)
label_lens.append(len(tokenized_label["input_ids"]))
tokenized_sources.append(tokenize_source)
else:
if self.demonstrations!=None:
task_prompt = ""
task_prompt += TASK_PROMT[self.task]
if self.task!="MeetingBank": #MeetingBank不给例子
task_prompt += Constrained_PROMPT
for demonstration in self.demonstrations:
if self.task=="Py150":
task_prompt+= "Code:\n"
task_prompt+=demonstration["prompt"]
task_prompt+=demonstration["answer"]+"\n\n"
if self.task=="Py150":
task_prompt+= "Code:\n"
# task_prompt += Constrained_PROMPT
if self.task!="Py150": #Py150不带prompt
instruction = instruction[len(TASK_PROMT[self.task]):]
instruction = task_prompt+instruction
tokenize_source = self.tokenize(instruction, limit_len, add_bos_token=True, add_eos_token=False)
tokenized_sources.append(tokenize_source)
if len(tokenize_source["input_ids"]) > actual_max_len:
actual_max_len = len(tokenize_source["input_ids"])
# 取batch中的最大长度和limit_input_len中的最小值作为实际padding长度
# 并确保长度是pad_to_multiple_of的倍数
actual_pad_len = (
(actual_max_len + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of * self.pad_to_multiple_of)
# 对于left padding和prompt部分的mask
for idx in range(len(tokenized_sources)):
pad_len = actual_pad_len - len(tokenized_sources[idx]["input_ids"])
assert sum(tokenized_sources[idx]["attention_mask"]) == len(tokenized_sources[idx]["input_ids"])
tokenized_sources[idx]["input_ids"] = [self.tokenizer.pad_token_id] * pad_len + tokenized_sources[idx][
"input_ids"]
tokenized_sources[idx]["attention_mask"] = [0] * pad_len + tokenized_sources[idx]["attention_mask"]
if not self.inference:
label_len = label_lens[idx]
label_mask_len = actual_pad_len - label_len
tokenized_sources[idx]["labels"] = [-100] * label_mask_len + tokenized_sources[idx]["labels"][
-label_len:]
assert len(tokenized_sources[idx]["input_ids"]) == len(tokenized_sources[idx]["attention_mask"]) == len(
tokenized_sources[idx]["labels"]) == actual_pad_len
model_inputs = {'input_ids': torch.tensor([source["input_ids"] for source in tokenized_sources]),
'attention_mask': torch.tensor([source["attention_mask"] for source in tokenized_sources])}
if not self.inference:
model_inputs['labels'] = torch.tensor([source["labels"] for source in tokenized_sources])
model_inputs['sources'] = sources
if self.inference:
model_inputs['gts'] = gts
return model_inputs