Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llama2_70B-Megatron pretraining #389

Merged
merged 9 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions training/benchmarks/llama2_70B/megatron/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
## 模型信息
- Introduction

Llama 2, a collection of pretrained and fine-tuned large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. Meta's fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases. Llama2 outperform open-source chat models on most benchmarks meta's researchers tested, and based on their human evaluations for helpfulness and safety, may be a suitable substitute for closedsource models. Meta provide a detailed description of their approach to fine-tuning and safety improvements of Llama 2-Chat in order to enable the community to build on their work and contribute to the responsible development of LLMs.

- Paper
[LLAMA2](https://arxiv.org/pdf/2307.09288.pdf)

- 模型代码来源

This case includes code from the LLAMA 2 COMMUNITY LICENSE AGREEMENT License open source project at:https://github.com/facebookresearch/llama-recipes/tree/main


## 数据准备

### 模型配置及tokenizer准备

本测试样例为预训练case,需要下载tokenizer,下载链接为 https://github.com/FlagOpen/FlagScale/tree/main/examples/llama2/tokenizer

在data_dir下创建tokenizer目录,将上述链接中的tokenizer.model文件下载到此目录中


### 数据集准备

本测试样例数据使用FlagScale-llama2预处理好的数据集,下载链接为

https://model.ks3-cn-beijing.ksyuncs.com/nlpdata/pile_wikipedia_demo.bin

https://model.ks3-cn-beijing.ksyuncs.com/nlpdata/pile_wikipedia_demo.idx

将上述两个文件放置于data_dir下。

This case includes datasets from the MIT License open source project at https://github.com/EleutherAI/the-pile

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

### 数据集引用

```
@article{pile,
title={The {P}ile: An 800GB Dataset of Diverse Text for Language Modeling},
author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and Presser, Shawn and Leahy, Connor},
journal={arXiv preprint arXiv:2101.00027},
year={2020}
}
```

### 框架与芯片支持情况
| | Pytorch |
| ---------- | ------- |
| Nvidia GPU | ✅ |
| 昆仑芯 XPU | N/A |
| 天数智芯 | N/A |
145 changes: 145 additions & 0 deletions training/benchmarks/llama2_70B/megatron/megatron_main.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
export PYTHONPATH=$PYTHONPATH:/workspace/FlagScale
export CUDA_DEVICE_MAX_CONNECTIONS=1

DATA_DIR=$1
GPUS_PER_NODE=$2
NNODES=$3
NODE_RANK=$4
MASTER_ADDR=$5
MASTER_PORT=$6
TRAIN_SAMPLES=$7
TP=$8
PP=$9
M_BATCHSIZE=${10}
G_BATCHSIZE=${11}
SEQLENGTH=${12}
FLASH_ATTN=${13}
RECOMPUTE=${14}
VENDOR_SHELL=${15}

echo $DATA_DIR
echo $GPUS_PER_NODE
echo $NNODES
echo $NODE_RANK
echo $MASTER_ADDR
echo $MASTER_PORT
echo $TRAIN_SAMPLES
echo $TP
echo $PP
echo $M_BATCHSIZE
echo $G_BATCHSIZE
echo $SEQLENGTH
echo $FLASH_ATTN
echo $RECOMPUTE
echo $VENDOR_SHELL

DATA_PATH=$DATA_DIR/llama_00_text_document
TOKENIZER_PATH=$DATA_DIR/tokenizer/tokenizer.model

DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"

if [ "$FLASH_ATTN" = "True" ]; then
TRAINING_ARGS="
--train-samples $TRAIN_SAMPLES \
--eval-iters 0 \
--tensor-model-parallel-size $TP \
--pipeline-model-parallel-size $PP \
--micro-batch-size $M_BATCHSIZE \
--global-batch-size $G_BATCHSIZE \
--disable-bias-linear \
--use-distributed-optimizer \
--use-flash-attn
"
else
TRAINING_ARGS="
--train-samples $TRAIN_SAMPLES \
--eval-iters 0 \
--tensor-model-parallel-size $TP \
--pipeline-model-parallel-size $PP \
--micro-batch-size $M_BATCHSIZE \
--global-batch-size $G_BATCHSIZE \
--disable-bias-linear \
--use-distributed-optimizer
"
fi

MIXED_PRECISION_ARGS="
--bf16
"

DATA_ARGS="
--data-path $DATA_PATH \
--tokenizer-type Llama2Tokenizer \
--tokenizer-model $TOKENIZER_PATH \
--split 1
"

NETWORK_ARGS="
--num-layers 80 \
--hidden-size 8192 \
--num-attention-heads 64 \
--ffn-hidden-size 28672 \
--seq-length $SEQLENGTH \
--max-position-embeddings $SEQLENGTH \
--normalization RMSNorm \
--group-query-attention \
--num-query-groups 8 \
--use-rotary-position-embeddings \
--no-position-embedding \
--swiglu \
--multiple-of 4096 \
--sequence-parallel \
--untie-embeddings-and-output-weights
"

if [ "$RECOMPUTE" = "True" ]; then
RECOMPUTE_ARGS="
--recompute-activations
"

INITIALIZATION_ARGS="
--init-method-std 0.02 \
--seed 1234
"

REGULARIZATION_ARGS="
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--weight-decay 1e-2 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--clip-grad 1.0
"

LEARNING_RATE_ARGS="
--lr 0.00015 \
--min-lr 1.0e-5 \
--lr-decay-style cosine \
--lr-warmup-fraction .01
"

LOGGING_ARGS="
--log-interval 1
"

source $VENDOR_SHELL
cmd="torchrun $DISTRIBUTED_ARGS /workspace/FlagScale/pretrain_llama.py \
$TRAINING_ARGS \
$MIXED_PRECISION_ARGS \
$DATA_ARGS \
$NETWORK_ARGS \
$INITIALIZATION_ARGS \
$REGULARIZATION_ARGS \
$LEARNING_RATE_ARGS \
$CHECKPOINTING_ARGS \
$RECOMPUTE_ARGS \
$LOGGING_ARGS
"
echo $cmd
eval $cmd
92 changes: 92 additions & 0 deletions training/benchmarks/llama2_70B/megatron/run_pretraining.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import subprocess
from argparse import ArgumentParser
import os
import sys
from importlib import import_module


def parse_args():
'''we parse ddp related args, check system config args, and running env
args such as --data_dir_xxx. Then pass all useful args to the real
training script.
'''
parser = ArgumentParser(description="flagscale main python")
parser.add_argument("--nproc_per_node", type=int, required=True)
parser.add_argument("--nnodes", type=int, required=True)
parser.add_argument("--node_rank", type=int, required=True)
parser.add_argument("--master_addr", type=str, required=True)
parser.add_argument("--master_port", type=int, required=True)
parser.add_argument("--vendor", type=str, required=True)
parser.add_argument("--data_dir", type=str, required=True)
parser.add_argument("--log_dir", type=str, required=True)
parser.add_argument("--flagperf_config_file", type=str, required=True)
args, unknown_args = parser.parse_known_args()
args.unknown_args = unknown_args
return args


if __name__ == "__main__":
args = parse_args()
print(args)

sys.path.append(os.path.dirname(args.flagperf_config_file))
config_file = os.path.basename(args.flagperf_config_file).split('.')[0]
config_dir_path = os.path.dirname(args.flagperf_config_file)

module = import_module(config_file)

seqlength = getattr(module, 'seqlength')
batchsize = getattr(module, 'batchsize')
accumulate_steps = getattr(module, 'accumulate_steps')
train_tokens = getattr(module, 'train_tokens')
theoryflops = getattr(module, 'theoryflops')
epochs = getattr(module, 'epochs')
flashattn = getattr(module, 'flashattn')
recompute = getattr(module, 'recompute')
tensor_parallel = getattr(module, 'tensor_parallel')
pipeline_parallel = getattr(module, 'pipeline_parallel')

train_samples = int((train_tokens * epochs) // seqlength)
mbs = batchsize
gbs = batchsize * args.nproc_per_node * args.nnodes * accumulate_steps // (tensor_parallel *
pipeline_parallel)

task_log_file = os.path.join(args.log_dir, "megatron.log.txt")

exec_cmd = "bash megatron_main.sh"
exec_cmd = exec_cmd + " " + args.data_dir
exec_cmd = exec_cmd + " " + str(args.nproc_per_node)
exec_cmd = exec_cmd + " " + str(args.nnodes)
exec_cmd = exec_cmd + " " + str(args.node_rank)
exec_cmd = exec_cmd + " " + args.master_addr
exec_cmd = exec_cmd + " " + str(args.master_port)
exec_cmd = exec_cmd + " " + str(train_samples)
exec_cmd = exec_cmd + " " + str(tensor_parallel)
exec_cmd = exec_cmd + " " + str(pipeline_parallel)
exec_cmd = exec_cmd + " " + str(mbs)
exec_cmd = exec_cmd + " " + str(gbs)
exec_cmd = exec_cmd + " " + str(seqlength)
exec_cmd = exec_cmd + " " + str(flashattn)
exec_cmd = exec_cmd + " " + str(recompute)
exec_cmd = exec_cmd + " " + os.path.join(config_dir_path, "training_adapter.sh")

with open(task_log_file, "w") as f:
p = subprocess.Popen(exec_cmd,
shell=True,
stdout=f,
stderr=subprocess.STDOUT)
p.wait()

time_per_step = -1.0
with open(task_log_file) as f:
for line in f.readlines():
if "elapsed time per iteration (ms): " in line:
info = line.split("|")[2]
steptime = info.split(":")[1]
time_per_step = float(steptime) / 1000

whole_tps = gbs * seqlength / time_per_step
chip_tps = whole_tps / (args.nproc_per_node * args.nnodes)
print("System tokens per second: ", whole_tps)
print("Tokens/p/s: ", chip_tps)
print("MFU: ", chip_tps * 70000000000.0 * 6 / theoryflops)
4 changes: 4 additions & 0 deletions training/nvidia/docker_image/megatron/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
FROM nvcr.io/nvidia/pytorch:23.09-py3
RUN /bin/bash -c "pip config set global.index-url https://mirror.baidu.com/pypi/simple"
RUN /bin/bash -c "uname -a"
RUN /bin/bash -c alias python3=python
6 changes: 6 additions & 0 deletions training/nvidia/docker_image/megatron/megatron_install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
# using github mirrors to avoid github TTL
git clone https://githubfast.com/FlagOpen/FlagScale
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要指定一下版本,可以指定commit

git checkout 26cd6643c472f853e077779abaa51bb6a1c140bf
echo 'export PYTHONPATH=$PYTHONPATH:/workspace/FlagScale' >> /root/.bashrc
source /root/.bashrc
60 changes: 60 additions & 0 deletions training/nvidia/llama2_70B-megatron/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

### Nvidia GPU配置与运行信息参考
#### 环境配置
- ##### 硬件环境
- 机器型号: NVIDIA H800(80G)
- 加速卡型号: NVIDIA_H800-80GB
- CPU型号: Intel(R) Xeon(R) Platinum 8462Y+
- 多机网络类型、带宽: InfiniBand, 200Gb/s

- ##### 软件环境
- OS版本:Ubuntu 22.04 LTS
- OS kernel版本: 5.15.0-25-generic
- 加速卡驱动版本:535.129.03
- Docker 版本:24.0.7
- 训练框架版本:FlagScale.git@26cd664
- 依赖软件版本:sentencepiece

- ##### 并行策略

- 并行技术:张量、流水、数据混合并行,具体并行方案见“运行情况”章节
- 实施者:FlagScale
- 实施细节:/

- ##### 优化策略

- flash attention 2

### 运行情况

* 输入批尺寸
1. local_batchsize(micro_batchsize),简写为LBS,即实际进入模型的张量批尺寸,为config_H100x4x8.py中所写,在本case中默认为1
2. seqlength(max_position_embedding),简写为MPE,即实际进入模型的序列长度,为config_H100x4x8.py中所写,在本case中默认为4096
3. gradient_accumulate_steps,简写为GAS,即梯度累加步数,为ds_config.json中所写,在本case中默认为44
4. global_batchsize恒等于local_batchsize\*gradient_accumulate_steps\*data_parallel_size。在本case中,data_parallel_size=world_size/TPsize/PPsize。

* 通用指标

| 指标名称 | 指标值 | 特殊说明 |
| ------------ | -------------------------- | ---------------------------------- |
| 任务类别 | 自然语言理解 | |
| 模型 | llama2_70b | |
| 数据集 | pile wikipedia | |
| 数据精度 | precision,见“性能指标” | 可选fp32/amp/fp16/bf16 |
| 超参修改 | parallel,见“性能指标” | 格式为TPxPPyDPz,例如TP2PP1DP4 |
| 超参修改 | fix_hp,见“性能指标” | 跑满硬件设备评测吞吐量所需特殊超参 |
| 硬件设备简称 | nvidia H800 | |
| 硬件存储使用 | mem,见“性能指标” | 通常称为“显存”,单位为GiB |
| 计算使用率 | MFU,见“性能指标” | 参见PaLM论文定义 |
| **吞吐量** | **token/p/s,见“性能指标”** | 平均单卡每秒处理的token数 |

* 性能指标

值得注意的是,下列第4组实验的global_batchsize与llama2原始论文相同, 训练100 step,此项实验也将作为精度对齐所用实验。

| 配置 | precision | parallel | fix_hp | token/p/s | loss | mem | MFU |
| ------------------ | -------- | --------- | ---------------- | ------ | ------- | --------- | --------- |
| H800四机32卡(4x8) | fp32 | TP8PP4DP1 | recompute=True,(theoryflops=495T) | 253.61 | 0.94 | 77/80 | 21.5% |
| H800四机32卡(4x8) | amp | TP8PP4DP1 | / | 641.93 | 5.7 | 62/80 | 27.2% |
| H800四机32卡(4x8) | amp | TP4PP8DP1 | / | 791.37 | 5.6 | 74/80 | 33.6% |
| H800四机32卡(4x8) | amp | TP4PP8DP1 | GAS=1024(GBS=1024=4M tokens) | 908.29 | 7.1 | 74/80 | 38.6% |
10 changes: 10 additions & 0 deletions training/nvidia/llama2_70B-megatron/config/config_H800x4x8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
seqlength = 4096
batchsize = 1
accumulate_steps = 44
train_tokens = 100000000
theoryflops = 989000000000000.0
epochs = 1
flashattn = True
recompute = False
tensor_parallel = 8
pipeline_parallel = 4
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sentencepiece
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
echo "[Prompt] nvidia adaption is NULL, for other Vendors"
Loading