Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,15 @@ class DistillationSaveState(TypedDict):
float
] # Can be any metric. Setted to 'accuracy' by default in validation.
consumed_samples: int
total_valid_tokens: int # Track total number of non-padding tokens during training


def _default_distillation_save_state() -> DistillationSaveState:
return {
"step": 0,
"val_reward": -99999999.0, # Aligned with GRPO
"consumed_samples": 0,
"total_valid_tokens": 0,
}


Expand Down Expand Up @@ -491,6 +493,9 @@ def distillation_train(
# common config/state itmes
step = distillation_save_state["step"]
consumed_samples = distillation_save_state["consumed_samples"]
total_valid_tokens = distillation_save_state.get(
"total_valid_tokens", 0
) # Default to 0 for backward compatibility with older checkpoints
val_period = master_config["distillation"]["val_period"]
val_at_start = master_config["distillation"]["val_at_start"]
colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"]
Expand Down Expand Up @@ -677,6 +682,27 @@ def distillation_train(
)
logger.log_metrics(val_metrics, step + 1, prefix="validation")

metrics = {
"loss": train_results["loss"].numpy(),
"grad_norm": train_results["grad_norm"].numpy(),
"mean_prompt_length": repeated_batch["length"].numpy(),
"total_num_tokens": input_lengths.numpy(),
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k in {
"lr",
"wd",
"global_valid_seqs",
"global_valid_toks",
"mean_prompt_length",
}:
metrics[k] = np.mean(v).item()
else:
metrics[k] = np.sum(v).item()
metrics.update(rollout_metrics)
total_valid_tokens += metrics["global_valid_toks"]

## Checkpointing
consumed_samples += master_config["distillation"][
"num_prompts_per_step"
Expand All @@ -697,6 +723,7 @@ def distillation_train(
student_policy.prepare_for_training()

distillation_save_state["step"] = step + 1
distillation_save_state["total_valid_tokens"] = total_valid_tokens
if val_metrics is not None:
distillation_save_state["val_reward"] = val_metrics["accuracy"]
elif "val_reward" in distillation_save_state:
Expand Down Expand Up @@ -744,26 +771,6 @@ def distillation_train(
log_data["input_lengths"] = input_lengths.tolist()
logger.log_batched_dict_as_jsonl(log_data, f"train_data_step{step}.jsonl")

metrics = {
"loss": train_results["loss"].numpy(),
"grad_norm": train_results["grad_norm"].numpy(),
"mean_prompt_length": repeated_batch["length"].numpy(),
"total_num_tokens": input_lengths.numpy(),
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k in {
"lr",
"wd",
"global_valid_seqs",
"global_valid_toks",
"mean_prompt_length",
}:
metrics[k] = np.mean(v).item()
else:
metrics[k] = np.sum(v).item()
metrics.update(rollout_metrics)

timing_metrics: dict[str, float] = timer.get_timing_metrics(
reduction_op="sum"
) # type: ignore
Expand Down Expand Up @@ -817,6 +824,9 @@ def distillation_train(
percent = (v / total_time * 100) if total_time > 0 else 0
print(f" • {k}: {v:.2f}s ({percent:.1f}%)")

timing_metrics["valid_tokens_per_sec_per_gpu"] = (
metrics["global_valid_toks"] / total_time / total_num_gpus
)
logger.log_metrics(metrics, step + 1, prefix="train")
logger.log_metrics(timing_metrics, step + 1, prefix="timing/train")

Expand Down
36 changes: 25 additions & 11 deletions nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class DPOSaveState(TypedDict):
step: int # Track step within current epoch
total_steps: int # Track total number of steps across all epochs
consumed_samples: int
total_valid_tokens: int # Track total number of non-padding tokens during training


def _default_dpo_save_state() -> DPOSaveState:
Expand All @@ -53,6 +54,7 @@ def _default_dpo_save_state() -> DPOSaveState:
"step": 0,
"total_steps": 0,
"consumed_samples": 0,
"total_valid_tokens": 0,
}


Expand Down Expand Up @@ -503,10 +505,14 @@ def dpo_train(
current_epoch = 0
current_step = 0
total_steps = 0
total_valid_tokens = 0
else:
current_epoch = dpo_save_state["epoch"]
current_step = dpo_save_state["step"]
total_steps = dpo_save_state["total_steps"]
total_valid_tokens = dpo_save_state.get(
"total_valid_tokens", 0
) # Default to 0 for backward compatibility with older checkpoints

dpo_config = master_config["dpo"]
# Validation configuration
Expand Down Expand Up @@ -587,6 +593,17 @@ def dpo_train(
val_metrics, validation_timings = validation_result
else:
val_metrics, validation_timings = None, None
metrics = {
"loss": train_results["loss"].numpy(),
"grad_norm": train_results["grad_norm"].numpy(),
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k in {"lr", "wd", "global_valid_seqs", "global_valid_toks"}:
metrics[k] = np.mean(v).item()
else:
metrics[k] = np.sum(v).item()
total_valid_tokens += metrics["global_valid_toks"]

## Checkpointing
dpo_save_state["consumed_samples"] += master_config["policy"][
Expand All @@ -609,6 +626,7 @@ def dpo_train(
dpo_save_state["step"] = (current_step + 1) % len(train_dataloader)
dpo_save_state["total_steps"] = total_steps + 1
dpo_save_state["epoch"] = current_epoch
dpo_save_state["total_valid_tokens"] = total_valid_tokens
# Remove outdated validation metrics
for key in list(dpo_save_state):
if (
Expand Down Expand Up @@ -659,17 +677,6 @@ def dpo_train(
)
checkpointer.finalize_checkpoint(checkpoint_path)

losses = train_results["loss"]
metrics = {
"loss": train_results["loss"].numpy(),
"grad_norm": train_results["grad_norm"].numpy(),
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k in {"lr", "wd", "global_valid_seqs", "global_valid_toks"}:
metrics[k] = np.mean(v).item()
else:
metrics[k] = np.sum(v).item()
timing_metrics = timer.get_timing_metrics(reduction_op="sum")

print("\n📊 Training Results:")
Expand Down Expand Up @@ -704,6 +711,13 @@ def dpo_train(
percent = (v / total_time * 100) if total_time > 0 else 0
print(f" • {k}: {v:.2f}s ({percent:.1f}%)")

total_num_gpus = (
master_config["cluster"]["num_nodes"]
* master_config["cluster"]["gpus_per_node"]
)
timing_metrics["valid_tokens_per_sec_per_gpu"] = (
metrics["global_valid_toks"] / total_time / total_num_gpus
)
logger.log_metrics(metrics, total_steps + 1, prefix="train")
logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train")

Expand Down
106 changes: 63 additions & 43 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class GRPOSaveState(TypedDict):
current_step: int
current_epoch: int
total_steps: int
total_valid_tokens: int # Track total number of non-padding tokens during training
val_reward: NotRequired[
float
] # Optional field - may not be present during training
Expand All @@ -115,6 +116,7 @@ def _default_grpo_save_state() -> GRPOSaveState:
"current_step": 0,
"current_epoch": 0,
"total_steps": 0,
"total_valid_tokens": 0,
"val_reward": -99999999.0,
}

Expand Down Expand Up @@ -606,6 +608,9 @@ def grpo_train(
consumed_samples = grpo_save_state[
"consumed_samples"
] # total samples consumed across all epochs
total_valid_tokens = grpo_save_state.get(
"total_valid_tokens", 0
) # total valid tokens processed across all epochs; default to 0 for backward compatibility with older checkpoints
val_at_start = master_config["grpo"]["val_at_start"]
val_period = master_config["grpo"]["val_period"]
colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"]
Expand Down Expand Up @@ -846,6 +851,28 @@ def grpo_train(
logger.log_metrics(
val_metrics, total_steps + 1, prefix="validation"
)
metrics = {
"loss": train_results["loss"].numpy(),
"reward": rewards.numpy(),
"grad_norm": train_results["grad_norm"].numpy(),
"mean_prompt_length": repeated_batch["length"].numpy(),
"total_num_tokens": input_lengths.numpy(),
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k in {
"lr",
"wd",
"reward",
"global_valid_seqs",
"global_valid_toks",
"mean_prompt_length",
}:
metrics[k] = np.mean(v).item()
else:
metrics[k] = np.sum(v).item()
metrics.update(rollout_metrics)
total_valid_tokens += metrics["global_valid_toks"]

## Checkpointing
consumed_samples += master_config["grpo"]["num_prompts_per_step"]
Expand All @@ -869,6 +896,7 @@ def grpo_train(
grpo_save_state["current_step"] = current_step + 1
grpo_save_state["total_steps"] = total_steps + 1
grpo_save_state["current_epoch"] = current_epoch
grpo_save_state["total_valid_tokens"] = total_valid_tokens
if val_metrics is not None:
grpo_save_state["val_reward"] = val_metrics["accuracy"]
elif "val_reward" in grpo_save_state:
Expand Down Expand Up @@ -922,28 +950,6 @@ def grpo_train(
log_data, f"train_data_step{total_steps}.jsonl"
)

metrics = {
"loss": train_results["loss"].numpy(),
"reward": rewards.numpy(),
"grad_norm": train_results["grad_norm"].numpy(),
"mean_prompt_length": repeated_batch["length"].numpy(),
"total_num_tokens": input_lengths.numpy(),
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k in {
"lr",
"wd",
"reward",
"global_valid_seqs",
"global_valid_toks",
"mean_prompt_length",
}:
metrics[k] = np.mean(v).item()
else:
metrics[k] = np.sum(v).item()
metrics.update(rollout_metrics)

timing_metrics: dict[str, float] = timer.get_timing_metrics(
reduction_op="sum"
) # type: ignore
Expand Down Expand Up @@ -993,6 +999,9 @@ def grpo_train(
percent = (v / total_time * 100) if total_time > 0 else 0
print(f" • {k}: {v:.2f}s ({percent:.1f}%)", flush=True)

timing_metrics["valid_tokens_per_sec_per_gpu"] = (
metrics["global_valid_toks"] / total_time / total_num_gpus
)
performance_metrics = print_performance_metrics(
train_results, metrics, timing_metrics, master_config
)
Expand Down Expand Up @@ -1190,6 +1199,9 @@ def async_grpo_train(
step = grpo_save_state["current_step"]
weight_version = step # Tracks refitted weight versions
consumed_samples = grpo_save_state["consumed_samples"]
total_valid_tokens = grpo_save_state.get(
"total_valid_tokens", 0
) # Default to 0 for backward compatibility with older checkpoints
val_period = master_config["grpo"]["val_period"]
val_at_start = master_config["grpo"]["val_at_start"]
colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"]
Expand Down Expand Up @@ -1639,6 +1651,27 @@ def async_grpo_train(

# Resume trajectory collection after validation
trajectory_collector.resume.remote()
metrics = {
"loss": train_results["loss"].numpy(),
"reward": rewards.numpy(),
"grad_norm": train_results["grad_norm"].numpy(),
"mean_prompt_length": repeated_batch["length"].numpy(),
"total_num_tokens": input_lengths.numpy(),
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k in {
"lr",
"wd",
"reward",
"global_valid_seqs",
"global_valid_toks",
}:
metrics[k] = np.mean(v).item()
else:
metrics[k] = np.sum(v).item()
metrics.update(rollout_metrics)
total_valid_tokens += metrics["global_valid_toks"]

# Checkpointing (same as sync version)
consumed_samples += master_config["grpo"]["num_prompts_per_step"]
Expand All @@ -1649,6 +1682,7 @@ def async_grpo_train(
policy.prepare_for_training()

grpo_save_state["current_step"] = step + 1
grpo_save_state["total_valid_tokens"] = total_valid_tokens
if val_metrics is not None:
grpo_save_state["val_reward"] = val_metrics["accuracy"]
elif "val_reward" in grpo_save_state:
Expand Down Expand Up @@ -1700,27 +1734,6 @@ def async_grpo_train(
log_data["input_lengths"] = input_lengths.tolist()
logger.log_batched_dict_as_jsonl(log_data, f"train_data_step{step}.jsonl")

metrics = {
"loss": train_results["loss"].numpy(),
"reward": rewards.numpy(),
"grad_norm": train_results["grad_norm"].numpy(),
"mean_prompt_length": repeated_batch["length"].numpy(),
"total_num_tokens": input_lengths.numpy(),
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k in {
"lr",
"wd",
"reward",
"global_valid_seqs",
"global_valid_toks",
}:
metrics[k] = np.mean(v).item()
else:
metrics[k] = np.sum(v).item()
metrics.update(rollout_metrics)

timing_metrics: dict[str, float] = timer.get_timing_metrics(
reduction_op="sum"
)
Expand All @@ -1746,6 +1759,13 @@ def async_grpo_train(
percent = (v / total_time * 100) if total_time > 0 else 0
print(f" • {k}: {v:.2f}s ({percent:.1f}%)")

total_num_gpus = (
master_config["cluster"]["num_nodes"]
* master_config["cluster"]["gpus_per_node"]
)
timing_metrics["valid_tokens_per_sec_per_gpu"] = (
metrics["global_valid_toks"] / total_time / total_num_gpus
)
performance_metrics = print_performance_metrics(
train_results, metrics, timing_metrics, master_config
)
Expand Down
Loading
Loading