Skip to content

Commit db437d2

Browse files
committed
prod: apply black drift cleanup
Six files on the prod base had black-non-compliant formatting that pre-commit on PR #25 flagged as failures. Applying `black==24.3.0` (matches .pre-commit-config.yaml) brings them in line so CI passes. Also fixes the single line in train_async.py from this PR that black wants (blank line after the import). No behavioral changes; pure whitespace + line breaks.
1 parent 0431dbf commit db437d2

7 files changed

Lines changed: 30 additions & 34 deletions

File tree

miles/backends/training_utils/log_utils.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,33 +22,33 @@
2222
# Maps bare metric names to their W&B top-level section(s).
2323
# Keys appearing in multiple sections (e.g. pg_loss) are emitted under each.
2424
_TRAIN_METRIC_GROUPS: dict[str, list[str]] = {
25-
"ppo_kl": ["policy_shift"],
26-
"ois": ["policy_shift"],
27-
"pg_clipfrac": ["policy_shift"],
28-
"pg_loss": ["policy_shift", "optimization"],
29-
"log_probs": ["policy_shift"], # current policy (training forward pass)
30-
"old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout)
31-
"ref_kl": ["policy_shift"],
25+
"ppo_kl": ["policy_shift"],
26+
"ois": ["policy_shift"],
27+
"pg_clipfrac": ["policy_shift"],
28+
"pg_loss": ["policy_shift", "optimization"],
29+
"log_probs": ["policy_shift"], # current policy (training forward pass)
30+
"old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout)
31+
"ref_kl": ["policy_shift"],
3232
"train_rollout_logprob_abs_diff": ["train_inference_mismatch"],
33-
"train_rollout_logprob_diff": ["train_inference_mismatch"],
34-
"tis": ["train_inference_mismatch"],
35-
"tis_abs": ["train_inference_mismatch"],
36-
"tis_clipfrac": ["train_inference_mismatch"],
37-
"loss": ["optimization"],
38-
"entropy_loss": ["optimization"],
39-
"kl_loss": ["optimization"],
40-
"grad_norm": ["optimization"],
33+
"train_rollout_logprob_diff": ["train_inference_mismatch"],
34+
"tis": ["train_inference_mismatch"],
35+
"tis_abs": ["train_inference_mismatch"],
36+
"tis_clipfrac": ["train_inference_mismatch"],
37+
"loss": ["optimization"],
38+
"entropy_loss": ["optimization"],
39+
"kl_loss": ["optimization"],
40+
"grad_norm": ["optimization"],
4141
}
4242

4343
# Maps rollout batch field names to their W&B top-level section.
4444
_ROLLOUT_DATA_METRIC_GROUPS: dict[str, str] = {
45-
"log_probs": "train_inference_mismatch", # FSDP log probs at rollout time
45+
"log_probs": "train_inference_mismatch", # FSDP log probs at rollout time
4646
"rollout_log_probs": "train_inference_mismatch", # inference engine log probs
47-
"ref_log_probs": "policy_shift", # reference model log probs
48-
"rewards": "reward",
49-
"raw_reward": "reward",
50-
"advantages": "reward",
51-
"returns": "reward",
47+
"ref_log_probs": "policy_shift", # reference model log probs
48+
"rewards": "reward",
49+
"raw_reward": "reward",
50+
"advantages": "reward",
51+
"returns": "reward",
5252
}
5353

5454

@@ -533,7 +533,7 @@ def log_train_step(
533533
for full_key, val in log_dict_out.items():
534534
if not full_key.startswith(prefix):
535535
continue
536-
bare_key = full_key[len(prefix):]
536+
bare_key = full_key[len(prefix) :]
537537
if bare_key in _TRAIN_METRIC_GROUPS:
538538
for group in _TRAIN_METRIC_GROUPS[bare_key]:
539539
grouped_additions[f"{group}/{bare_key}"] = val

miles/backends/training_utils/loss.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,9 @@ def policy_loss_function(
693693
if "rollout_log_probs" in batch and batch["rollout_log_probs"]:
694694
rollout_log_probs_cat = torch.cat(batch["rollout_log_probs"], dim=0)
695695
log_probs_batch_cat = torch.cat(batch["log_probs"], dim=0)
696-
train_rollout_logprob_abs_diff = sum_of_sample_mean((old_log_probs - rollout_log_probs_cat).abs()).clone().detach()
696+
train_rollout_logprob_abs_diff = (
697+
sum_of_sample_mean((old_log_probs - rollout_log_probs_cat).abs()).clone().detach()
698+
)
697699
# signed: log π(inf) − log π(fsdp rollout)
698700
train_rollout_logprob_diff = sum_of_sample_mean(rollout_log_probs_cat - log_probs_batch_cat).clone().detach()
699701

miles/ray/rollout.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,9 +1400,7 @@ def _compute_grouped_response_metrics(args, group: list[Sample], prefix: str) ->
14001400
}
14011401

14021402

1403-
def _compute_group_outcome_metrics(
1404-
args, all_samples: list[Sample], prefix: str = "reward"
1405-
) -> dict:
1403+
def _compute_group_outcome_metrics(args, all_samples: list[Sample], prefix: str = "reward") -> dict:
14061404
"""Fraction of prompt groups that are unanimously correct or incorrect. GRPO only."""
14071405
if args.advantage_estimator == "ppo":
14081406
return {}

miles/rollout/generate_utils/openai_endpoint_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,5 +247,5 @@ def _truncate_sample_output(sample: Sample, keep_tokens: int, tokenizer) -> None
247247
if sample.loss_mask is not None:
248248
sample.loss_mask = sample.loss_mask[:keep_tokens]
249249
if sample.rollout_routed_experts is not None:
250-
sample.rollout_routed_experts = sample.rollout_routed_experts[:len(sample.tokens) - 1]
250+
sample.rollout_routed_experts = sample.rollout_routed_experts[: len(sample.tokens) - 1]
251251
sample.status = Sample.Status.TRUNCATED

miles/rollout/session/linear_trajectory.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,10 +340,7 @@ def _evict_stale_sessions(self) -> None:
340340
if not self._session_last_access:
341341
return
342342
now = time.monotonic()
343-
stale = [
344-
sid for sid, ts in self._session_last_access.items()
345-
if now - ts > self._SESSION_TTL_SECS
346-
]
343+
stale = [sid for sid, ts in self._session_last_access.items() if now - ts > self._SESSION_TTL_SECS]
347344
for sid in stale:
348345
self.sessions.pop(sid, None)
349346
self._session_last_access.pop(sid, None)

miles/utils/replay_base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,7 @@ def _get_replay_result(top_indices, scores, topk, *args, **kwargs):
123123
_, sorted_free = masked_scores.sort(dim=1, descending=True)
124124
# The k-th -1 slot in each row gets sorted_free[row, k].
125125
pad_cumsum = torch.cumsum(padding_mask.long(), dim=1) - 1
126-
fill_values = torch.gather(sorted_free, 1, pad_cumsum.clamp(min=0)).to(
127-
top_indices.dtype
128-
)
126+
fill_values = torch.gather(sorted_free, 1, pad_cumsum.clamp(min=0)).to(top_indices.dtype)
129127
top_indices = torch.where(padding_mask, fill_values, top_indices)
130128

131129
if return_probs:

train_async.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,6 @@ async def train(args):
7777
args = parse_args()
7878
if getattr(args, "enable_r3_correctness_check", False):
7979
from miles.utils.replay_base import RoutingReplayManager
80+
8081
RoutingReplayManager.enable_check_replay_result = True
8182
asyncio.run(train(args))

0 commit comments

Comments
 (0)