-
Notifications
You must be signed in to change notification settings - Fork 718
Expand file tree
/
Copy pathmodel_args.py
More file actions
562 lines (506 loc) · 17.7 KB
/
model_args.py
File metadata and controls
562 lines (506 loc) · 17.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
import json
import os
import sys
from dataclasses import asdict, dataclass, field, fields
from multiprocessing import cpu_count
import warnings
from typing import Union
from torch.utils.data import Dataset
def get_default_process_count():
process_count = cpu_count() - 2 if cpu_count() > 2 else 1
if sys.platform == "win32":
process_count = min(process_count, 61)
return process_count
def get_special_tokens():
return ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
@dataclass
class ModelArgs:
adafactor_beta1: float = None
adafactor_clip_threshold: float = 1.0
adafactor_decay_rate: float = -0.8
adafactor_eps: tuple = field(default_factory=lambda: (1e-30, 1e-3))
adafactor_relative_step: bool = True
adafactor_scale_parameter: bool = True
adafactor_warmup_init: bool = True
adam_betas: tuple = field(default_factory=lambda: (0.9, 0.999))
adam_epsilon: float = 1e-8
best_model_dir: str = "outputs/best_model"
cache_dir: str = "cache_dir/"
config: dict = field(default_factory=dict)
cosine_schedule_num_cycles: float = 0.5
custom_layer_parameters: list = field(default_factory=list)
custom_parameter_groups: list = field(default_factory=list)
dataloader_num_workers: int = 0
dataset_cache_dir: str = None
do_lower_case: bool = False
dynamic_quantize: bool = False
early_stopping_consider_epochs: bool = False
early_stopping_delta: float = 0
early_stopping_metric: str = "eval_loss"
early_stopping_metric_minimize: bool = True
early_stopping_patience: int = 3
encoding: str = None
eval_batch_size: int = 100
evaluate_during_training: bool = False
evaluate_during_training_silent: bool = True
evaluate_during_training_steps: int = 2000
evaluate_during_training_verbose: bool = False
evaluate_each_epoch: bool = True
fp16: bool = True
gradient_accumulation_steps: int = 1
learning_rate: float = 4e-5
local_rank: int = -1
logging_steps: int = 50
loss_type: str = None
loss_args: dict = field(default_factory=dict)
manual_seed: int = None
max_grad_norm: float = 1.0
max_seq_length: int = 128
model_name: str = None
model_type: str = None
multiprocessing_chunksize: int = -1
n_gpu: int = 1
no_cache: bool = False
no_save: bool = False
not_saved_args: list = field(default_factory=list)
num_train_epochs: int = 1
optimizer: str = "AdamW"
output_dir: str = "outputs/"
overwrite_output_dir: bool = False
polynomial_decay_schedule_lr_end: float = 1e-7
polynomial_decay_schedule_power: float = 1.0
process_count: int = field(default_factory=get_default_process_count)
quantized_model: bool = False
reprocess_input_data: bool = True
save_best_model: bool = True
save_eval_checkpoints: bool = True
save_model_every_epoch: bool = True
save_optimizer_and_scheduler: bool = True
save_steps: int = 2000
scheduler: str = "linear_schedule_with_warmup"
silent: bool = False
skip_special_tokens: bool = True
tensorboard_dir: str = None
thread_count: int = None
tokenizer_name: str = None
tokenizer_type: str = None
train_batch_size: int = 8
train_custom_parameters_only: bool = False
trust_remote_code: bool = False
use_cached_eval_features: bool = False
use_early_stopping: bool = False
use_hf_datasets: bool = False
use_multiprocessing: bool = True
use_multiprocessing_for_evaluation: bool = True
wandb_kwargs: dict = field(default_factory=dict)
wandb_project: str = None
warmup_ratio: float = 0.06
warmup_steps: int = 0
weight_decay: float = 0.0
def update_from_dict(self, new_values):
if isinstance(new_values, dict):
for key, value in new_values.items():
setattr(self, key, value)
else:
raise (TypeError(f"{new_values} is not a Python dict."))
def get_args_for_saving(self):
args_for_saving = {
key: value
for key, value in asdict(self).items()
if key not in self.not_saved_args
}
if "settings" in args_for_saving["wandb_kwargs"]:
del args_for_saving["wandb_kwargs"]["settings"]
return args_for_saving
def save(self, output_dir):
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "model_args.json"), "w") as f:
args_dict = self.get_args_for_saving()
if args_dict["tokenizer_type"] is not None and not isinstance(
args_dict["tokenizer_type"], str
):
args_dict["tokenizer_type"] = type(args_dict["tokenizer_type"]).__name__
json.dump(args_dict, f)
def load(self, input_dir):
if input_dir:
model_args_file = os.path.join(input_dir, "model_args.json")
if os.path.isfile(model_args_file):
with open(model_args_file, "r") as f:
model_args = json.load(f)
self.update_from_dict(model_args)
@dataclass
class ClassificationArgs(ModelArgs):
"""
Model args for a ClassificationModel
"""
model_class: str = "ClassificationModel"
as_reranker: bool = False
batch_chunk_size: int = None
labels_list: list = field(default_factory=list)
labels_map: dict = field(default_factory=dict)
lazy_delimiter: str = "\t"
lazy_labels_column: int = 1
lazy_loading: bool = False
lazy_loading_start_line: int = 1
lazy_text_a_column: bool = None
lazy_text_b_column: bool = None
lazy_text_column: int = 0
onnx: bool = False
pairwise_reranking_format: str = "repeat_query"
regression: bool = False
sliding_window: bool = False
special_tokens_list: list = field(default_factory=list)
stride: float = 0.8
tie_value: int = 1
tourney_mode: bool = False
@dataclass
class MultiLabelClassificationArgs(ClassificationArgs):
"""
Model args for a MultiLabelClassificationModel
"""
model_class: str = "MultiLabelClassificationModel"
sliding_window: bool = False
stride: float = 0.8
threshold: float = 0.5
tie_value: int = 1
labels_list: list = field(default_factory=list)
labels_map: dict = field(default_factory=dict)
lazy_loading: bool = False
special_tokens_list: list = field(default_factory=list)
@dataclass
class NERArgs(ModelArgs):
"""
Model args for a NERModel
"""
model_class: str = "NERModel"
classification_report: bool = False
labels_list: list = field(default_factory=list)
lazy_loading: bool = False
lazy_loading_start_line: int = 0
onnx: bool = False
special_tokens_list: list = field(default_factory=list)
@dataclass
class QuestionAnsweringArgs(ModelArgs):
"""
Model args for a QuestionAnsweringModel
"""
model_class: str = "QuestionAnsweringModel"
doc_stride: int = 384
early_stopping_metric: str = "correct"
early_stopping_metric_minimize: bool = False
lazy_loading: bool = False
max_answer_length: int = 100
max_query_length: int = 64
n_best_size: int = 20
null_score_diff_threshold: float = 0.0
special_tokens_list: list = field(default_factory=list)
@dataclass
class T5Args(ModelArgs):
"""
Model args for a T5Model
"""
model_class: str = "T5Model"
add_prefix: bool = True
as_reranker: bool = False
batch_chunk_size: int = None
dataset_class: Dataset = None
do_sample: bool = False
early_stopping: bool = True
evaluate_generated_text: bool = False
evaluate_before_training: bool = False
length_penalty: float = 2.0
max_length: int = 20
max_steps: int = -1
num_beams: int = 1
num_return_sequences: int = 1
preprocess_inputs: bool = True
repetition_penalty: float = 1.0
scheduler: str = "constant_schedule_with_warmup"
adafactor_relative_step: bool = False
adafactor_scale_parameter: bool = False
adafactor_warmup_init: bool = False
learning_rate: float = 1e-3
optimizer: str = "Adafactor"
special_tokens_list: list = field(default_factory=list)
top_k: float = None
top_p: float = None
use_multiprocessed_decoding: bool = True
@dataclass
class GenerationArgs:
"""
Args for language generation.
"""
max_length: int = 20
max_new_tokens: int = None
min_length: int = 0
min_new_tokens: int = None
early_stopping: bool = False
max_time: float = None
do_sample: bool = False
num_beams: int = 1
num_beam_groups: int = 1
penalty_alpha: float = None
use_cache: bool = True
temperature: float = 1.0
top_k: int = 50
top_p: float = 1.0
repetition_penalty: float = 1.0
diversity_penalty: float = 0.0
def get_dict(self):
d = asdict(self)
return {k: v for k, v in d.items() if v is not None}
@dataclass
class LanguageModelingArgs(ModelArgs):
"""
Model args for a LanguageModelingModel
"""
model_class: str = "LanguageModelingModel"
block_size: int = -1
chunk_text: bool = True
config_name: str = None
dataset_class: Dataset = None
dataset_type: str = "None"
data_format: str = "text"
discriminator_config: dict = field(default_factory=dict)
discriminator_loss_weight: float = 50.0
generator_config: dict = field(default_factory=dict)
max_steps: int = -1
min_frequency: int = 2
mlm: bool = True
mlm_probability: float = 0.15
sliding_window: bool = False
special_tokens: list = field(default_factory=get_special_tokens)
stride: float = 0.8
tie_generator_and_discriminator_embeddings: bool = True
tokenizer_name: str = None
vocab_size: int = None
clean_text: bool = True
handle_chinese_chars: bool = True
special_tokens_list: list = field(default_factory=list)
strip_accents: bool = True
local_rank: int = -1
loftq_bits: int = 4
loftq_config: dict = field(default_factory=dict)
lora_config: dict = field(default_factory=dict)
peft: bool = False
qlora: bool = False
rag: bool = False
rag_replace_method: str = "prepend"
nf4: bool = False
use_autoencoder: bool = False
stream_hf_datasets: bool = False
def save(self, output_dir):
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "model_args.json"), "w") as f:
args_dict = self.get_args_for_saving()
if args_dict["dataset_class"] is not None:
args_dict["dataset_class"] = type(args_dict["dataset_class"]).__name__
json.dump(self.get_args_for_saving(), f)
def load(self, input_dir):
if input_dir:
model_args_file = os.path.join(input_dir, "model_args.json")
if os.path.isfile(model_args_file):
with open(model_args_file, "r") as f:
model_args = json.load(f)
if model_args["dataset_class"]:
warnings.warn(
"This model was trained using a custom dataset_class."
"This cannot be loaded automatically and must be specified in the model args"
"when loading the model."
)
self.update_from_dict(model_args)
@dataclass
class Seq2SeqArgs(ModelArgs):
"""
Model args for a Seq2SeqModel
"""
model_class: str = "Seq2SeqModel"
base_marian_model_name: str = None
dataset_class: Dataset = None
do_sample: bool = False
early_stopping: bool = True
evaluate_generated_text: bool = False
faiss_d: int = 768
faiss_m: int = 128
faiss_index_type: str = "IndexFlatIP"
include_title_in_corpus: bool = True
length_penalty: float = 2.0
max_length: int = 20
max_steps: int = -1
num_beams: int = 1
num_return_sequences: int = 1
rag_embed_batch_size: int = 16
repetition_penalty: float = 1.0
save_knowledge_dataset: bool = True
save_knowledge_dataset_with_checkpoints: bool = False
split_text_character: str = " "
split_text_n: int = 100
src_lang: str = "en_XX"
tgt_lang: str = "ro_RO"
top_k: float = None
top_p: float = None
use_multiprocessed_decoding: bool = False
def save(self, output_dir):
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "model_args.json"), "w") as f:
args_dict = self.get_args_for_saving()
if args_dict["dataset_class"] is not None:
args_dict["dataset_class"] = type(args_dict["dataset_class"]).__name__
json.dump(self.get_args_for_saving(), f)
def load(self, input_dir):
if input_dir:
model_args_file = os.path.join(input_dir, "model_args.json")
if os.path.isfile(model_args_file):
with open(model_args_file, "r") as f:
model_args = json.load(f)
if model_args["dataset_class"]:
warnings.warn(
"This model was trained using a custom dataset_class."
"This cannot be loaded automatically and must be specified in the model args"
"when loading the model."
)
self.update_from_dict(model_args)
@dataclass
class RetrievalArgs(Seq2SeqArgs):
"""
Model args for a RetrievalModel
"""
model_class: str = "RetrievalModel"
ance_refresh_n_epochs: int = 1
ance_training: bool = False
batch_chunk_size: int = None
cluster_concatenated: bool = False
cluster_every_n_epochs: int = 1
cluster_queries: bool = False
cluster_train_size: Union[int, float] = None
context_config: dict = field(default_factory=dict)
curriculum_clustering: bool = False
data_format: str = "st"
ddp_training: bool = False
disable_datasets_caching: bool = False
embed_batch_size: int = 128
evaluate_with_beir: bool = False
external_embeddings: bool = False
extra_cls_token_count: int = 0
extra_mask_token_count: int = 0
faiss_clustering: bool = True
faiss_index_type: str = "IndexFlatIP"
gradient_caching: bool = False
gradient_caching_steps: int = 16
hard_negatives: bool = False
hard_negatives_in_eval: bool = False
include_bce_loss: bool = False
include_hard_negatives_for_triplets_only: bool = False
include_margin_mse_loss: bool = False
include_nll_loss: bool = True
include_title: bool = True
include_triplet_loss: bool = False
reranking_kl_div_loss: bool = False
include_kl_div_loss: bool = False
kl_div_lambda: float = 1.0
kmeans_k: int = -1
larger_representations: bool = False
margin_mse_lambda: float = 1
mse_loss: bool = False
moving_average_loss_count: int = 10
multi_negatives: bool = False
multi_head_vectors: bool = False
multi_head_vector_strategy: str = "maxsim"
minmax_multi_head_vectors: bool = False
multi_vector_query: bool = False
query_vector_count: int = 50
nll_lambda: float = 1.0
nll_lambda_start_decay: int = None
nll_lambda_min: float = None
n_hard_negatives: int = 1
output_dropout: float = 0.1
pytrec_eval_metrics: list = field(
default_factory=lambda: ["recip_rank", "recall_100", "ndcg_cut_10", "ndcg"]
)
query_config: dict = field(default_factory=dict)
remove_duplicates_from_eval_passages: bool = False
relevance_level: int = 1
repeat_high_loss_n: int = 0
rerank_batch_size: int = 256
retrieval_batch_size: int = 2048
retrieve_n_docs: int = 10
save_clustering_idx: bool = False
save_passage_dataset: bool = True
skip_hard_negatives_for_nll: bool = False
tas_clustering: bool = False
teacher_type: str = "colbert"
tie_encoders: bool = False
train_context_encoder: bool = True
train_query_encoder: bool = True
triplet_lambda: float = 1.0
triplet_margin: float = 1.0
unified_rr: bool = False
unified_cross_rr: bool = False
use_autoencoder: bool = False
use_hf_datasets: bool = True
use_pooler_output: bool = False
reranking_config: dict = field(default_factory=dict)
autoencoder_mse_loss: bool = True
autoencoder_kl_div_loss: bool = False
mean_pooling: bool = False
include_quartet_loss: bool = False
include_multi_negatives_loss: bool = False
quartet_training_format: bool = False
quartet_lambda: float = 1.0
similarity_function: str = "dot_product"
@dataclass
class LanguageGenerationArgs(ModelArgs):
"""
Model args for a LanguageGenerationModel
"""
model_class: str = "LanguageGenerationModel"
do_sample: bool = True
early_stopping: bool = True
evaluate_generated_text: bool = False
length_penalty: float = 2.0
max_length: int = 20
max_steps: int = -1
num_beams: int = 1
num_return_sequences: int = 1
repetition_penalty: float = 1.0
top_k: float = 50
top_p: float = 0.95
prompt: str = ""
stop_token: str = None
temperature: float = 1.0
padding_text: str = ""
xlm_language: str = ""
config_name: str = None
tokenizer_name: str = None
special_tokens_list: list = field(default_factory=list)
@dataclass
class ConvAIArgs(ModelArgs):
"""
Model args for a ConvAIModel
"""
model_class: str = "ConvAIModel"
do_sample: bool = True
lm_coef: float = 2.0
max_history: int = 2
max_length: int = 20
mc_coef: float = 1.0
min_length: int = 1
num_candidates: int = 2
personality_permutations: int = 1
temperature: float = 0.7
top_k: float = 0
top_p: float = 0.9
@dataclass
class MultiModalClassificationArgs(ModelArgs):
"""
Model args for a MultiModalClassificationModel
"""
model_class: str = "MultiModalClassificationModel"
regression: bool = False
num_image_embeds: int = 1
text_label: str = "text"
labels_label: str = "labels"
images_label: str = "images"
image_type_extension: str = ""
data_type_extension: str = ""
special_tokens_list: list = field(default_factory=list)