Skip to content

Commit

Permalink
Merge branch 'main' into patch-2
Browse files Browse the repository at this point in the history
  • Loading branch information
minlu1021 committed May 1, 2024
2 parents 96b5564 + ff7bcc6 commit dae061f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def train(
throughput = sample_processed / step_time
throughputs.append(throughput)

tflops_per_gpu = compute_tflops(throughput, num_params, world_size, batch_seqlen)
tflops_per_gpu = compute_tflops(args, sample_processed, step_time, world_size)

if not total_steps % args.logging_freq and args.log_reduced_training_loss > 0:
loss_scalar = reduce_loss(loss)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def compute_tflops(args, global_batch_size, step_time, world_size):
# Based on
# https://github.com/NVIDIA/Megatron-LM/blob/ba773259dbe5735fbd91ca41e7f4ded60b335c52/megatron/training/training.py#L65
num_experts_routed_to = 1 if args.moe > 1 else args.num_experts_per_tok
if args.num_key_value_heads is None:
args.num_key_value_heads = args.num_heads
num_flops = (
12
* global_batch_size
Expand Down

0 comments on commit dae061f

Please sign in to comment.