-
Notifications
You must be signed in to change notification settings - Fork 5
/
tp_zero1_llama2_7b_hf_pretrain.py
690 lines (623 loc) · 25 KB
/
tp_zero1_llama2_7b_hf_pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
# coding=utf-8
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Modifications Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import math
import json
import os
import sys
import time
from collections import namedtuple
from datetime import datetime, timezone
from typing import Any, Dict, List
import torch
import torch.distributed as dist
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
from logger import Logger
from modeling_llama_nxd import CoreAttention, LlamaForCausalLM
from training_utils import Throughput, create_llama_pretraining_dataset
from transformers import AdamW, LlamaConfig, set_seed
from transformers.optimization import get_linear_schedule_with_warmup
import neuronx_distributed as nxd
from neuronx_distributed.parallel_layers import (
checkpointing,
grads,
layers,
parallel_state,
)
from neuronx_distributed.parallel_layers.utils import requires_init_pg_override
from neuronx_distributed.utils.adamw_fp32_optim_params import AdamW_FP32OptimParams
# For PT autocast.
torch.cuda.is_bf16_supported = lambda: True
# Workaround for NaNs seen with transformers version >= 4.21.0
# https://github.com/aws-neuron/aws-neuron-sdk/issues/593
import transformers.modeling_utils as modeling_utils
if os.environ.get("XLA_USE_BF16") or os.environ.get("XLA_DOWNCAST_BF16"):
modeling_utils.get_parameter_dtype = lambda x: torch.bfloat16
datetime_str = str(datetime.now())
results = {"inference_success": 1}
Metric = namedtuple("Metric", ["name", "value", "units", "additional_data"])
class TrainingMetrics:
"""
This class is used for logging metrics to a json file. One can provide a
dictionary of metrics that needs to be stored, and it wpuld get
written to the file.
Arguments:
json_file: File used for logging. If no file exists, new file would be created.
"""
def __init__(self, json_file):
self.json_file = json_file
def read_modify_write_file(self, data, key: str = "metrics") -> None:
"""
data (dict of training parameters or list of metrics): Data to update in the file.
key (str): the dictionary key under which data is to be recorded
"""
result_dict = {}
print(f"Writing data to the provided results file: {self.json_file}")
if os.path.exists(self.json_file):
with open(self.json_file) as json_file:
result_dict = json.loads(json_file.read()) or result_dict
print(f"Updating with {key} data: {data}")
if result_dict:
try:
# handle internal named entity if present
results = result_dict[next(iter(result_dict))]
except Exception:
results = result_dict
current = results.get(key)
if not current:
results[key] = data
else:
if isinstance(current, list):
current.extend(data)
elif isinstance(current, dict):
current.update(data)
else:
result_dict["results"] = {key: data}
with open(self.json_file, "w") as json_file:
json.dump(result_dict, json_file)
def store_metrics(self, metrics: List[Metric]) -> None:
"""
Writes collected metrics to the file.
"""
data = [
{
"MetricName": metric.name,
"MeasuredValue": metric.value,
"Units": metric.units,
"Timestamp": datetime.now(timezone.utc).isoformat(),
"AdditionalData": metric.additional_data,
}
for metric in metrics
]
self.update(data=data, key="metrics")
def store_parameters(self, parameters: Dict[str, Any]) -> None:
"""
Writes specified model and configuration parameters to the file.
"""
self.update(data=parameters, key="parameters")
def update(self, **kwargs: Any) -> None:
"""
Write specified data to the output file.
"""
self.read_modify_write_file(**kwargs)
def get_model(flags):
model_path, seq_len = flags.model_path, flags.seq_len
config = LlamaConfig.from_pretrained(model_path)
config.use_cache = False
config.kv_shared_group_size = flags.kv_replicator
config.qkv_linear = flags.qkv_linear
config.max_position_embeddings = max(config.max_position_embeddings, seq_len)
if flags.num_layers > 0:
config.num_hidden_layers = flags.num_layers
if flags.sequence_parallel_enabled:
config.sequence_parallel_enabled = True
if flags.selective_checkpoint_enabled:
config.selective_checkpoint_enabled = True
xm.master_print(config)
model = LlamaForCausalLM(config)
def get_sin_cos_matrix(config):
head_dim = config.hidden_size // config.num_attention_heads
base = 10000
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(config.max_position_embeddings, dtype=inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos()[None, None, :, :].to(torch.float32), emb.sin()[None, None, :, :].to(torch.float32)
# Here we make sure we use the same sine and cosine matrices for all layers.
# Making use of same tensors would make the CSE algorithm eliminate the lookup call
# from layers, keeping only lookup from first layer.
with torch.no_grad():
cos, sin = get_sin_cos_matrix(config)
for layer in model.model.layers:
layer.self_attn.rotary_emb.cos_cached = cos
layer.self_attn.rotary_emb.sin_cached = sin
xm.master_print(model)
return model
def get_dtype(model) -> str:
"""
Reference: https://pytorch.org/xla/release/1.12/index.html#xla-tensors-and-bfloat16
"""
if "XLA_USE_BF16" in os.environ:
return "torch.bfloat16"
if "XLA_DOWNCAST_BF16" in os.environ:
if "torch.float" in str(model.dtype):
return "torch.bfloat16"
if "torch.double" in str(model.dtype):
return "torch.float32"
return str(model.dtype)
def allreduce_sequence_parallel_gradients(optimizer):
"""All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used.
Modified from megatron-lm:
https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425
"""
from neuronx_distributed.parallel_layers.mappings import (
reduce_from_tensor_model_parallel_region,
)
grads = []
for param_group in optimizer.__getstate__()["param_groups"]:
for group, params in param_group.items():
if group == "params":
for p in params:
if isinstance(p, torch.Tensor) and p.grad is not None:
sequence_parallel_param = getattr(p, "sequence_parallel_enabled", False)
if sequence_parallel_param:
grads.append(p.grad.data)
for grad in grads:
# sum v.s. average: sum
reduce_from_tensor_model_parallel_region(grad)
def train_llama(flags):
set_seed(flags.seed)
if flags.use_meta_device_init:
model_init_config = {
"meta_device_init": True,
"param_init_fn": init_weights,
}
else:
model_init_config = None
# Setting up NxD config
nxd_config = nxd.neuronx_distributed_config(
tensor_parallel_size=flags.tensor_parallel_size,
optimizer_config={"zero_one_enabled": flags.use_zero_1, "grad_clipping": True, "max_grad_norm": 1.0},
sequence_parallel=flags.sequence_parallel_enabled,
activation_checkpoint_config=CoreAttention if flags.selective_checkpoint_enabled else "full",
model_init_config=model_init_config,
)
# Creating NxD model
model = nxd.initialize_parallel_model(nxd_config, get_model, flags)
world_size = parallel_state.get_data_parallel_size()
is_root = xm.is_master_ordinal(local=False)
extract_graphs_only = os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY", None)
device = xm.xla_device()
model_dtype = get_dtype(model)
running_loss = torch.zeros(1, dtype=torch.double).to(device)
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "LayerNorm"] # gamma/beta are in LayerNorm.weight
optimizer_grouped_parameters = [
{
"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
"weight_decay": 0.01,
},
{
"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
if flags.use_mix_precision:
optimizer_cls = AdamW_FP32OptimParams
else:
optimizer_cls = AdamW
# Creating NxD Optimizer
optimizer = nxd.initialize_parallel_optimizer(nxd_config, optimizer_cls, optimizer_grouped_parameters, lr=flags.lr)
optimizer.zero_grad()
if is_root:
if not os.path.exists(flags.output_dir):
os.makedirs(flags.output_dir, exist_ok=True)
if not extract_graphs_only:
logger = Logger(flags, world_size, model_dtype)
metric_writer = TrainingMetrics(flags.metrics_file)
throughput = Throughput(
flags.batch_size, world_size, flags.grad_accum_usteps, logging_interval=flags.logging_interval
)
print("--------TRAINING CONFIG----------")
print(flags)
print("--------MODEL CONFIG----------")
print(model.config)
print("---------------------------------")
metric_writer.store_parameters(
{
"Model": model.config.model_type,
"Model configuration": str(model.config),
"World size": xm.xrt_world_size(),
"Data parallel degree": world_size,
"Batch size": flags.batch_size,
"Total steps": flags.steps_this_run,
"Seed": flags.seed,
"Optimizer": str(optimizer),
"Data type": model_dtype,
"Gradient accumulation microsteps": flags.grad_accum_usteps,
"Warmup steps": flags.warmup_steps,
"Dataset": os.path.basename(os.path.normpath(flags.data_dir)),
"Environment variables": {
variable: value
for variable, value in os.environ.items()
if variable.startswith("NEURON") or variable.startswith("XLA")
},
}
)
def train_loop_fn(model, optimizer, train_loader, epoch, global_step, training_ustep, running_loss, use_zero_1):
for _, data in enumerate(train_loader):
training_ustep += 1
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
labels = data["labels"]
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
loss = outputs.loss / flags.grad_accum_usteps
loss.backward()
running_loss += loss.detach()
if training_ustep % flags.grad_accum_usteps == 0:
xm.mark_step()
# loss averaging
running_loss_div = running_loss / world_size
# Collecting loss across all data-parallel ranks
running_loss_reduced = xm.all_reduce(
xm.REDUCE_SUM,
running_loss_div,
groups=parallel_state.get_data_parallel_group(as_list=True),
)
running_loss_reduced_detached = running_loss_reduced.detach()
running_loss.zero_()
optimizer.step()
total_norm = optimizer.grad_norm # Global norm before clipping
optimizer.zero_grad()
scheduler.step()
global_step += 1
def _print_logs(running_loss_reduced_detached, total_norm):
if is_root and not extract_graphs_only:
total_norm_cpu = None
if flags.print_grad_norm:
total_norm_cpu = total_norm.cpu().item()
# NOTE: The running_loss is the loss of the global_step
logger.log(
epoch,
global_step,
running_loss_reduced_detached.cpu().item(),
optimizer.param_groups[0]["lr"],
throughput.get_throughput(),
total_norm_cpu,
)
if global_step % flags.logging_interval == 0:
# Printing the loss inside the step closure. This won't block
# the tracing for next step. Also, we are logging every N steps.
# This is done to reduce the overhead of copying tensors to CPU.
# Tensor copy is expensive since it prevents the next step to start.
xm.add_step_closure(_print_logs, (running_loss_reduced_detached, total_norm.detach()))
# Save checkpoint using checkpoint API
if (flags.checkpoint_freq > 0) and (global_step % flags.checkpoint_freq == 0):
xm.add_step_closure(
nxd.save_checkpoint,
(
flags.checkpoint_dir, # checkpoint directory
f"step_{global_step}", # tag
model, # model
optimizer, # optimizer
scheduler, # scheduler
{"epoch": epoch, "global_step": global_step, "cli_args": flags.__dict__}, # user content
8, # num_workers
True, # use_xser
flags.num_kept_checkpoint, # num_kept_ckpts
),
)
if global_step >= flags.steps_this_run:
# NOTE: Prevent runtime "Call to recv failed : Broken pipe" issue
xm.mark_step()
break
return (
global_step,
training_ustep,
running_loss,
running_loss_reduced_detached.cpu().item(),
)
train_start = time.time()
training_ustep = 0
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=flags.warmup_steps,
num_training_steps=flags.max_steps,
last_epoch=-1,
)
def get_loading_tag(flags):
if flags.loading_step == "-1":
return "-1"
if flags.loading_step == "latest_if_exists":
return None if nxd.has_checkpoint(flags.checkpoint_dir) else "-1"
return f"step_{flags.loading_step}"
tag = get_loading_tag(flags)
if tag != "-1":
user_content = nxd.load_checkpoint(
flags.checkpoint_dir,
tag=tag,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
if user_content is not None:
epoch = user_content["epoch"]
global_step = user_content["global_step"]
elif flags.pretrained_weight: # 1 (exist) or 0
user_content = nxd.load_checkpoint(
flags.checkpoint_dir,
tag="pretrained_weight",
model=model,
optimizer=None,
scheduler=None,
)
global_step = 0
epoch = 0
if flags.use_zero_1:
# We need to do init_zero1 here since after loading model weights, we
# need to sync the new params with base optimzier params.
optimizer.optimizer.init_zero()
else:
global_step = 0
epoch = 0
assert os.path.exists(os.path.expanduser(flags.data_dir)), "ERROR: Data directory {} doesn't exist!".format(
flags.data_dir
)
mini_batch_size = flags.batch_size
train_dataloader, _ = create_llama_pretraining_dataset(
flags.data_dir,
mini_batch_size,
parallel_state.get_data_parallel_size(),
parallel_state.get_data_parallel_rank(),
flags.seed,
)
# We wrap the dataloader with MpDeviceLoader. This dataloader should take
# care of copying the tensors to device and also inserting the mark_step at
# iteration end.
train_device_loader = pl.MpDeviceLoader(train_dataloader, device)
while True:
xm.master_print(
"Epoch {} begin {}".format(epoch, time.asctime()),
flush=True,
)
global_step, training_ustep, running_loss, final_loss = train_loop_fn(
model,
optimizer,
train_device_loader,
epoch,
global_step,
training_ustep,
running_loss,
flags.use_zero_1,
)
if is_root and not extract_graphs_only:
final_time = time.time()
time_diff = final_time - train_start
print(
"Epoch {} step {} end {} loss {} perf {} seq/sec (at train microstep {} time {} from beginning time {})".format(
epoch,
global_step,
time.asctime(),
final_loss,
logger.throughputs[-1],
training_ustep,
final_time,
train_start,
),
flush=True,
)
additional_data = {
"Epoch": epoch,
"Global step": global_step,
"Microstep": training_ustep,
}
metric_data = [
Metric("Loss", final_loss, "", additional_data),
Metric("Throughput", logger.throughputs[-1], "seq/s", additional_data),
]
metric_writer.store_metrics(metric_data)
if global_step >= flags.steps_this_run:
if is_root and not extract_graphs_only:
# record aggregate & final statistics in the metrics file
additional_data = {
"Epoch": epoch,
"Global step": global_step,
"Microstep": training_ustep,
}
min_throughput_index = math.ceil(10 / args.logging_interval)
if len(logger.throughputs) > min_throughput_index:
throughputs_to_average = logger.throughputs[min_throughput_index:]
else:
throughputs_to_average = logger.throughputs
average_throughput = round(
sum(throughputs_to_average) / len(throughputs_to_average), 4
) if len(logger.throughputs) > 0 else None
metric_data = [
Metric("Final loss", final_loss, "", additional_data),
Metric(
"Time to train",
round(time_diff / 60, 4),
"minutes",
additional_data,
),
Metric(
"Average throughput",
average_throughput,
"seq/s",
additional_data,
),
Metric(
"Peak throughput",
max(logger.throughputs),
"seq/s",
additional_data,
),
]
metric_writer.store_metrics(metric_data)
return
epoch += 1
def _mp_fn(index, flags):
torch.set_default_tensor_type("torch.FloatTensor")
train_llama(flags)
xm.rendezvous("_mp_fn finished")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
type=str,
help="Model weight and config path.",
)
parser.add_argument(
"--data_dir",
type=str,
help="Pre-tokenized dataset directory.",
)
parser.add_argument(
"--output_dir",
type=str,
default="./output",
help="Directory for checkpoints and logs.",
)
parser.add_argument(
"--metrics_file",
type=str,
default="results.json",
help="training metrics results file",
)
parser.add_argument("--batch_size", type=int, default=8, help="Worker batch size.")
parser.add_argument(
"--max_steps",
type=int,
help="Maximum total accumulation-steps to run.",
)
parser.add_argument(
"--steps_this_run",
type=int,
help="Exit early at <value> steps and not go to max_steps. -1 to mean no early exit.",
)
parser.add_argument(
"--seed",
type=int,
default=12349,
help="Random seed. Worker seed is this value + worker rank.",
)
parser.add_argument("--lr", type=float, default=4e-4, help="Learning rate.")
parser.add_argument(
"--warmup_steps",
type=int,
default=2000,
help="Number of warmup accumulation-steps for learning rate .",
)
parser.add_argument(
"--grad_accum_usteps",
type=int,
default=64,
help="Gradient accumulation micro-steps (an accumulation-step has <value> micro-steps.",
)
parser.add_argument(
"--print_grad_norm",
default=False,
action="store_true",
help="Whether to print grad norm",
)
parser.add_argument("--tensor_parallel_size", default=2, type=int, help="Tensor parallel size")
parser.add_argument("--seq_len", default=2048, type=int, help="Sequence length")
parser.add_argument("--use_mix_precision", action="store_true", help="Use mix precision.")
parser.add_argument("--use_zero_1", action="store_true", help="Use ZeRO-1.")
parser.add_argument(
"--num_layers",
type=int,
default=-1,
help="Override number of layers for this LLaMA model",
)
parser.add_argument(
"--sequence_parallel_enabled",
default=False,
action="store_true",
help="Enable sequence parallel",
)
parser.add_argument(
"--selective_checkpoint_enabled",
default=False,
action="store_true",
help="Enable selective checkpoint",
)
parser.add_argument(
"--use_meta_device_init",
default=False,
action="store_true",
help="use meta device initialization",
)
parser.add_argument(
"--logging_interval",
default=1,
type=int,
help="logging every N steps",
)
parser.add_argument(
"--qkv_linear",
default=False,
action="store_true",
help="Whether to use the QKV Module",
)
parser.add_argument(
"--kv_replicator",
default=1,
type=int,
help="KV replication number",
)
# Checkpointing
parser.add_argument("--checkpoint_freq", type=int, default=100000, help="save checkpoint freq")
parser.add_argument("--checkpoint_dir", type=str, default=None)
parser.add_argument("--pretrained_weight", default=False, action="store_true")
parser.add_argument(
"--loading_step",
type=str,
default="-1",
help='resume from checkpoint generated by the step, "-1" means no load, "latest_if_exists" is a valid option to resume from latest',
)
parser.add_argument(
"--num_kept_checkpoint",
type=int,
default=-1,
help="number of checkpoints kept, old checkpoint will get deleted",
)
args = parser.parse_args(sys.argv[1:])
if args.steps_this_run < 0:
args.steps_this_run = args.max_steps
os.environ["NEURON_RT_STOCHASTIC_ROUNDING_EN"] = "1"
if args.use_mix_precision:
os.environ["XLA_DOWNCAST_BF16"] = "1"
else:
os.environ["XLA_USE_BF16"] = "1"
# WORLD_SIZE is set by torchrun
if os.environ.get("WORLD_SIZE"):
if requires_init_pg_override():
pass
import torch_xla.experimental.pjrt_backend
dist.init_process_group("xla", init_method="pjrt://")
else:
dist.init_process_group("xla")
_mp_fn(0, args)
else:
xmp.spawn(_mp_fn, args=(args,))