Skip to content

Commit

Permalink
benchmark: add peft fsdp benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
Vinkle-hzt committed Apr 10, 2024
1 parent 4a769e1 commit 185904f
Showing 1 changed file with 170 additions and 0 deletions.
170 changes: 170 additions & 0 deletions benchmarks/bench_peft_fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import os
import fire
import random
import torch
import torch.optim as optim
import time
import argparse

from peft import get_peft_model, PeftModelForCausalLM, LoraConfig, TaskType
from dataclasses import dataclass
from transformers import LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy
from llama_recipes.configs.fsdp import fsdp_config as FSDP_CONFIG
from llama_recipes.configs.training import train_config as TRAIN_CONFIG
from llama_recipes.policies import apply_fsdp_checkpointing
from llama_recipes.utils.train_utils import (
setup,
setup_environ_flags,
clear_gpu_cache,
print_model_size,
get_policies,
)

from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy
)


@dataclass
class BenchmarkArgs():
batch_size: int = 8
seq_len: int = 1024
accumulation_steps: int = 4
test_steps: int = 100


@dataclass
class PeftArgs():
rank = 16
alpha = 16
dropout = 0.05
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]


def setup_labels(benchmark_args: BenchmarkArgs) -> torch.Tensor:
batch_input_ids = []
for _ in range(0, benchmark_args.batch_size):
batch_input_ids.append([random.randint(1, 10000)
for _ in range(benchmark_args.seq_len)])
return torch.tensor(batch_input_ids, dtype=torch.long)


def create_optimizer(model: FSDP, train_config: TRAIN_CONFIG):
optimizer = optim.AdamW(model.parameters(),
lr=train_config.lr,
weight_decay=train_config.weight_decay)
return optimizer


def setup_lora_adapter(model: LlamaForCausalLM, peft_args: PeftArgs) -> PeftModelForCausalLM:
peft_lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
r=peft_args.rank,
lora_alpha=peft_args.alpha,
lora_dropout=peft_args.dropout,
target_modules=peft_args.target_modules,
bias="none",
inference_mode=False)
model = get_peft_model(model, peft_lora_config)

return model


def create_model(rank: int, fsdp_config: FSDP_CONFIG, train_config: TRAIN_CONFIG) -> FSDP:
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
use_cache=False
)
print_model_size(model, train_config, rank)

# setup lora
peft_args = PeftArgs()
model = setup_lora_adapter(model, peft_args)
model.print_trainable_parameters()

hsdp_device_mesh = None
if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size,
sharding_group_size=fsdp_config.sharding_group_size)
print("HSDP device mesh is ready")

# setup FSDP
device_id = torch.cuda.current_device()
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
model = FSDP(model,
auto_wrap_policy=my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
cpu_offload=None,
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
sharding_strategy=fsdp_config.sharding_strategy,
device_mesh=hsdp_device_mesh,
device_id=device_id,
limit_all_gathers=True,
sync_module_states=train_config.low_cpu_fsdp,
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
if train_config.low_cpu_fsdp and rank != 0 else None,)
if fsdp_config.fsdp_activation_checkpointing:
apply_fsdp_checkpointing(model)
return model


def init_args():
parser = argparse.ArgumentParser(description='PEFT FSDP benchmarks')
parser.add_argument('--base_model', type=str, required=True,
help='Path to or name of base model')
return parser.parse_args()


def main():
args = init_args()

setup()
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
torch.cuda.set_device(local_rank)
clear_gpu_cache(local_rank)
setup_environ_flags(rank)

fsdp_config = FSDP_CONFIG()

train_config = TRAIN_CONFIG()
train_config.model_name = args.base_model
train_config.enable_fsdp = True
train_config.use_peft = True

benchmark_args = BenchmarkArgs()

model = create_model(rank, fsdp_config, train_config)
optimizer = create_optimizer(model, train_config)
labels = setup_labels(benchmark_args)

train(model, optimizer, labels, local_rank, benchmark_args)


def train(model: FSDP,
optimizer: optim.AdamW,
labels: torch.Tensor,
local_rank: int,
benchmark_args: BenchmarkArgs):
autocast = torch.cuda.amp.autocast
start_time = time.time()
total_tokens = 0
for step in range(benchmark_args.test_steps):
data = labels.to(local_rank)
total_tokens += data.numel()
with autocast():
loss = model(input_ids=data, labels=data).loss
loss.backward()

if (step + 1) % benchmark_args.accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()

if local_rank == 0:
print(f'average {total_tokens / (time.time() - start_time) : .2f} tokens/s')


if __name__ == "__main__":
fire.Fire(main)

0 comments on commit 185904f

Please sign in to comment.