Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llama2_70B-Megatron pretraining #389

Merged
merged 9 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions training/benchmarks/llama2_70B/megatron/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
## 模型信息
- Introduction

Llama 2, a collection of pretrained and fine-tuned large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. Meta's fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases. Llama2 outperform open-source chat models on most benchmarks meta's researchers tested, and based on their human evaluations for helpfulness and safety, may be a suitable substitute for closedsource models. Meta provide a detailed description of their approach to fine-tuning and safety improvements of Llama 2-Chat in order to enable the community to build on their work and contribute to the responsible development of LLMs.

- Paper
[LLAMA2](https://arxiv.org/pdf/2307.09288.pdf)

- 模型代码来源

This case includes code from the LLAMA 2 COMMUNITY LICENSE AGREEMENT License open source project at:https://github.com/facebookresearch/llama-recipes/tree/main


## 数据准备

### 模型配置及tokenizer准备

本测试样例为预训练case,需要下载tokenizer,下载链接为 https://github.com/FlagOpen/FlagScale/tree/main/examples/llama2/tokenizer

在data_dir下创建tokenizer目录,将上述链接中的tokenizer.model文件下载到此目录中


### 数据集准备

本测试样例数据使用FlagScale-llama2预处理好的数据集,下载链接为

https://model.ks3-cn-beijing.ksyuncs.com/nlpdata/pile_wikipedia_demo.bin

https://model.ks3-cn-beijing.ksyuncs.com/nlpdata/pile_wikipedia_demo.idx

将上述两个文件放置于data_dir下。

This case includes datasets from the MIT License open source project at https://github.com/EleutherAI/the-pile

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

### 数据集引用

```
@article{pile,
title={The {P}ile: An 800GB Dataset of Diverse Text for Language Modeling},
author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and Presser, Shawn and Leahy, Connor},
journal={arXiv preprint arXiv:2101.00027},
year={2020}
}
```

### 框架与芯片支持情况
| | Pytorch |
| ---------- | ------- |
| Nvidia GPU | ✅ |
| 昆仑芯 XPU | N/A |
| 天数智芯 | N/A |
133 changes: 133 additions & 0 deletions training/benchmarks/llama2_70B/megatron/megatron_main.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
export PYTHONPATH=$PYTHONPATH:/workspace/FlagScale
export CUDA_DEVICE_MAX_CONNECTIONS=1

DATA_DIR=$1
GPUS_PER_NODE=$2
NNODES=$3
NODE_RANK=$4
MASTER_ADDR=$5
MASTER_PORT=$6
TRAIN_SAMPLES=$7
TP=$8
PP=$9
M_BATCHSIZE=${10}
G_BATCHSIZE=${11}
SEQLENGTH=${12}
FLASH_ATTN=${13}
VENDOR_SHELL=${14}

echo $DATA_DIR
echo $GPUS_PER_NODE
echo $NNODES
echo $NODE_RANK
echo $MASTER_ADDR
echo $MASTER_PORT
echo $TRAIN_SAMPLES
echo $TP
echo $PP
echo $M_BATCHSIZE
echo $G_BATCHSIZE
echo $SEQLENGTH
echo $FLASH_ATTN
echo $VENDOR_SHELL

DATA_PATH=$DATA_DIR/llama_00_text_document
TOKENIZER_PATH=$DATA_DIR/tokenizer/tokenizer.model

DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"

if [ "$FLASH_ATTN" = "True" ]; then
TRAINING_ARGS="
--train-samples $TRAIN_SAMPLES \
--eval-iters 0 \
--tensor-model-parallel-size $TP \
--pipeline-model-parallel-size $PP \
--micro-batch-size $M_BATCHSIZE \
--global-batch-size $G_BATCHSIZE \
--disable-bias-linear \
--use-distributed-optimizer \
--use-flash-attn
"
else
TRAINING_ARGS="
--train-samples $TRAIN_SAMPLES \
--eval-iters 0 \
--tensor-model-parallel-size $TP \
--pipeline-model-parallel-size $PP \
--micro-batch-size $M_BATCHSIZE \
--global-batch-size $G_BATCHSIZE \
--disable-bias-linear \
--use-distributed-optimizer
"
fi

MIXED_PRECISION_ARGS="
--bf16
"

DATA_ARGS="
--data-path $DATA_PATH \
--tokenizer-type Llama2Tokenizer \
--tokenizer-model $TOKENIZER_PATH \
--split 1
"

NETWORK_ARGS="
--num-layers 80 \
--hidden-size 8192 \
--num-attention-heads 64 \
--ffn-hidden-size 28672 \
--seq-length $SEQLENGTH \
--max-position-embeddings $SEQLENGTH \
--normalization RMSNorm \
--group-query-attention \
--num-query-groups 8 \
--use-rotary-position-embeddings \
--no-position-embedding \
--swiglu \
--multiple-of 4096 \
--sequence-parallel \
--untie-embeddings-and-output-weights
"

INITIALIZATION_ARGS="
--init-method-std 0.02 \
--seed 1234
"

REGULARIZATION_ARGS="
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--weight-decay 1e-2 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--clip-grad 1.0
"

LEARNING_RATE_ARGS="
--lr 0.00015 \
--min-lr 1.0e-5 \
--lr-decay-style cosine \
--lr-warmup-fraction .01
"

source $VENDOR_SHELL
cmd="torchrun $DISTRIBUTED_ARGS pretrain_llama.py \
$TRAINING_ARGS \
$MIXED_PRECISION_ARGS \
$DATA_ARGS \
$NETWORK_ARGS \
$INITIALIZATION_ARGS \
$REGULARIZATION_ARGS \
$LEARNING_RATE_ARGS \
$CHECKPOINTING_ARGS \
$LOGGING_ARGS
"
echo $cmd
eval $cmd
189 changes: 189 additions & 0 deletions training/benchmarks/llama2_70B/megatron/pretrain_llama.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件在perf里面需要改动吗?能否直接使用FlagScale仓库的pretrain_llama

Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Pretrain LLaMA."""

import os
import torch
from torch import Tensor
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import GPTDataset
from megatron.model import LLaMAModel
from megatron.training import pretrain
from megatron.utils import (
get_ltor_masks_and_position_ids,
get_batch_on_this_cp_rank,
average_losses_across_data_parallel_group
)
from megatron.arguments import core_transformer_config_from_args


def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
config = core_transformer_config_from_args(args)
print_rank_0('building LLaMA model ...')
model = LLaMAModel(
config=config,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model


def get_batch(data_iterator):
"""Generate a batch."""

# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None

args = get_args()
tokenizer = get_tokenizer()

# Items and their type.
keys = ['text']
datatype = torch.int64

# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = tensor_parallel.broadcast_data(keys, data, datatype)

# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()

# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)

batch = {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids
}
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)

return batch.values()

def loss_func(loss_mask: Tensor, output_tensor: Tensor):
"""Loss function.

Args:
loss_mask (Tensor): Used to mask out some portions of the loss
output_tensor (Tensor): The tensor with the losses
"""
args = get_args()

losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
if args.context_parallel_size > 1:
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)])
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
loss = loss[0] / loss[1]
else:
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

# Check individual rank losses are not NaN prior to DP all-reduce.
if args.check_for_nan_in_loss_and_grad:
global_rank = torch.distributed.get_rank()
assert not loss.isnan(), (
f'Rank {global_rank}: found NaN in local forward loss calculation. '
f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}'
)

# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])

return loss * args.context_parallel_size, {'lm loss': averaged_loss[0]}


def forward_step(data_iterator, model: LLaMAModel):
"""Forward training step.

Args:
data_iterator : Input data iterator
model (LLaMAModel): The LLaMA Model
"""
args = get_args()
timers = get_timers()

# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()

output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)

return output_tensor, partial(loss_func, loss_mask)


def is_dataset_built_on_rank():
return (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and mpu.get_tensor_model_parallel_rank() == 0


def core_llama_dataset_config_from_args(args):
return GPTDatasetConfig(
is_built_on_rank=is_dataset_built_on_rank,
random_seed=args.seed,
sequence_length=args.seq_length,
blend=args.data_path,
blend_per_split=[args.train_data_path, args.valid_data_path, args.test_data_path],
split=args.split,
path_to_cache=args.data_cache_path,
)


def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build the train test and validation datasets.

Args:
train_val_test_num_samples : A list containing the number of samples in train test and validation.
"""
args = get_args()

print_rank_0("> building train, validation, and test datasets for GPT ...")

train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
GPTDataset,
train_val_test_num_samples,
core_llama_dataset_config_from_args(args)
).build()

print_rank_0("> finished creating LLaMA datasets ...")

return train_ds, valid_ds, test_ds


if __name__ == "__main__":

# Temporary for transition to core datasets
train_valid_test_datasets_provider.is_distributed = True

pretrain(train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'Llama2Tokenizer'},
get_batch_fn=get_batch)
Loading