-
Notifications
You must be signed in to change notification settings - Fork 389
/
run.py
111 lines (94 loc) · 3.79 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import logging
import os
from pathlib import Path
from transformers import AutoConfig, AutoTokenizer
from transformers import (
HfArgumentParser,
set_seed,
)
from .arguments import ModelArguments, DataArguments, \
RetrieverTrainingArguments as TrainingArguments
from .data import TrainDatasetForEmbedding, EmbedCollator
from .modeling import BiEncoderModel
from .trainer import BiTrainer
logger = logging.getLogger(__name__)
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
logger.info("Model parameters %s", model_args)
logger.info("Data parameters %s", data_args)
# Set seed
set_seed(training_args.seed)
num_labels = 1
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=False,
)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
cache_dir=model_args.cache_dir,
)
logger.info('Config: %s', config)
model = BiEncoderModel(model_name=model_args.model_name_or_path,
normlized=training_args.normlized,
sentence_pooling_method=training_args.sentence_pooling_method,
negatives_cross_device=training_args.negatives_cross_device,
temperature=training_args.temperature,
use_inbatch_neg=training_args.use_inbatch_neg,
)
if training_args.fix_position_embedding:
for k, v in model.named_parameters():
if "position_embeddings" in k:
logging.info(f"Freeze the parameters for {k}")
v.requires_grad = False
train_dataset = TrainDatasetForEmbedding(args=data_args, tokenizer=tokenizer)
trainer = BiTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=EmbedCollator(
tokenizer,
query_max_len=data_args.query_max_len,
passage_max_len=data_args.passage_max_len
),
tokenizer=tokenizer
)
Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)
# Training
trainer.train()
trainer.save_model()
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
if trainer.is_world_process_zero():
tokenizer.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
main()