From 41174e1641faff2763624c178874bdd2f8d51bba Mon Sep 17 00:00:00 2001 From: jamesruio <1428173426@qq.com> Date: Tue, 9 Jan 2024 17:43:46 +0800 Subject: [PATCH 1/8] Add llama2_70B-Megatron pretraining --- .../benchmarks/llama2_70B/megatron/README.md | 21 ++ .../llama2_70B/megatron/megatron_main.sh | 130 ++++++++++++ .../llama2_70B/megatron/pretrain_llama.py | 189 ++++++++++++++++++ .../llama2_70B/megatron/run_pretraining.py | 88 ++++++++ .../nvidia/docker_image/megatron/Dockerfile | 4 + .../docker_image/megatron/megatron_install.sh | 5 + training/nvidia/llama2_70B-megatron/README.md | 57 ++++++ .../config/config_H800x4x8.py | 9 + .../config/requirements.txt | 1 + .../run_benchmarks/config/cluster_conf.py | 2 +- training/run_benchmarks/config/test_conf.py | 2 +- .../megatron/start_megatron_task.py | 145 ++++++++++++++ 12 files changed, 651 insertions(+), 2 deletions(-) create mode 100644 training/benchmarks/llama2_70B/megatron/README.md create mode 100644 training/benchmarks/llama2_70B/megatron/megatron_main.sh create mode 100644 training/benchmarks/llama2_70B/megatron/pretrain_llama.py create mode 100644 training/benchmarks/llama2_70B/megatron/run_pretraining.py create mode 100644 training/nvidia/docker_image/megatron/Dockerfile create mode 100644 training/nvidia/docker_image/megatron/megatron_install.sh create mode 100644 training/nvidia/llama2_70B-megatron/README.md create mode 100644 training/nvidia/llama2_70B-megatron/config/config_H800x4x8.py create mode 100644 training/nvidia/llama2_70B-megatron/config/requirements.txt create mode 100644 training/run_benchmarks/megatron/start_megatron_task.py diff --git a/training/benchmarks/llama2_70B/megatron/README.md b/training/benchmarks/llama2_70B/megatron/README.md new file mode 100644 index 000000000..0d04addcf --- /dev/null +++ b/training/benchmarks/llama2_70B/megatron/README.md @@ -0,0 +1,21 @@ +## 模型信息 + +Llama 2, a collection of pretrained and fine-tuned large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. Meta's fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases. Llama2 outperform open-source chat models on most benchmarks meta's researchers tested, and based on their human evaluations for helpfulness and safety, may be a suitable substitute for closedsource models. Meta provide a detailed description of their approach to fine-tuning and safety improvements of Llama 2-Chat in order to enable the community to build on their work and contribute to the responsible development of LLMs. + +代码来源: https://github.com/facebookresearch/llama-recipes/tree/main + +## 模型配置及tokenizer准备 + +本测试样例为预训练case,需要下载tokenizer,下载链接为 https://github.com/FlagOpen/FlagScale/tree/main/examples/llama2/tokenizer + +在data_dir下创建tokenizer目录,将上述链接中的tokenizer.model文件下载到此目录中 + +## 数据准备 + +本测试样例数据使用FlagScale-llama2预处理好的数据集,下载链接为 + + + +将上述两个文件放置于data_dir下。 \ No newline at end of file diff --git a/training/benchmarks/llama2_70B/megatron/megatron_main.sh b/training/benchmarks/llama2_70B/megatron/megatron_main.sh new file mode 100644 index 000000000..142514545 --- /dev/null +++ b/training/benchmarks/llama2_70B/megatron/megatron_main.sh @@ -0,0 +1,130 @@ +export PYTHONPATH=$PYTHONPATH:/workspace/FlagScale +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +DATA_DIR=$1 +GPUS_PER_NODE=$2 +NNODES=$3 +NODE_RANK=$4 +MASTER_ADDR=$5 +MASTER_PORT=$6 +TRAIN_SAMPLES=$7 +TP=$8 +PP=$9 +M_BATCHSIZE=${10} +G_BATCHSIZE=${11} +SEQLENGTH=${12} +FLASH_ATTN=${13} + +echo $DATA_DIR +echo $GPUS_PER_NODE +echo $NNODES +echo $NODE_RANK +echo $MASTER_ADDR +echo $MASTER_PORT +echo $TRAIN_SAMPLES +echo $TP +echo $PP +echo $M_BATCHSIZE +echo $G_BATCHSIZE +echo $SEQLENGTH +echo $FLASH_ATTN + +DATA_PATH=$DATA_DIR/llama_00_text_document +TOKENIZER_PATH=$DATA_DIR/tokenizer/tokenizer.model + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ +" + +if [ "$FLASH_ATTN" = "True" ]; then + TRAINING_ARGS=" + --train-samples $TRAIN_SAMPLES \ + --eval-iters 0 \ + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --micro-batch-size $M_BATCHSIZE \ + --global-batch-size $G_BATCHSIZE \ + --disable-bias-linear \ + --use-distributed-optimizer \ + --use-flash-attn + " +else + TRAINING_ARGS=" + --train-samples $TRAIN_SAMPLES \ + --eval-iters 0 \ + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --micro-batch-size $M_BATCHSIZE \ + --global-batch-size $G_BATCHSIZE \ + --disable-bias-linear \ + --use-distributed-optimizer + " +fi + +MIXED_PRECISION_ARGS=" + --fp16 +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --tokenizer-type Llama2Tokenizer \ + --tokenizer-model $TOKENIZER_PATH \ + --split 949,50,1 \ +" + +NETWORK_ARGS=" + --num-layers 80 \ + --hidden-size 8192 \ + --num-attention-heads 64 \ + --ffn-hidden-size 28672 \ + --seq-length $SEQLENGTH \ + --max-position-embeddings $SEQLENGTH \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups 8 \ + --use-rotary-position-embeddings \ + --no-position-embedding \ + --swiglu \ + --multiple-of 4096 \ + --sequence-parallel \ + --untie-embeddings-and-output-weights +" + +INITIALIZATION_ARGS=" + --init-method-std 0.02 \ + --seed 1234 +" + +REGULARIZATION_ARGS=" + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --weight-decay 1e-2 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --clip-grad 1.0 +" + +LEARNING_RATE_ARGS=" + --lr 0.00015 \ + --min-lr 1.0e-5 \ + --lr-decay-style cosine \ + --lr-warmup-fraction .01 \ +" + +cmd="torchrun $DISTRIBUTED_ARGS pretrain_llama.py \ + $TRAINING_ARGS \ + $MIXED_PRECISION_ARGS \ + $DATA_ARGS \ + $NETWORK_ARGS \ + $INITIALIZATION_ARGS \ + $REGULARIZATION_ARGS \ + $LEARNING_RATE_ARGS \ + $CHECKPOINTING_ARGS \ + $LOGGING_ARGS + " +echo $cmd +eval $cmd \ No newline at end of file diff --git a/training/benchmarks/llama2_70B/megatron/pretrain_llama.py b/training/benchmarks/llama2_70B/megatron/pretrain_llama.py new file mode 100644 index 000000000..b4e5142af --- /dev/null +++ b/training/benchmarks/llama2_70B/megatron/pretrain_llama.py @@ -0,0 +1,189 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +"""Pretrain LLaMA.""" + +import os +import torch +from torch import Tensor +from functools import partial +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import get_tokenizer +from megatron.core import mpu, tensor_parallel +from megatron.core.enums import ModelType +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.gpt_dataset import GPTDatasetConfig +from megatron.core.datasets.gpt_dataset import GPTDataset +from megatron.model import LLaMAModel +from megatron.training import pretrain +from megatron.utils import ( + get_ltor_masks_and_position_ids, + get_batch_on_this_cp_rank, + average_losses_across_data_parallel_group +) +from megatron.arguments import core_transformer_config_from_args + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + args = get_args() + config = core_transformer_config_from_args(args) + print_rank_0('building LLaMA model ...') + model = LLaMAModel( + config=config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) + return model + + +def get_batch(data_iterator): + """Generate a batch.""" + + # TODO: this is pretty hacky, find a better way + if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): + return None, None, None, None, None + + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss) + + batch = { + 'tokens': tokens, + 'labels': labels, + 'loss_mask': loss_mask, + 'attention_mask': attention_mask, + 'position_ids': position_ids + } + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) + + return batch.values() + +def loss_func(loss_mask: Tensor, output_tensor: Tensor): + """Loss function. + + Args: + loss_mask (Tensor): Used to mask out some portions of the loss + output_tensor (Tensor): The tensor with the losses + """ + args = get_args() + + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + if args.context_parallel_size > 1: + loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)]) + torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) + loss = loss[0] / loss[1] + else: + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Check individual rank losses are not NaN prior to DP all-reduce. + if args.check_for_nan_in_loss_and_grad: + global_rank = torch.distributed.get_rank() + assert not loss.isnan(), ( + f'Rank {global_rank}: found NaN in local forward loss calculation. ' + f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' + ) + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss * args.context_parallel_size, {'lm loss': averaged_loss[0]} + + +def forward_step(data_iterator, model: LLaMAModel): + """Forward training step. + + Args: + data_iterator : Input data iterator + model (LLaMAModel): The LLaMA Model + """ + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def is_dataset_built_on_rank(): + return (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and mpu.get_tensor_model_parallel_rank() == 0 + + +def core_llama_dataset_config_from_args(args): + return GPTDatasetConfig( + is_built_on_rank=is_dataset_built_on_rank, + random_seed=args.seed, + sequence_length=args.seq_length, + blend=args.data_path, + blend_per_split=[args.train_data_path, args.valid_data_path, args.test_data_path], + split=args.split, + path_to_cache=args.data_cache_path, + ) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build the train test and validation datasets. + + Args: + train_val_test_num_samples : A list containing the number of samples in train test and validation. + """ + args = get_args() + + print_rank_0("> building train, validation, and test datasets for GPT ...") + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + GPTDataset, + train_val_test_num_samples, + core_llama_dataset_config_from_args(args) + ).build() + + print_rank_0("> finished creating LLaMA datasets ...") + + return train_ds, valid_ds, test_ds + + +if __name__ == "__main__": + + # Temporary for transition to core datasets + train_valid_test_datasets_provider.is_distributed = True + + pretrain(train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'Llama2Tokenizer'}, + get_batch_fn=get_batch) diff --git a/training/benchmarks/llama2_70B/megatron/run_pretraining.py b/training/benchmarks/llama2_70B/megatron/run_pretraining.py new file mode 100644 index 000000000..b0bde8aa8 --- /dev/null +++ b/training/benchmarks/llama2_70B/megatron/run_pretraining.py @@ -0,0 +1,88 @@ +import subprocess +from argparse import ArgumentParser +import os +import sys +from importlib import import_module + + +def parse_args(): + '''we parse ddp related args, check system config args, and running env + args such as --data_dir_xxx. Then pass all useful args to the real + training script. + ''' + parser = ArgumentParser(description="flagscale main python") + parser.add_argument("--nproc_per_node", type=int, required=True) + parser.add_argument("--nnodes", type=int, required=True) + parser.add_argument("--node_rank", type=int, required=True) + parser.add_argument("--master_addr", type=str, required=True) + parser.add_argument("--master_port", type=int, required=True) + parser.add_argument("--vendor", type=str, required=True) + parser.add_argument("--data_dir", type=str, required=True) + parser.add_argument("--log_dir", type=str, required=True) + parser.add_argument("--flagperf_config_file", type=str, required=True) + args, unknown_args = parser.parse_known_args() + args.unknown_args = unknown_args + return args + + +if __name__ == "__main__": + args = parse_args() + print(args) + + sys.path.append(os.path.dirname(args.flagperf_config_file)) + config_file = os.path.basename(args.flagperf_config_file).split('.')[0] + + module = import_module(config_file) + + seqlength = getattr(module, 'seqlength') + batchsize = getattr(module, 'batchsize') + accumulate_steps = getattr(module, 'accumulate_steps') + train_tokens = getattr(module, 'train_tokens') + theoryflops = getattr(module, 'theoryflops') + epochs = getattr(module, 'epochs') + flashattn = getattr(module, 'flashattn') + tensor_parallel = getattr(module, 'tensor_parallel') + pipeline_parallel = getattr(module, 'pipeline_parallel') + + train_samples = int((train_tokens * epochs) // seqlength) + mbs = batchsize + gbs = batchsize * args.nproc_per_node * args.nnodes * accumulate_steps // (tensor_parallel * + pipeline_parallel) + + task_log_file = os.path.join(args.log_dir, "megatron.log.txt") + + exec_cmd = "bash megatron_main.sh" + exec_cmd = exec_cmd + " " + args.data_dir + exec_cmd = exec_cmd + " " + str(args.nproc_per_node) + exec_cmd = exec_cmd + " " + str(args.nnodes) + exec_cmd = exec_cmd + " " + str(args.node_rank) + exec_cmd = exec_cmd + " " + args.master_addr + exec_cmd = exec_cmd + " " + str(args.master_port) + exec_cmd = exec_cmd + " " + str(train_samples) + exec_cmd = exec_cmd + " " + str(tensor_parallel) + exec_cmd = exec_cmd + " " + str(pipeline_parallel) + exec_cmd = exec_cmd + " " + str(mbs) + exec_cmd = exec_cmd + " " + str(gbs) + exec_cmd = exec_cmd + " " + str(seqlength) + exec_cmd = exec_cmd + " " + str(flashattn) + + with open(task_log_file, "w") as f: + p = subprocess.Popen(exec_cmd, + shell=True, + stdout=f, + stderr=subprocess.STDOUT) + p.wait() + + time_per_step = -1.0 + with open(task_log_file) as f: + for line in f.readlines(): + if "elapsed time per iteration (ms): " in line: + info = line.split("|")[2] + steptime = info.split(":")[1] + time_per_step = float(steptime) / 1000 + + whole_tps = gbs * seqlength / time_per_step + chip_tps = whole_tps / (args.nproc_per_node * args.nnodes) + print("System tokens per second: ", whole_tps) + print("Tokens/p/s: ", chip_tps) + print("MFU: ", chip_tps * 7000000000.0 * 6 / theoryflops) \ No newline at end of file diff --git a/training/nvidia/docker_image/megatron/Dockerfile b/training/nvidia/docker_image/megatron/Dockerfile new file mode 100644 index 000000000..c776c9da7 --- /dev/null +++ b/training/nvidia/docker_image/megatron/Dockerfile @@ -0,0 +1,4 @@ +FROM nvcr.io/nvidia/pytorch:23.09-py3 +RUN /bin/bash -c "pip config set global.index-url https://mirror.baidu.com/pypi/simple" +RUN /bin/bash -c "uname -a" +RUN /bin/bash -c alias python3=python \ No newline at end of file diff --git a/training/nvidia/docker_image/megatron/megatron_install.sh b/training/nvidia/docker_image/megatron/megatron_install.sh new file mode 100644 index 000000000..ba8f157a5 --- /dev/null +++ b/training/nvidia/docker_image/megatron/megatron_install.sh @@ -0,0 +1,5 @@ +#!/bin/bash +# using github mirrors to avoid github TTL +git clone -b kunlunxin_llama70B https://github.com/jamesruio/FlagScale.git +echo 'export PYTHONPATH=$PYTHONPATH:/workspace/FlagScale' >> /root/.bashrc +source /root/.bashrc \ No newline at end of file diff --git a/training/nvidia/llama2_70B-megatron/README.md b/training/nvidia/llama2_70B-megatron/README.md new file mode 100644 index 000000000..a2a6fac47 --- /dev/null +++ b/training/nvidia/llama2_70B-megatron/README.md @@ -0,0 +1,57 @@ + +### Nvidia GPU配置与运行信息参考 +#### 环境配置 +- ##### 硬件环境 + - 机器型号: NVIDIA H800(80G) + - 加速卡型号: NVIDIA_H800-80GB + - CPU型号: Intel(R) Xeon(R) Platinum 8462Y+ + - 多机网络类型、带宽: InfiniBand, 200Gb/s + +- ##### 软件环境 + - OS版本:Ubuntu 22.04 LTS + - OS kernel版本: 5.15.0-25-generic + - 加速卡驱动版本:535.129.03 + - Docker 版本:24.0.7 + - 训练框架版本:FlagScale.git@6fc099c + - 依赖软件版本:sentencepiece + +- ##### 并行策略 + + - 并行技术:张量、流水、数据混合并行,具体并行方案见“运行情况”章节 + - 实施者:FlagScale + - 实施细节:/ + +- ##### 优化策略 + + - flash attention 2 + +### 运行情况 + +* 输入批尺寸 + 1. local_batchsize(micro_batchsize),简写为LBS,即实际进入模型的张量批尺寸,为config_H100x4x8.py中所写,在本case中默认为1 + 2. seqlength(max_position_embedding),简写为MPE,即实际进入模型的序列长度,为config_H100x4x8.py中所写,在本case中默认为4096 + 3. gradient_accumulate_steps,简写为GAS,即梯度累加步数,为ds_config.json中所写,在本case中默认为44 + 4. global_batchsize恒等于local_batchsize\*gradient_accumulate_steps\*data_parallel_size。在本case中,data_parallel_size=world_size/TPsize/PPsize。 + +* 通用指标 + +| 指标名称 | 指标值 | 特殊说明 | +| ------------ | -------------------------- | ---------------------------------- | +| 任务类别 | 自然语言理解 | | +| 模型 | llama2_70b | | +| 数据集 | pile wikipedia | | +| 数据精度 | amp | | +| 超参修改 | parallel,见“性能指标” | 格式为TPxPPyDPz,例如TP2PP1DP4 | +| 超参修改 | fix_hp,见“性能指标” | 跑满硬件设备评测吞吐量所需特殊超参 | +| 硬件设备简称 | nvidia H800 | | +| 硬件存储使用 | mem,见“性能指标” | 通常称为“显存”,单位为GiB | +| 计算使用率 | MFU,见“性能指标” | 参见PaLM论文定义 | +| **吞吐量** | **token/p/s,见“性能指标”** | 平均单卡每秒处理的token数 | + +* 性能指标 + +| 配置 | parallel | fix_hp | token/p/s | loss | mem | MFU | +| ------------------- | ------ | ---------------- | ------ | ------- | --------- | --------- | +| H800单机8卡(4x8) | TP8PP4DP1 | / | 3936.8 | 4.70 | 74/80 | 53.0% | +| H800单机8卡(4x8) | TP4PP8DP1 | LBS=8 | 2866.3 | 4.71 | 55/80 | 38.6% | +| H800单机8卡(4x8) | TP4PP4DP2 | LBS=4 | 2093.5 | 4.02 | 57/80 | 28.2% | \ No newline at end of file diff --git a/training/nvidia/llama2_70B-megatron/config/config_H800x4x8.py b/training/nvidia/llama2_70B-megatron/config/config_H800x4x8.py new file mode 100644 index 000000000..63044c545 --- /dev/null +++ b/training/nvidia/llama2_70B-megatron/config/config_H800x4x8.py @@ -0,0 +1,9 @@ +seqlength = 4096 +batchsize = 1 +accumulate_steps = 44 +train_tokens = 100000000 +theoryflops = 989000000000000.0 +epochs = 1 +flashattn = True +tensor_parallel = 8 +pipeline_parallel = 4 \ No newline at end of file diff --git a/training/nvidia/llama2_70B-megatron/config/requirements.txt b/training/nvidia/llama2_70B-megatron/config/requirements.txt new file mode 100644 index 000000000..ad213956e --- /dev/null +++ b/training/nvidia/llama2_70B-megatron/config/requirements.txt @@ -0,0 +1 @@ +sentencepiece \ No newline at end of file diff --git a/training/run_benchmarks/config/cluster_conf.py b/training/run_benchmarks/config/cluster_conf.py index be628e197..0c184df36 100644 --- a/training/run_benchmarks/config/cluster_conf.py +++ b/training/run_benchmarks/config/cluster_conf.py @@ -1,7 +1,7 @@ '''Cluster configs''' # Hosts to run the benchmark. Each item is an IP address or a hostname. -HOSTS = ["10.1.2.2", "10.1.2.3", "10.1.2.4"] +HOSTS = ["192.2.32.13", "192.2.32.14", "192.2.32.2", "192.2.32.4"] # Hosts port to run the tensorflow distribution_strategy = 'multi_worker_mirrored' HOSTS_PORTS = ["2222"] diff --git a/training/run_benchmarks/config/test_conf.py b/training/run_benchmarks/config/test_conf.py index 2af948a72..c48b5a981 100644 --- a/training/run_benchmarks/config/test_conf.py +++ b/training/run_benchmarks/config/test_conf.py @@ -91,7 +91,7 @@ # "llama2_7b:deepspeed:A100:1:8:1": "/raid/dataset/llama2_7b_pretrain", # "aquila2_7b:flagscale:A100:1:8:1": "/raid/dataset/aquila2_7b_pretrain", - + # "llama2_70B:megatron:H800:4:8:1": "/raid/dataset/llama2_70B_pretrain", # "llama1_7B:paddle_2.5.1:TP1PP1SH2SP8A10040G:1:8:1":"/raid/dataset/llama/" # "llama1_7B:paddle_2.5.1:TP2PP1SH1SP4A10040G:1:8:1":"/raid/dataset/llama/" # "llama1_7B:paddle_2.5.1:TP2PP1SH2SP4A10040G:1:8:1":"/raid/dataset/llama/" diff --git a/training/run_benchmarks/megatron/start_megatron_task.py b/training/run_benchmarks/megatron/start_megatron_task.py new file mode 100644 index 000000000..77f3a0c35 --- /dev/null +++ b/training/run_benchmarks/megatron/start_megatron_task.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +# -*- coding: UTF-8 -*- +'''This script is called in container to execute the real training task. + Support pytorch DDP only. +''' +import os +import sys +import subprocess +from argparse import ArgumentParser + +CURR_PATH = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.abspath(os.path.join(CURR_PATH, "../../"))) +from utils import flagperf_logger +from utils import start_task_helper as helper + +START_LOGGER = flagperf_logger.FlagPerfLogger() + + +def parse_args(): + '''we parse ddp related args, check system config args, and running env + args such as --data_dir_xxx. Then pass all useful args to the real + training script. + ''' + parser = ArgumentParser(description="Start pytorch training process. ") + parser.add_argument("--node_rank", + type=int, + default=0, + help="The rank of the node for multi-node distributed " + "training") + parser.add_argument("--master_addr", + default="127.0.0.1", + type=str, + help="Master node (rank 0)'s address, should be either" + "the IP address or the hostname of node 0, for " + "single node multi-proc training, the " + "--master_addr can simply be 127.0.0.1") + parser.add_argument("--master_port", + default=29501, + type=int, + help="Master node (rank 0)'s free port that needs to " + "be used for communication during distributed " + "training") + parser.add_argument("--nnodes", + type=int, + required=True, + help="how many hosts to run the testcase.") + parser.add_argument("--nproc", + type=int, + required=True, + help="how many processes will run on each host.") + parser.add_argument("--vendor", + type=str, + required=True, + help="The accelerator vendor that run the located.") + parser.add_argument("--visible_dev_env", + type=str, + default=None, + help="The accelerator XXX_VISIBLE_DEVICE env name.") + parser.add_argument("--case_name", + type=str, + required=True, + help="Name of testcase.") + parser.add_argument("--round", + type=int, + required=True, + help="round of testcase, for repeating test.") + parser.add_argument("--model_name", + type=str, + required=True, + help="The model name of testcase.") + parser.add_argument("--host_addr", + type=str, + required=True, + help="The host address that start task.") + parser.add_argument("--train_script", + type=str, + required=True, + help="The training script to start by this launcher.") + parser.add_argument("--enable_extern_config", + action="store_true", + help="Sets to enable non-standard config parameters.") + parser.add_argument("--extern_config_file", + type=str, + required=True, + help="The testcase config file.") + parser.add_argument("--data_dir", + type=str, + default="/mnt/dataset/", + help="Data directory.") + parser.add_argument("--log_dir", + type=str, + default="/workspace/flagperf/training/result/", + help="Log directory in container.") + parser.add_argument("--log_level", + type=str, + default="debug", + help="Log level.") + + args, unknown_args = parser.parse_known_args() + args.unknown_args = unknown_args + return args + + +def main(): + '''Parse args and start the training task. Support DDP. + ''' + task_args = parse_args() + task_args.framework = "megatron" + + task_log_dir = helper.init_flagperf_logger(START_LOGGER, task_args) + helper.write_pid_file(task_args.log_dir, "start_megatron_task.pid") + + train_script_path = helper.get_train_script_path(task_args) + config_dir, config_file = helper.get_config_dir_file(task_args) + config_file = os.path.join(config_dir, config_file) + + START_LOGGER.info("Hello Flagscale") + print(train_script_path) + # START_LOGGER.info("Hello Flagscale") + # os.system("git clone httpsxxxx") + + exec_cmd = "cd " + os.path.dirname(train_script_path) + ";" + exec_cmd = exec_cmd + "python run_pretraining.py" + exec_cmd = exec_cmd + " --nproc_per_node=" + str(task_args.nproc) + exec_cmd = exec_cmd + " --nnodes=" + str(task_args.nnodes) + exec_cmd = exec_cmd + " --node_rank=" + str(task_args.node_rank) + exec_cmd = exec_cmd + " --master_addr=" + task_args.master_addr + exec_cmd = exec_cmd + " --master_port=" + str(task_args.master_port) + exec_cmd = exec_cmd + " --vendor=" + task_args.vendor + exec_cmd = exec_cmd + " --data_dir=" + task_args.data_dir + exec_cmd = exec_cmd + " --log_dir=" + task_log_dir + exec_cmd = exec_cmd + " --flagperf_config_file=" + config_file + + task_log_file = os.path.join(task_log_dir, "rank0.log.txt") + + with open(task_log_file, "w") as f: + p = subprocess.Popen(exec_cmd, + shell=True, + stdout=f, + stderr=subprocess.STDOUT) + p.wait() + + +if __name__ == '__main__': + main() \ No newline at end of file From 3f93915386b6950a6e8bf7f13671f2cc67d9b264 Mon Sep 17 00:00:00 2001 From: jamesruio <1428173426@qq.com> Date: Wed, 10 Jan 2024 19:12:59 +0800 Subject: [PATCH 2/8] support vendor shell and update readme --- .../benchmarks/llama2_70B/megatron/README.md | 52 ++++++++++++++++--- .../llama2_70B/megatron/megatron_main.sh | 11 ++-- .../llama2_70B/megatron/run_pretraining.py | 2 + training/nvidia/llama2_70B-megatron/README.md | 5 +- .../config/training_adapter.sh | 1 + .../megatron/start_megatron_task.py | 3 -- 6 files changed, 58 insertions(+), 16 deletions(-) create mode 100644 training/nvidia/llama2_70B-megatron/config/training_adapter.sh diff --git a/training/benchmarks/llama2_70B/megatron/README.md b/training/benchmarks/llama2_70B/megatron/README.md index 0d04addcf..a7a8bd3a1 100644 --- a/training/benchmarks/llama2_70B/megatron/README.md +++ b/training/benchmarks/llama2_70B/megatron/README.md @@ -1,21 +1,61 @@ ## 模型信息 +- Introduction Llama 2, a collection of pretrained and fine-tuned large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. Meta's fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases. Llama2 outperform open-source chat models on most benchmarks meta's researchers tested, and based on their human evaluations for helpfulness and safety, may be a suitable substitute for closedsource models. Meta provide a detailed description of their approach to fine-tuning and safety improvements of Llama 2-Chat in order to enable the community to build on their work and contribute to the responsible development of LLMs. -代码来源: https://github.com/facebookresearch/llama-recipes/tree/main +- Paper +[LLAMA2](https://arxiv.org/pdf/2307.09288.pdf) -## 模型配置及tokenizer准备 +- 模型代码来源 + +This case includes code from the LLAMA 2 COMMUNITY LICENSE AGREEMENT License open source project at:https://github.com/facebookresearch/llama-recipes/tree/main + + +## 数据准备 + +### 模型配置及tokenizer准备 本测试样例为预训练case,需要下载tokenizer,下载链接为 https://github.com/FlagOpen/FlagScale/tree/main/examples/llama2/tokenizer 在data_dir下创建tokenizer目录,将上述链接中的tokenizer.model文件下载到此目录中 -## 数据准备 + +### 数据集准备 本测试样例数据使用FlagScale-llama2预处理好的数据集,下载链接为 - +``` +@article{pile, + title={The {P}ile: An 800GB Dataset of Diverse Text for Language Modeling}, + author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and Presser, Shawn and Leahy, Connor}, + journal={arXiv preprint arXiv:2101.00027}, + year={2020} +} +``` -将上述两个文件放置于data_dir下。 \ No newline at end of file +### 框架与芯片支持情况 +| | Pytorch | +| ---------- | ------- | +| Nvidia GPU | ✅ | +| 昆仑芯 XPU | N/A | +| 天数智芯 | N/A | \ No newline at end of file diff --git a/training/benchmarks/llama2_70B/megatron/megatron_main.sh b/training/benchmarks/llama2_70B/megatron/megatron_main.sh index 142514545..99ceb3591 100644 --- a/training/benchmarks/llama2_70B/megatron/megatron_main.sh +++ b/training/benchmarks/llama2_70B/megatron/megatron_main.sh @@ -14,6 +14,7 @@ M_BATCHSIZE=${10} G_BATCHSIZE=${11} SEQLENGTH=${12} FLASH_ATTN=${13} +VENDOR_SHELL=${14} echo $DATA_DIR echo $GPUS_PER_NODE @@ -28,6 +29,7 @@ echo $M_BATCHSIZE echo $G_BATCHSIZE echo $SEQLENGTH echo $FLASH_ATTN +echo $VENDOR_SHELL DATA_PATH=$DATA_DIR/llama_00_text_document TOKENIZER_PATH=$DATA_DIR/tokenizer/tokenizer.model @@ -37,7 +39,7 @@ DISTRIBUTED_ARGS=" --nnodes $NNODES \ --node_rank $NODE_RANK \ --master_addr $MASTER_ADDR \ - --master_port $MASTER_PORT \ + --master_port $MASTER_PORT " if [ "$FLASH_ATTN" = "True" ]; then @@ -66,14 +68,14 @@ else fi MIXED_PRECISION_ARGS=" - --fp16 + --bf16 " DATA_ARGS=" --data-path $DATA_PATH \ --tokenizer-type Llama2Tokenizer \ --tokenizer-model $TOKENIZER_PATH \ - --split 949,50,1 \ + --split 1 " NETWORK_ARGS=" @@ -112,9 +114,10 @@ LEARNING_RATE_ARGS=" --lr 0.00015 \ --min-lr 1.0e-5 \ --lr-decay-style cosine \ - --lr-warmup-fraction .01 \ + --lr-warmup-fraction .01 " +source $VENDOR_SHELL cmd="torchrun $DISTRIBUTED_ARGS pretrain_llama.py \ $TRAINING_ARGS \ $MIXED_PRECISION_ARGS \ diff --git a/training/benchmarks/llama2_70B/megatron/run_pretraining.py b/training/benchmarks/llama2_70B/megatron/run_pretraining.py index b0bde8aa8..0501e2319 100644 --- a/training/benchmarks/llama2_70B/megatron/run_pretraining.py +++ b/training/benchmarks/llama2_70B/megatron/run_pretraining.py @@ -31,6 +31,7 @@ def parse_args(): sys.path.append(os.path.dirname(args.flagperf_config_file)) config_file = os.path.basename(args.flagperf_config_file).split('.')[0] + config_dir_path = os.path.dirname(args.flagperf_config_file) module = import_module(config_file) @@ -65,6 +66,7 @@ def parse_args(): exec_cmd = exec_cmd + " " + str(gbs) exec_cmd = exec_cmd + " " + str(seqlength) exec_cmd = exec_cmd + " " + str(flashattn) + exec_cmd = exec_cmd + " " + os.path.join(config_dir_path, "training_adapter.sh") with open(task_log_file, "w") as f: p = subprocess.Popen(exec_cmd, diff --git a/training/nvidia/llama2_70B-megatron/README.md b/training/nvidia/llama2_70B-megatron/README.md index a2a6fac47..4fc9965f3 100644 --- a/training/nvidia/llama2_70B-megatron/README.md +++ b/training/nvidia/llama2_70B-megatron/README.md @@ -52,6 +52,5 @@ | 配置 | parallel | fix_hp | token/p/s | loss | mem | MFU | | ------------------- | ------ | ---------------- | ------ | ------- | --------- | --------- | -| H800单机8卡(4x8) | TP8PP4DP1 | / | 3936.8 | 4.70 | 74/80 | 53.0% | -| H800单机8卡(4x8) | TP4PP8DP1 | LBS=8 | 2866.3 | 4.71 | 55/80 | 38.6% | -| H800单机8卡(4x8) | TP4PP4DP2 | LBS=4 | 2093.5 | 4.02 | 57/80 | 28.2% | \ No newline at end of file +| H800单机8卡(4x8) | TP8PP4DP1 | / | 641.93 | 5.7 | 62/80 | 27.2% | +| H800单机8卡(4x8) | TP4PP8DP1 | / | 791.37 | 5.6 | 74/80 | 33.6% | \ No newline at end of file diff --git a/training/nvidia/llama2_70B-megatron/config/training_adapter.sh b/training/nvidia/llama2_70B-megatron/config/training_adapter.sh new file mode 100644 index 000000000..13cdbd889 --- /dev/null +++ b/training/nvidia/llama2_70B-megatron/config/training_adapter.sh @@ -0,0 +1 @@ +echo "[Prompt] nvidia adaption is NULL, for other Vendors" \ No newline at end of file diff --git a/training/run_benchmarks/megatron/start_megatron_task.py b/training/run_benchmarks/megatron/start_megatron_task.py index 77f3a0c35..9a71ebac4 100644 --- a/training/run_benchmarks/megatron/start_megatron_task.py +++ b/training/run_benchmarks/megatron/start_megatron_task.py @@ -115,9 +115,6 @@ def main(): config_file = os.path.join(config_dir, config_file) START_LOGGER.info("Hello Flagscale") - print(train_script_path) - # START_LOGGER.info("Hello Flagscale") - # os.system("git clone httpsxxxx") exec_cmd = "cd " + os.path.dirname(train_script_path) + ";" exec_cmd = exec_cmd + "python run_pretraining.py" From 0171acaa18136b79b43dedf479f6287552377c9a Mon Sep 17 00:00:00 2001 From: jamesruio <1428173426@qq.com> Date: Thu, 11 Jan 2024 22:00:57 +0800 Subject: [PATCH 3/8] add fp32 training performance data and update readme --- .../llama2_70B/megatron/megatron_main.sh | 16 +- .../llama2_70B/megatron/pretrain_llama.py | 189 ------------------ .../llama2_70B/megatron/run_pretraining.py | 4 +- .../docker_image/megatron/megatron_install.sh | 2 +- training/nvidia/llama2_70B-megatron/README.md | 11 +- .../config/config_H800x4x8.py | 1 + .../run_benchmarks/config/cluster_conf.py | 2 +- 7 files changed, 26 insertions(+), 199 deletions(-) delete mode 100644 training/benchmarks/llama2_70B/megatron/pretrain_llama.py diff --git a/training/benchmarks/llama2_70B/megatron/megatron_main.sh b/training/benchmarks/llama2_70B/megatron/megatron_main.sh index 99ceb3591..a5f01b8ee 100644 --- a/training/benchmarks/llama2_70B/megatron/megatron_main.sh +++ b/training/benchmarks/llama2_70B/megatron/megatron_main.sh @@ -14,7 +14,8 @@ M_BATCHSIZE=${10} G_BATCHSIZE=${11} SEQLENGTH=${12} FLASH_ATTN=${13} -VENDOR_SHELL=${14} +RECOMPUTE=${14} +VENDOR_SHELL=${15} echo $DATA_DIR echo $GPUS_PER_NODE @@ -29,6 +30,7 @@ echo $M_BATCHSIZE echo $G_BATCHSIZE echo $SEQLENGTH echo $FLASH_ATTN +echo $RECOMPUTE echo $VENDOR_SHELL DATA_PATH=$DATA_DIR/llama_00_text_document @@ -96,6 +98,11 @@ NETWORK_ARGS=" --untie-embeddings-and-output-weights " +if [ "$RECOMPUTE" = "True" ]; then + RECOMPUTE_ARGS=" + --recompute-activations + " + INITIALIZATION_ARGS=" --init-method-std 0.02 \ --seed 1234 @@ -117,8 +124,12 @@ LEARNING_RATE_ARGS=" --lr-warmup-fraction .01 " +LOGGING_ARGS=" + --log-interval 1 +" + source $VENDOR_SHELL -cmd="torchrun $DISTRIBUTED_ARGS pretrain_llama.py \ +cmd="torchrun $DISTRIBUTED_ARGS /workspace/FlagScale/pretrain_llama.py \ $TRAINING_ARGS \ $MIXED_PRECISION_ARGS \ $DATA_ARGS \ @@ -127,6 +138,7 @@ cmd="torchrun $DISTRIBUTED_ARGS pretrain_llama.py \ $REGULARIZATION_ARGS \ $LEARNING_RATE_ARGS \ $CHECKPOINTING_ARGS \ + $RECOMPUTE_ARGS \ $LOGGING_ARGS " echo $cmd diff --git a/training/benchmarks/llama2_70B/megatron/pretrain_llama.py b/training/benchmarks/llama2_70B/megatron/pretrain_llama.py deleted file mode 100644 index b4e5142af..000000000 --- a/training/benchmarks/llama2_70B/megatron/pretrain_llama.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -"""Pretrain LLaMA.""" - -import os -import torch -from torch import Tensor -from functools import partial -from megatron import get_args -from megatron import print_rank_0 -from megatron import get_timers -from megatron import get_tokenizer -from megatron.core import mpu, tensor_parallel -from megatron.core.enums import ModelType -from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder -from megatron.core.datasets.gpt_dataset import GPTDatasetConfig -from megatron.core.datasets.gpt_dataset import GPTDataset -from megatron.model import LLaMAModel -from megatron.training import pretrain -from megatron.utils import ( - get_ltor_masks_and_position_ids, - get_batch_on_this_cp_rank, - average_losses_across_data_parallel_group -) -from megatron.arguments import core_transformer_config_from_args - - -def model_provider(pre_process=True, post_process=True): - """Build the model.""" - args = get_args() - config = core_transformer_config_from_args(args) - print_rank_0('building LLaMA model ...') - model = LLaMAModel( - config=config, - num_tokentypes=0, - parallel_output=True, - pre_process=pre_process, - post_process=post_process - ) - return model - - -def get_batch(data_iterator): - """Generate a batch.""" - - # TODO: this is pretty hacky, find a better way - if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): - return None, None, None, None, None - - args = get_args() - tokenizer = get_tokenizer() - - # Items and their type. - keys = ['text'] - datatype = torch.int64 - - # Broadcast data. - if data_iterator is not None: - data = next(data_iterator) - else: - data = None - data_b = tensor_parallel.broadcast_data(keys, data, datatype) - - # Unpack. - tokens_ = data_b['text'].long() - labels = tokens_[:, 1:].contiguous() - tokens = tokens_[:, :-1].contiguous() - - # Get the masks and postition ids. - attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( - tokens, - tokenizer.eod, - args.reset_position_ids, - args.reset_attention_mask, - args.eod_mask_loss) - - batch = { - 'tokens': tokens, - 'labels': labels, - 'loss_mask': loss_mask, - 'attention_mask': attention_mask, - 'position_ids': position_ids - } - # slice batch along sequence dimension for context parallelism - batch = get_batch_on_this_cp_rank(batch) - - return batch.values() - -def loss_func(loss_mask: Tensor, output_tensor: Tensor): - """Loss function. - - Args: - loss_mask (Tensor): Used to mask out some portions of the loss - output_tensor (Tensor): The tensor with the losses - """ - args = get_args() - - losses = output_tensor.float() - loss_mask = loss_mask.view(-1).float() - if args.context_parallel_size > 1: - loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)]) - torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) - loss = loss[0] / loss[1] - else: - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() - - # Check individual rank losses are not NaN prior to DP all-reduce. - if args.check_for_nan_in_loss_and_grad: - global_rank = torch.distributed.get_rank() - assert not loss.isnan(), ( - f'Rank {global_rank}: found NaN in local forward loss calculation. ' - f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' - ) - - # Reduce loss for logging. - averaged_loss = average_losses_across_data_parallel_group([loss]) - - return loss * args.context_parallel_size, {'lm loss': averaged_loss[0]} - - -def forward_step(data_iterator, model: LLaMAModel): - """Forward training step. - - Args: - data_iterator : Input data iterator - model (LLaMAModel): The LLaMA Model - """ - args = get_args() - timers = get_timers() - - # Get the batch. - timers('batch-generator', log_level=2).start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( - data_iterator) - timers('batch-generator').stop() - - output_tensor = model(tokens, position_ids, attention_mask, - labels=labels) - - return output_tensor, partial(loss_func, loss_mask) - - -def is_dataset_built_on_rank(): - return (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and mpu.get_tensor_model_parallel_rank() == 0 - - -def core_llama_dataset_config_from_args(args): - return GPTDatasetConfig( - is_built_on_rank=is_dataset_built_on_rank, - random_seed=args.seed, - sequence_length=args.seq_length, - blend=args.data_path, - blend_per_split=[args.train_data_path, args.valid_data_path, args.test_data_path], - split=args.split, - path_to_cache=args.data_cache_path, - ) - - -def train_valid_test_datasets_provider(train_val_test_num_samples): - """Build the train test and validation datasets. - - Args: - train_val_test_num_samples : A list containing the number of samples in train test and validation. - """ - args = get_args() - - print_rank_0("> building train, validation, and test datasets for GPT ...") - - train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( - GPTDataset, - train_val_test_num_samples, - core_llama_dataset_config_from_args(args) - ).build() - - print_rank_0("> finished creating LLaMA datasets ...") - - return train_ds, valid_ds, test_ds - - -if __name__ == "__main__": - - # Temporary for transition to core datasets - train_valid_test_datasets_provider.is_distributed = True - - pretrain(train_valid_test_datasets_provider, - model_provider, - ModelType.encoder_or_decoder, - forward_step, - args_defaults={'tokenizer_type': 'Llama2Tokenizer'}, - get_batch_fn=get_batch) diff --git a/training/benchmarks/llama2_70B/megatron/run_pretraining.py b/training/benchmarks/llama2_70B/megatron/run_pretraining.py index 0501e2319..5d7cea938 100644 --- a/training/benchmarks/llama2_70B/megatron/run_pretraining.py +++ b/training/benchmarks/llama2_70B/megatron/run_pretraining.py @@ -42,6 +42,7 @@ def parse_args(): theoryflops = getattr(module, 'theoryflops') epochs = getattr(module, 'epochs') flashattn = getattr(module, 'flashattn') + recompute = getattr(module, 'recompute') tensor_parallel = getattr(module, 'tensor_parallel') pipeline_parallel = getattr(module, 'pipeline_parallel') @@ -66,6 +67,7 @@ def parse_args(): exec_cmd = exec_cmd + " " + str(gbs) exec_cmd = exec_cmd + " " + str(seqlength) exec_cmd = exec_cmd + " " + str(flashattn) + exec_cmd = exec_cmd + " " + str(recompute) exec_cmd = exec_cmd + " " + os.path.join(config_dir_path, "training_adapter.sh") with open(task_log_file, "w") as f: @@ -87,4 +89,4 @@ def parse_args(): chip_tps = whole_tps / (args.nproc_per_node * args.nnodes) print("System tokens per second: ", whole_tps) print("Tokens/p/s: ", chip_tps) - print("MFU: ", chip_tps * 7000000000.0 * 6 / theoryflops) \ No newline at end of file + print("MFU: ", chip_tps * 70000000000.0 * 6 / theoryflops) \ No newline at end of file diff --git a/training/nvidia/docker_image/megatron/megatron_install.sh b/training/nvidia/docker_image/megatron/megatron_install.sh index ba8f157a5..073708b43 100644 --- a/training/nvidia/docker_image/megatron/megatron_install.sh +++ b/training/nvidia/docker_image/megatron/megatron_install.sh @@ -1,5 +1,5 @@ #!/bin/bash # using github mirrors to avoid github TTL -git clone -b kunlunxin_llama70B https://github.com/jamesruio/FlagScale.git +git clone https://githubfast.com/FlagOpen/FlagScale echo 'export PYTHONPATH=$PYTHONPATH:/workspace/FlagScale' >> /root/.bashrc source /root/.bashrc \ No newline at end of file diff --git a/training/nvidia/llama2_70B-megatron/README.md b/training/nvidia/llama2_70B-megatron/README.md index 4fc9965f3..185746635 100644 --- a/training/nvidia/llama2_70B-megatron/README.md +++ b/training/nvidia/llama2_70B-megatron/README.md @@ -40,7 +40,7 @@ | 任务类别 | 自然语言理解 | | | 模型 | llama2_70b | | | 数据集 | pile wikipedia | | -| 数据精度 | amp | | +| 数据精度 | precision,见“性能指标” | 可选fp32/amp/fp16/bf16 | | 超参修改 | parallel,见“性能指标” | 格式为TPxPPyDPz,例如TP2PP1DP4 | | 超参修改 | fix_hp,见“性能指标” | 跑满硬件设备评测吞吐量所需特殊超参 | | 硬件设备简称 | nvidia H800 | | @@ -50,7 +50,8 @@ * 性能指标 -| 配置 | parallel | fix_hp | token/p/s | loss | mem | MFU | -| ------------------- | ------ | ---------------- | ------ | ------- | --------- | --------- | -| H800单机8卡(4x8) | TP8PP4DP1 | / | 641.93 | 5.7 | 62/80 | 27.2% | -| H800单机8卡(4x8) | TP4PP8DP1 | / | 791.37 | 5.6 | 74/80 | 33.6% | \ No newline at end of file +| 配置 | precision | parallel | fix_hp | token/p/s | loss | mem | MFU | +| ------------------ | -------- | --------- | ---------------- | ------ | ------- | --------- | --------- | +| H800单机8卡(4x8) | fp32 | TP8PP4DP1 | recompute=True | 253.61 | 0.94 | 77/80 | 10.7% | +| H800单机8卡(4x8) | amp | TP8PP4DP1 | / | 641.93 | 5.7 | 62/80 | 27.2% | +| H800单机8卡(4x8) | amp | TP4PP8DP1 | / | 791.37 | 5.6 | 74/80 | 33.6% | diff --git a/training/nvidia/llama2_70B-megatron/config/config_H800x4x8.py b/training/nvidia/llama2_70B-megatron/config/config_H800x4x8.py index 63044c545..fe8af8395 100644 --- a/training/nvidia/llama2_70B-megatron/config/config_H800x4x8.py +++ b/training/nvidia/llama2_70B-megatron/config/config_H800x4x8.py @@ -5,5 +5,6 @@ theoryflops = 989000000000000.0 epochs = 1 flashattn = True +recompute = False tensor_parallel = 8 pipeline_parallel = 4 \ No newline at end of file diff --git a/training/run_benchmarks/config/cluster_conf.py b/training/run_benchmarks/config/cluster_conf.py index 0c184df36..be628e197 100644 --- a/training/run_benchmarks/config/cluster_conf.py +++ b/training/run_benchmarks/config/cluster_conf.py @@ -1,7 +1,7 @@ '''Cluster configs''' # Hosts to run the benchmark. Each item is an IP address or a hostname. -HOSTS = ["192.2.32.13", "192.2.32.14", "192.2.32.2", "192.2.32.4"] +HOSTS = ["10.1.2.2", "10.1.2.3", "10.1.2.4"] # Hosts port to run the tensorflow distribution_strategy = 'multi_worker_mirrored' HOSTS_PORTS = ["2222"] From f1df6ecaf47382aaed7c3da799444a5e09020e2d Mon Sep 17 00:00:00 2001 From: shh2000 <13820618441@163.com> Date: Fri, 12 Jan 2024 11:53:32 +0800 Subject: [PATCH 4/8] fix&add --- training/nvidia/llama2_70B-megatron/README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/training/nvidia/llama2_70B-megatron/README.md b/training/nvidia/llama2_70B-megatron/README.md index 185746635..71b4ae4b5 100644 --- a/training/nvidia/llama2_70B-megatron/README.md +++ b/training/nvidia/llama2_70B-megatron/README.md @@ -50,8 +50,11 @@ * 性能指标 +值得注意的是,下列第4组实验的global_batchsize与llama2原始论文相同,此项实验也将作为精度对齐所用实验。 + | 配置 | precision | parallel | fix_hp | token/p/s | loss | mem | MFU | | ------------------ | -------- | --------- | ---------------- | ------ | ------- | --------- | --------- | -| H800单机8卡(4x8) | fp32 | TP8PP4DP1 | recompute=True | 253.61 | 0.94 | 77/80 | 10.7% | -| H800单机8卡(4x8) | amp | TP8PP4DP1 | / | 641.93 | 5.7 | 62/80 | 27.2% | -| H800单机8卡(4x8) | amp | TP4PP8DP1 | / | 791.37 | 5.6 | 74/80 | 33.6% | +| H800四机32卡(4x8) | fp32 | TP8PP4DP1 | recompute=True | 253.61 | 0.94 | 77/80 | 10.7% | +| H800四机32卡(4x8) | amp | TP8PP4DP1 | / | 641.93 | 5.7 | 62/80 | 27.2% | +| H800四机32卡(4x8) | amp | TP4PP8DP1 | / | 791.37 | 5.6 | 74/80 | 33.6% | +| H800四机32卡(4x8) | amp | TP4PP8DP1 | GAS=1024(GBS=1024=4M tokens) | 791.37 | 5.6 | 74/80 | 33.6% | From a2a01094d64bc39e86efa5f14f675558df6c4742 Mon Sep 17 00:00:00 2001 From: shh2000 <13820618441@163.com> Date: Fri, 12 Jan 2024 13:16:19 +0800 Subject: [PATCH 5/8] fix&add --- training/nvidia/llama2_70B-megatron/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/training/nvidia/llama2_70B-megatron/README.md b/training/nvidia/llama2_70B-megatron/README.md index 71b4ae4b5..02c92a6be 100644 --- a/training/nvidia/llama2_70B-megatron/README.md +++ b/training/nvidia/llama2_70B-megatron/README.md @@ -54,7 +54,7 @@ | 配置 | precision | parallel | fix_hp | token/p/s | loss | mem | MFU | | ------------------ | -------- | --------- | ---------------- | ------ | ------- | --------- | --------- | -| H800四机32卡(4x8) | fp32 | TP8PP4DP1 | recompute=True | 253.61 | 0.94 | 77/80 | 10.7% | +| H800四机32卡(4x8) | fp32 | TP8PP4DP1 | recompute=True,(theoryflops=495T) | 253.61 | 0.94 | 77/80 | 21.4% | | H800四机32卡(4x8) | amp | TP8PP4DP1 | / | 641.93 | 5.7 | 62/80 | 27.2% | | H800四机32卡(4x8) | amp | TP4PP8DP1 | / | 791.37 | 5.6 | 74/80 | 33.6% | | H800四机32卡(4x8) | amp | TP4PP8DP1 | GAS=1024(GBS=1024=4M tokens) | 791.37 | 5.6 | 74/80 | 33.6% | From cf15a873ebb29f5cc4a22e7e7bb72f8f9a9a73a1 Mon Sep 17 00:00:00 2001 From: shh2000 <13820618441@163.com> Date: Fri, 12 Jan 2024 13:20:20 +0800 Subject: [PATCH 6/8] fix&add --- training/nvidia/llama2_70B-megatron/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/training/nvidia/llama2_70B-megatron/README.md b/training/nvidia/llama2_70B-megatron/README.md index 02c92a6be..a5ae3f67c 100644 --- a/training/nvidia/llama2_70B-megatron/README.md +++ b/training/nvidia/llama2_70B-megatron/README.md @@ -54,7 +54,7 @@ | 配置 | precision | parallel | fix_hp | token/p/s | loss | mem | MFU | | ------------------ | -------- | --------- | ---------------- | ------ | ------- | --------- | --------- | -| H800四机32卡(4x8) | fp32 | TP8PP4DP1 | recompute=True,(theoryflops=495T) | 253.61 | 0.94 | 77/80 | 21.4% | +| H800四机32卡(4x8) | fp32 | TP8PP4DP1 | recompute=True,(theoryflops=495T) | 253.61 | 0.94 | 77/80 | 21.5% | | H800四机32卡(4x8) | amp | TP8PP4DP1 | / | 641.93 | 5.7 | 62/80 | 27.2% | | H800四机32卡(4x8) | amp | TP4PP8DP1 | / | 791.37 | 5.6 | 74/80 | 33.6% | | H800四机32卡(4x8) | amp | TP4PP8DP1 | GAS=1024(GBS=1024=4M tokens) | 791.37 | 5.6 | 74/80 | 33.6% | From d97c7914b13033118ab99b35c87d60b1f80a64b6 Mon Sep 17 00:00:00 2001 From: shh2000 <13820618441@163.com> Date: Fri, 12 Jan 2024 15:23:03 +0800 Subject: [PATCH 7/8] fix&add --- training/nvidia/llama2_70B-megatron/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/training/nvidia/llama2_70B-megatron/README.md b/training/nvidia/llama2_70B-megatron/README.md index a5ae3f67c..e6266436e 100644 --- a/training/nvidia/llama2_70B-megatron/README.md +++ b/training/nvidia/llama2_70B-megatron/README.md @@ -50,11 +50,11 @@ * 性能指标 -值得注意的是,下列第4组实验的global_batchsize与llama2原始论文相同,此项实验也将作为精度对齐所用实验。 +值得注意的是,下列第4组实验的global_batchsize与llama2原始论文相同, 训练100 step,此项实验也将作为精度对齐所用实验。 | 配置 | precision | parallel | fix_hp | token/p/s | loss | mem | MFU | | ------------------ | -------- | --------- | ---------------- | ------ | ------- | --------- | --------- | | H800四机32卡(4x8) | fp32 | TP8PP4DP1 | recompute=True,(theoryflops=495T) | 253.61 | 0.94 | 77/80 | 21.5% | | H800四机32卡(4x8) | amp | TP8PP4DP1 | / | 641.93 | 5.7 | 62/80 | 27.2% | | H800四机32卡(4x8) | amp | TP4PP8DP1 | / | 791.37 | 5.6 | 74/80 | 33.6% | -| H800四机32卡(4x8) | amp | TP4PP8DP1 | GAS=1024(GBS=1024=4M tokens) | 791.37 | 5.6 | 74/80 | 33.6% | +| H800四机32卡(4x8) | amp | TP4PP8DP1 | GAS=1024(GBS=1024=4M tokens) | 908.29 | 7.1 | 74/80 | 38.6% | From 4a4d232c974f20b0e24e5706f28ab5132ad005b7 Mon Sep 17 00:00:00 2001 From: jamesruio <1428173426@qq.com> Date: Fri, 12 Jan 2024 18:20:43 +0800 Subject: [PATCH 8/8] update framework commit version and readme --- training/nvidia/docker_image/megatron/megatron_install.sh | 1 + training/nvidia/llama2_70B-megatron/README.md | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/training/nvidia/docker_image/megatron/megatron_install.sh b/training/nvidia/docker_image/megatron/megatron_install.sh index 073708b43..7662a98b6 100644 --- a/training/nvidia/docker_image/megatron/megatron_install.sh +++ b/training/nvidia/docker_image/megatron/megatron_install.sh @@ -1,5 +1,6 @@ #!/bin/bash # using github mirrors to avoid github TTL git clone https://githubfast.com/FlagOpen/FlagScale +git checkout 26cd6643c472f853e077779abaa51bb6a1c140bf echo 'export PYTHONPATH=$PYTHONPATH:/workspace/FlagScale' >> /root/.bashrc source /root/.bashrc \ No newline at end of file diff --git a/training/nvidia/llama2_70B-megatron/README.md b/training/nvidia/llama2_70B-megatron/README.md index e6266436e..882826036 100644 --- a/training/nvidia/llama2_70B-megatron/README.md +++ b/training/nvidia/llama2_70B-megatron/README.md @@ -12,7 +12,7 @@ - OS kernel版本: 5.15.0-25-generic - 加速卡驱动版本:535.129.03 - Docker 版本:24.0.7 - - 训练框架版本:FlagScale.git@6fc099c + - 训练框架版本:FlagScale.git@26cd664 - 依赖软件版本:sentencepiece - ##### 并行策略