Skip to content

Commit

Permalink
Save changes
Browse files Browse the repository at this point in the history
  • Loading branch information
awgu committed Sep 1, 2023
1 parent 1437b91 commit f912af0
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 29 deletions.
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
.DS_Store
__pycache__
.ipynb_checkpoints
*.pickle
profiles/
PATH/
_memory_viz.py
llama-2-13b-hf/
llama-2-7b-hf/
*.out


64 changes: 48 additions & 16 deletions llama_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import os
import sys

import fire
import functools
import torch
import torch.distributed as dist
import torch.optim as optim
Expand Down Expand Up @@ -44,6 +46,12 @@
get_policies
)

sys.path.append("../distx")

from data_parallel import DataParallelMeshInfo
from data_parallel.parallelize import shard
from torch.distributed._tensor import DeviceMesh


def main(**kwargs):
# Update the configuration for the training and sharding process
Expand All @@ -64,6 +72,8 @@ def main(**kwargs):
torch.cuda.set_device(local_rank)
clear_gpu_cache(local_rank)
setup_environ_flags(rank)
# Record history *after* distributed ranks have been initialized
torch.cuda.memory._record_memory_history()

# Load the pre-trained model and setup its configuration
if train_config.enable_fsdp and train_config.low_cpu_fsdp:
Expand Down Expand Up @@ -135,22 +145,44 @@ def main(**kwargs):

freeze_transformer_layers(train_config.num_freeze_layers)

mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)

model = FSDP(
model,
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
sharding_strategy=fsdp_config.sharding_strategy,
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
sync_module_states=train_config.low_cpu_fsdp,
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
if train_config.low_cpu_fsdp and rank != 0 else None,
)
if fsdp_config.fsdp_activation_checkpointing:
policies.apply_fsdp_checkpointing(model)
USE_PER_PARAM = False

if not USE_PER_PARAM:
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)

model = FSDP(
model,
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
sharding_strategy=fsdp_config.sharding_strategy,
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
sync_module_states=train_config.low_cpu_fsdp,
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
if train_config.low_cpu_fsdp and rank != 0 else None,
)
if fsdp_config.fsdp_activation_checkpointing:
policies.apply_fsdp_checkpointing(model)
else:
mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()))
mesh_info = DataParallelMeshInfo(mesh, shard_mesh_dim=0)
post_forward_mesh_info = mesh_info
fully_shard = functools.partial(
shard, mesh_info=mesh_info, post_forward_mesh_info=post_forward_mesh_info
)
from torch.distributed._composable import checkpoint as checkpoint_activations
for module in model.modules():
if isinstance(module, LlamaDecoderLayer):
checkpoint_activations(module)
fully_shard(module)
fully_shard(model)
# Manually move buffers to CUDA (for now)
for buffer in model.buffers():
buffer.data = buffer.cuda()
if rank == 0:
avoid_record_streams = os.environ.get("TORCH_NCCL_AVOID_RECORD_STREAMS", "0")
print(f"TORCH_NCCL_AVOID_RECORD_STREAMS={avoid_record_streams}")
elif not train_config.quantization and not train_config.enable_fsdp:
model.to("cuda")

Expand Down
32 changes: 19 additions & 13 deletions multi_node.slurm
Original file line number Diff line number Diff line change
@@ -1,36 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.


#!/bin/bash

#SBATCH --job-name=Nano-2d-trainer-20b-8nodes

#SBATCH --job-name=llama2
#SBATCH --partition=train
#SBATCH --ntasks=2
#SBATCH --nodes=2
#SBATCH --gpus-per-task=4
#SBATCH --partition=train
#SBATCH --gpus-per-task=8
#SBATCH --cpus-per-task=96

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
# Enable for A100
export FI_PROVIDER="efa"

echo Node IP: $head_node_ip
# Enable for A100
export LOGLEVEL=INFO
export FI_PROVIDER="efa"
export FI_EFA_USE_DEVICE_RDMA=1
export NCCL_ALGO=ring

# debugging flags (optional)
export NCCL_DEBUG=WARN
export NCCL_DEBUG_SUBSYS=WARN
export PYTHONFAULTHANDLER=1

export LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH

echo $LD_LIBRARY_PATH
export CUDA_LAUNCH_BLOCKING=0

# on your cluster you might need these:
# set the network interface
export NCCL_SOCKET_IFNAME="ens"
export FI_EFA_USE_DEVICE_RDMA=1
export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond"


# export TORCH_NCCL_AVOID_RECORD_STREAMS=1

srun torchrun --nnodes 2 --nproc_per_node 8 --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $head_node_ip:29500 llama_finetuning.py --enable_fsdp --model_name llama-2-7b-hf/ --pure_bf16 --use_fast_kernels

srun torchrun --nproc_per_node 4 --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $head_node_ip:29500 llama_finetuning.py --enable_fsdp --use_peft --peft_method lora

20 changes: 20 additions & 0 deletions utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time

import fire
import pickle
import torch
import transformers
from datasets import load_dataset
Expand Down Expand Up @@ -86,6 +87,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
total_length = len(train_dataloader)//gradient_accumulation_steps
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch}", total=total_length)
for step, batch in enumerate(train_dataloader):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for key in batch.keys():
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
Expand All @@ -110,6 +114,22 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
optimizer.zero_grad()
pbar.update(step//gradient_accumulation_steps)

end_event.record()
torch.cuda.synchronize()

if step == 4 and rank == 0:
snapshot = torch.cuda.memory._snapshot()
with open('snapshot.pickle', 'wb') as f:
pickle.dump(snapshot, f)
if rank == 0:
elapsed_time = start_event.elapsed_time(end_event)
print(f"elapsed time / iteration: {elapsed_time:.3f} ms")
mem_stats = torch.cuda.memory_stats()
peak_active_gb = mem_stats["active_bytes.all.peak"] / (1024 ** 3)
peak_reserved_gb = mem_stats["reserved_bytes.all.peak"] / (1024 ** 3)
num_retries = mem_stats["num_alloc_retries"]
print(f"peak active: {peak_active_gb} GB | peak reserved: {peak_reserved_gb} GB | num_retries: {num_retries}")

pbar.set_description(f"Training Epoch: {epoch}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")

epoch_end_time = time.perf_counter()-epoch_start_time
Expand Down

0 comments on commit f912af0

Please sign in to comment.