-
Notifications
You must be signed in to change notification settings - Fork 25
/
train.py
373 lines (344 loc) · 16.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
"""Train HuggingFace LayoutLM on Amazon Textract results
(This script also allows for training non-layout-aware models e.g. T5 on seq2seq conditional text
generation task)
"""
# Python Built-Ins:
from inspect import signature
import os
import shutil
from typing import Optional, Tuple
# External Dependencies:
from torch import distributed as dist
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoModelForMaskedLM,
AutoModelForTokenClassification,
AutoProcessor,
AutoTokenizer,
EarlyStoppingCallback,
LayoutLMv2Config,
LayoutXLMProcessor,
LayoutXLMTokenizerFast,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
ProcessorMixin,
set_seed,
Trainer,
)
from transformers.file_utils import EntryNotFoundError
from transformers.trainer_utils import get_last_checkpoint
# Local Dependencies:
from . import config
from . import data
from . import logging_utils
from .smddpfix import Trainer
from .models.layoutlmv2 import LayoutLMv2ForPretraining
logger = logging_utils.getLogger("main")
def get_model(
model_args: config.ModelArguments, data_args: config.DataTrainingArguments
) -> Tuple[PretrainedConfig, PreTrainedModel, PreTrainedTokenizerFast, Optional[ProcessorMixin]]:
"""Load pre-trained Config, Model, Tokenizer, and Processor if one exists"""
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=data_args.num_labels,
label2id={str(i): i for i in range(data_args.num_labels)},
id2label={i: str(i) for i in range(data_args.num_labels)},
finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
# For LayoutLMV1 we used to explicitly set:
# max_position_embeddings=data_args.max_seq_length,
# max_2d_position_embeddings=2 * data_args.max_seq_length,
# ...But the LayoutLMV2 tokenizer has max_position_embeddings=514 (+2) and 2d=1024... so
# rather handle the inconsistency, we'll ignore it as basic usage won't need to set them.
)
tokenizer_name_or_path = (
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
)
try:
# AutoTokenizer doesn't detect XLM and instantiates a LayoutLMv2Processor instead!
ProcessorClass = (
LayoutXLMProcessor if "xlm" in tokenizer_name_or_path.lower() else AutoProcessor
)
processor = ProcessorClass.from_pretrained(
model_args.model_name_or_path,
# Feature Extractor overrides:
apply_ocr=False, # We use Amazon Textract
do_resize=False, # External thumbnailing service handles this
# Tokenizer overrides:
only_label_first_subword=False, # We aggregate over all word token labels
use_fast=True,
# Download settings:
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
if hasattr(processor, "tokenizer"):
tokenizer = processor.tokenizer
elif isinstance(processor, PreTrainedTokenizerBase):
# AutoProcessor loaded something, but it's just a standard tokenizer.
# This happens e.g. with t5-base model as at HF transformers==4.25.1
tokenizer = processor
processor = None
else:
tokenizer = None
except (EntryNotFoundError, OSError):
processor = None
tokenizer = None
except ValueError as ve:
if "unrecognized processor" in str(ve).lower():
processor = None
tokenizer = None
else:
raise ve
if not processor:
logger.info(
"This model type does not have a Processor: %s",
model_args.model_name_or_path,
)
if not tokenizer:
if "xlm" in tokenizer_name_or_path.lower():
# AutoTokenizer doesn't detect XLM and instantiates a LayoutLMv2Tokenizer instead!
tokenizer = LayoutXLMTokenizerFast.from_pretrained(
tokenizer_name_or_path,
only_label_first_subword=False,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
elif config.model_type in {"gpt2", "roberta"}:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path,
only_label_first_subword=False,
cache_dir=model_args.cache_dir,
use_fast=True,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
add_prefix_space=True,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path,
only_label_first_subword=False,
cache_dir=model_args.cache_dir,
use_fast=True,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
if data_args.task_name == "ner":
model = AutoModelForTokenClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
elif data_args.task_name == "mlm":
if isinstance(config, LayoutLMv2Config):
logger.info(
"As of v4.18, HF transformers does not bundle a variant of LayoutLMv2/XLM for "
"pre-training. Using a custom implementation which may not exactly align to the "
"published research."
)
model = LayoutLMv2ForPretraining.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
else:
model = AutoModelForMaskedLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
elif data_args.task_name == "seq2seq":
model = AutoModelForSeq2SeqLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
else:
raise ValueError(
f"Unknown data_args.task_name '{data_args.task_name}' not in ('mlm', 'ner')"
)
return config, model, tokenizer, processor
def train(
model_args: config.ModelArguments,
data_args: config.DataTrainingArguments,
training_args: config.SageMakerTrainingArguments,
) -> Trainer:
training_args._setup_devices # Force distributed setup if applicable and not already done
logger.info("Started with local_rank %s", training_args.local_rank)
# Don't strictly need this around the model setup too, but keeps logs more understandable:
# Using the HF decorator rather than torch.distributed.barrier() to try and keep a bit more
# environment-agnostic:
with training_args.main_process_first(desc="Waiting for main process to load model and data"):
logger.info("Creating config and model")
_, model, tokenizer, processor = get_model(model_args, data_args)
if hasattr(model, "layoutlmv2") and training_args.n_gpu > 1:
if dist.is_initialized():
logger.info("Synchronizing LayoutLMv2 visual batch norm for distributed training")
model.layoutlmv2.visual.synchronize_batch_norm()
else:
raise ValueError(
"For multi-GPU training, LayoutLMv2/XLM must be run in Distributed Data "
"Parallel mode (PyTorch native or SageMaker Distributed). Consider using SM "
"DDP on a supported instance type (e.g. ml.p3.16xlarge), OR launching native "
"via PyTorch DDP via ddp_launcher.py entrypoint"
)
# For more information, see:
# https://github.com/NielsRogge/Transformers-Tutorials/issues/30
# https://github.com/huggingface/transformers/issues/14110
# https://sagemaker.readthedocs.io/en/stable/api/training/smd_data_parallel_use_sm_pysdk.html
# For SM Distributed, ddp_launcher.py is not necessary - point straight to train.py
# Tokenizer check: Our MLM/NER data prep requires a fast tokenizer.
if data_args.task_name in ("mlm", "ner") and not isinstance(
tokenizer, PreTrainedTokenizerFast
):
raise ValueError(
"This example script only works for models that have a fast tokenizer. See the list "
"at https://huggingface.co/transformers/index.html#supported-frameworks for details."
)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
logger.warning("No previous checkpoint found: training from scratch")
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this "
"behavior, create the training job with an empty `checkpoint_s3_uri` or none."
)
# Was going to close the old main_process_first context and start a separate one here to
# maximize the time available for dataset prep until the default DDP 30min timeout kicks
# in... But doing so seems to cause datasets to deadlock/freeze after splitting pages for
# the training dataset. I think the same happens if using torch.distributed.barrier too?
logger.info("Loading datasets")
datasets = data.get_datasets(
data_args,
tokenizer,
processor,
model_param_names=set(signature(model).parameters),
n_workers=training_args.dataproc_num_workers,
cache_dir=model_args.cache_dir,
)
if datasets.train_dataset:
logger.info(f"train dataset has {len(datasets.train_dataset)} samples")
else:
logger.info("No training dataset provided")
if datasets.eval_dataset:
logger.info(f"validation dataset has {len(datasets.eval_dataset)} samples")
else:
logger.info("No validation dataset provided")
logger.info("Setting up trainer")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=datasets.train_dataset,
eval_dataset=datasets.eval_dataset,
# No `tokenizer`, as either the dataset or the data_collator does it for us
data_collator=datasets.data_collator,
callbacks=[
EarlyStoppingCallback(
early_stopping_patience=training_args.early_stopping_patience,
early_stopping_threshold=training_args.early_stopping_threshold,
)
]
if (
training_args.early_stopping_patience is not None
or training_args.early_stopping_threshold is not None
)
else [],
compute_metrics=datasets.metric_computer,
)
if not training_args.do_train:
logger.warning(f"Training skipped (args.do_train={training_args.do_train})")
else:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
trainer.save_model()
max_train_samples = (
data_args.max_train_samples
if data_args.max_train_samples is not None
else len(datasets.train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(datasets.train_dataset))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
logger.info(f"Saving model to {training_args.model_dir}")
trainer.save_model(training_args.model_dir)
if processor:
# (processor saves tokenizer anyway)
processor.save_pretrained(os.path.join(training_args.model_dir))
else:
tokenizer.save_pretrained(os.path.join(training_args.model_dir))
# To enable directly deploying this model via SageMaker SDK's Estimator.deploy() (rather than
# needing to create a PyTorchModel with entry_point / source_dir args), we need to save any
# inference handler function code to model_dir/code. Here we compromise efficiency to the
# benefit of usage simplicity, by just copying the contents of this training code folder to the
# model/code folder for inference:
code_path = os.path.join(training_args.model_dir, "code")
if not os.path.abspath(".").startswith("/opt/ml/"):
logger.warning(
"Skipping output code copy step: Seems not to be running inside SageMaker job"
)
# If you try to recursively copy '.' in, for example, a SMStudio environment where '.' is
# the notebooks/ folder (not notebooks/src) and notebooks/data is populated - you could be
# waiting a very... long... time... Just create an empty folder to demonstrate:
os.makedirs(code_path, exist_ok=True)
else:
logger.info(f"Copying code to {code_path} for inference")
for currpath, _, files in os.walk("."):
for file in files:
# Skip any filenames starting with dot:
if file.startswith("."):
continue
filepath = os.path.join(currpath, file)
# Skip any pycache or dot folders:
if ((os.path.sep + ".") in filepath) or ("__pycache__" in filepath):
continue
relpath = filepath[len(".") :]
if relpath.startswith(os.path.sep):
relpath = relpath[1:]
outpath = os.path.join(code_path, relpath)
logger.info(f"Copying {filepath} to {outpath}")
os.makedirs(outpath.rpartition(os.path.sep)[0], exist_ok=True)
shutil.copy2(filepath, outpath)
return trainer
def main() -> None:
"""CLI script entry point to parse arguments and run training"""
model_args, data_args, training_args = config.parse_args()
# Logging setup:
log_level = training_args.get_process_log_level()
logging_utils.setLevel(log_level)
logger.info("Loaded arguments:\n%s\n%s\n%s", model_args, data_args, training_args)
logger.info("Starting!")
if training_args.seed:
set_seed(training_args.seed)
else:
logger.info("Random seed not set - results will be non-deterministic")
# Start training:
train(model_args, data_args, training_args)