-
Notifications
You must be signed in to change notification settings - Fork 15
/
main.py
129 lines (117 loc) · 5.24 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import copy
import logging
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence, List
import os
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from dataset import LMDataset, LMSortDataset, LMPackDataset
from trainer import TrainerNoShuffle
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="THUDM/LongAlign-6B-64k-base")
pack_loss: bool = field(default=False)
@dataclass
class DataArguments:
train_file: str = field(default=None, metadata={"help": "Path to the training data."})
validation_file: str = field(default=None, metadata={"help": "Path to the training data."})
preprocessing_num_workers: Optional[int] = field(
default=1,
metadata={"help": "The number of processes to use for the preprocessing."},
)
prompt_column: Optional[str] = field(
default=None,
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
)
response_column: Optional[str] = field(
default=None,
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
)
batch_method: str = field(default="naive")
@dataclass
class TrainingArguments(transformers.Seq2SeqTrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
@dataclass
class DataCollatorForLMDataset(object):
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key].unsqueeze(0) for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.cat(input_ids, dim=0)
labels = torch.cat(labels, dim=0)
eos_indices = input_ids.argmin(dim=1) - 1
max_position = eos_indices.max()
if max_position < 0:
return dict(
input_ids=input_ids,
labels=labels
)
return dict(
input_ids=input_ids[:, :max_position+1],
labels=labels[:, :max_position+1]
)
@dataclass
class DataCollatorForLMPackDataset(object):
def __call__(self, instances):
input_ids, attention_masks = tuple([instance[key].unsqueeze(0) for instance in instances] for key in ["input_ids", "attention_mask"])
batch_seq_num = instances[0]["labels"][2]
labels = ([instance["labels"][0].unsqueeze(0) for instance in instances], [instance["labels"][1].unsqueeze(0) for instance in instances])
input_ids = torch.cat(input_ids, dim=0)
labels = (torch.cat(labels[0], dim=0), torch.cat(labels[1], dim=0))
labels = (labels[0], labels[1] * torch.cuda.device_count() / batch_seq_num)
max_length = input_ids.shape[1]
attention_mask = attention_masks[0].squeeze()
acc_length = max_length
for new_attention_mask in attention_masks[1:]:
new_attention_mask = new_attention_mask.squeeze()
attention_mask = torch.cat([attention_mask, new_attention_mask[1:]+acc_length], dim=0)
acc_length += max_length
return dict(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
def make_supervised_data_module(data_args) -> Dict:
if data_args.batch_method == "naive":
train_dataset = LMDataset(data_args.train_file)
data_collator = DataCollatorForLMDataset()
elif data_args.batch_method == "pack":
train_dataset = LMPackDataset(data_args.train_file)
data_collator = DataCollatorForLMPackDataset()
elif data_args.batch_method == "sort":
train_dataset = LMSortDataset(data_args.train_file)
data_collator = DataCollatorForLMDataset()
return dict(train_dataset=train_dataset, data_collator=data_collator)
def train():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if "chatglm" in model_args.model_name_or_path.lower() or "longalign-6b" in model_args.model_name_or_path.lower():
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True, empty_init=False
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True
)
else:
from modeling_llama import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path,
trust_remote_code=True)
if model_args.pack_loss:
model.pack_loss = True
data_module = make_supervised_data_module(data_args=data_args)
trainer = TrainerNoShuffle(
model=model,
tokenizer=tokenizer,
args=training_args,
**data_module
)
trainer.train(resume_from_checkpoint=False)
trainer.save_model()
if __name__ == "__main__":
train()