|
22 | 22 | # Maps bare metric names to their W&B top-level section(s). |
23 | 23 | # Keys appearing in multiple sections (e.g. pg_loss) are emitted under each. |
24 | 24 | _TRAIN_METRIC_GROUPS: dict[str, list[str]] = { |
25 | | - "ppo_kl": ["policy_shift"], |
26 | | - "ois": ["policy_shift"], |
27 | | - "pg_clipfrac": ["policy_shift"], |
28 | | - "pg_loss": ["policy_shift", "optimization"], |
29 | | - "log_probs": ["policy_shift"], # current policy (training forward pass) |
30 | | - "old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout) |
31 | | - "ref_kl": ["policy_shift"], |
| 25 | + "ppo_kl": ["policy_shift"], |
| 26 | + "ois": ["policy_shift"], |
| 27 | + "pg_clipfrac": ["policy_shift"], |
| 28 | + "pg_loss": ["policy_shift", "optimization"], |
| 29 | + "log_probs": ["policy_shift"], # current policy (training forward pass) |
| 30 | + "old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout) |
| 31 | + "ref_kl": ["policy_shift"], |
32 | 32 | "train_rollout_logprob_abs_diff": ["train_inference_mismatch"], |
33 | | - "train_rollout_logprob_diff": ["train_inference_mismatch"], |
34 | | - "tis": ["train_inference_mismatch"], |
35 | | - "tis_abs": ["train_inference_mismatch"], |
36 | | - "tis_clipfrac": ["train_inference_mismatch"], |
37 | | - "loss": ["optimization"], |
38 | | - "entropy_loss": ["optimization"], |
39 | | - "kl_loss": ["optimization"], |
40 | | - "grad_norm": ["optimization"], |
| 33 | + "train_rollout_logprob_diff": ["train_inference_mismatch"], |
| 34 | + "tis": ["train_inference_mismatch"], |
| 35 | + "tis_abs": ["train_inference_mismatch"], |
| 36 | + "tis_clipfrac": ["train_inference_mismatch"], |
| 37 | + "loss": ["optimization"], |
| 38 | + "entropy_loss": ["optimization"], |
| 39 | + "kl_loss": ["optimization"], |
| 40 | + "grad_norm": ["optimization"], |
41 | 41 | } |
42 | 42 |
|
43 | 43 | # Maps rollout batch field names to their W&B top-level section. |
44 | 44 | _ROLLOUT_DATA_METRIC_GROUPS: dict[str, str] = { |
45 | | - "log_probs": "train_inference_mismatch", # FSDP log probs at rollout time |
| 45 | + "log_probs": "train_inference_mismatch", # FSDP log probs at rollout time |
46 | 46 | "rollout_log_probs": "train_inference_mismatch", # inference engine log probs |
47 | | - "ref_log_probs": "policy_shift", # reference model log probs |
48 | | - "rewards": "reward", |
49 | | - "raw_reward": "reward", |
50 | | - "advantages": "reward", |
51 | | - "returns": "reward", |
| 47 | + "ref_log_probs": "policy_shift", # reference model log probs |
| 48 | + "rewards": "reward", |
| 49 | + "raw_reward": "reward", |
| 50 | + "advantages": "reward", |
| 51 | + "returns": "reward", |
52 | 52 | } |
53 | 53 |
|
54 | 54 |
|
@@ -533,7 +533,7 @@ def log_train_step( |
533 | 533 | for full_key, val in log_dict_out.items(): |
534 | 534 | if not full_key.startswith(prefix): |
535 | 535 | continue |
536 | | - bare_key = full_key[len(prefix):] |
| 536 | + bare_key = full_key[len(prefix) :] |
537 | 537 | if bare_key in _TRAIN_METRIC_GROUPS: |
538 | 538 | for group in _TRAIN_METRIC_GROUPS[bare_key]: |
539 | 539 | grouped_additions[f"{group}/{bare_key}"] = val |
|
0 commit comments