Skip to content

Commit

Permalink
feat: add runtime diag (#297)
Browse files Browse the repository at this point in the history
* feat: add runtime diag

* add diag_outlier_ratio

---------

Co-authored-by: yingtongxiong <974106207@qq.com>
  • Loading branch information
sunpengsdu and yingtongxiong committed Sep 8, 2023
1 parent 06807a6 commit 1ee31ff
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 13 deletions.
2 changes: 2 additions & 0 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
min_length=50,
# train_folder=TRAIN_FOLDER,
# valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=10,
diag_outlier_ratio=1.1,
)

grad_scaler = dict(
Expand Down
7 changes: 7 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ def args_sanity_check():
if "valid_every" not in data:
data._add_item("valid_every", 0)

if "empty_cache_and_diag_interval" not in data:
data._add_item("empty_cache_and_diag_interval", 50)

if "diag_outlier_ratio" not in data:
data._add_item("diag_outlier_ratio", 1.1)
data.diag_outlier_ratio = max(1, data.diag_outlier_ratio)

if gpc.is_rank_for_log():
logger.info("+" * 15 + " Data Info " + "+" * 15) # pylint: disable=W1201
logger.info(f"seq_len: {data.seq_len}")
Expand Down
16 changes: 16 additions & 0 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def _step(self, closure=None, norms=None):

# check for overflow
found_inf = False
found_nan = False
# if there is INF values in grades, compute_norm func would also returns -1
# thus, we try to avoid call _check_overflow here
# found_inf = self._check_overflow()
Expand All @@ -578,9 +579,13 @@ def _step(self, closure=None, norms=None):
if -1 in norms.values():
found_inf = True

if -2 in norms.values():
found_nan = True

loss_scale = float(self.loss_scale.item()) # backup
if gpc.config.model.dtype is not torch.float32:
self.grad_scaler.update(found_inf)

# update loss scale if overflow occurs
if found_inf:
if gpc.is_rank_for_log():
Expand All @@ -593,6 +598,17 @@ def _step(self, closure=None, norms=None):
self.zero_grad()
return False, norms

if found_nan:
if gpc.is_rank_for_log():
logger.warning("Nan grad norm occurs, please check it.")
send_alert_message(
address=gpc.config.monitor.alert.feishu_alert_address,
message="Nan grad norm occurs, please check it.",
)
self._grad_store._averaged_gradients = dict()
self.zero_grad()
return False, norms

# copy the grad of fp16 param to fp32 param
single_grad_partition_groups = []
for group_id in range(self.num_param_groups):
Expand Down
3 changes: 3 additions & 0 deletions internlm/solver/optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
if total_norm == float("inf") or total_norm == -float("inf"):
total_norm = -1

if math.isnan(total_norm):
total_norm = -2

return total_norm


Expand Down
1 change: 1 addition & 0 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def record_current_batch_training_metrics(

set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))

timer.store_last_timers()
if success_update in (0, True):
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
if is_no_pp_or_last_stage():
Expand Down
109 changes: 101 additions & 8 deletions internlm/utils/gputest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from flash_attn.modules.mha import FlashSelfAttention, SelfAttention
from torch.utils import benchmark

from internlm.monitor import send_alert_message
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer

try:
import GPUtil
Expand All @@ -24,6 +26,23 @@
logger = get_logger(__file__)


def empty_cache_and_diag(batch_count, interval=50):
"""empty cuda cache and run diag bench or tests."""
if interval <= 0:
interval = 50
if batch_count % int(interval) == 0:
# there is no need to do diag on the first batch
if batch_count > 0:
if gpc.is_rank_for_log():
logger.info("Empty Cache and Diagnosis GPU/NCCL/Timer ...")
with torch.no_grad():
timer_diagnosis()
bench_gpu()
bench_net()
# do empty_cache after the bench
torch.cuda.empty_cache()


def benchmark_forward(
test_fn,
*inputs,
Expand Down Expand Up @@ -81,14 +100,78 @@ def get_cpu_temperature():
return cpu_temperature


def timer_diagnosis():
"""Diagnosis running time"""

if len(timer.names) == 0 or len(timer.times) == 0:
return

world_size = gpc.get_world_size(ParallelMode.DATA)
if world_size < 2:
return

# if gpc.is_rank_for_log():
# logger.info("Diagnosis running timers ...")

# detect slow rank compared to other ranks in the same DP group
running_time = torch.Tensor(timer.times).to(device=get_current_device())
avg_time = running_time.detach().clone()
if world_size <= 4:
dist.all_reduce(avg_time, op=torch.distributed.ReduceOp.AVG, group=gpc.get_group(ParallelMode.DATA))
else:
running_time_max = avg_time.detach().clone()
running_time_min = avg_time.detach().clone()
dist.all_reduce(running_time_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.DATA))
dist.all_reduce(running_time_min, op=torch.distributed.ReduceOp.MIN, group=gpc.get_group(ParallelMode.DATA))
dist.all_reduce(avg_time, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA))
avg_time = (avg_time - running_time_max - running_time_min) / (world_size - 2)

diag_result = running_time > avg_time * gpc.config.data.diag_outlier_ratio
diag_result = diag_result.tolist()
avg_time = avg_time.tolist()

for slow, name, time, avg in zip(diag_result, timer.names, timer.times, avg_time):
if slow is False or avg < 0.5:
continue
msg = (
f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} is slower than avg on {name}, "
f"Hostname {socket.gethostname()}, "
f"its time {time:.2f}, avg {avg:.2f}, "
f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}"
)
logger.warning(msg)
send_alert_message(
address=gpc.config.monitor.alert.feishu_alert_address,
message=msg,
)

# detect slow rank compared to historical timer data
for name, time in zip(timer.names, timer.times):
if name not in timer.hist or len(timer.hist[name]) < 5:
continue
hist_avg = sum(timer.hist[name]) / len(timer.hist[name])
if time > hist_avg * gpc.config.data.diag_outlier_ratio and time > 0.5:
msg = (
f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} is slower than hist avg on {name}, "
f"Hostname {socket.gethostname()}, "
f"its time {time:.2f}, hist_avg {hist_avg:.2f}, "
f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}"
)
logger.warning(msg)
send_alert_message(
address=gpc.config.monitor.alert.feishu_alert_address,
message=msg,
)


def bench_net():
"""Benchmark nccl performance for slow node detection."""

if gpc.get_world_size(ParallelMode.GLOBAL) <= 1:
return

if gpc.is_rank_for_log():
logger.info("benchmarking network speed ...")
# if gpc.is_rank_for_log():
# logger.info("benchmarking network speed ...")

repeats = 100
input_data = torch.randn(
Expand All @@ -113,20 +196,25 @@ def allreduce_fn(inputs):
allreduce_time_avg = allreduce_time / gpc.get_world_size(ParallelMode.GLOBAL)
allreduce_time_avg = float(allreduce_time_avg.item())

if allreduce_time_this >= allreduce_time_avg * 1.05:
logger.warning(
if allreduce_time_this >= allreduce_time_avg * gpc.config.data.diag_outlier_ratio:
msg = (
f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} NCCL test is slower than avg, "
f"Hostname {socket.gethostname()}, "
f"allreduce_time {allreduce_time_this:.2f}, avg {allreduce_time_avg:.2f}, "
f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}"
)
logger.warning(msg)
send_alert_message(
address=gpc.config.monitor.alert.feishu_alert_address,
message=msg,
)


def bench_gpu(use_flash_attn=True):
"""Benchmark single GPU performance for slow node detection."""

if gpc.is_rank_for_log():
logger.info("benchmarking gpu speed ...")
# if gpc.is_rank_for_log():
# logger.info("benchmarking gpu speed ...")

headdim = 64
dim = 2048
Expand Down Expand Up @@ -154,10 +242,15 @@ def bench_gpu(use_flash_attn=True):
speed_avg = speed / gpc.get_world_size(ParallelMode.GLOBAL)
speed_avg = float(speed_avg.item())

if speed_this <= speed_avg * 0.95:
logger.warning(
if speed_this <= speed_avg / gpc.config.data.diag_outlier_ratio:
msg = (
f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} GPU is slower than avg, "
f"Hostname {socket.gethostname()}, "
f"tflops {speed_this:.2f}, avg {speed_avg:.2f}, "
f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}"
)
logger.warning(msg)
send_alert_message(
address=gpc.config.monitor.alert.feishu_alert_address,
message=msg,
)
25 changes: 23 additions & 2 deletions internlm/utils/megatron_timers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ def __init__(self, name):
self.start_time = time.time()
self.stream = torch.cuda.current_stream()

def start(self):
def start(self, reset_all=True):
"""Start the timer."""
# need to reset all timers in a new batch
if self.name_ == "one-batch" and reset_all is True:
megatron_timer.reset()

assert not self.started_, "timer has already been started"
self.stream.synchronize()
self.start_time = time.time()
Expand Down Expand Up @@ -48,7 +52,7 @@ def elapsed(self, reset=True):
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
self.start(reset_all=False)
return elapsed_


Expand All @@ -57,12 +61,29 @@ class Timers:

def __init__(self):
self.timers = {}
self.hist = {}
self.names = []
self.times = []

def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]

def store_last_timers(self):
"""Store timers to two list"""
self.names = []
self.times = []
for key, value in self.timers.items():
senconds = round(float(value.elapsed(reset=False)), 4)
self.names.append(key)
self.times.append(senconds)
if key not in self.hist:
self.hist[key] = []
self.hist[key].append(senconds)
if len(self.hist[key]) > 10:
self.hist[key].pop(0)

def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
Expand Down
5 changes: 2 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
parse_args,
)
from internlm.utils.evaluation import evaluate_on_val_dls
from internlm.utils.gputest import empty_cache_and_diag
from internlm.utils.logger import get_logger, initialize_uniscale_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import CheckpointManager
Expand Down Expand Up @@ -193,9 +194,7 @@ def main(args):
with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
# start iterating the train data and begin training
for batch_count in range(train_state.batch_count, total_steps):
if batch_count % 50 == 0:
torch.cuda.empty_cache()

empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
start_time = time.time()
timer("one-batch").start()

Expand Down

0 comments on commit 1ee31ff

Please sign in to comment.